In [None]:
!pip install -q "flwr[simulation]" flwr-datasets

In [None]:
# you might see a warning after running the command below, this can be ignored
# if you are running this outside Colab, you probably need to adjust the command below
# !pip install torch==1.13.1+cpu torchvision==0.14.1+cpu --extra-index-url https://download.pytorch.org/whl/cpu

In [None]:
!pip install matplotlib

In [10]:
import torch
import pandas as pd

from torch.utils.data import DataLoader
from datasets import Dataset, DatasetDict

label_mapping = {
    "Negative": 0,
    "Neutral": 1,
    "Positive": 2
}

df_training = pd.read_csv("data/twitter_training.csv", names=["tweet_id", "entity", "label", "text"])
df_training = df_training[df_training.label != "Irrelevant"].drop(columns=["tweet_id", "entity"]).dropna()
df_training["label"] = df_training["label"].apply(lambda l: label_mapping[str(l)])
dataset_training = Dataset.from_pandas(df_training, preserve_index=False)

df_validation = pd.read_csv("data/twitter_validation.csv", names=["tweet_id", "entity", "label", "text"])
df_validation = df_validation[df_validation.label != "Irrelevant"].drop(columns=["tweet_id", "entity"]).dropna()
df_validation["label"] = df_validation["label"].apply(lambda l: label_mapping[l])
dataset_validation = Dataset.from_pandas(df_validation, preserve_index=False)


{'label': 0,
 'text': 'Grounded almost was pretty cool even despite the top tier unfunny writing until we became yet another annoying crafting game. I seriously can’t wait on this shitty trend to die'}

In [11]:
import transformers

def tokenize_and_numericalize_example(example, tokenizer):
    ids = tokenizer(example["text"], truncation=True)["input_ids"]
    return {"ids": ids}

tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-uncased")

dataset_training = dataset_training.map(tokenize_and_numericalize_example, fn_kwargs={"tokenizer": tokenizer})
dataset_validation = dataset_validation.map(tokenize_and_numericalize_example, fn_kwargs={"tokenizer": tokenizer})

dataset_training = dataset_training.with_format(type="torch", columns=["ids", "label"])
dataset_validation = dataset_validation.with_format(type="torch", columns=["ids", "label"])

dataset_dict = DatasetDict({
    "train": dataset_training,
    "test": dataset_validation
})

dataset_dict

Map:   0%|          | 0/61121 [00:00<?, ? examples/s]

Map:   0%|          | 0/828 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['label', 'text', 'ids'],
        num_rows: 61121
    })
    test: Dataset({
        features: ['label', 'text', 'ids'],
        num_rows: 828
    })
})

In [12]:
def get_collate_fn(pad_index):
    def collate_fn(batch):
        batch_ids = [i["ids"] for i in batch]
        batch_ids = nn.utils.rnn.pad_sequence(
            batch_ids, padding_value=pad_index, batch_first=True
        )
        batch_label = [i["label"] for i in batch]
        batch_label = torch.stack(batch_label)
        batch = {"ids": batch_ids, "label": batch_label}
        return batch

    return collate_fn

def get_dataloader(ds, batch_size, pad_index, shuffle=False):
    collate_fn = get_collate_fn(pad_index)
    dataloader = torch.utils.data.DataLoader(
        dataset=ds,
        batch_size=batch_size,
        collate_fn=collate_fn,
        shuffle=shuffle,
    )
    
    return dataloader

pad = tokenizer.pad_token_id

trdl = get_dataloader(dataset_dict["train"], 8, pad, shuffle=True)
tsdl = get_dataloader(dataset_dict["test"], 8, pad, shuffle=True)

In [None]:
import matplotlib.pyplot as plt
from collections import Counter


# construct histogram
all_labels = dataset_dict["train"]["label"]
all_label_counts = Counter(all_labels)

# visualise histogram
bar = plt.bar(all_label_counts.keys(), all_label_counts.values())
_ = plt.bar_label(bar)

# plot formatting
_ = plt.xticks([label for label in all_label_counts.keys()])

In [13]:
import torch.nn as nn
#import torch.nn.functional as F

class Transformer(nn.Module):
    def __init__(self, transformer, num_classes, freeze):
        super().__init__()
        self.transformer = transformer
        hidden_dim = transformer.config.hidden_size
        self.fc = nn.Linear(hidden_dim, num_classes)
        if freeze:
            for param in self.transformer.parameters():
                param.requires_grad = False

    def forward(self, ids):
        # ids = [batch size, seq len]
        output = self.transformer(ids, output_attentions=True)
        hidden = output.last_hidden_state
        # hidden = [batch size, seq len, hidden dim]
        attention = output.attentions[-1]
        # attention = [batch size, n heads, seq len, seq len]
        cls_hidden = hidden[:, 0, :]
        prediction = self.fc(torch.tanh(cls_hidden))
        # prediction = [batch size, output dim]
        return prediction
    
