In [None]:
import random
from collections import Counter, defaultdict, namedtuple
from typing import Tuple, List, Dict, Any

import torch
import numpy as np

from tqdm import tqdm, trange
from transformers import AutoTokenizer, AutoModel, PreTrainedTokenizer, AutoModelForSequenceClassification
from bertviz import head_view, model_view
from transformers.tokenization_utils_base import BatchEncoding
from transformers.trainer_callback import dataclass
from sklearn.metrics import f1_score, accuracy_score
import pandas as pd
from torch.utils.data import random_split

In [None]:
def set_global_seed(seed: int) -> None:
    
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


set_global_seed(42)

In [None]:
model_name = "DeepPavlov/rubert-base-cased"

device = "mps"

In [None]:
def get_label2idx(label_set: List[str]) -> Dict[str, int]:

    label2idx: Dict[str, int] = {}

    it = 0
    for label in label_set:
        label2idx[label] = it
        it += 1

    return label2idx

label2idx = get_label2idx(['-1', '0', '1'])

In [None]:
def parse_dataset() -> Tuple[List[List[str]], List[List[str]]]:
    
    df = pd.read_csv("sample_data/train_data.csv", sep='\t') #train_data

    data = []
    labels = []
        
    for _, item in df.iterrows():
        origin = item['sentence']
        data.append(origin)
        labels.append(item['label'])
    
    return data, labels

class TransformersDataset(torch.utils.data.Dataset):

    def __init__(
        self,
        data: List[List[str]],
        labels: List[str],
    ):
        self.data = data
        self.labels = self.process_labels(labels, label2idx)

    def __len__(self):
        return len(self.data)

    def __getitem__(
        self,
        idx: int,
    ) -> Tuple[List[str], int]:
        return self.data[idx], self.labels[idx]

    @staticmethod
    def process_labels(
        labels: List[str],
        label2idx: Dict[str, int],
    ) -> List[int]:
       
        l_indices = [] 
        for label in labels:
            # print(label2idx)
            # print(label)
            l_indices.append(label2idx[str(label)])
            
        return l_indices

In [None]:
class TransformersCollator:

    def __init__(
        self,
        tokenizer: PreTrainedTokenizer,
        tokenizer_kwargs: Dict[str, Any],
    ):
        self.tokenizer = tokenizer
        self.tokenizer_kwargs = tokenizer_kwargs

    def __call__(
        self,
        batch: List[Tuple[List[str], int]],
    ) -> Tuple[torch.LongTensor, torch.LongTensor]:
        data, labels = zip(*batch)

        data = self.tokenizer(list(data), **self.tokenizer_kwargs)

        data.pop("offset_mapping")

        return data, torch.tensor(labels)

In [None]:
data, labels = parse_dataset()
dataset = TransformersDataset(
    data=data,
    labels=labels,
)

In [None]:
train_data, val_data = random_split(dataset, [5637, 1000])

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer_kwargs = {
    #"is_split_into_words":    True,
    "return_offsets_mapping": True,
    "padding":                True,
    "truncation":             True,
    "max_length":             512,
    "return_tensors":         "pt",
    "add_special_tokens":    False
}
collator = TransformersCollator(
    tokenizer=tokenizer,
    tokenizer_kwargs=tokenizer_kwargs,
)

In [None]:
train_dataloader = torch.utils.data.DataLoader(
    train_data,
    batch_size=12,
    shuffle=True,
    collate_fn=collator,
)
valid_dataloader = torch.utils.data.DataLoader(
    val_data,
    batch_size=1,
    shuffle=False,
    collate_fn=collator,
)
tokens, labels = next(iter(train_dataloader))

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels = 3, 
    output_attentions=True,
).to(device)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
criterion = torch.nn.CrossEntropyLoss()

In [None]:
def compute_metrics(y_true, y_pred):

  f1_macro = f1_score(
      y_true=y_true.flatten().cpu(),
      y_pred=torch.argmax(y_pred, axis=1).flatten().cpu(),
      average="macro",
      zero_division=0,
    )
  
  accuracy = accuracy_score(
      y_true=y_true.flatten().cpu(),
      y_pred=torch.argmax(y_pred, axis=1).flatten().cpu(),
  )
  return f1_macro, accuracy

