In [2]:
import pickle
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import numpy as np
from sklearn.metrics import accuracy_score, f1_score
import matplotlib.pyplot as plt

## Co-Attention model

In [3]:
class CoAttentionModel(nn.Module):
    def __init__(self, embed_dim, num_labels):
        super(CoAttentionModel, self).__init__()
        self.embed_dim = embed_dim
        self.num_labels = num_labels

        self.W_b = nn.Parameter(torch.Tensor(embed_dim, embed_dim))
        nn.init.xavier_uniform_(self.W_b)

        self.transform1 = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
        self.transform2 = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(),
            nn.Dropout(0.1)
        )

        self.classifier = nn.Sequential(
            nn.Linear(embed_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_labels)
        )

        self.batch_norm = nn.BatchNorm1d(embed_dim)

    def forward(self, x1, x2):

        x1_transformed = self.transform1(x1)
        x2_transformed = self.transform2(x2)

        affinity = torch.matmul(x1_transformed, self.W_b)
        affinity = torch.matmul(affinity, x2_transformed.transpose(1, 2))

        attention_weights1 = F.softmax(affinity, dim=2)
        attention_weights2 = F.softmax(affinity.transpose(1, 2), dim=2)

        attended_features1 = torch.matmul(attention_weights1, x2_transformed)
        attended_features2 = torch.matmul(attention_weights2, x1_transformed)

        attended_features = attended_features1 + attended_features2
        attended_features = self.batch_norm(attended_features.permute(0, 2, 1)).permute(0, 2, 1)
        attended_features = attended_features.mean(dim=1)

        logits = self.classifier(attended_features)

        return logits

In [4]:
co_attention_model = CoAttentionModel(embed_dim=768, num_labels=21)

In [5]:
co_attention_model = torch.load('/content/drive/MyDrive/NLP/Project/Novelty_models/Co Attention Model New/CoAttentionModel.pt')

## Cross Attention Model

In [6]:
class CrossAttentionModel(nn.Module):

    def __init__(self, embed_dim, num_labels):
        super(CrossAttentionModel, self).__init__()
        self.embed_dim = embed_dim
        self.num_labels = num_labels

        self.attention1 = nn.MultiheadAttention(embed_dim, num_heads=8, batch_first=True, dropout=0.1)
        self.attention2 = nn.MultiheadAttention(embed_dim, num_heads=8, batch_first=True, dropout=0.1)

        self.transform1 = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(),
            nn.Dropout(0.1)
        )
        self.transform2 = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(),
            nn.Dropout(0.1)
        )

        self.classifier = nn.Sequential(
            nn.Linear(embed_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_labels)
        )

        self.batch_norm = nn.BatchNorm1d(embed_dim)

    def forward(self, x1, x2):

        x1_transformed = self.transform1(x1)
        x2_transformed = self.transform2(x2)

        attended_features1, _ = self.attention1(x1_transformed, x2_transformed, x2_transformed)
        attended_features2, _ = self.attention2(x2_transformed, x1_transformed, x1_transformed)

        attended_features = attended_features1 + attended_features2
        attended_features = self.batch_norm(attended_features.permute(0, 2, 1)).permute(0, 2, 1)
        attended_features = attended_features.mean(dim=1)

        logits = self.classifier(attended_features)

        return logits

In [7]:
cross_attention_model = CrossAttentionModel(embed_dim=768, num_labels=21)

In [8]:
cross_attention_model = torch.load('/content/drive/MyDrive/NLP/Project/Novelty_models/Cross Attention Model New/CrossAttentionModel.pt')

## Testing

In [9]:
langs = ['en', 'es', 'de', 'bg', 'hu', 'lv']

In [13]:
class MyDataset(Dataset):
    def __init__(self, adapter_embeds, nmt_embeds, labels):
        self.adapter_embeds = adapter_embeds
        self.nmt_embeds = nmt_embeds
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.adapter_embeds[idx], self.nmt_embeds[idx], self.labels[idx]