# class CNN(nn.Module):
#     def __init__(
#             self,
#             num_classes: int,
#             vocab_size: int,
#             embed_dim: int,
#             num_filters,
#             filter_sizes,
#             dropout
#     ) -> None:
#         super(CNN, self).__init__()
# 
#         self.embedding = nn.Embedding(vocab_size, embed_dim)
# 
#         self.convs = nn.ModuleList([
#             #nn.Conv2d(1, num_filters, (fs, embed_dim)) for fs in filter_sizes
#             nn.Conv1d(embed_dim, num_filters, fs) for fs in filter_sizes
#         ])
# 
#         self.fc = nn.Linear(num_filters * len(filter_sizes), num_classes)
#         self.dropout = nn.Dropout(dropout)
# 
#     def forward(self, ids):
#         # ids = [batch size, seq len]
#         embedded = self.dropout(self.embedding(ids))
#         # embedded = [batch size, seq len, embedding dim]
#         embedded = embedded.permute(0, 2, 1)
#         # embedded = [batch size, embedding dim, seq len]
#         conved = [torch.relu(conv(embedded)) for conv in self.convs]
#         # conved_n = [batch size, n filters, seq len - filter_sizes[n] + 1]
#         pooled = [conv.max(dim=-1).values for conv in conved]
#         # pooled_n = [batch size, n filters]
#         cat = self.dropout(torch.cat(pooled, dim=-1))
#         # cat = [batch size, n filters * len(filter_sizes)]
#         prediction = self.fc(cat)
#         # prediction = [batch size, output dim]
#         return prediction
#     # def forward(self, x: torch.Tensor) -> torch.Tensor:
#     #     x = self.embedding(x)
#     # 
#     #     return x

In [20]:
tf = transformers.AutoModel.from_pretrained("bert-base-uncased")

model = Transformer(tf, num_classes=3, freeze=False)
num_parameters = sum(value.numel() for value in model.state_dict().values())
print(f"{num_parameters = }")

criterion = torch.nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = model.to(device)
criterion = criterion.to(device)

print(model)

num_parameters = 109484547
Transformer(
  (transformer): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNor

In [18]:
import numpy as np
import tqdm

def get_accuracy(prediction, label):
    batch_size, _ = prediction.shape
    predicted_classes = prediction.argmax(dim=-1)
    correct_predictions = predicted_classes.eq(label).sum()
    accuracy = correct_predictions / batch_size
    return accuracy

def train(net, dataloader, optimizer):
    net.train()
    epoch_losses = []
    epoch_accs = []
    
    for batch in tqdm.tqdm(dataloader, desc="training..."):
        ids = batch["ids"].to(device)
        label = batch["label"].to(device)
        prediction = net(ids)
        loss = criterion(prediction, label)
        accuracy = get_accuracy(prediction, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_losses.append(loss.item())
        epoch_accs.append(accuracy.item())
        
    return np.mean(epoch_losses), np.mean(epoch_accs)

def test(net, dataloader):
    net.eval()
    epoch_losses = []
    epoch_accs = []
    with torch.no_grad():
        for batch in tqdm.tqdm(dataloader, desc="evaluating..."):
            ids = batch["ids"].to(device)
            label = batch["label"].to(device)
            prediction = net(ids)
            loss = criterion(prediction, label)
            accuracy = get_accuracy(prediction, label)
            epoch_losses.append(loss.item())
            epoch_accs.append(accuracy.item())
            
    return np.mean(epoch_losses), np.mean(epoch_accs)

def run_centralised(
    trainloader, testloader, epochs: int, lr: float
):
    # define optimiser with hyperparameters supplied
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for e in range(epochs):
        print(f"Training epoch {e} ...")
        train(model, trainloader, optimizer)

    # training is completed, then evaluate model on the test set
    loss, accuracy = test(model, testloader)
    print(f"{loss = }")
    print(f"{accuracy = }")

In [19]:
# Run the centralised training
run_centralised(trdl, tsdl, epochs=5, lr=1e-5)

Training epoch 0 ...


training...:   0%|          | 0/7641 [00:00<?, ?it/s]We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.
training...:   0%|          | 22/7641 [04:57<28:38:43, 13.54s/it]


KeyboardInterrupt: 