In [10]:
import argparse
import os
import csv
import json
import pandas as pd
import torch
from torch.distributed import destroy_process_group, init_process_group
from torch.nn import BCEWithLogitsLoss
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import AdamW
from torch.utils.data import DataLoader, Subset
from torch.utils.data.distributed import DistributedSampler
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score,
    balanced_accuracy_score,
)
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from transformers import AutoTokenizer
from torch.utils.data import Dataset

In [51]:
def evaluate(model, dataloader):
    model.eval()
    preds, targets = [], []
    total_loss = 0.0
    loss_fn = BCEWithLogitsLoss()

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"]
            labels = batch["labels"]

            outputs = model(input_ids)
            
            mask = labels.view(-1) != -100
            labels = labels.view(-1)[mask].float()
            outputs = outputs.view(-1)[mask]

            loss = loss_fn(outputs, labels)
            total_loss += loss.item()

            logits = torch.sigmoid(outputs).squeeze().cpu().numpy()
            labels = labels.squeeze().cpu().numpy()

            preds.extend(logits)
            targets.extend(labels)

    bin_preds = [1 if p >= 0.5 else 0 for p in preds]

    metrics = {
        "loss": total_loss / len(dataloader),
        "accuracy": accuracy_score(targets, bin_preds),
        "balanced_accuracy": balanced_accuracy_score(targets, bin_preds),
        "precision": precision_score(targets, bin_preds),
        "recall": recall_score(targets, bin_preds),
        "f1": f1_score(targets, bin_preds),
        #"auc": roc_auc_score(targets, preds),
    }

    return metrics

In [52]:
from typing import Dict, List, Union
def collate_fn(
    batch: List[Dict[str, torch.tensor]], tokenizer: AutoTokenizer
) -> Dict[str, torch.tensor]:
    texts = [item["text"] for item in batch]
    labels = [item["label"] for item in batch]
    encodings = tokenizer(
        texts, truncation=True, padding="longest", return_tensors="pt"
    )

    labels_padded = [
        torch.where(t == 0, torch.tensor(-100), torch.tensor(label))
        for t, label in zip(encodings["attention_mask"], labels)
    ]
    labels_padded = torch.cat(labels_padded)
    encodings["labels"] = labels_padded

    return encodings

In [53]:
class TextDataset(Dataset):
    def __init__(
        self, texts: List[str], labels: List[int]) -> None:
        """
        texts: list of texts.
        labels: list of labels for all samples.
        """
        self.texts = texts
        self.labels = labels

    def __len__(self) -> int:
        return len(self.texts)

    def __getitem__(self, idx: int) -> Dict[str, Union[str, int]]:
        text = self.texts[idx]
        label = self.labels[idx]

        return {"text": text, "label": label}

In [54]:
import torch.nn as nn
class BaselineClassifier(nn.Module):
    def __init__(
        self,
        d_model: int,
        num_layers: int,
        nhead: int,
        max_seq_length: int,
        vocab_size: int,
        pad_token_id: int,
        num_labels: int,
    ) -> None:
        super(BaselineClassifier, self).__init__()
        self.pad_token_id = pad_token_id
        self.token_embedding = nn.Embedding(
            vocab_size, d_model, padding_idx=pad_token_id
        )
        self.pos_embedding = nn.Embedding(max_seq_length, d_model)
        decoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(decoder_layer, num_layers=num_layers)
        self.classifier = nn.Linear(d_model * 2, num_labels)

    def forward(self, token_ids: torch.tensor) -> torch.tensor:
        batch_size, seq_len = token_ids.shape

        token_emb = self.token_embedding(token_ids)
        pos_ids = torch.arange(seq_len, device=token_ids.device).unsqueeze(0)
        pos_emb = self.pos_embedding(pos_ids)
        embeddings = token_emb + pos_emb

        causal_mask = torch.triu(
            torch.ones(seq_len, seq_len, device=token_ids.device, dtype=torch.bool),
            diagonal=1,
        )

        pad_mask = token_ids.eq(self.pad_token_id)  # shape: (batch_size, seq_len)

        output = self.transformer(
            embeddings, mask=causal_mask, src_key_padding_mask=pad_mask
        )

        B, T, C = output.shape
        all_tokens_hidden = output  # (B, T, C)
        last_token_hidden = output[:, -1, :]  # (B, C)
        last_token_hidden = last_token_hidden.unsqueeze(1).expand(B, T, C)

        combined_representation = torch.cat(
            (all_tokens_hidden, last_token_hidden), dim=-1
        )
        logits = self.classifier(combined_representation)
        return logits

In [55]:
batch_size = 32

In [56]:
BASELINE_MODELS = {"mini": {
        "d_model": 512,
        "num_layers": 6,
        "num_heads": 8,
        "max_len": 512,
    }}

