## Library installation

In [None]:
!pip install -q "flwr[simulation]" flwr-datasets

In [None]:
!pip install transformers

In [None]:
!pip install torch==1.13.1+cpu torchvision==0.14.1+cpu --extra-index-url https://download.pytorch.org/whl/cpu

In [None]:
!pip install matplotlib

In [None]:
!pip install scikit-learn

## Imports

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import random

import torch
import torch.nn as nn

import tqdm

from datasets import Dataset, DatasetDict

from sklearn.metrics import accuracy_score, precision_score, recall_score

from torch.utils.data import DataLoader, Subset

from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizer

## Data preprocessing

In [None]:
label_mapping = {
    "negative": 0,
    "neutral": 1,
    "positive": 2,
}

In [None]:
def preprocess_data(df: pd.DataFrame) -> pd.DataFrame:
    df = df[["sentence", "gold_label"]].rename(columns={"sentence": "text", "gold_label": "label"})
    df = df[df.label != "mixed"].dropna()
    df["label"] = df["label"].map(label_mapping)

    return df

In [None]:
def tokenize_data(ds: Dataset, tokenizer: PreTrainedTokenizer) -> pd.DataFrame:
    ds = ds.map(
        lambda s, tok: {
            "ids": (encoded := tok(s["text"], truncation=True, padding=True))["input_ids"],
            "attention_mask": encoded["attention_mask"],
        },
        fn_kwargs={"tok": tokenizer},
    )

    return ds

In [None]:
training_data = preprocess_data(pd.read_json("data/dynasent-v1.1-round01-yelp-train.jsonl", lines=True))
validation_data = pd.concat([
    preprocess_data(pd.read_json("data/dynasent-v1.1-round01-yelp-test.jsonl", lines=True)),
    preprocess_data(pd.read_json("data/dynasent-v1.1-round01-yelp-dev.jsonl", lines=True)),
], ignore_index=True).drop_duplicates()

In [None]:
#joint_data = pd.concat([training_data, validation_data], ignore_index=True)

In [None]:
#training_data, validation_data = train_test_split(joint_data, test_size=0.15)

In [None]:
training_ds = Dataset.from_pandas(training_data, preserve_index=False)
validation_ds = Dataset.from_pandas(validation_data, preserve_index=False)

In [None]:
distilbert_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

In [None]:
training_ds = tokenize_data(training_ds, distilbert_tokenizer)
validation_ds = tokenize_data(validation_ds, distilbert_tokenizer)

In [None]:
training_ds = training_ds.with_format(type="torch", columns=["ids", "label", "attention_mask"])
validation_ds = validation_ds.with_format(type="torch", columns=["ids", "label", "attention_mask"])

In [None]:
DatasetDict({"train": training_ds, "test": validation_ds})

In [None]:
def get_data_loader(dataset: Dataset, batch_size: int, pad_index, shuffle=False) -> DataLoader:
    def collate_fn(batch):
        batch_ids = nn.utils.rnn.pad_sequence([i["ids"] for i in batch], padding_value=pad_index, batch_first=True)
        batch_label = torch.stack([i["label"] for i in batch])
        batch_mask = nn.utils.rnn.pad_sequence([i["attention_mask"] for i in batch], padding_value=pad_index, batch_first=True)
        
        return {
            "ids": batch_ids,
            "label": batch_label,
            "attention_mask": batch_mask,
        }
    
    dl = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        shuffle=shuffle
    )
    
    return dl

In [None]:
training_dl = get_data_loader(training_ds, 32, distilbert_tokenizer.pad_token_id, shuffle=True)
validation_dl = get_data_loader(validation_ds, 32, distilbert_tokenizer.pad_token_id)

## Model definition

