In [54]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
from nltk.tokenize import sent_tokenize

In [55]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
#model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")

In [3]:
model.lm_head = nn.Identity() 

In [4]:
# special_tokens = {
#     "additional_special_tokens": ["<|S1|>", "<|S2|>", "<|S3|>", "<|S4|>", "<|S5|>"]
# }

In [5]:
# tokenizer.add_special_tokens(special_tokens)
# model.resize_token_embeddings(len(tokenizer))

In [6]:
for param in model.parameters():
    param.requires_grad = False

# Unfreeze only the embeddings
# embedding_layer = model.get_input_embeddings()
# for param in embedding_layer.parameters():
#     param.requires_grad = True

In [None]:
class SentenceClassifier(nn.Module):
    def __init__(self, base_model, num_labels):
        super(SentenceClassifier, self).__init__()
        self.base_model = base_model
        self.classifier = nn.Linear(self.base_model.config.hidden_size * 2, num_labels)


    def forward(self, input_ids, attention_mask):
        outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
        B, T, C = outputs.logits.shape
        
        all_tokens_hidden = outputs.logits # (B, T, C)
        last_token_hidden = outputs.logits[:, -1, :] # (B, C)
        last_token_hidden = last_token_hidden.unsqueeze(1).expand(B, T, C)

        combined_representation = torch.cat((all_tokens_hidden, last_token_hidden), dim=-1)
        logits = self.classifier(combined_representation)
        return logits

In [8]:
classifier = SentenceClassifier(model, num_labels=1)

In [9]:
# get number of parameters
sum(p.numel() for p in classifier.parameters() if p.requires_grad)

4097

In [47]:
tokenizer.pad_token = "<|finetune_right_pad_id|>"

In [53]:
# get id of pad token
tokenizer.pad_token_id

128004

In [56]:
class SentenceDataset(Dataset):
    def __init__(self, texts, labels, tokenizer):
        """
        texts: list of multi-sentence strings.
        labels: list of lists containing one label per sentence.
        """
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.texts)
    

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]

        return {"text": text, "label": label}

In [57]:
import pandas as pd
df = pd.read_csv("../data/datasets/test2.csv")

dataset = SentenceDataset(df["text"].tolist()[:1000], df["label"].tolist()[:1000], tokenizer)

In [58]:
def collate_fn(batch, tokenizer):
    texts = [item["text"] for item in batch]
    labels = [item["label"] for item in batch]
    encodings = tokenizer(
        texts,
        truncation=True,
        padding="max_length",
        max_length=512,
        return_tensors="pt"
    )

    labels_padded = [torch.where(t == 0, torch.tensor(-100), torch.tensor(label)) for t, label in zip(encodings["attention_mask"], labels)]
    labels_padded = torch.cat(labels_padded)
    encodings["labels"] = labels_padded

    return encodings

In [59]:
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn= lambda batch: collate_fn(batch, tokenizer))

In [60]:
class BaselineClassifier(nn.Module):
    def __init__(
        self,
        d_model: int,
        num_layers: int,
        nhead: int,
        max_seq_length: int,
        vocab_size: int,
        pad_token_id: int,
        num_labels: int,
    ) -> None:
        super(BaselineClassifier, self).__init__()
        self.pad_token_id = pad_token_id
        self.token_embedding = nn.Embedding(
            vocab_size, d_model, padding_idx=pad_token_id
        )
        self.pos_embedding = nn.Embedding(max_seq_length, d_model)
        decoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(decoder_layer, num_layers=num_layers)
        self.classifier = nn.Linear(d_model * 2, num_labels)

    def forward(self, token_ids: torch.tensor) -> torch.tensor:
        batch_size, seq_len = token_ids.shape

        token_emb = self.token_embedding(token_ids)
        pos_ids = torch.arange(seq_len, device=token_ids.device).unsqueeze(0)
        pos_emb = self.pos_embedding(pos_ids)
        embeddings = token_emb + pos_emb

        causal_mask = torch.triu(
            torch.ones(seq_len, seq_len, device=token_ids.device, dtype=torch.bool),
            diagonal=1,
        )

        pad_mask = token_ids.eq(self.pad_token_id)  # shape: (batch_size, seq_len)

        output = self.transformer(
            embeddings, mask=causal_mask, src_key_padding_mask=pad_mask
        )

        B, T, C = output.shape
        all_tokens_hidden = output  # (B, T, C)
        last_token_hidden = output[:, -1, :]  # (B, C)
        last_token_hidden = last_token_hidden.unsqueeze(1).expand(B, T, C)

        combined_representation = torch.cat(
            (all_tokens_hidden, last_token_hidden), dim=-1
        )
        logits = self.classifier(combined_representation)
        return logits