def train_epoch(
    model: torch.nn.Module,
    dataloader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: torch.nn.Module,
    device: torch.device,
    epoch: int,
) -> None:

    model.train()

    epoch_loss = []
    batch_f1_metrics_list = []
    batch_ac_metrics_list = []

    for i, (tokens, labels) in tqdm(
        enumerate(dataloader),
        total=len(dataloader),
        desc="loop over train batches",
    ):

        tokens, labels = tokens.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(**tokens)
        loss = criterion(outputs["logits"], labels)
        loss.backward()
        optimizer.step()

        epoch_loss.append(loss.item())


        with torch.no_grad():
            model.eval()
            outputs_inference = model(**tokens)["logits"]
            model.train()

        f1_metric, ac_metric = compute_metrics(
            y_true=labels,
            y_pred=outputs_inference,
        )
       
        batch_f1_metrics_list.append(f1_metric)
        batch_ac_metrics_list.append(ac_metric)


    avg_loss = np.mean(epoch_loss)
    print(f"Train loss: {avg_loss}\n")

    f1_metric_per_batch = np.mean(batch_f1_metrics_list)
    ac_metric_per_batch = np.mean(batch_ac_metrics_list)
    print(f"Train F1-macro: {f1_metric_per_batch}\n")
    print(f"Train Accuracy: {ac_metric_per_batch}\n")

def evaluate_epoch(
    model: torch.nn.Module,
    dataloader: torch.utils.data.DataLoader,
    criterion: torch.nn.Module,
    device: torch.device,
    epoch: int,
) -> None:
  
    model.eval()

    epoch_loss = []
    batch_f1_metrics_list = []
    batch_ac_metrics_list = []
    
    with torch.no_grad():

        for i, (tokens, labels) in tqdm(
            enumerate(dataloader),
            total=len(dataloader),
            desc="loop over test batches",
        ):

            tokens, labels = tokens.to(device), labels.to(device)

            outputs = model(**tokens)["logits"]
            loss = criterion(outputs, labels)

            epoch_loss.append(loss.item())

            f1_metric, ac_metric = compute_metrics(
                y_true=labels,
                y_pred=outputs,
            )
          
            batch_f1_metrics_list.append(f1_metric)
            batch_ac_metrics_list.append(ac_metric)

        avg_loss = np.mean(epoch_loss)
        print(f"Test loss:  {avg_loss}\n")

        f1_metric_per_batch = np.mean(batch_f1_metrics_list)
        ac_metric_per_batch = np.mean(batch_ac_metrics_list)
        print(f"Test F1-macro: {f1_metric_per_batch}\n")
        print(f"Test Accuracy: {ac_metric_per_batch}\n")

def train_with_val(
    n_epochs: int,
    model: torch.nn.Module,
    train_dataloader: torch.utils.data.DataLoader,
    test_dataloader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: torch.nn.Module,
    device: torch.device,
) -> None:

    for epoch in range(n_epochs):

        print(f"Epoch [{epoch+1} / {n_epochs}]\n")

        train_epoch(
            model=model,
            dataloader=train_dataloader,
            optimizer=optimizer,
            criterion=criterion,
            device=device,
            epoch=epoch,
        )
        evaluate_epoch(
            model=model,
            dataloader=test_dataloader,
            criterion=criterion,
            device=device,
            epoch=epoch,
        ) 

def train(
    n_epochs: int,
    model: torch.nn.Module,
    train_dataloader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: torch.nn.Module,
    device: torch.device,
) -> None:

    for epoch in range(n_epochs):

        print(f"Epoch [{epoch+1} / {n_epochs}]\n")

        train_epoch(
            model=model,
            dataloader=train_dataloader,
            optimizer=optimizer,
            criterion=criterion,
            device=device,
            epoch=epoch,
        )

In [None]:
train_with_val(5, model, train_dataloader, valid_dataloader, optimizer, criterion, device)