In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
print(f"PyTorch version: {torch.__version__}")

# Check PyTorch has access to MPS (Metal Performance Shader, Apple's GPU architecture)
print(f"Is MPS (Metal Performance Shader) built? {torch.backends.mps.is_built()}")
print(f"Is MPS available? {torch.backends.mps.is_available()}")

# Set the device      
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

PyTorch version: 2.2.1
Is MPS (Metal Performance Shader) built? True
Is MPS available? True
Using device: mps


In [3]:
import numpy as np
import torch
import random
import copy
import pandas as pd


from arabert.preprocess import ArabertPreprocessor
from sklearn.metrics import (accuracy_score, classification_report,
                             confusion_matrix, f1_score, precision_score,
                             recall_score, roc_auc_score)
from torch.utils.data import DataLoader, Dataset
from transformers import (AutoConfig, AutoModel, AutoModelForSequenceClassification,
                          AutoTokenizer, BertTokenizer, Trainer,
                          TrainingArguments)
from transformers.data.processors.utils import InputFeatures

from utils import *

from arabert.preprocess import ArabertPreprocessor
import numpy as np

from transformers import AutoConfig, AutoModelForTokenClassification, AutoTokenizer
from transformers import Trainer , TrainingArguments
from transformers.trainer_utils import EvaluationStrategy
from transformers.data.processors.utils import InputFeatures
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from sklearn.utils import resample
import logging
import torch
from utils import *

In [4]:
train_data = load_json('../data/subtask1/split70.json')
dev_data = load_json('../data/subtask1/split10.json')

def normalize_data(items):
    normalized_data = []
    text_lists = []
    label_list = []
    for item in items:
        sentence_id = item['global_sentence_id']
        sentence = []
        labels = []
        for token_info in item['tokens']:
            word = token_info['token']
            tv = token_info['tags'][0]
            label = tv['value']
            sentence.append(word)
            labels.append(label)
        #sentence = ' '.join(sentence)
        #labels = ' '.join(labels)
        text_lists.append(sentence)
        label_list.append(labels)
    return text_lists, label_list


# Normalize the data
text_train, labels_train = normalize_data(train_data)
text_dev, labels_dev = normalize_data(dev_data)

data loaded from path ../data/subtask1/split70.json
data loaded from path ../data/subtask1/split10.json


In [5]:
all_labels = list(set([label for sublist in labels_train for label in sublist]))
label_map = { v:index for index, v in enumerate(all_labels)}
inv_label_map = {i: label for i, label in enumerate(all_labels)}

In [6]:
from data import NERDataset

In [7]:
model_name = 'aubmindlab/bert-base-arabertv02'
task_name = 'tokenclassification'

In [8]:
train_dataset = NERDataset(
    texts=text_train,
    tags=labels_train,
    label_list=all_labels,
    model_name=model_name,
    max_length=512
    )

dev_dataset = NERDataset(
    texts=text_dev,
    tags=labels_dev,
    label_list=all_labels,
    model_name=model_name,
    max_length=512
    )

In [9]:
datastore_path = './'
datastore_keys = torch.from_numpy(load_npy(datastore_path + 'datastore_keys.npy'))
datastore_values = torch.from_numpy(load_npy(datastore_path + 'datastore_values.npy'))

datastore_keys.shape, datastore_values.shape

(torch.Size([390900, 768]), torch.Size([390900]))

In [10]:
datastore_values.unique()

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41])

In [11]:
frequency = torch.bincount(datastore_values)

In [12]:
frequency

tensor([253952,     53,    309,    899,    141,   2922,   3586,    412,  10590,
          7999,   1365,   3717,   4107,    560,   4714,    997,    172,   1291,
            96,    139,     61,    747,  16337,      3,   1850,   8052,    310,
           350,  10705,   4519,    346,     92,      4,     43,    368,    250,
          2739,     15,      6,   5758,    974,  39350])

In [13]:
indices = [i for i in range(42)]

# Define the maximum number of samples for each index
max_samples = 500  # or any other number

# Create a mask for these indices
mask = torch.zeros(datastore_values.size(), dtype=torch.bool)
for index in indices:
    # Create a mask for the current index
    index_mask = (datastore_values == index)

    # Calculate the number of samples to remove from the current index
    reduction_amount = max(0, index_mask.sum() - max_samples)

    # If reduction_amount is more than the total number of samples for the current index, set it to the total number
    reduction_amount = min(reduction_amount, index_mask.sum())

    # Create a random mask for the current index
    random_mask = torch.rand(datastore_values.size()) > (1 - reduction_amount / index_mask.sum())

    # Combine the masks
    mask |= index_mask & random_mask

