In [1]:
# week 3 
# for each language train a classifier 
import pandas as pd
import plotly.express as px
from plotly.subplots import make_subplots
import os
from tqdm.asyncio import tqdm as async_tqdm
from tqdm import tqdm
import nest_asyncio
import torch
from torch import nn
from torch.optim import Adam
from torch.nn import BCELoss
from dataclasses import dataclass
import numpy as np
nest_asyncio.apply()

current_dir = os.getcwd()
if current_dir.endswith("code"):
    os.chdir("..")
else:
    print("current dir", current_dir)


In [2]:
ds_train = pd.read_parquet("dataset/train_df.parquet")
ds_val = pd.read_parquet("dataset/val_df.parquet")


In [3]:
def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif torch.backends.mps.is_available():
        return torch.device("mps")
    else:
        return torch.device("cpu")

In [4]:
from git import Optional
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, precision_score, recall_score, balanced_accuracy_score


class AnswerableClassifier(nn.Module):
    def __init__(
        self,
        d_model: int = 128,
        expansion_factor: int = 2,
        with_context: bool = True,
    ):
        super().__init__()
        
        #simple linear model
        _d_model = d_model * 2 if with_context else d_model
        d_hidden = _d_model * expansion_factor
        self.l_in = nn.Linear(_d_model, d_hidden)
        self.relu = nn.ReLU()
        self.l_out = nn.Linear(d_hidden, 1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        x = self.l_in(x)
        x = self.relu(x)
        x = self.l_out(x)
        x = self.sigmoid(x)
        return x

@dataclass
class TrainConfig:
    lr: float = 1e-3
    batch_size : int = 32
    n_epochs : int = 10

class Trainer:
    def __init__(
        self,
        config: TrainConfig,
        d_model : int,
        expansion_factor: int = 2,
        with_context: bool = True,
    ):
        self.d_model = d_model
        self.config = config
        self.with_context = with_context
        if with_context:
            self.d_model *= 2
        self.model = AnswerableClassifier(d_model, expansion_factor, with_context)
        self.optimizer = Adam(self.model.parameters(), lr=self.config.lr)
        self.criterion = BCELoss()
        self.device = get_device()

    def get_batch(self, data: pd.DataFrame, with_context : bool = True) -> tuple[torch.Tensor, torch.Tensor]:
        
        # we want to concatenate the question and context embeddings
        # note that we are using matryoshka embeddings from openai 
        # so we can slice the embeddings and still get useful information
        question_embeddings = torch.tensor(np.stack(data['question_embedding']), dtype=torch.float32, device=self.device)
        if with_context:
            context_embeddings = torch.tensor(np.stack(data['context_embedding']), dtype=torch.float32, device=self.device)
            x = torch.cat(
                (question_embeddings[:, :self.d_model], context_embeddings[:, :self.d_model]), 
                dim=1, 
            )
        else:
            x = question_embeddings[:, :self.d_model]
            
        y = torch.tensor(np.stack(data['answer_label']), device=self.device, dtype=torch.float32).unsqueeze(1)
        
        return x, y

    def fit(self, train_df: pd.DataFrame):
        
        total_iterations = self.config.n_epochs * len(train_df) // self.config.batch_size
        pbar = tqdm(total=total_iterations, desc="Training")
        for epoch in range(self.config.n_epochs):
            self.model.train()
            self.model.to(torch.float32)
            self.model.to(self.device)
            
            for batch_idx in range(0, len(train_df), self.config.batch_size):
                data = train_df.iloc[batch_idx:batch_idx+self.config.batch_size]
                x, y = self.get_batch(data, with_context=self.with_context)
                pred = self.model(x)
                loss = self.criterion(pred, y)
                loss.backward()
                #print("grad", self.model.l_in.weight.grad)
                self.optimizer.step()
                pbar.update(1)
                self.model.zero_grad()

        pbar.close()
        
    def evaluate(self, val_df: pd.DataFrame):
        self.model.eval()
        self.model.to(self.device)
        
        preds = []
        true_values = []
        with torch.no_grad():
            for batch_idx in range(0, len(val_df), self.config.batch_size):
                data = val_df.iloc[batch_idx:batch_idx+self.config.batch_size]
                x, y = self.get_batch(data, with_context=self.with_context)
                pred = self.model(x)
                preds.extend(pred.view(-1).cpu().numpy())
                true_values.extend(y.view(-1).cpu().numpy())
          
        preds = np.array(preds).round()
        true_values = np.array(true_values).round()
          
        bce_loss = self.criterion(torch.tensor(preds), torch.tensor(true_values))
        acc = accuracy_score(true_values, preds)
        balanced_acc = balanced_accuracy_score(true_values, preds)
        f1 = f1_score(true_values, preds)
        precision = precision_score(true_values, preds)
        recall = recall_score(true_values, preds)
        conf_matrix = confusion_matrix(true_values, preds)
        normalized_conf_matrix = conf_matrix.astype('float') / conf_matrix.sum(axis=1)[:, np.newaxis]
        
        return {
            'accuracy': acc,
            'bce_loss': bce_loss,
            'balanced_accuracy': balanced_acc,
            'f1': f1,
            'precision': precision,
            'recall': recall,
            'confusion_matrix': conf_matrix,
            'normalized_confusion_matrix': normalized_conf_matrix
        }
        
    
    def save(self):
        os.makedirs("models", exist_ok=True)
        torch.save(self.model.state_dict(), "models/model.pt")


def filter_language(ds: pd.DataFrame, language: str) -> pd.DataFrame:
    return ds[ds['lang'] == language]


train_ru = filter_language(ds_train, 'ru')
val_ru = filter_language(ds_val, 'ru')
train_ja = filter_language(ds_train, 'ja')
val_ja = filter_language(ds_val, 'ja')
train_fi = filter_language(ds_train, 'fi')
val_fi = filter_language(ds_val, 'fi')


In [8]:
#scaling experiments

expansion_factors = [2]


for lang in ['ru', 'ja', 'fi']:
    lst = []
    train_ds = filter_language(ds_train, lang)
    val_ds = filter_language(ds_val, lang)
    for expansion_factor in tqdm(expansion_factors):
        print(f"Training for {lang} with expansion factor {expansion_factor}")
        trainer = Trainer(TrainConfig(), d_model=1536, with_context=True, expansion_factor=expansion_factor)
        trainer.fit(train_ds)
        metrics = trainer.evaluate(val_ds)
        metrics['expansion_factor'] = expansion_factor
        lst.append(metrics)
        print(f"Evaluation for {lang} with expansion factor {expansion_factor}: {metrics}")
""" 
    df = pd.DataFrame(lst)

    fig = px.line(
        df, 
        x='expansion_factor', 
        y='balanced_accuracy', 
        title=f'Accuracy for {lang}', 
        labels={'expansion_factor': 'Expansion Factor', 'normalized_accuracy': 'Accuracy'}
    )
    fig.show() """




  0%|          | 0/1 [00:00<?, ?it/s]

Training for ru with expansion factor 2




[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

Training: 620it [00:06, 94.80it/s]                         
100%|██████████| 1/1 [00:06<00:00,  6.63s/it]


Evaluation for ru with expansion factor 2: {'accuracy': 0.9772727272727273, 'bce_loss': tensor(2.2727), 'balanced_accuracy': 0.7954545454545454, 'f1': 0.988110964332893, 'precision': 0.9765013054830287, 'recall': 1.0, 'confusion_matrix': array([[ 13,   9],
       [  0, 374]]), 'normalized_confusion_matrix': array([[0.59090909, 0.40909091],
       [0.        , 1.        ]]), 'expansion_factor': 2}


  0%|          | 0/1 [00:00<?, ?it/s]

Training for ja with expansion factor 2




[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

Training: 720it [00:07, 94.05it/s]                         
100%|██████████| 1/1 [00:07<00:00,  7.86s/it]


Evaluation for ja with expansion factor 2: {'accuracy': 0.8289473684210527, 'bce_loss': tensor(17.1053), 'balanced_accuracy': 0.524390243902439, 'f1': 0.9055690072639225, 'precision': 0.827433628318584, 'recall': 1.0, 'confusion_matrix': array([[  4,  78],
       [  0, 374]]), 'normalized_confusion_matrix': array([[0.04878049, 0.95121951],
       [0.        , 1.        ]]), 'expansion_factor': 2}


  0%|          | 0/1 [00:00<?, ?it/s]

Training for fi with expansion factor 2




[A[A

[A[A

[A[A

[A[A

[A[A

[A[A
[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

Training: 670it [00:07, 93.40it/s]
100%|██████████| 1/1 [00:07<00:00,  7.28s/it]

Evaluation for fi with expansion factor 2: {'accuracy': 0.9242424242424242, 'bce_loss': tensor(7.5758), 'balanced_accuracy': 0.713680387409201, 'f1': 0.9585921325051759, 'precision': 0.937246963562753, 'recall': 0.9809322033898306, 'confusion_matrix': array([[ 25,  31],
       [  9, 463]]), 'normalized_confusion_matrix': array([[0.44642857, 0.55357143],
       [0.0190678 , 0.9809322 ]]), 'expansion_factor': 2}





" \n    df = pd.DataFrame(lst)\n\n    fig = px.line(\n        df, \n        x='expansion_factor', \n        y='balanced_accuracy', \n        title=f'Accuracy for {lang}', \n        labels={'expansion_factor': 'Expansion Factor', 'normalized_accuracy': 'Accuracy'}\n    )\n    fig.show() "