In [None]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer, AutoModel, AutoTokenizer, AutoModelForMaskedLM
import pandas as pd
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.optim import Adam
import copy

In [None]:
#Creating the mapping from labels to id
id2label = {
    0: "MSA",
    1: "MGH",
    2: "EGY",
    3: "LEV",
    4: "IRQ",
    5: "GLF"
}
label2id = {
    "MSA":0,
    "MGH":1,
    "EGY":2,
    "LEV":3,
    "IRQ":4,
    "GLF":5
}


In [None]:
#Load the model
model = AutoModelForSequenceClassification.from_pretrained(
    'CAMeL-Lab/bert-base-arabic-camelbert-mix', num_labels=6, id2label=id2label, label2id=label2id
)
baseline_model = copy.deepcopy((model))
jtt_model = copy.deepcopy((model))
spare_partition_model = copy.deepcopy(model)
spare_model = copy.deepcopy(model)
tokenizer = AutoTokenizer.from_pretrained("CAMeL-Lab/bert-base-arabic-camelbert-mix")

In [None]:
#Read the data
df = pd.read_csv('./full_cleaned_data.tsv',sep='\t')
##Dataset: follow the paradigm of the typical pytorch dataset
grouped_df = df.groupby('split')
dfs = {name: group for name, group in grouped_df}
train_df = dfs['train'].sample(n=6400)
dev_df = dfs['dev'].sample(n=320)
test_df = dfs['test'].sample(n=320)

In [None]:
df.head()

In [None]:
from pprint import pprint
unique_countries = df['country'].unique()
country2id = {
    c:i for i,c in enumerate(list(unique_countries))
}
id2country = {
    i:c for i,c in enumerate(list(unique_countries))
}
pprint(country2id)
pprint(label2id)
print(test_df.iloc[0]['dialect'])

In [None]:
import torch
from spuco.datasets.base_spuco_compatible_dataset import BaseSpuCoCompatibleDataset
class ArabicDataset(BaseSpuCoCompatibleDataset):
    def __init__(self, dataframe, tokenizer, label2id, country2id):
        self.df = dataframe
        self.encodings = tokenizer(dataframe['text'].values.tolist(),truncation=True, padding=True)
        self.labels = dataframe['dialect'].apply(lambda x: label2id[x]).values.tolist()
        self.label2id = label2id
        self.country2id = country2id

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item, idx

    def __len__(self):
        return len(self.labels)
    
    @property
    def spurious(self):
        return [self.country2id[c] for c in list(self.df["country"])]
    @property
    def group_partition(self):
        partition_keys = {
            (self.label2id[label], self.country2id[country]):[] for label in self.label2id.keys() for country in self.country2id.keys()
        }
        for i in range(len(self.df)):
            label,country = self.label2id[self.df.iloc[i]['dialect']], self.country2id[self.df.iloc[i]["country"]]
            partition_keys[(label, country)].append(i)
        return partition_keys
    @property
    def group_weights(self):
        """
        Dictionary containing the fractional weights of each group
        """
        partition = self.group_partition
        total = len(self.labels)
        return {
            key: len(val)/total for key, val in partition.items()
        }
    
    def labels(self):
        return self.labels

    @property
    def num_classes(self):
        return len(self.label2id.keys())


In [None]:
trainset = ArabicDataset(train_df, tokenizer,label2id, country2id)
train_loader = DataLoader(trainset, batch_size = 128, shuffle = True)
devset = ArabicDataset(dev_df, tokenizer, label2id, country2id)
dev_loader = DataLoader(devset, batch_size = 128, shuffle = True)
testset = ArabicDataset(test_df, tokenizer, label2id, country2id)
test_loader = DataLoader(testset, batch_size = 128, shuffle = True)

In [None]:
def save_model(model,path):
    torch.save(model.state_dict(), path)

def train(model,epochs,train_loader,dev_loader,optimizer,lr_scheduler, device):
  model.to(device)
  loss_log = {}
  accuracy_log = []
  for epoch in range(epochs):
    epoch_loss = []
    model.train()
    print("Start Training:")
    with tqdm(train_loader, unit="batch") as tepoch:
      for batch,index in tepoch:
        optimizer.zero_grad()
        tepoch.set_description(f"Epoch {epoch}")
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        outputs = model(input_ids,attention_mask = attention_mask, labels = labels)
        loss = outputs[0]
        loss.backward()
        optimizer.step()
        if lr_scheduler:
          lr_scheduler.step()
        tepoch.set_postfix(loss=loss.item())
        epoch_loss.append(loss.item())
    loss_log[epoch] = epoch_loss


    model.eval()
    print("Evaluation:")
    num_right = 0
    num_items = 0
    with tqdm(dev_loader, unit="batch") as depoch:
      for batch,index in depoch:
        depoch.set_description(f"Epoch {epoch}")
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        with torch.no_grad():
          output = model(input_ids,attention_mask)
          logits = output.logits
          predictions = torch.argmax(logits, dim = -1)
          correct_num = (predictions == labels).sum()
        num_right += correct_num
        num_items += len(batch['labels'])
      accuracy = num_right / num_items
      print("accuracy= %.3f" %(accuracy))
      accuracy_log.append(accuracy)
  return loss_log, accuracy_log
    