# Apply the mask to the tensor
new_datastore_values = datastore_values[~mask]
new_datastore_keys = datastore_keys[~mask]

In [14]:
new_datastore_values.shape, new_datastore_keys.shape

(torch.Size([14732]), torch.Size([14732, 768]))

In [15]:
new_datastore_values.unique()

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41])

In [16]:
frequency = torch.bincount(new_datastore_values)

In [17]:
frequency

tensor([486,  53, 309, 492, 141, 556, 513, 412, 525, 511, 488, 476, 499, 499,
        481, 502, 172, 475,  96, 139,  61, 488, 534,   3, 477, 548, 310, 350,
        498, 498, 346,  92,   4,  43, 368, 250, 502,  15,   6, 509, 500, 505])

In [18]:
embeddings = torch.from_numpy(load_npy(datastore_path + 'test_embeddings.npy'))
labels = torch.from_numpy(load_npy(datastore_path + 'test_labels.npy'))
logits = torch.from_numpy(load_npy(datastore_path + 'test_logits.npy'))

embeddings.shape, labels.shape, logits.shape

(torch.Size([57547, 768]), torch.Size([57547]), torch.Size([57547, 42]))

In [19]:
embeddings.unsqueeze(0).shape

torch.Size([1, 57547, 768])

In [20]:
embeddings = embeddings.unsqueeze(0)
labels = labels.unsqueeze(0)
logits = logits.unsqueeze(0)

embeddings.shape, labels.shape, logits.shape

(torch.Size([1, 57547, 768]),
 torch.Size([1, 57547]),
 torch.Size([1, 57547, 42]))

In [21]:
predicted_labels = torch.argmax(logits, axis=-1)
predicted_labels.shape

torch.Size([1, 57547])

In [22]:
print_classification_report(labels, predicted_labels)

              precision    recall  f1-score   support

           0       0.99      0.98      0.99     37983
           1       0.33      0.67      0.44         3
           2       0.81      0.88      0.84        33
           3       0.84      0.88      0.86       133
           4       1.00      1.00      1.00        16
           5       0.76      0.77      0.77       414
           6       0.80      0.87      0.83       508
           7       0.75      0.70      0.72        80
           8       0.95      0.96      0.96      1488
           9       0.97      0.94      0.95      1170
          10       0.93      0.96      0.95       204
          11       0.96      0.96      0.96       514
          12       0.78      0.78      0.78       589
          13       0.81      0.85      0.83        86
          14       0.99      0.98      0.98       595
          15       0.86      0.89      0.88       130
          16       0.90      0.86      0.88        22
          17       0.88    

In [23]:
def get_logits(logits, embeddings, datastore_keys, datastore_values, num_labels,K, lambda_, link_temperature=1.0):
    # cosine similarity
    knn_feats = datastore_keys.squeeze(0).transpose(0, 1) # [feature_size=768, datastore_size]
    embeddings = embeddings.view(-1, embeddings.shape[-1])  # [sentences, feature_size=768]
    sim = torch.mm(embeddings, knn_feats) # [sentences, datastore_size]

    sentences = embeddings.shape[0]
    datastore_size = knn_feats.shape[1]

    norm_1 = (knn_feats ** 2).sum(dim=0, keepdim=True).sqrt() # [1, datastore_size]
    norm_2 = (embeddings ** 2).sum(dim=1, keepdim=True).sqrt() # [sentences, 1]
    scores = (sim / (norm_1 + 1e-10) / (norm_2 + 1e-10)).view(1, sentences, -1) # [1, sentences, datastore_size]
    knn_labels = datastore_values.view(1, 1, datastore_size).expand(1, sentences, datastore_size) # [1, sentences, datastore_size]

    # select scores and labels of the top k only
    topk_scores, topk_idxs = torch.topk(scores, dim=-1, k=K)  # [1, sentences, topk]
    scores = topk_scores
    knn_labels = knn_labels.gather(dim=-1, index=topk_idxs)  # [[1, sentences, topk]

    # transform scores to softmax probabilities
    sim_probs = torch.softmax(scores / link_temperature, dim=-1) # [[1, sentences, topk]

    # 1. create zero tensor for probabilites as placeholder
    knn_probabilities = torch.zeros_like(sim_probs[:, :, 0]).unsqueeze(-1).repeat([1, 1, num_labels])  # [1, sentences, num_labels]
    # for each row (dim=2)
    # sum the probabilities from sim softmax probabilities (src=sim_probs) grouped by class (index=knn_labels)
    knn_probabilities = knn_probabilities.scatter_add(dim=2, index=knn_labels, src=sim_probs) # [1, sentences, num_labels]

    # interpolate between logits and knn_probabilites
    probabilities = lambda_*logits + (1-lambda_)*knn_probabilities

    # argmax to get most likely label
    argmax_labels = torch.argmax(probabilities, 2, keepdim=False)

    # return predicted labels
    return argmax_labels

