In [1]:
import argparse
import os
import csv
import json
import pandas as pd
import torch
from torch.distributed import destroy_process_group, init_process_group
from torch.nn import BCEWithLogitsLoss
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import AdamW
from torch.utils.data import DataLoader, Subset
from torch.utils.data.distributed import DistributedSampler
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score,
    balanced_accuracy_score,
)
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from transformers import AutoTokenizer
from torch.utils.data import Dataset

In [13]:
def evaluate(model, dataloader):
    model.eval()
    preds, targets = [], []
    total_loss = 0.0
    loss_fn = BCEWithLogitsLoss()

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"]
            attention_mask = batch["attention_mask"]
            labels = batch["labels"]

            outputs = model(input_ids, attention_mask=attention_mask)
            
            mask = labels.view(-1) != -100
            labels = labels.view(-1)[mask].float()
            outputs = outputs.view(-1)[mask]

            loss = loss_fn(outputs, labels)
            total_loss += loss.item()

            logits = torch.sigmoid(outputs).squeeze().cpu().numpy()
            labels = labels.squeeze().cpu().numpy()

            preds.extend(logits)
            targets.extend(labels)

    bin_preds = [1 if p >= 0.5 else 0 for p in preds]

    metrics = {
        "loss": total_loss / len(dataloader),
        "accuracy": accuracy_score(targets, bin_preds),
        "balanced_accuracy": balanced_accuracy_score(targets, bin_preds),
        "precision": precision_score(targets, bin_preds),
        "recall": recall_score(targets, bin_preds),
        "f1": f1_score(targets, bin_preds),
        #"auc": roc_auc_score(targets, preds),
    }

    return metrics

In [3]:
from typing import Dict, List, Union
def collate_fn(
    batch: List[Dict[str, torch.tensor]], tokenizer: AutoTokenizer
) -> Dict[str, torch.tensor]:
    texts = [item["text"] for item in batch]
    labels = [item["label"] for item in batch]
    encodings = tokenizer(
        texts, truncation=True, padding="longest", return_tensors="pt"
    )

    labels_padded = [
        torch.where(t == 0, torch.tensor(-100), torch.tensor(label))
        for t, label in zip(encodings["attention_mask"], labels)
    ]
    labels_padded = torch.cat(labels_padded)
    encodings["labels"] = labels_padded

    return encodings

In [4]:
from typing import Dict, List, Tuple, Union
class TextDataset(Dataset):
    def __init__(
        self, texts: List[str], labels: List[int]) -> None:
        """
        texts: list of texts.
        labels: list of labels for all samples.
        """
        self.texts = texts
        self.labels = labels

    def __len__(self) -> int:
        return len(self.texts)

    def __getitem__(self, idx: int) -> Dict[str, Union[str, int]]:
        text = self.texts[idx]
        label = self.labels[idx]

        return {"text": text, "label": label}

In [5]:
import torch
import torch.nn as nn
from transformers import AutoModel


class FineTuneClassifier(nn.Module):
    def __init__(self, base_model_path: str, num_labels: int) -> None:
        super(FineTuneClassifier, self).__init__()
        self.base_model = AutoModel.from_pretrained(base_model_path)

        for param in self.base_model.parameters():
            param.requires_grad = False
            
        self.classifier = nn.Linear(self.base_model.config.hidden_size * 2, num_labels)

    @classmethod
    def from_classifier_head(cls, base_model_path: str, path: str, num_labels: int) -> nn.Module:
        model = cls(base_model_path, num_labels)
        model.classifier.load_state_dict(torch.load(path))
        return model

    def forward(
        self, input_ids: torch.tensor, attention_mask: torch.tensor
    ) -> torch.tensor:
        outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
        B, T, C = outputs.last_hidden_state.shape

        all_tokens_hidden = outputs.last_hidden_state  # (B, T, C)
        last_token_hidden = outputs.last_hidden_state[:, -1, :]  # (B, C)
        last_token_hidden = last_token_hidden.unsqueeze(1).expand(B, T, C)

        combined_representation = torch.cat(
            (all_tokens_hidden, last_token_hidden), dim=-1
        )
        logits = self.classifier(combined_representation)
        return logits

In [6]:
batch_size = 32

In [7]:
ds_path = "../../data/datasets/master_mini/train.csv"

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
tokenizer.pad_token = "<|finetune_right_pad_id|>"

df_data = pd.read_csv(ds_path)
train_dataset = TextDataset(df_data["text"].tolist()[:100], df_data["label"].tolist()[:100])
val_dataset = TextDataset(df_data["text"].tolist()[:100], df_data["label"].tolist()[:100])



