In [None]:
import pandas as pd

train_main_df = pd.read_parquet("data/train.parquet", engine="pyarrow")
train_main_df.head()

In [None]:
required_columns = ["prompt", "response_a", "response_b", "winner"]
train_df = train_main_df[required_columns]
train_df.head()

In [None]:
from sklearn.model_selection import train_test_split

train_frame, validation_frame = train_test_split(train_df, random_state=2024, test_size=0.25)

In [None]:
corpus = list()
for idx in range(len(train_frame)):
    corpus.append(train_frame.iloc[idx]["prompt"])
    corpus.append(train_frame.iloc[idx]["response_a"])
    corpus.append(train_frame.iloc[idx]["response_b"])

In [None]:
from tqdm.auto import tqdm

vocabulary = set()
token_lens = list()


for idx, sentence in tqdm(enumerate(corpus)):
    sentence = sentence.replace("\n", " ")
    tokens = sentence.split(" ")
    token_lens.append(len(tokens))
    for token in tokens:
        if token != " " or token != "":
            vocabulary.add(token)

In [None]:
word_to_idx = {
    word:idx for idx, word in enumerate(vocabulary)
}
len(word_to_idx)

In [None]:
import pickle

with open("vocabulary.pkl", "wb") as f:
    pickle.dump(word_to_idx, f)

In [None]:
import numpy as np

np.max(token_lens)

In [None]:
from torch.utils.data import Dataset
import torch


class ResponseDataset(Dataset):
    def __init__(self, df, word_to_idx=word_to_idx, max_len=2048, pad_token="[PAD]", oov_token="[OOV]"):
        self.df = df
        self.word_to_idx = word_to_idx
        self.max_len = max_len
        self.pad_token = pad_token
        self.oov_token = oov_token
        
        # add pad and oov token
        self.word_to_idx[pad_token] = len(word_to_idx)
        self.word_to_idx[oov_token] = len(word_to_idx)
        
        # label dict
        self.label_dict = {
            "model_a": 0,
            "model_b": 1
        }
        
    def __len__(self):
        return len(self.df)
    
    def __encode(self, text):
        encoded = torch.ones(self.max_len, dtype=torch.long) * \
            self.word_to_idx.get(self.pad_token)
        
        text = text.replace("\n", " ")
        tokens = text.split(" ")
        # limit to max len
        tokens = tokens[:self.max_len]
        
        for idx, token in enumerate(tokens):
            word_idx = self.word_to_idx.get(token, self.word_to_idx.get(self.oov_token))
            encoded[idx] = word_idx
            
        return encoded
    
    def __getitem__(self, idx):
        prompt = self.__encode(self.df.iloc[idx]["prompt"])
        response_a = self.__encode(self.df.iloc[idx]["response_a"])
        response_b = self.__encode(self.df.iloc[idx]["response_b"])
        
        label = self.df.iloc[idx]["winner"]
        label = self.label_dict.get(label)
        
        return {
            "prompt": prompt,
            "positive": response_a if label == 0 else response_b,
            "negative": response_b if label == 0 else response_a,
        }
        
        
ds = ResponseDataset(train_frame)
ds[0]

In [None]:
trainset = ResponseDataset(train_frame)
valset = ResponseDataset(validation_frame)

In [None]:
from torch.utils.data import DataLoader

bs = 128
train_loader = DataLoader(trainset, batch_size=bs, shuffle=True) 
val_loader = DataLoader(valset, batch_size=bs, shuffle=False) 

In [None]:
for batch in train_loader:
    print(batch)
    break

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import lightning.pytorch as L
import torch.optim as optim

class Classifier(L.LightningModule):
    def __init__(
        self,
        input_dim,
        hidden_dim,
        embedding_dim,
        embedding_size,
        dropout,
        lr,
        batch_size
    ):
        super().__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.embedding_dim = embedding_dim
        self.embedding_size = embedding_size
        self.dropout = dropout
        self.lr = lr
        self.batch_size = batch_size
        
        self.save_hyperparameters()
        
        # modules
        self.embedding = nn.Embedding(
            self.embedding_size,
            self.embedding_dim
        )
        
        self.mlp = nn.Sequential(
            nn.Linear(self.embedding_dim, self.hidden_dim),
            nn.ReLU(),
            nn.Dropout(self.dropout),
            nn.Linear(self.hidden_dim, hidden_dim // 2),
            nn.Tanh()
        )
        
        # criterion 
        self.criterion = nn.TripletMarginLoss()
        
    def forward(self, prompt, positive, negative):
        prompt = self.embedding(prompt)
        positive = self.embedding(positive)
        negative = self.embedding(negative)
        
        prompt = self.mlp(prompt)
        positive = self.mlp(positive)
        negative = self.mlp(negative)
        
        return prompt, positive, negative
    
    def compute_loss(self, batch):
        prompt, positive, negative = self(**batch)
        loss = self.criterion(prompt, positive, negative)
        
        return loss
    
    # TODO: later
    def compute_metrics(self, batch):
        pass
    
    
    def configure_optimizers(self):
        return optim.Adam(lr=self.lr, params=self.parameters())
    
    def training_step(self, batch, batch_idx):
        loss = self.compute_loss(batch)
        self.log("loss/train", loss, prog_bar=True,
                 batch_size=self.batch_size)
        
        return {
            "loss": loss,
            "log": {
                "loss/train": loss
            }
        }
    
    def validation_step(self, batch, batch_idx):
        loss = self.compute_loss(batch)
        self.log("loss/validation", loss, prog_bar=True,
                 batch_size=self.batch_size)
        return {
            "val_loss": loss,
            "log": {
                "loss/validation": loss
            }
        }

    


# model = Classifier(2048, 256, 512, len(word_to_idx) + 2, 0.1, 1e-3, bs)
# with torch.no_grad():
#     for batch in train_loader:
#         out = model(**batch)
#         p, a, b = out
        
#         loss = F.triplet_margin_loss(p, a, b)
#         print(loss)
        
#         break

In [None]:
torch.set_float32_matmul_precision('high')


model = Classifier(2048, 256, 512, len(word_to_idx) + 2, 0.1, 1e-3, bs)

trainer = L.Trainer(
    max_epochs=2,
    devices=1,
    accelerator="gpu",
    log_every_n_steps=50
)

trainer.fit(model, train_loader, val_loader)