In [25]:
tlogits, tembeddings, tlabels = logits[:,:500:], embeddings[:,:500,:], labels[:,:500]

tlogits.shape, tembeddings.shape

(torch.Size([1, 500, 42]), torch.Size([1, 500, 768]))

In [26]:
def get_f1_score(y_true, y_pred, average = 'weighted'):
  predictions = y_pred.squeeze(0).cpu().numpy()
  true_labels = y_true.squeeze(0).cpu().numpy()
  f1 = f1_score(true_labels, predictions, average=average)
  return f1  


In [60]:
f1s = {}

for i in [0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0]:
    f1s[i] = {}
    for j in [5,10,125,250,500]:
        f1s[i][j] = []

In [61]:
f1s

{0.0: {5: [], 10: [], 125: [], 250: [], 500: []},
 0.1: {5: [], 10: [], 125: [], 250: [], 500: []},
 0.2: {5: [], 10: [], 125: [], 250: [], 500: []},
 0.3: {5: [], 10: [], 125: [], 250: [], 500: []},
 0.4: {5: [], 10: [], 125: [], 250: [], 500: []},
 0.5: {5: [], 10: [], 125: [], 250: [], 500: []},
 0.6: {5: [], 10: [], 125: [], 250: [], 500: []},
 0.7: {5: [], 10: [], 125: [], 250: [], 500: []},
 0.8: {5: [], 10: [], 125: [], 250: [], 500: []},
 0.9: {5: [], 10: [], 125: [], 250: [], 500: []},
 1.0: {5: [], 10: [], 125: [], 250: [], 500: []}}

In [None]:
for i in range(500,57547,500):
    #print(i, 57547-i)
    j = i+500
    tlogits, tembeddings, tlabels = logits[:,i:j:], embeddings[:,i:j,:], labels[:,i:j]
    for lambda_ in [0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0]:
        for k in [5,10,125,250,500]:
            predicted_labels_protoype_average = get_logits(tlogits, tembeddings, new_datastore_keys, new_datastore_values, 42, k, lambda_)
            f1 = get_f1_score(tlabels, predicted_labels_protoype_average)
            f1s[lambda_][k].append(f1)
            print(f1)

In [69]:
max_f1 = 0
for i in [0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0]:
    for j in [5,10,125,250,500]:
        f1s_list = f1s[i][j]
        f1s_list_mean = sum(f1s_list) / len(f1s_list)
        max_f1 = max(max_f1, f1s_list_mean)
        print(i,j,f1s_list_mean)


0.0 5 0.9491600862673334
0.0 10 0.9508867567388144
0.0 125 0.9607593724390714
0.0 250 0.9635460499663281
0.0 500 0.9647693098627762
0.1 5 0.959416684771326
0.1 10 0.9615545317073475
0.1 125 0.9660052770217675
0.1 250 0.9674801326695979
0.1 500 0.9688637035498704
0.2 5 0.9668261773295824
0.2 10 0.9674801836318679
0.2 125 0.9690396063305886
0.2 250 0.9697204133280399
0.2 500 0.9705573229783027
0.3 5 0.97001252181818
0.3 10 0.9703821785087027
0.3 125 0.970742436402798
0.3 250 0.9709743161031822
0.3 500 0.9712512074763876
0.4 5 0.9716252037931844
0.4 10 0.9715447559095843
0.4 125 0.9715269120519034
0.4 250 0.971545217983545
0.4 500 0.9719267706027336
0.5 5 0.9721713396593812
0.5 10 0.9720976407729347
0.5 125 0.9719142234231674
0.5 250 0.9720860882415782
0.5 500 0.9722174131897752
0.6 5 0.972421874401129
0.6 10 0.9723379923510994
0.6 125 0.9723192273003461
0.6 250 0.9722895746743239
0.6 500 0.972482064397995
0.7 5 0.972438623724556
0.7 10 0.9724309581264741
0.7 125 0.9723199143673183
0.7 25

In [70]:
f1s_list_mean

0.9724189971691566