In [None]:
class Transformer(nn.Module):
    def __init__(self, transformer, num_classes: int, freeze: bool):
        super().__init__()
        
        self.transformer = transformer
        self.fc = nn.Linear(transformer.config.hidden_size, num_classes)
        self.dropout = nn.Dropout(0.3)
        
        if freeze:
            for param in self.transformer.parameters():
                param.requires_grad = False

    
    def forward(self, ids: torch.Tensor, attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        output = self.transformer(ids, attention_mask=attention_mask, output_attentions=True)
        pooled_mean = torch.mean(output.last_hidden_state, dim=1)
        cls_hidden = self.dropout(pooled_mean)
        prediction = self.fc(cls_hidden)

        return prediction, output.attentions

In [None]:
distilbert_tf = AutoModel.from_pretrained("distilbert-base-uncased")

model = Transformer(distilbert_tf, num_classes=3, freeze=False)

## Loading from backup
### Checkpoint loading

In [None]:
# 6th epoch
checkpoint = torch.load("model/checkpoint5.pth")

In [None]:
model.load_state_dict(checkpoint["model_state_dict"])

### Trained model loading

In [None]:
model.load_state_dict(torch.load("model/trained_model.pth", weights_only=True))

## Model configuration

In [None]:
num_parameters = sum(value.numel() for value in model.state_dict().values())
print(f"{num_parameters = }")

In [None]:
criterion = torch.nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = model.to(device)
criterion = criterion.to(device)

print(model)

## Training and evaluation

In [None]:
def get_accuracy(prediction, label) -> np.float64:
    predicted_classes = prediction.argmax(dim=-1).cpu().numpy()
    actual_labels = label.cpu().numpy()
    
    return accuracy_score(actual_labels, predicted_classes)

def get_precision(prediction, label) -> np.float64:
    predicted_classes = prediction.argmax(dim=-1).cpu().numpy()
    actual_labels = label.cpu().numpy()
    
    return precision_score(actual_labels, predicted_classes, average="macro", zero_division=0)

def get_recall(prediction, label) -> np.float64:
    predicted_classes = prediction.argmax(dim=-1).cpu().numpy()
    actual_labels = label.cpu().numpy()
    
    return recall_score(actual_labels, predicted_classes, average="macro", zero_division=0)

def get_f1_score(precision: np.float64, recall: np.float64) -> np.float64:
    return np.float64(2.0) * (precision * recall) / (precision + recall)

In [None]:

def train(
        net: Transformer,
        data_loader: DataLoader,
        optimizer: torch.optim.Optimizer,
) -> tuple[np.float64, np.float64, np.float64, np.float64, np.float64]:
    net.train()
    
    batch_losses = []
    batch_accuracies = []
    batch_precisions = []
    batch_recalls = []
    
    for batch in tqdm.tqdm(data_loader, desc="Training..."):
        ids = batch["ids"].to(device)
        label = batch["label"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        
        prediction, _ = net(ids, attention_mask)
        
        loss = criterion(prediction, label)
        accuracy = get_accuracy(prediction, label)
        precision = get_precision(prediction, label)
        recall = get_recall(prediction, label)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        batch_losses.append(loss)
        batch_accuracies.append(accuracy)
        batch_precisions.append(precision)
        batch_recalls.append(recall)

    avg_loss = np.mean(batch_losses)
    avg_accuracy = np.mean(batch_accuracies)
    avg_precision = np.mean(batch_precisions)
    avg_recall = np.mean(batch_recalls)
    f1_score = get_f1_score(avg_precision, avg_recall)
        
    return avg_loss, avg_accuracy, avg_precision, avg_recall, f1_score

def test(
        net: Transformer,
        data_loader: DataLoader,
) -> tuple[np.float64, np.float64, np.float64, np.float64, np.float64]:
    net.eval()
    
    batch_losses = []
    batch_accuracies = []
    batch_precisions = []
    batch_recalls = []
    
    with torch.no_grad():
        for batch in tqdm.tqdm(data_loader, desc="Evaluating..."):
            ids = batch["ids"].to(device)
            label = batch["label"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            
            prediction, _ = net(ids, attention_mask)
            
            loss = criterion(prediction, label)
            accuracy = get_accuracy(prediction, label)
            precision = get_precision(prediction, label)
            recall = get_recall(prediction, label)
            
            batch_losses.append(loss)
            batch_accuracies.append(accuracy)
            batch_precisions.append(precision)
            batch_recalls.append(recall)

    avg_loss = np.mean(batch_losses)
    avg_accuracy = np.mean(batch_accuracies)
    avg_precision = np.mean(batch_precisions)
    avg_recall = np.mean(batch_recalls)
    f1_score = get_f1_score(avg_precision, avg_recall)
            
    return avg_loss, avg_accuracy, avg_precision, avg_recall, f1_score

def run_centralized(
        training_loader: DataLoader,
        validation_loader: DataLoader,
        epochs: int,
        learning_rate: float,
        save_checkpoints: bool,
        first_epoch: int = 0,
        optimizer_state_dict = None,
) -> dict[str, dict[str, list[np.float64]]]:
    train_losses = []
    train_accuracies = []
    train_precisions = []
    train_recalls = []
    train_f1_scores = []

    test_losses = []
    test_accuracies = []
    test_precisions = []
    test_recalls = []
    test_f1_scores = []
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

    if optimizer_state_dict:
        optimizer.load_state_dict(optimizer_state_dict)

    for epoch in range(first_epoch, epochs):
        print(f"Training epoch #{epoch + 1}:")
        
        train_loss, train_accuracy, train_precision, train_recall, train_f1_score = train(model, training_loader, optimizer)

        train_losses.append(train_loss)
        train_accuracies.append(train_accuracy)
        train_precisions.append(train_precision)
        train_recalls.append(train_recall)
        train_f1_scores.append(train_f1_score)

        print(f"{train_loss = }")
        print(f"{train_accuracy = }")
        print(f"{train_precision = }")
        print(f"{train_recall = }")
        print(f"{train_f1_score = }")

        test_loss, test_accuracy, test_precision, test_recall, test_f1_score = test(model, validation_loader)
        
        test_losses.append(test_loss)
        test_accuracies.append(test_accuracy)
        test_precisions.append(test_precision)
        test_recalls.append(test_recall)
        test_f1_scores.append(test_f1_score)

        print(f"{test_loss = }")
        print(f"{test_accuracy = }")
        print(f"{test_precision = }")
        print(f"{test_recall = }")
        print(f"{test_f1_score = }")
        
        if save_checkpoints:
            torch.save({
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "train": {
                    "loss": train_losses,
                    "accuracy": train_accuracies,
                    "precision": train_precisions,
                    "recall": train_recalls,
                    "f1_score": train_f1_scores,
                },
                "test": {
                    "loss": test_losses,
                    "accuracy": test_accuracies,
                    "precision": test_precisions,
                    "recall": test_recalls,
                    "f1_score": test_f1_scores,
                },
            }, f"model/checkpoint{epoch}.pth")
            
    return {
        "train": {
            "loss": train_losses,
            "accuracy": train_accuracies,
            "precision": train_precisions,
            "recall": train_recalls,
            "f1_score": train_f1_scores,
        },
        "test": {
            "loss": test_losses,
            "accuracy": test_accuracies,
            "precision": test_precisions,
            "recall": test_recalls,
            "f1_score": test_f1_scores,
        },
    }

In [None]:
n = 6
lr = 1e-5

In [None]:
metrics = run_centralized(
    training_dl,
    validation_dl,
    epochs=n,
    learning_rate=lr,
    save_checkpoints=True,
    #first_epoch=checkpoint["epoch"],
    #optimizer_state_dict=checkpoint["optimizer_state_dict"],
)

In [None]:
torch.cuda.empty_cache()

### Predictions on 10 random samples

In [None]:
reverse_mapping = {v: k for k, v in label_mapping.items()}

random_indices = random.sample(range(0, len(validation_ds)), 10)

texts = [validation_ds['text'][i] for i in random_indices]
labels = [validation_ds['label'][i].item() for i in random_indices]
predictions = []

prediction_sample = get_data_loader(Subset(validation_ds, random_indices), 1, distilbert_tokenizer.pad_token_id)

In [None]:
model.eval()

with torch.no_grad():
    i = 0
    
    for batch in prediction_sample:
        ids = batch["ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        prediction, _ = model(ids, attention_mask)
        predicted_class = prediction.argmax(dim=-1)

        predictions.append(predicted_class.item())

In [None]:
for item in list(zip(texts, [reverse_mapping[l] for l in labels], [reverse_mapping[p] for p in predictions])):
    print(item)

### Model saving

In [None]:
torch.save(model.state_dict(), "model/trained_model.pth")

### Evaluation graph

In [None]:
epoch_rng = range(1, n + 1)

In [None]:
plt.plot(epoch_rng, metrics["train"]["loss"], label="Training", color="yellow")
plt.plot(epoch_rng, metrics["test"]["loss"], label="Validation", color="red")

plt.xlabel("Epoch")
plt.ylabel("Value")

plt.title("Model Loss Evaluation")
plt.legend()
plt.show()

In [None]:
plt.plot(epoch_rng, metrics["train"]["accuracy"], label="Training", color="yellow")
plt.plot(epoch_rng, metrics["test"]["accuracy"], label="Validation", color="blue")

plt.xlabel("Epoch")
plt.ylabel("Value")

plt.title("Model Accuracy Evaluation")
plt.legend()
plt.show()

In [None]:
plt.plot(epoch_rng, metrics["train"]["precision"], label="Training", color="yellow")
plt.plot(epoch_rng, metrics["test"]["precision"], label="Validation", color="green")

plt.xlabel("Epoch")
plt.ylabel("Value")

plt.title("Model Precision Evaluation")
plt.legend()
plt.show()

In [None]:
plt.plot(epoch_rng, metrics["train"]["recall"], label="Training", color="yellow")
plt.plot(epoch_rng, metrics["test"]["recall"], label="Validation", color="purple")

plt.xlabel("Epoch")
plt.ylabel("Value")

plt.title("Model Recall Evaluation")
plt.legend()
plt.show()

In [None]:
plt.plot(epoch_rng, metrics["train"]["f1_score"], label="Training", color="yellow")
plt.plot(epoch_rng, metrics["test"]["f1_score"], label="Validation", color="orange")

plt.xlabel("Epoch")
plt.ylabel("Value")

plt.title("Model F1 Score Evaluation")
plt.legend()
plt.show()