In [61]:
ds_path = "../data/datasets/test2.csv"
model_config = BASELINE_MODELS["mini"]

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
tokenizer.pad_token = "<|finetune_right_pad_id|>"

df_data = pd.read_csv(ds_path)
train_dataset = TextDataset(df_data["text"].tolist()[:100], df_data["label"].tolist()[:100])
val_dataset = TextDataset(df_data["text"].tolist()[:100], df_data["label"].tolist()[:100])



train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=lambda batch: collate_fn(batch, tokenizer),
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=lambda batch: collate_fn(batch, tokenizer),
)

model = BaselineClassifier(
    d_model=model_config["d_model"],
    num_layers=model_config["num_layers"],
    nhead=model_config["num_heads"],
    max_seq_length=model_config["max_len"],
    vocab_size=len(tokenizer),
    pad_token_id=tokenizer.pad_token_id,
    num_labels=1,
)

In [62]:
loss_fn = BCEWithLogitsLoss()
optimizer = AdamW(model.parameters(), lr=3e-4)

In [63]:
history_path = "tmp.csv"
best_val_acc = -1
with open(history_path, mode="w", newline="") as f:
    writer = csv.DictWriter(
        f,
        fieldnames=[
            "epoch",
            "train_loss",
            "train_accuracy",
            "train_balanced_accuracy",
            "train_precision",
            "train_recall",
            "train_f1",
            "train_auc",
            "val_loss",
            "val_accuracy",
            "val_balanced_accuracy",
            "val_precision",
            "val_recall",
            "val_f1",
            "val_auc",
        ],
    )
    writer.writeheader()

In [None]:
for epoch in range(3):

    model.train()
    epoch_loss = 0.0
    progress = tqdm(train_loader)

    all_logits = []
    all_labels = []
    all_bin_preds = []

    for batch in progress:
        input_ids = batch["input_ids"]
        labels = batch["labels"]

        optimizer.zero_grad()
        outputs = model(input_ids)

        mask = labels.view(-1) != -100
        labels = labels.view(-1)[mask].float()
        outputs = outputs.view(-1)[mask]

        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        progress.set_description(f"Loss: {loss.item():.4f}")

        # === Collect predictions during training ===
        logits = torch.sigmoid(outputs).squeeze().detach().cpu()
        labels_cpu = labels.squeeze().cpu()
        bin_preds = (logits >= 0.5).long()

        all_logits.extend(logits.tolist())
        all_labels.extend(labels_cpu.tolist())
        all_bin_preds.extend(bin_preds.tolist())

    avg_loss = epoch_loss / len(train_loader)

    train_metrics = {
        "accuracy": accuracy_score(all_labels, all_bin_preds),
        "balanced_accuracy": balanced_accuracy_score(all_labels, all_bin_preds),
        "precision": precision_score(all_labels, all_bin_preds),
        "recall": recall_score(all_labels, all_bin_preds),
        "f1": f1_score(all_labels, all_bin_preds),
        #"auc": roc_auc_score(all_labels, all_logits),
    }

    val_metrics = evaluate(model, val_loader)

    print(f"Epoch {epoch+1} complete. Avg loss: {avg_loss:.4f}")
    print("Train Metrics:", train_metrics)
    print("Val Metrics:", val_metrics)

    record = {
        "epoch": epoch + 1,
        "train_loss": avg_loss,
        **{f"train_{k}": v for k, v in train_metrics.items()},
        **{f"val_{k}": v for k, v in val_metrics.items()},
    }

    # Save training history
    with open(history_path, mode="a", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=record.keys())
        writer.writerow(record)

    # Save best model
    if val_metrics["accuracy"] > best_val_acc:
        best_val_acc = val_metrics["accuracy"]
        torch.save(
            model.state_dict(),
            "tmp.pt"
        )
        print(f"New best model saved (val accuracy: {best_val_acc:.4f})")

Loss: 0.0000: 100%|██████████| 4/4 [00:12<00:00,  3.15s/it]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 1 complete. Avg loss: 0.0000
Train Metrics: {'accuracy': 1.0, 'balanced_accuracy': np.float64(1.0), 'precision': np.float64(0.0), 'recall': np.float64(0.0), 'f1': np.float64(0.0)}
Val Metrics: {'loss': 1.4317501609184546e-05, 'accuracy': 1.0, 'balanced_accuracy': np.float64(1.0), 'precision': np.float64(0.0), 'recall': np.float64(0.0), 'f1': np.float64(0.0)}