def get_accuracy(model, loader, device):
  num_items = 0
  num_correct = 0
  model.eval()
  step = 0
  with torch.no_grad():
    with tqdm(loader, unit="batch") as tepoch:
        for batch,index in tepoch:
          tepoch.set_description(f"Evaluating {step}")
          input_ids = batch['input_ids'].to(device)
          attention_mask = batch['attention_mask'].to(device)
          labels = batch['labels'].to(device)
          outputs = model(input_ids,attention_mask = attention_mask)
          logits = outputs.logits
          predictions = torch.argmax(logits,dim = -1)
          right = (predictions == labels).sum()
          num = len(input_ids)
          num_items += num
          num_correct += right
          step += 1
    return (num_correct/num_items).item()
  
def get_accuracy_no_tqdm(model, loader, device):
  num_items = 0
  num_correct = 0
  model.eval()
  step = 0
  with torch.no_grad():
    for batch,index in loader:
      input_ids = batch['input_ids'].to(device)
      attention_mask = batch['attention_mask'].to(device)
      labels = batch['labels'].to(device)
      outputs = model(input_ids,attention_mask = attention_mask)
      logits = outputs.logits
      predictions = torch.argmax(logits,dim = -1)
      right = (predictions == labels).sum()
      num = len(input_ids)
      num_items += num
      num_correct += right
      step += 1
  return (num_correct/num_items).item()

In [None]:
#Train the baseline model. lr = 1e-5, adam, epoch = 4
epoch = 3
lr = 1e-5

scheduler = None
baseline_model = copy.deepcopy(model)
optim = Adam(baseline_model.parameters(), lr = lr)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
loss_log, accuracy_log = train(baseline_model, epoch , train_loader, dev_loader, optim, scheduler, device)


In [None]:
from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler
def get_group_accuracy(model, dataset, device, batch_size):
    model.eval()
    accuracies = {}
    testloaders = {}
    group_partition = dataset.group_partition
    su = 0
    for key in group_partition.keys():
            if len(group_partition[key]) == 0:
                 continue
            su += len(group_partition[key])
            sampler = SubsetRandomSampler(group_partition[key])
            testloaders[key] = DataLoader(dataset, batch_size=batch_size, sampler=sampler, shuffle=False)
    for key in tqdm(sorted(group_partition.keys()), "Evaluating group-wise accuracy", ):
        if len(group_partition[key]) == 0:
            continue
        accuracies[key] = get_accuracy_no_tqdm(model, testloaders[key],device)
        print(f"Group {key} Accuracy: {accuracies[key]}")
    return accuracies
    

In [None]:
accuracies = get_group_accuracy(baseline_model,testset,device,128)
#get_accuracy(baseline_model, list(loaders.values())[0], device)


In [None]:
print("Test accuracy of the baseline: {}".format(get_accuracy(baseline_model, test_loader, device)))

In [None]:
save_model(baseline_model, "baseline.pt")

In [None]:
def generate_upsample_indices(model, dataloader):
  model.eval()
  step = 0
  indices = []
  with torch.no_grad():
    with tqdm(dataloader, unit="batch") as tepoch:
        for batch,index in tepoch:
          tepoch.set_description(f"Evaluating {step}")
          input_ids = batch['input_ids'].to(device)
          attention_mask = batch['attention_mask'].to(device)
          labels = batch['labels'].to(device)
          outputs = model(input_ids,attention_mask = attention_mask)
          logits = outputs.logits
          predictions = torch.argmax(logits,dim = -1)
          masks = (predictions != labels).cpu()
          wrong_indices = index[masks]
          indices+=wrong_indices.tolist()
          step +=1
  return indices

