In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
from nltk.tokenize import sent_tokenize

In [3]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
#model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")

In [3]:
model.lm_head = nn.Identity() 

In [4]:
# special_tokens = {
#     "additional_special_tokens": ["<|S1|>", "<|S2|>", "<|S3|>", "<|S4|>", "<|S5|>"]
# }

In [5]:
# tokenizer.add_special_tokens(special_tokens)
# model.resize_token_embeddings(len(tokenizer))

In [6]:
for param in model.parameters():
    param.requires_grad = False

# Unfreeze only the embeddings
# embedding_layer = model.get_input_embeddings()
# for param in embedding_layer.parameters():
#     param.requires_grad = True

In [None]:
class SentenceClassifier(nn.Module):
    def __init__(self, base_model, num_labels):
        super(SentenceClassifier, self).__init__()
        self.base_model = base_model
        self.classifier = nn.Linear(self.base_model.config.hidden_size * 2, num_labels)


    def forward(self, input_ids, attention_mask):
        outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
        B, T, C = outputs.logits.shape
        
        all_tokens_hidden = outputs.logits # (B, T, C)
        last_token_hidden = outputs.logits[:, -1, :] # (B, C)
        last_token_hidden = last_token_hidden.unsqueeze(1).expand(B, T, C)

        combined_representation = torch.cat((all_tokens_hidden, last_token_hidden), dim=-1)
        logits = self.classifier(combined_representation)
        return logits

In [8]:
classifier = SentenceClassifier(model, num_labels=1)

In [9]:
# get number of parameters
sum(p.numel() for p in classifier.parameters() if p.requires_grad)

4097

In [4]:
tokenizer.pad_token = "<|finetune_right_pad_id|>"

In [5]:
class SentenceDataset(Dataset):
    def __init__(self, texts, labels, tokenizer):
        """
        texts: list of multi-sentence strings.
        labels: list of lists containing one label per sentence.
        """
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.texts)
    

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]

        return {"text": text, "label": label}

In [6]:
import pandas as pd
df = pd.read_csv("../data/datasets/test2.csv")

dataset = SentenceDataset(df["text"].tolist()[:1000], df["label"].tolist()[:1000], tokenizer)

In [7]:
def collate_fn(batch):
    texts = [item["text"] for item in batch]
    labels = [item["label"] for item in batch]
    encodings = tokenizer(
        texts,
        truncation=True,
        padding="max_length",
        max_length=512,
        return_tensors="pt"
    )

    labels_padded = [torch.where(t == 0, torch.tensor(-100), torch.tensor(label)) for t, label in zip(encodings["attention_mask"], labels)]
    labels_padded = torch.cat(labels_padded)
    encodings["labels"] = labels_padded

    return encodings

In [8]:
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

In [9]:
from tqdm import tqdm

In [43]:
import torch
import torch.nn as nn

# Model hyperparameters
d_model = 512
vocab_size = 10000
seq_length = 10
batch_size = 32
nhead = 8
num_layers = 6
pad_token_id = 0

class TransformerDecoderWithEmbeddings(nn.Module):
    def __init__(self, vocab_size, d_model, seq_length, num_layers, nhead, pad_token_id):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id)
        self.pos_embedding = nn.Embedding(seq_length, d_model)
        decoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.transformer = nn.TransformerEncoder(decoder_layer, num_layers=num_layers)
        self.pad_token_id = pad_token_id

    def forward(self, token_ids):
        batch_size, seq_len = token_ids.shape
        
        # Compute token and positional embeddings
        token_emb = self.token_embedding(token_ids)
        pos_ids = torch.arange(seq_len, device=token_ids.device).unsqueeze(0)
        pos_emb = self.pos_embedding(pos_ids)
        embeddings = token_emb + pos_emb

        # Create a boolean causal mask: True for positions that should be masked.
        causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=token_ids.device, dtype=torch.bool), diagonal=1)

        # Create a padding mask: True at positions where the token is a pad token.
        pad_mask = token_ids.eq(self.pad_token_id)  # shape: (batch_size, seq_len)

        # Pass the embeddings along with both masks.
        output = self.transformer(embeddings, mask=causal_mask, src_key_padding_mask=pad_mask)
        return output

# Instantiate the model.
model = TransformerDecoderWithEmbeddings(vocab_size, d_model, seq_length, num_layers, nhead, pad_token_id)

In [44]:
x = torch.randint(1, vocab_size, (batch_size, seq_length))
x[:, -2:] = pad_token_id  # Adding padding at the last two positions

In [45]:
output = model(x)
output

tensor([[[ 1.6574e+00,  5.1426e-01,  2.1099e+00,  ..., -7.0132e-01,
          -7.6854e-01, -1.8805e+00],
         [ 8.1144e-01, -1.1377e-01,  1.7264e+00,  ..., -8.3019e-01,
           4.0725e-01, -1.3923e+00],
         [ 3.6627e-01,  2.2241e-01,  2.0570e+00,  ...,  4.6055e-01,
          -6.0812e-01, -8.8407e-01],
         ...,
         [-6.5247e-02,  9.3704e-01,  9.5043e-01,  ...,  7.0408e-01,
          -4.8024e-02, -1.0185e+00],
         [ 3.8166e-01,  9.0517e-01,  3.9536e-01,  ...,  3.0520e-01,
          -9.4121e-01, -8.3781e-01],
         [ 1.1688e+00,  5.3544e-01,  8.1944e-01,  ...,  1.3912e+00,
           1.7287e-01, -8.7220e-01]],

        [[ 1.0444e+00, -2.7527e-01,  4.2235e-01,  ...,  1.3844e-01,
          -6.0575e-01, -2.9296e-01],
         [-2.0117e-01, -3.4369e-01,  6.3022e-01,  ...,  6.4732e-01,
          -5.9118e-02, -2.7027e-01],
         [ 1.9697e-01, -1.0001e-01,  5.3323e-01,  ...,  6.9925e-01,
          -5.7076e-01, -7.7109e-01],
         ...,
         [ 7.7518e-01,  3

In [None]:
optimizer = optim.AdamW(classifier.parameters(), lr=5e-5)
criterion = nn.BCEWithLogitsLoss()
device = "cuda"
classifier.to(device)
classifier.train()
num_epochs = 3
for epoch in range(num_epochs):
    total_loss = 0
    for batch in tqdm(dataloader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels_batch = batch['labels'].to(device)
        mask = labels_batch != -100

        optimizer.zero_grad()
        logits = classifier(input_ids, attention_mask)
        logits = logits.view(-1)[mask]
        labels_batch = labels_batch.view(-1)[mask].float()

        loss = criterion(logits, labels_batch)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(dataloader):.4f}")

100%|██████████| 125/125 [04:46<00:00,  2.29s/it]


Epoch 1/3, Loss: 0.0434


100%|██████████| 125/125 [05:22<00:00,  2.58s/it]


Epoch 2/3, Loss: 0.0036


  2%|▏         | 2/125 [00:05<05:20,  2.61s/it]