In [16]:
adapter_embeddings = {}
nmt_embeddings = {}
labels = {}

for i in langs:
    try:
        adapter_embedding = pickle.load(open(f'/content/drive/MyDrive/NLP/Project/New_test_embeddings/adapter_embeddings_{i}.pkl', 'rb'))
        nmt_embedding = pickle.load(open(f'/content/drive/MyDrive/NLP/Project/New_test_embeddings/nmt_embeddings_{i}.pkl', 'rb'))
        label = pickle.load(open(f'/content/drive/MyDrive/NLP/Project/New_test_embeddings/cpu_labels_{i}.pkl', 'rb'))

        adapter_embeddings[i] = torch.cat(adapter_embedding, dim = 0)
        nmt_embeddings[i] = torch.cat(nmt_embedding, dim = 0)
        labels[i] = torch.cat(label, dim = 0)
    except Exception as e:
        print(e)
        print("Language:", i)

In [17]:
for i in langs:
    print(f"{i}: {len(adapter_embeddings[i])}, {len(nmt_embeddings[i])}, {len(labels[i])}")

en: 500, 500, 500
es: 500, 500, 500
de: 500, 500, 500
bg: 500, 500, 500
hu: 500, 500, 500
lv: 478, 478, 478


In [64]:
test_datasets = {}
for i in langs:
    test_datasets[i] = MyDataset(adapter_embeddings[i], nmt_embeddings[i], labels[i])

In [65]:
test_dataloaders = {}
for i in langs:
    test_dataloaders[i] = DataLoader(test_datasets[i], batch_size=32, shuffle=True)

In [84]:
def validate(model, data_loader):
    preds = []
    true_labels = []
    macro_f1s = []
    weighted_f1s = []
    accs = []

    model.eval()
    for adapter_embeds, nmt_embeds, labels in tqdm(data_loader):
        with torch.no_grad():
            logits = model(adapter_embeds, nmt_embeds)
            outputs = torch.sigmoid(logits)
            for i in outputs:
                x = [1 if j > np.average(i) else 0 for j in i]
                preds.extend(x)

            true_labels.extend(labels.detach().numpy())

    # preds = np.array(preds) > 0.5
    true_labels = np.array(true_labels).flatten()
    macro_f1 = f1_score(true_labels, preds, average='macro')
    weighted_f1 = f1_score(true_labels, preds, average='weighted')
    accuracy = accuracy_score(true_labels, preds)
    # macro_f1s.append(macro_f1)
    # weighted_f1s.append(weighted_f1)
    # accs.append(accuracy)

    return macro_f1, weighted_f1, accuracy

In [76]:
co_attention_macro_f1s = {}
co_attention_weighted_f1s = {}
co_attention_accuracies = {}

for i in langs:
    co_attention_macro_f1s[i], co_attention_weighted_f1s[i], co_attention_accuracies[i] = validate(co_attention_model, test_dataloaders[i])

100%|██████████| 16/16 [00:09<00:00,  1.64it/s]
100%|██████████| 16/16 [00:10<00:00,  1.49it/s]
100%|██████████| 16/16 [00:08<00:00,  1.91it/s]
100%|██████████| 16/16 [00:07<00:00,  2.13it/s]
100%|██████████| 16/16 [00:05<00:00,  2.69it/s]
100%|██████████| 15/15 [00:07<00:00,  2.10it/s]


In [85]:
cross_attention_macro_f1s = {}
cross_attention_weighted_f1s = {}
cross_attention_accuracies = {}

for i in langs:
    cross_attention_macro_f1s[i], cross_attention_weighted_f1s[i], cross_attention_accuracies[i] = validate(cross_attention_model, test_dataloaders[i])

100%|██████████| 16/16 [00:22<00:00,  1.43s/it]
100%|██████████| 16/16 [00:19<00:00,  1.22s/it]
100%|██████████| 16/16 [00:18<00:00,  1.17s/it]
100%|██████████| 16/16 [00:18<00:00,  1.16s/it]
100%|██████████| 16/16 [00:19<00:00,  1.22s/it]
100%|██████████| 15/15 [00:21<00:00,  1.41s/it]