In [None]:
import random
from typing import Iterator, List
import numpy as np
from torch.utils.data import Sampler
class CustomIndicesSampler(Sampler[int]):
    """
    Samples from the specified indices (pass indices - upsampled, downsampled, group balanced etc. to this class)
    Default is no shuffle.
    """
    def __init__(
        self,
        indices: List[int],
        shuffle: bool = False,
    ):
        """
        Samples elements from the specified indices.

        :param indices: The list of indices to sample from.
        :type indices: list[int]
        :param shuffle: Whether to shuffle the indices. Default is False.
        :type shuffle: bool, optional
        """
        self.indices = indices
        self.shuffle = shuffle

    def __iter__(self) -> Iterator[int]:
        """
        Returns an iterator over the sampled indices.

        :return: An iterator over the sampled indices.
        :rtype: iterator[int]
        """
        if self.shuffle:
            random.shuffle(self.indices)
        return iter(self.indices)

    def __len__(self) -> int:
        """
        Returns the number of sampled indices.

        :return: The number of sampled indices.
        :rtype: int
        """
        return len(self.indices)

In [None]:
def create_upsample_dataloader(old_dataset, batch_size, error_indices, E):
  indices = list(range(len(old_dataset))) + E * error_indices
  copy_old = copy.deepcopy(old_dataset)
  loader = DataLoader(copy_old,batch_size, sampler = CustomIndicesSampler(indices,True))
  return loader

In [None]:
error_indices = generate_upsample_indices(baseline_model,train_loader)
print("Number of train examples wrong after 3 epochs of the base line:{}".format(len(error_indices)))
upsampled_loader = create_upsample_dataloader(trainset, 128, error_indices, 3)

In [None]:
optim = Adam(jtt_model.parameters(), lr = 1e-5)
train(jtt_model, 3, upsampled_loader, dev_loader, optim, scheduler, device)

In [None]:
accuracies = get_group_accuracy(jtt_model,testset,device,128)
print("Test accuracy of the jtt : {}".format(get_accuracy(jtt_model, test_loader, device)))

In [None]:
save_model(jtt_model, "jtt.pt")

In [None]:
#Train for one episode for clustering in spare
optim = Adam(spare_model.parameters(), lr = 1e-5)
train(spare_model, 1, train_loader, dev_loader, optim, scheduler,device)

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
spare_dataset = copy.deepcopy(trainset)
#once we copy spare, we need to create a dataloader that loads it in order?
spare_loader = DataLoader(spare_dataset, batch_size = 128, shuffle = False)
spare_model.eval()
Z = None
Labels = []
Indices = []
for batch, index in tqdm(spare_loader):
  with torch.no_grad():
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    labels = batch['labels'].to(device)
    z = spare_model(input_ids,attention_mask = attention_mask, labels = labels).logits
    if Z is None:
      Z = z.detach().cpu()
    else:
      Z = torch.cat((Z,z.detach().cpu()),dim=0)
    Labels+= labels.detach().cpu().tolist()
    Indices+=index.detach().cpu().tolist()

In [None]:
save_model(spare_model, "spare_partition.pt")

In [None]:
from spuco.group_inference.spare_inference import SpareInference
from spuco.group_inference.cluster import ClusterAlg
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
inferer = SpareInference(Z= Z, class_labels = Labels, cluster_alg= ClusterAlg.KMEANS, max_clusters = 20, device = device, verbose = False)

In [None]:
groups = inferer.infer_groups()

In [None]:
from pprint import pprint
p = [0] * len(Indices)
factor = 1
for cj, indices in groups[0].items():
    V = len(indices)
    w = 1/V
    w_lambda = w ** factor
    summation = w_lambda * V
    for indice in indices:
        p[indice] = w / summation

In [None]:
from collections import Counter
print(Counter(p))

In [None]:
from torch.utils.data import WeightedRandomSampler
spare_sampler = WeightedRandomSampler(p,len(trainset),replacement=True)
spare_loader = DataLoader(trainset,128,False,spare_sampler)

""""original_model = AutoModelForSequenceClassification.from_pretrained(
    'CAMeL-Lab/bert-base-arabic-camelbert-mix', num_labels=6, id2label=id2label, label2id=label2id
)"""
optim = Adam(spare_partition_model.parameters(), lr = 1e-5)
train(spare_partition_model,3,spare_loader,dev_loader,optim, None,device)


In [None]:
print("Spare final accuracy:{}".format(get_accuracy(spare_partition_model, test_loader,device)))

save_model(spare_partition_model, "./spare_final.pt")








In [None]:
'''
inputs = tokenizer(test_df['text'].iloc[0], return_tensors="pt")
print(test_df['text'].iloc[0])
import torch
with torch.no_grad():
    logits = model(**inputs).logits
    print(torch.argmax(logits))
print(test_df['dialect'].iloc[0])
'''