In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
from nltk.tokenize import sent_tokenize

In [2]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")

In [3]:
model.lm_head = nn.Identity() 

In [4]:
# special_tokens = {
#     "additional_special_tokens": ["<|S1|>", "<|S2|>", "<|S3|>", "<|S4|>", "<|S5|>"]
# }

In [5]:
# tokenizer.add_special_tokens(special_tokens)
# model.resize_token_embeddings(len(tokenizer))

In [6]:
for param in model.parameters():
    param.requires_grad = False

# Unfreeze only the embeddings
# embedding_layer = model.get_input_embeddings()
# for param in embedding_layer.parameters():
#     param.requires_grad = True

In [7]:
class SentenceClassifier(nn.Module):
    def __init__(self, base_model, num_labels):
        super(SentenceClassifier, self).__init__()
        self.base_model = base_model
        self.classifier = nn.Linear(self.base_model.config.hidden_size * 2, num_labels)


    def forward(self, input_ids, attention_mask):
        outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
        B, T, C = outputs.logits.shape
        
        all_tokens_hidden = outputs.logits # (B, T, C)
        last_token_hidden = outputs.logits[:, -1, :] # (B, C)

        last_token_hidden = last_token_hidden.unsqueeze(1).expand(B, T, C)

        combined_representation = torch.cat((all_tokens_hidden, last_token_hidden), dim=-1)
        logits = self.classifier(combined_representation)
        return logits

In [8]:
classifier = SentenceClassifier(model, num_labels=1)

In [9]:
# get number of parameters
sum(p.numel() for p in classifier.parameters() if p.requires_grad)

4097

In [10]:
tokenizer.pad_token = "<|finetune_right_pad_id|>"

In [11]:
class SentenceDataset(Dataset):
    def __init__(self, texts, labels, tokenizer):
        """
        texts: list of multi-sentence strings.
        labels: list of lists containing one label per sentence.
        """
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.texts)
    

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]

        return {"text": text, "label": label}

In [18]:
import pandas as pd
df = pd.read_csv("../data/datasets/test2.csv")

dataset = SentenceDataset(df["text"].tolist()[:1000], df["label"].tolist()[:1000], tokenizer)

In [19]:
def pad_labels(labels, max_len):
    """
    Pads the labels to the maximum sentence count.
    """
    padded_labels = []
    if isinstance(labels, list):
        for label in labels:
            padded_labels.append([label] * max_len)
        return padded_labels
    else:
        return [labels] * max_len

In [20]:
def collate_fn(batch):
    texts = [item["text"] for item in batch]
    labels = [item["label"] for item in batch]
    encodings = tokenizer(
        texts,
        truncation=True,
        padding="max_length",
        max_length=512,
        return_tensors="pt"
    )

    labels_padded = [torch.where(t == 0, torch.tensor(-100), torch.tensor(label)) for t, label in zip(encodings["attention_mask"], labels)]
    labels_padded = torch.cat(labels_padded)
    encodings["labels"] = labels_padded

    return encodings

In [21]:
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

In [22]:
from tqdm import tqdm

In [None]:
optimizer = optim.AdamW(classifier.parameters(), lr=5e-5)
criterion = nn.BCEWithLogitsLoss()
device = "cuda"
classifier.to(device)
classifier.train()
num_epochs = 3
for epoch in range(num_epochs):
    total_loss = 0
    for batch in tqdm(dataloader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels_batch = batch['labels'].to(device)
        mask = labels_batch != -100

        optimizer.zero_grad()
        logits = classifier(input_ids, attention_mask)
        logits = logits.view(-1)[mask]
        labels_batch = labels_batch.view(-1)[mask].float()

        loss = criterion(logits, labels_batch)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(dataloader):.4f}")

100%|██████████| 125/125 [04:46<00:00,  2.29s/it]


Epoch 1/3, Loss: 0.0434


100%|██████████| 125/125 [05:22<00:00,  2.58s/it]


Epoch 2/3, Loss: 0.0036


  2%|▏         | 2/125 [00:05<05:20,  2.61s/it]