In [61]:
classifier = BaselineClassifier(512, 6, 8, 512, len(tokenizer), tokenizer.pad_token_id, 1)

In [63]:
import os
from torch.distributed import init_process_group, destroy_process_group
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist

# set up DDP (distributed data parallel).
# torchrun command sets the env variables RANK, LOCAL_RANK, and WORLD_SIZE
ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?

In [64]:
ddp

False

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import pandas as pd

def setup_process(rank, world_size, backend='nccl'):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group(backend, rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def evaluate(model, dataloader, device):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels_batch = batch['labels'].to(device)
            mask = labels_batch != -100

            logits = model(input_ids, attention_mask)
            logits = logits.view(-1)[mask]
            labels_batch = labels_batch.view(-1)[mask].float()

            preds = torch.sigmoid(logits) > 0.5
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels_batch.cpu().numpy())

    acc = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')
    return acc, precision, recall, f1

def train(rank, world_size, dataloader, val_dataset, model, num_epochs=3):
    setup_process(rank, world_size)
    device = torch.device(f'cuda:{rank}' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    # Wrap model with DDP
    model = DDP(model, device_ids=[rank])

    # Create a DistributedSampler for the training dataset
    train_sampler = DistributedSampler(dataloader.dataset, num_replicas=world_size, rank=rank)
    train_loader = DataLoader(dataloader.dataset, batch_size=32, sampler=train_sampler, shuffle=False)
    
    # Similarly, setup validation dataloader (no need for shuffling)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

    optimizer = optim.AdamW(model.parameters(), lr=5e-5)
    criterion = nn.BCEWithLogitsLoss()
    best_f1 = 0.0
    history_rows = []
    model_save_path = "best_model.pt"
    os.makedirs("training_logs", exist_ok=True)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        all_preds = []
        all_labels = []

        # Set the sampler epoch for shuffling
        train_sampler.set_epoch(epoch)
        
        for batch in tqdm(train_loader, desc=f"Rank {rank} Epoch {epoch+1}/{num_epochs}"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels_batch = batch['labels'].to(device)
            mask = labels_batch != -100

            optimizer.zero_grad()
            logits = model(input_ids, attention_mask)
            logits = logits.view(-1)[mask]
            labels_batch = labels_batch.view(-1)[mask].float()

            loss = criterion(logits, labels_batch)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            preds = torch.sigmoid(logits).detach().cpu() > 0.5
            all_preds.extend(preds.numpy())
            all_labels.extend(labels_batch.cpu().numpy())

        train_acc = accuracy_score(all_labels, all_preds)
        _, _, train_f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')
        avg_loss = total_loss / len(train_loader)

        # Evaluate only on rank 0 to avoid redundant validation
        if rank == 0:
            val_acc, _, _, val_f1 = evaluate(model, val_loader, device)
            if val_f1 > best_f1:
                best_f1 = val_f1
                # Save the model (unwrap DDP)
                torch.save(model.module.state_dict(), model_save_path)
                print(f"Saved new best model (F1: {best_f1:.4f})")
            print(
                f"Epoch {epoch+1}/{num_epochs} "
                f"| Train Loss: {avg_loss:.4f} "
                f"| Train F1: {train_f1:.4f} "
                f"| Val F1: {val_f1:.4f}"
            )
            history_rows.append({
                "epoch": epoch + 1,
                "train_loss": avg_loss,
                "train_accuracy": train_acc,
                "train_f1": train_f1,
                "val_f1": val_f1
            })

    if rank == 0:
        df_history = pd.DataFrame(history_rows)
        df_history.to_csv("training_logs/training_history.csv", index=False)
        print("Training history saved to training_logs/training_history.csv")

    cleanup()

if __name__ == '__main__':
    world_size = torch.cuda.device_count()
    # Assume `dataloader` is defined and contains your training dataset.
    # Also assume `val_dataset` is your validation dataset.
    # And assume `classifier` is your model.
    mp.spawn(train, args=(world_size, dataloader, val_dataset, classifier), nprocs=world_size)