In [78]:
pickle.dump(co_attention_macro_f1s, open(f'/content/drive/MyDrive/NLP/Project/Novelty_models/Co Attention Model New/macro_f1s_above_median.pkl', 'wb'))
pickle.dump(co_attention_weighted_f1s, open(f'/content/drive/MyDrive/NLP/Project/Novelty_models/Co Attention Model New/weighted_f1s_above_median.pkl', 'wb'))
pickle.dump(co_attention_accuracies, open(f'/content/drive/MyDrive/NLP/Project/Novelty_models/Co Attention Model New/accuracies_f1s_above_median.pkl', 'wb'))

In [86]:
pickle.dump(cross_attention_macro_f1s, open(f'/content/drive/MyDrive/NLP/Project/Novelty_models/Cross Attention Model New/macro_f1s_above_avg.pkl', 'wb'))
pickle.dump(cross_attention_weighted_f1s, open(f'/content/drive/MyDrive/NLP/Project/Novelty_models/Cross Attention Model New/weighted_f1s_above_avg.pkl', 'wb'))
pickle.dump(cross_attention_accuracies, open(f'/content/drive/MyDrive/NLP/Project/Novelty_models/Cross Attention Model New/accuracies_f1s_above_avg.pkl', 'wb'))

## Results in a df

In [82]:
co_attention_macro_f1s = pickle.load(open('/content/drive/MyDrive/NLP/Project/Novelty_models/Co Attention Model New/macro_f1s_above_avg.pkl', 'rb'))
co_attention_weighted_f1s = pickle.load(open('/content/drive/MyDrive/NLP/Project/Novelty_models/Co Attention Model New/weighted_f1s_above_avg.pkl', 'rb'))
co_attention_accuracies_f1s = pickle.load(open('/content/drive/MyDrive/NLP/Project/Novelty_models/Co Attention Model New/accuracies_f1s_above_avg.pkl', 'rb'))

cross_attention_macro_f1s = pickle.load(open('/content/drive/MyDrive/NLP/Project/Novelty_models/Cross Attention Model New/macro_f1s_above_avg.pkl', 'rb'))
cross_attention_weighted_f1s = pickle.load(open('/content/drive/MyDrive/NLP/Project/Novelty_models/Cross Attention Model New/weighted_f1s_above_avg.pkl', 'rb'))
cross_attention_accuracies_f1s = pickle.load(open('/content/drive/MyDrive/NLP/Project/Novelty_models/Cross Attention Model New/accuracies_f1s_above_avg.pkl', 'rb'))

In [90]:
import pandas as pd
print("CO-ATTENTION ABOVE AVG")
pd.DataFrame([co_attention_macro_f1s, co_attention_weighted_f1s, co_attention_accuracies_f1s], index = ['macro_F1', 'weighted_F1', 'Accuracy'])

CO-ATTENTION ABOVE AVG


Unnamed: 0,en,es,de,bg,hu,lv
macro_F1,0.805888,0.643996,0.567517,0.718923,0.590491,0.608287
weighted_F1,0.880923,0.76931,0.712685,0.815549,0.729709,0.743129
Accuracy,0.873524,0.746857,0.681143,0.796381,0.700857,0.716378


In [91]:
print("CROSS-ATTENTION ABOVE AVG")
pd.DataFrame([cross_attention_macro_f1s, cross_attention_weighted_f1s, cross_attention_accuracies_f1s], index = ['macro_F1', 'weighted_F1', 'Accuracy'])

CROSS-ATTENTION ABOVE AVG


Unnamed: 0,en,es,de,bg,hu,lv
macro_F1,0.808293,0.632071,0.575192,0.730065,0.577623,0.59881
weighted_F1,0.882374,0.758829,0.714869,0.824055,0.719688,0.735971
Accuracy,0.875048,0.733905,0.682286,0.806381,0.689048,0.70801