Loss: 0.0000: 100%|██████████| 4/4 [00:12<00:00,  3.02s/it]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 2 complete. Avg loss: 0.0000
Train Metrics: {'accuracy': 1.0, 'balanced_accuracy': np.float64(1.0), 'precision': np.float64(0.0), 'recall': np.float64(0.0), 'f1': np.float64(0.0)}
Val Metrics: {'loss': 1.4317501609184546e-05, 'accuracy': 1.0, 'balanced_accuracy': np.float64(1.0), 'precision': np.float64(0.0), 'recall': np.float64(0.0), 'f1': np.float64(0.0)}


Loss: 0.0000: 100%|██████████| 4/4 [00:15<00:00,  3.77s/it]

Epoch 3 complete. Avg loss: 0.0000
Train Metrics: {'accuracy': 1.0, 'balanced_accuracy': np.float64(1.0), 'precision': np.float64(0.0), 'recall': np.float64(0.0), 'f1': np.float64(0.0)}
Val Metrics: {'loss': 1.4317501609184546e-05, 'accuracy': 1.0, 'balanced_accuracy': np.float64(1.0), 'precision': np.float64(0.0), 'recall': np.float64(0.0), 'f1': np.float64(0.0)}



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [70]:
all_bin_preds[5] = 1

In [73]:
    train_metrics = {
        "accuracy": accuracy_score(all_labels, all_bin_preds),
        "balanced_accuracy": balanced_accuracy_score(all_labels, all_bin_preds),
        "precision": precision_score(all_labels, all_bin_preds),
        "recall": recall_score(all_labels, all_bin_preds),
        "f1": f1_score(all_labels, all_bin_preds),
        #"auc": roc_auc_score(all_labels, all_logits),
    }

In [72]:
all_labels[10] = 1

In [74]:
# load best model
model.load_state_dict(torch.load("tmp.pt"))

  model.load_state_dict(torch.load("tmp.pt"))


<All keys matched successfully>

In [75]:
outputs = model(input_ids)

In [78]:
outputs

tensor([[[-10.0458],
         [ -9.9391],
         [ -9.8662],
         [ -9.9228],
         [ -9.9073],
         [ -9.8283],
         [ -9.6905],
         [ -9.8302],
         [ -9.7495],
         [ -9.8175],
         [ -9.8983],
         [ -9.9091],
         [ -9.6634],
         [ -9.7902],
         [ -9.7862],
         [ -9.7688],
         [ -9.8400],
         [ -9.7902],
         [ -9.7406],
         [ -9.9397],
         [ -9.6986],
         [ -9.8392],
         [ -9.6576],
         [ -9.7194],
         [ -9.6786],
         [ -9.6472],
         [ -9.8190],
         [ -9.7704],
         [ -9.6493],
         [ -9.6756],
         [ -9.6884],
         [ -9.6772],
         [ -9.6640],
         [ -9.7262]],

        [[ -9.9769],
         [ -9.8635],
         [ -9.8011],
         [ -9.9941],
         [ -9.8414],
         [ -9.8537],
         [ -9.8505],
         [ -9.7871],
         [ -9.8393],
         [ -9.8396],
         [ -9.7392],
         [ -9.8484],
         [ -9.6124],
         [ 

In [79]:
outputs.view(-1)

tensor([-10.0458,  -9.9391,  -9.8662,  -9.9228,  -9.9073,  -9.8283,  -9.6905,
         -9.8302,  -9.7495,  -9.8175,  -9.8983,  -9.9091,  -9.6634,  -9.7902,
         -9.7862,  -9.7688,  -9.8400,  -9.7902,  -9.7406,  -9.9397,  -9.6986,
         -9.8392,  -9.6576,  -9.7194,  -9.6786,  -9.6472,  -9.8190,  -9.7704,
         -9.6493,  -9.6756,  -9.6884,  -9.6772,  -9.6640,  -9.7262,  -9.9769,
         -9.8635,  -9.8011,  -9.9941,  -9.8414,  -9.8537,  -9.8505,  -9.7871,
         -9.8393,  -9.8396,  -9.7392,  -9.8484,  -9.6124,  -9.6642,  -9.7654,
         -9.8171,  -9.8225,  -9.7837,  -9.8247,  -9.8032,  -9.7592,  -9.8131,
         -9.6025,  -9.6536,  -9.7511,  -9.6052,  -9.7243,  -9.7242,  -9.7492,
         -9.6570,  -9.8308,  -9.5790,  -9.7349,  -9.6418,  -9.9956,  -9.7355,
         -9.7828,  -9.7315,  -9.7822,  -9.6789,  -9.7625,  -9.8726,  -9.7537,
         -9.7736,  -9.6110,  -9.6347,  -9.6780,  -9.7623,  -9.6477,  -9.6809,
         -9.6089,  -9.5506,  -9.6651,  -9.6870,  -9.7623,  -9.57