train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=lambda batch: collate_fn(batch, tokenizer),
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=lambda batch: collate_fn(batch, tokenizer),
)

In [8]:
model = FineTuneClassifier(base_model_path="meta-llama/Llama-3.2-1B-Instruct", num_labels=1)

In [9]:
loss_fn = BCEWithLogitsLoss()
optimizer = AdamW(model.parameters(), lr=3e-4)

In [10]:
history_path = "tmp.csv"
best_val_acc = -1
with open(history_path, mode="w", newline="") as f:
    writer = csv.DictWriter(
        f,
        fieldnames=[
            "epoch",
            "train_loss",
            "train_accuracy",
            "train_balanced_accuracy",
            "train_precision",
            "train_recall",
            "train_f1",
            "train_auc",
            "val_loss",
            "val_accuracy",
            "val_balanced_accuracy",
            "val_precision",
            "val_recall",
            "val_f1",
            "val_auc",
        ],
    )
    writer.writeheader()

In [14]:
for epoch in range(3):

        print(f"\nEpoch {epoch+1}/{3}")

        model.train()
        epoch_loss = 0.0
        progress = tqdm(train_loader)

        all_logits = []
        all_labels = []
        all_bin_preds = []

        for batch in progress:
            input_ids = batch["input_ids"]
            attention_mask = batch["attention_mask"]
            labels = batch["labels"]

            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask)

            mask = labels.view(-1) != -100
            labels = labels.view(-1)[mask].float()
            outputs = outputs.view(-1)[mask]

            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            progress.set_description(f"Loss: {loss.item():.4f}")

            # Collect predictions during training
            logits = torch.sigmoid(outputs).squeeze().detach().cpu()
            labels_cpu = labels.squeeze().cpu()
            bin_preds = (logits >= 0.5).long()

            all_logits.extend(logits.tolist())
            all_labels.extend(labels_cpu.tolist())
            all_bin_preds.extend(bin_preds.tolist())

        avg_loss = epoch_loss / len(train_loader)

        train_metrics = {
            "accuracy": accuracy_score(all_labels, all_bin_preds),
            "balanced_accuracy": balanced_accuracy_score(all_labels, all_bin_preds),
            "precision": precision_score(all_labels, all_bin_preds),
            "recall": recall_score(all_labels, all_bin_preds),
            "f1": f1_score(all_labels, all_bin_preds),
            #"auc": roc_auc_score(all_labels, all_logits),
        }

        val_metrics = evaluate(model, val_loader)

        print(f"Epoch {epoch+1} complete. Avg loss: {avg_loss:.4f}")
        print("Train Metrics:", train_metrics)
        print("Val Metrics:", val_metrics)

        record = {
            "epoch": epoch + 1,
            "train_loss": avg_loss,
            **{f"train_{k}": v for k, v in train_metrics.items()},
            **{f"val_{k}": v for k, v in val_metrics.items()},
        }

        # Save training history
        with open(history_path, mode="a", newline="") as f:
            writer = csv.DictWriter(f, fieldnames=record.keys())
            writer.writerow(record)

        # Save best model
        torch.save(
            model.classifier.state_dict(),
            "finetuned_model.pt",
        )
        print(f"New best classifier saved (val accuracy: {best_val_acc:.4f})")


Epoch 1/3


Loss: 0.0014: 100%|██████████| 4/4 [03:05<00:00, 46.41s/it]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 1 complete. Avg loss: 0.0036
Train Metrics: {'accuracy': 0.9996322177271055, 'balanced_accuracy': np.float64(0.9996322177271055), 'precision': np.float64(0.0), 'recall': np.float64(0.0), 'f1': np.float64(0.0)}
Val Metrics: {'loss': 0.0017598293779883534, 'accuracy': 1.0, 'balanced_accuracy': np.float64(1.0), 'precision': np.float64(0.0), 'recall': np.float64(0.0), 'f1': np.float64(0.0)}
New best classifier saved (val accuracy: -1.0000)

Epoch 2/3


Loss: 0.0032:  25%|██▌       | 1/4 [01:40<05:00, 100.23s/it]


KeyboardInterrupt: 

In [15]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

4097

In [17]:
model = model.from_classifier_head(
    base_model_path="meta-llama/Llama-3.2-1B-Instruct",
    path="finetuned_model.pt",
    num_labels=1,
)

  model.classifier.load_state_dict(torch.load(path))


In [16]:
model.classifier.weight.data

tensor([[-0.0107,  0.0107, -0.0053,  ..., -0.0107, -0.0069,  0.0057]])

In [18]:
model.classifier.weight.data

tensor([[-0.0108,  0.0108, -0.0052,  ..., -0.0108, -0.0070,  0.0057]])