In [1]:
from data.Nystromformer.LRA.code import lra_config
from data.Nystromformer.LRA.code.dataset import LRADataset
#from Nystromformer.LRA.code.run_tasks import training_config
from torch.utils.data import DataLoader, RandomSampler
from torch import nn
import torch.nn.functional as F

piaynTaskDataDir = "data/"
piaynTaskModelDir = "data/"
task = "retrieval"

#get training config
training_config = lra_config.config[task]["training"]

#Check Train Config
print('Training Config: ', training_config)

#get pre-defined model config
model_config = lra_config.config[task]['model']

#Check model Config
print('Model Config: ', model_config)

#Get the dataset
train_dataset = LRADataset(piaynTaskDataDir + f"/{task}/{task}.train.pickle", True)
val_dataset = LRADataset(piaynTaskDataDir + f"/{task}/{task}.dev.pickle", False)
test_dataset = LRADataset(piaynTaskDataDir + f"/{task}/{task}.test.pickle", False)

#Create DataLoader iterators
ds_iter = {
    "train":enumerate(DataLoader(train_dataset, 
                                 batch_size = training_config["batch_size"], 
                                 drop_last = True)),
    "dev":enumerate(DataLoader(val_dataset, batch_size = 1, drop_last = True)),
    "test":enumerate(DataLoader(test_dataset, batch_size = 1, drop_last = True)),
}



Training Config:  {'batch_size': 32, 'learning_rate': 0.0001, 'warmup': 800, 'lr_decay': 'linear', 'weight_decay': 0, 'eval_frequency': 300, 'num_train_steps': 30000, 'num_eval_steps': 565}
Model Config:  {'learn_pos_emb': True, 'tied_weights': False, 'embedding_dim': 64, 'transformer_dim': 64, 'transformer_hidden_dim': 128, 'head_dim': 32, 'num_head': 2, 'num_layers': 2, 'vocab_size': 512, 'max_seq_len': 4000, 'dropout_prob': 0.1, 'attention_dropout': 0.1, 'pooling_mode': 'MEAN', 'num_classes': 2}
Loaded data//retrieval/retrieval.train.pickle... size=147086
Loaded data//retrieval/retrieval.dev.pickle... size=18090
Loaded data//retrieval/retrieval.test.pickle... size=17437


In [2]:
#Check sizes of batches
batch = next((ds_iter['train']))
for k,v in batch[1].items():
  print(k,v.shape)

input_ids_0 torch.Size([32, 4096])
mask_0 torch.Size([32, 4096])
input_ids_1 torch.Size([32, 4096])
mask_1 torch.Size([32, 4096])
label torch.Size([32])


In [3]:
from transformers import PerceiverForSequenceClassification, PerceiverForMaskedLM
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device('cpu')

In [4]:
from transformers import PerceiverConfig
#get default perceiver config
configuration = PerceiverConfig()

#Print Updated Perceiver Configuration
print(configuration)

PerceiverConfig {
  "attention_probs_dropout_prob": 0.1,
  "audio_samples_per_frame": 1920,
  "cross_attention_shape_for_attention": "kv",
  "cross_attention_widening_factor": 1,
  "d_latents": 1280,
  "d_model": 768,
  "hidden_act": "gelu",
  "image_size": 56,
  "initializer_range": 0.02,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 2048,
  "model_type": "perceiver",
  "num_blocks": 1,
  "num_cross_attention_heads": 8,
  "num_frames": 16,
  "num_latents": 256,
  "num_self_attends_per_block": 26,
  "num_self_attention_heads": 8,
  "output_shape": [
    1,
    16,
    224,
    224
  ],
  "qk_channels": null,
  "samples_per_patch": 16,
  "self_attention_widening_factor": 1,
  "train_size": [
    368,
    496
  ],
  "transformers_version": "4.16.2",
  "use_query_residual": true,
  "v_channels": null,
  "vocab_size": 262
}



In [5]:
configuration.num_labels = 2
configuration.num_self_attends_per_block = 5
configuration.d_latents = 512
configuration.d_model = 512
configuration.max_position_embeddings = 4097*2
configuration.num_cross_attention_heads = 4
configuration.num_self_attention_heads = 4
configuration.num_latents = 512
configuration.vocab_size = 512

In [6]:
def append_cls(inp, mask, vocab_size):
    batch_size = inp.size(0)
    cls_id = ((vocab_size - 1) * torch.ones(batch_size, dtype=inp.dtype, device=inp.device))#.long()
    cls_mask = torch.ones(batch_size, dtype=mask.dtype, device=mask.device)
    # inp = torch.cat([cls_id[:, None], inp[:, :-1]], dim=-1)
    # mask = torch.cat([cls_mask[:, None], mask[:, :-1]], dim=-1)
    inp = torch.cat([cls_id[:, None], inp], dim=-1)
    mask = torch.cat([cls_mask[:, None], mask], dim=-1)
    return inp, mask

In [7]:
class Retrieval(nn.Module):
    def __init__(self,config):
        super(Retrieval, self).__init__()
        self.config = config
        self.perceiver = PerceiverForSequenceClassification(config = config)

    def forward(self,input_ids_0, mask_0, input_ids_1, mask_1):
        inputs = torch.cat((input_ids_0, input_ids_1), axis = -1)
        mask = torch.cat((mask_0, mask_1), axis = -1)
        #print(input_ids_0.shape, mask_0.shape)
        o1 = self.perceiver(inputs = inputs, attention_mask = mask)
        return o1.logits

In [8]:
retriever = Retrieval(configuration)
device

device(type='cuda')

In [9]:
pytorch_total_params = sum(p.numel() for p in retriever.parameters())
pytorch_total_params_Trainable = sum(p.numel() for p in retriever.parameters() if p.requires_grad)
print('Total Parameters: ', pytorch_total_params, '\nTrainable Parameters: ', pytorch_total_params_Trainable)  

Total Parameters:  15769090 
Trainable Parameters:  15769090


In [10]:
input_ids_0 = batch[1]["input_ids_0"].to(device)
mask_0 = batch[1]["mask_0"].to(device)
input_ids_1 = batch[1]["input_ids_0"].to(device)
mask_1 = batch[1]["mask_0"].to(device)
labels = batch[1]["label"].to(device)
retriever = retriever.to(device)

In [11]:
outputs = retriever(input_ids_0,mask_0,input_ids_1,mask_1)

In [15]:
 next(ds_iter['train'])

(738,
 {'input_ids_0': tensor([[99, 40, 50,  ...,  0,  0,  0],
          [99, 40, 50,  ...,  0,  0,  0],
          [99, 40, 50,  ...,  0,  0,  0],
          ...,
          [99, 40, 50,  ...,  0,  0,  0],
          [99, 40, 50,  ...,  0,  0,  0],
          [99, 40, 50,  ...,  0,  0,  0]]),
  'mask_0': tensor([[1., 1., 1.,  ..., 0., 0., 0.],
          [1., 1., 1.,  ..., 0., 0., 0.],
          [1., 1., 1.,  ..., 0., 0., 0.],
          ...,
          [1., 1., 1.,  ..., 0., 0., 0.],
          [1., 1., 1.,  ..., 0., 0., 0.],
          [1., 1., 1.,  ..., 0., 0., 0.]]),
  'input_ids_1': tensor([[99, 40, 50,  ...,  0,  0,  0],
          [99, 40, 50,  ...,  0,  0,  0],
          [99, 35, 50,  ...,  0,  0,  0],
          ...,
          [99, 40, 50,  ...,  0,  0,  0],
          [99, 40, 50,  ...,  0,  0,  0],
          [99, 40, 50,  ...,  0,  0,  0]]),
  'mask_1': tensor([[1., 1., 1.,  ..., 0., 0., 0.],
          [1., 1., 1.,  ..., 0., 0., 0.],
          [1., 1., 1.,  ..., 0., 0., 0.],
          .

In [None]:
from torch.optim import AdamW
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score
from datasets import load_metric
import pandas as pd
from torch.nn import CrossEntropyLoss

best_score = 0 
prev_score = 0
# maxPatience = 3
# currentPatience = 0

#steps = int(training_config["num_train_steps"]/20000)
steps = 5000
loss_fn = CrossEntropyLoss()

optimizer = AdamW(retriever.parameters(), 
                  lr = 0.5,
                  betas = (0.9, 0.999), 
                  eps = 1e-6, 
                  weight_decay = training_config["weight_decay"])

lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer = optimizer,
    max_lr = training_config["learning_rate"],
    #max_lr = 0.5,
    pct_start = training_config["warmup"] / training_config["num_train_steps"],
    #pct_start = training_config["warmup"] / 5000,
    anneal_strategy = training_config["lr_decay"],
    total_steps = training_config["num_train_steps"],
    #verbose=True
)

#amp_scaler = torch.cuda.amp.GradScaler() if model_config["mixed_precision"] else None

#initialize training summary
trainingSummary = pd.DataFrame(columns=['step', 'train_acc', 'val_acc'])


retriever.to(device)

#initialize training accuracy metric
train_accuracy = load_metric("accuracy")
#batch = next(ds_iter['train'])[1]

for step in tqdm(range(30000)):  # Perform gradient updates for multiple steps
    
    #model.train()
    retriever.train()
    
    #print("Step:", step)
    #for batch in tqdm(train_dataloader):
    batch = next(ds_iter['train'])[1]

    # get the inputs; 
    input_ids_0 = batch["input_ids_0"].to(device)
    mask_0 = batch["mask_0"].to(device)
    input_ids_1 = batch["input_ids_0"].to(device)
    mask_1 = batch["mask_0"].to(device)
    labels = batch["label"].to(device)

    # zero the parameter gradients
    optimizer.zero_grad()

    # forward + backward + optimize
    outputs = retriever(input_ids_0,mask_0,input_ids_1,mask_1)
    loss = loss_fn(outputs,labels)
    loss.backward()
    optimizer.step()
    lr_scheduler.step()

    # evaluate
    predictions = outputs.argmax(-1).cpu().detach().numpy()
    accuracy = accuracy_score(y_true=batch["label"].numpy(), y_pred=predictions)
    references = batch["label"].numpy()
    train_accuracy.add_batch(predictions=predictions, references=references)
    
    if (step+1)%50  == 0:
        print(f"Loss: {loss.item()}, Accuracy: {accuracy}")

    #delete intermediate variables to free up GPU space
    del loss, outputs, input_ids_0, mask_0, labels, predictions, accuracy,input_ids_1,mask_1


    #Every 1000 steps validate and save model
    if (step+1)%training_config['eval_frequency']  == 0:
    #if (step+1)%2  == 0:
        #model.eval()
        retriever.eval()
        val_accuracy = load_metric("accuracy")

      #reset dev iterator
        ds_iter['dev'] = enumerate(DataLoader(val_dataset, batch_size = 32, drop_last = True))

        with torch.no_grad():
            for i, batch in tqdm(ds_iter['dev']):
                input_ids_0 = batch["input_ids_0"].to(device)
                mask_0 = batch["mask_0"].to(device)
                input_ids_1 = batch["input_ids_0"].to(device)
                mask_1 = batch["mask_0"].to(device)
                labels = batch["label"].to(device)

              # forward pass
                logits = retriever(input_ids_0,mask_0,input_ids_1,mask_1)
                predictions = logits.argmax(-1).cpu().detach().numpy()
                references = batch["label"].numpy()
                val_accuracy.add_batch(predictions=predictions, references=references)

          #delete intermediate variables to free up GPU space
                del logits, input_ids_0, mask_0, input_ids_1, mask_1, labels, predictions, references
      
      #Compute val accuracy
        final_val_score = val_accuracy.compute()['accuracy']
        print("Validation Accuracy:", final_val_score)

        if final_val_score >= best_score:
            best_score = final_val_score
            torch.save(retriever.to('cpu').state_dict(), piaynTaskModelDir + '/trainedPerceiverClassifierToken'+'.pkl')
            retriever.to(device)
        else:
            pass  

#         if final_val_score <= prev_score:
#             currentPatience += 1
#             if currentPatience >= maxPatience:
#                 print('Patience Limit reached! Stopping early!')
#                 torch.save(retriever.to('cpu').state_dict(), piaynTaskModelDir + '/trainedPerceiverClassifierStep_' + str(step + 1) + 'Token' + '.pkl')
#                 break  
#         else:
#             currentPatience = 0
      
      #Update prev_score
        prev_score = final_val_score

      #Compute training accuracy till now
        final_train_score = train_accuracy.compute()['accuracy']

      #Add to trainingSummary
        trainingSummary.loc[len(trainingSummary.index)] = [step+1, final_train_score, final_val_score]

  setattr(self, word, getattr(machar, word).flat[0])
  return self._float_to_str(self.smallest_subnormal)
  setattr(self, word, getattr(machar, word).flat[0])
  return self._float_to_str(self.smallest_subnormal)


  0%|          | 0/30000 [00:00<?, ?it/s]

Loss: 0.6894859075546265, Accuracy: 0.5625
Loss: 0.6942472457885742, Accuracy: 0.5
Loss: 0.6829642057418823, Accuracy: 0.59375
Loss: 0.6902452707290649, Accuracy: 0.5625
Loss: 0.7284085154533386, Accuracy: 0.3125
Loss: 0.6942605376243591, Accuracy: 0.5


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6912367343902588, Accuracy: 0.53125
Loss: 0.6914761066436768, Accuracy: 0.53125
Loss: 0.6933212280273438, Accuracy: 0.5
Loss: 0.6884312629699707, Accuracy: 0.65625
Loss: 0.6809334754943848, Accuracy: 0.59375
Loss: 0.6923745274543762, Accuracy: 0.53125


0it [00:00, ?it/s]

Validation Accuracy: 0.49551991150442476
Loss: 0.6931503415107727, Accuracy: 0.5
Loss: 0.709743857383728, Accuracy: 0.4375
Loss: 0.69174724817276, Accuracy: 0.53125
Loss: 0.6836007833480835, Accuracy: 0.65625
Loss: 0.6919981241226196, Accuracy: 0.53125
Loss: 0.692919135093689, Accuracy: 0.53125


0it [00:00, ?it/s]

Validation Accuracy: 0.504424778761062
Loss: 0.7228075265884399, Accuracy: 0.46875
Loss: 0.7171154618263245, Accuracy: 0.46875
Loss: 0.6969727873802185, Accuracy: 0.4375
Loss: 0.6959071159362793, Accuracy: 0.46875
Loss: 0.7001201510429382, Accuracy: 0.40625
Loss: 0.695033848285675, Accuracy: 0.46875


0it [00:00, ?it/s]

Validation Accuracy: 0.49551991150442476
Loss: 0.6882789134979248, Accuracy: 0.5625
Loss: 0.6954174637794495, Accuracy: 0.4375
Loss: 0.6823509931564331, Accuracy: 0.625
Loss: 0.6915706396102905, Accuracy: 0.65625
Loss: 0.6902466416358948, Accuracy: 0.59375
Loss: 0.7049723863601685, Accuracy: 0.40625


0it [00:00, ?it/s]

Validation Accuracy: 0.49551991150442476
Loss: 0.7015002965927124, Accuracy: 0.375
Loss: 0.6903893351554871, Accuracy: 0.59375
Loss: 0.6988027691841125, Accuracy: 0.5
Loss: 0.7059274315834045, Accuracy: 0.46875
Loss: 0.7049849033355713, Accuracy: 0.40625
Loss: 0.6881359815597534, Accuracy: 0.5625


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6863153576850891, Accuracy: 0.5625
Loss: 0.693158745765686, Accuracy: 0.5
Loss: 0.6979196071624756, Accuracy: 0.46875
Loss: 0.693934440612793, Accuracy: 0.46875
Loss: 0.6954904198646545, Accuracy: 0.46875
Loss: 0.6932809352874756, Accuracy: 0.5


0it [00:00, ?it/s]

Validation Accuracy: 0.49551991150442476
Loss: 0.6970162987709045, Accuracy: 0.5
Loss: 0.6942266821861267, Accuracy: 0.4375
Loss: 0.6778258681297302, Accuracy: 0.65625
Loss: 0.7173320651054382, Accuracy: 0.25
Loss: 0.6901180744171143, Accuracy: 0.625
Loss: 0.693069577217102, Accuracy: 0.53125


0it [00:00, ?it/s]

Validation Accuracy: 0.49551991150442476
Loss: 0.7008235454559326, Accuracy: 0.4375
Loss: 0.6933209896087646, Accuracy: 0.5
Loss: 0.6974654197692871, Accuracy: 0.40625
Loss: 0.6952540874481201, Accuracy: 0.46875
Loss: 0.6857843399047852, Accuracy: 0.5625
Loss: 0.6936303377151489, Accuracy: 0.5


0it [00:00, ?it/s]

Validation Accuracy: 0.49551991150442476
Loss: 0.6921275854110718, Accuracy: 0.53125
Loss: 0.6918521523475647, Accuracy: 0.53125
Loss: 0.6933379769325256, Accuracy: 0.46875
Loss: 0.6895410418510437, Accuracy: 0.59375
Loss: 0.6913453936576843, Accuracy: 0.5625
Loss: 0.6942559480667114, Accuracy: 0.5


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6935173273086548, Accuracy: 0.5
Loss: 0.6897783279418945, Accuracy: 0.6875
Loss: 0.691389799118042, Accuracy: 0.53125
Loss: 0.69449383020401, Accuracy: 0.5
Loss: 0.6953704953193665, Accuracy: 0.40625
Loss: 0.6916288137435913, Accuracy: 0.53125


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6935288310050964, Accuracy: 0.5
Loss: 0.706421434879303, Accuracy: 0.4375
Loss: 0.6861499547958374, Accuracy: 0.59375
Loss: 0.6957389116287231, Accuracy: 0.40625
Loss: 0.693561315536499, Accuracy: 0.5
Loss: 0.6952434778213501, Accuracy: 0.46875


0it [00:00, ?it/s]

Validation Accuracy: 0.49551991150442476
Loss: 0.6931106448173523, Accuracy: 0.5625
Loss: 0.6960961222648621, Accuracy: 0.46875
Loss: 0.6964335441589355, Accuracy: 0.375
Loss: 0.6858085989952087, Accuracy: 0.59375
Loss: 0.6895374059677124, Accuracy: 0.5625
Loss: 0.691241979598999, Accuracy: 0.5625


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6849003434181213, Accuracy: 0.65625
Loss: 0.6934799551963806, Accuracy: 0.5
Loss: 0.6989582180976868, Accuracy: 0.5
Loss: 0.6775001287460327, Accuracy: 0.71875
Loss: 0.6958803534507751, Accuracy: 0.4375
Loss: 0.6989479660987854, Accuracy: 0.40625


0it [00:00, ?it/s]

Validation Accuracy: 0.49551991150442476
Loss: 0.6949908137321472, Accuracy: 0.4375
Loss: 0.6931486129760742, Accuracy: 0.5
Loss: 0.6999830603599548, Accuracy: 0.40625
Loss: 0.6950840353965759, Accuracy: 0.40625
Loss: 0.6958957314491272, Accuracy: 0.40625
Loss: 0.6926968693733215, Accuracy: 0.53125


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.7033199071884155, Accuracy: 0.375
Loss: 0.6916803121566772, Accuracy: 0.5625
Loss: 0.6931514739990234, Accuracy: 0.5
Loss: 0.6924217343330383, Accuracy: 0.625
Loss: 0.6931670904159546, Accuracy: 0.5
Loss: 0.6923587918281555, Accuracy: 0.53125


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6916062235832214, Accuracy: 0.53125
Loss: 0.6981903910636902, Accuracy: 0.40625
Loss: 0.6914511919021606, Accuracy: 0.59375
Loss: 0.6931478381156921, Accuracy: 0.5
Loss: 0.6946742534637451, Accuracy: 0.4375
Loss: 0.6955200433731079, Accuracy: 0.4375


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.697178304195404, Accuracy: 0.375
Loss: 0.6971233487129211, Accuracy: 0.46875
Loss: 0.6947638392448425, Accuracy: 0.4375
Loss: 0.6803574562072754, Accuracy: 0.6875
Loss: 0.6921864748001099, Accuracy: 0.53125
Loss: 0.6879284977912903, Accuracy: 0.625


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6877775192260742, Accuracy: 0.71875
Loss: 0.7077734470367432, Accuracy: 0.4375
Loss: 0.6931887269020081, Accuracy: 0.5
Loss: 0.7021341919898987, Accuracy: 0.28125
Loss: 0.6884321570396423, Accuracy: 0.625
Loss: 0.6978330016136169, Accuracy: 0.40625


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6940386891365051, Accuracy: 0.46875
Loss: 0.6947165727615356, Accuracy: 0.46875
Loss: 0.6967348456382751, Accuracy: 0.34375
Loss: 0.690426230430603, Accuracy: 0.65625
Loss: 0.6925846934318542, Accuracy: 0.59375
Loss: 0.6898856163024902, Accuracy: 0.59375


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.678622841835022, Accuracy: 0.6875
Loss: 0.6908315420150757, Accuracy: 0.59375
Loss: 0.6962236762046814, Accuracy: 0.4375
Loss: 0.6939283609390259, Accuracy: 0.34375
Loss: 0.6881601214408875, Accuracy: 0.59375
Loss: 0.6936280727386475, Accuracy: 0.46875


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6939078569412231, Accuracy: 0.4375
Loss: 0.693242609500885, Accuracy: 0.5
Loss: 0.6883066892623901, Accuracy: 0.75
Loss: 0.6914730072021484, Accuracy: 0.53125
Loss: 0.6992104053497314, Accuracy: 0.28125
Loss: 0.6931787729263306, Accuracy: 0.5


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6909839510917664, Accuracy: 0.5625
Loss: 0.6973737478256226, Accuracy: 0.46875
Loss: 0.6687399744987488, Accuracy: 0.625
Loss: 0.6932569742202759, Accuracy: 0.5
Loss: 0.6933566927909851, Accuracy: 0.4375
Loss: 0.6925871968269348, Accuracy: 0.625


0it [00:00, ?it/s]

Validation Accuracy: 0.49551991150442476
Loss: 0.6930398941040039, Accuracy: 0.53125
Loss: 0.6932182312011719, Accuracy: 0.5
Loss: 0.689582109451294, Accuracy: 0.5625
Loss: 0.6889446377754211, Accuracy: 0.59375
Loss: 0.6976940631866455, Accuracy: 0.375
Loss: 0.6918294429779053, Accuracy: 0.5625


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.691098690032959, Accuracy: 0.59375
Loss: 0.6914541125297546, Accuracy: 0.5625
Loss: 0.6867204904556274, Accuracy: 0.65625
Loss: 0.6892856955528259, Accuracy: 0.59375
Loss: 0.694159984588623, Accuracy: 0.40625
Loss: 0.6981711983680725, Accuracy: 0.375


0it [00:00, ?it/s]

Validation Accuracy: 0.49551991150442476
Loss: 0.6948592662811279, Accuracy: 0.46875
Loss: 0.6911290884017944, Accuracy: 0.59375
Loss: 0.6939218044281006, Accuracy: 0.5
Loss: 0.6958126425743103, Accuracy: 0.4375
Loss: 0.6948192715644836, Accuracy: 0.4375
Loss: 0.6927306652069092, Accuracy: 0.5625


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6942331790924072, Accuracy: 0.4375
Loss: 0.6914708018302917, Accuracy: 0.59375
Loss: 0.692964494228363, Accuracy: 0.65625
Loss: 0.6926513314247131, Accuracy: 0.53125
Loss: 0.6931787133216858, Accuracy: 0.5
Loss: 0.696974515914917, Accuracy: 0.40625


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6932896375656128, Accuracy: 0.5
Loss: 0.6911075115203857, Accuracy: 0.5625
Loss: 0.6879859566688538, Accuracy: 0.71875
Loss: 0.6919229030609131, Accuracy: 0.53125
Loss: 0.6946833729743958, Accuracy: 0.4375
Loss: 0.6922845840454102, Accuracy: 0.53125


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6920808553695679, Accuracy: 0.5625
Loss: 0.6937074065208435, Accuracy: 0.46875
Loss: 0.6913012266159058, Accuracy: 0.5625
Loss: 0.6891766786575317, Accuracy: 0.6875
Loss: 0.6932665109634399, Accuracy: 0.5
Loss: 0.6911958456039429, Accuracy: 0.5625


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6897904872894287, Accuracy: 0.59375
Loss: 0.6904442310333252, Accuracy: 0.59375
Loss: 0.6910279393196106, Accuracy: 0.59375
Loss: 0.6964081525802612, Accuracy: 0.375
Loss: 0.6923608183860779, Accuracy: 0.53125
Loss: 0.6953025460243225, Accuracy: 0.4375


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6944995522499084, Accuracy: 0.4375
Loss: 0.693386435508728, Accuracy: 0.4375
Loss: 0.6936726570129395, Accuracy: 0.46875
Loss: 0.6952540874481201, Accuracy: 0.40625
Loss: 0.6920081973075867, Accuracy: 0.53125
Loss: 0.6922705769538879, Accuracy: 0.5625


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6926254630088806, Accuracy: 0.53125
Loss: 0.6939694285392761, Accuracy: 0.46875
Loss: 0.6965162754058838, Accuracy: 0.34375
Loss: 0.6945412755012512, Accuracy: 0.4375
Loss: 0.6921294927597046, Accuracy: 0.5625
Loss: 0.6923344731330872, Accuracy: 0.59375


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6870884299278259, Accuracy: 0.625
Loss: 0.6912479400634766, Accuracy: 0.5625
Loss: 0.6942313313484192, Accuracy: 0.40625
Loss: 0.6954378485679626, Accuracy: 0.34375
Loss: 0.6942157745361328, Accuracy: 0.4375
Loss: 0.6919301748275757, Accuracy: 0.5625


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6942014694213867, Accuracy: 0.46875
Loss: 0.6933043599128723, Accuracy: 0.4375
Loss: 0.695726752281189, Accuracy: 0.46875
Loss: 0.6953563690185547, Accuracy: 0.375
Loss: 0.6914878487586975, Accuracy: 0.5625
Loss: 0.6922663450241089, Accuracy: 0.5625


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6904219388961792, Accuracy: 0.6875
Loss: 0.692076563835144, Accuracy: 0.5625
Loss: 0.6927430033683777, Accuracy: 0.53125
Loss: 0.6931776404380798, Accuracy: 0.5
Loss: 0.694760262966156, Accuracy: 0.40625
Loss: 0.6956315040588379, Accuracy: 0.375


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6892072558403015, Accuracy: 0.65625
Loss: 0.6939964890480042, Accuracy: 0.46875
Loss: 0.6914312839508057, Accuracy: 0.59375
Loss: 0.692556619644165, Accuracy: 0.53125
Loss: 0.6910733580589294, Accuracy: 0.59375
Loss: 0.6924347281455994, Accuracy: 0.53125


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6941371560096741, Accuracy: 0.46875
Loss: 0.6959260702133179, Accuracy: 0.4375
Loss: 0.6880304217338562, Accuracy: 0.59375
Loss: 0.6914590001106262, Accuracy: 0.59375
Loss: 0.6932057738304138, Accuracy: 0.5
Loss: 0.6931828260421753, Accuracy: 0.5


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6945599317550659, Accuracy: 0.46875
Loss: 0.6941676139831543, Accuracy: 0.4375
Loss: 0.6965134143829346, Accuracy: 0.4375
Loss: 0.6931743621826172, Accuracy: 0.5
Loss: 0.6974876523017883, Accuracy: 0.46875
Loss: 0.694440484046936, Accuracy: 0.375


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6911265850067139, Accuracy: 0.59375
Loss: 0.6912210583686829, Accuracy: 0.625
Loss: 0.6943594217300415, Accuracy: 0.40625
Loss: 0.6924551725387573, Accuracy: 0.53125
Loss: 0.7009800672531128, Accuracy: 0.375
Loss: 0.6939242482185364, Accuracy: 0.46875


0it [00:00, ?it/s]

Validation Accuracy: 0.49551991150442476
Loss: 0.6902979016304016, Accuracy: 0.5625
Loss: 0.6917434334754944, Accuracy: 0.59375
Loss: 0.6927164196968079, Accuracy: 0.53125
Loss: 0.69317227602005, Accuracy: 0.5
Loss: 0.6934713125228882, Accuracy: 0.46875
Loss: 0.6919979453086853, Accuracy: 0.59375


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6920130252838135, Accuracy: 0.59375
Loss: 0.6919927000999451, Accuracy: 0.59375
Loss: 0.6924225687980652, Accuracy: 0.5625
Loss: 0.6931713819503784, Accuracy: 0.5
Loss: 0.6960423588752747, Accuracy: 0.4375
Loss: 0.6927924752235413, Accuracy: 0.53125


0it [00:00, ?it/s]

Validation Accuracy: 0.49551991150442476
Loss: 0.6909798383712769, Accuracy: 0.625
Loss: 0.6973752975463867, Accuracy: 0.46875
Loss: 0.6933743953704834, Accuracy: 0.5
Loss: 0.7322176098823547, Accuracy: 0.21875
Loss: 0.6925228834152222, Accuracy: 0.53125
Loss: 0.6944824457168579, Accuracy: 0.4375


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6905867457389832, Accuracy: 0.65625
Loss: 0.6920683979988098, Accuracy: 0.5625
Loss: 0.6951798796653748, Accuracy: 0.4375
Loss: 0.6924846172332764, Accuracy: 0.53125
Loss: 0.6922504305839539, Accuracy: 0.53125
Loss: 0.6935672163963318, Accuracy: 0.46875


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6932013034820557, Accuracy: 0.5
Loss: 0.6933108568191528, Accuracy: 0.5
Loss: 0.695416271686554, Accuracy: 0.375
Loss: 0.6908019781112671, Accuracy: 0.625
Loss: 0.6918061375617981, Accuracy: 0.5625
Loss: 0.6926124691963196, Accuracy: 0.53125


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6931886672973633, Accuracy: 0.5
Loss: 0.6936745047569275, Accuracy: 0.46875
Loss: 0.6923848986625671, Accuracy: 0.5625
Loss: 0.6928015947341919, Accuracy: 0.53125
Loss: 0.6946248412132263, Accuracy: 0.40625
Loss: 0.6931747794151306, Accuracy: 0.5


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6931852102279663, Accuracy: 0.5
Loss: 0.6931855082511902, Accuracy: 0.5
Loss: 0.6931942701339722, Accuracy: 0.5
Loss: 0.6885048151016235, Accuracy: 0.75
Loss: 0.6948021650314331, Accuracy: 0.4375
Loss: 0.6963645815849304, Accuracy: 0.375


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6958565711975098, Accuracy: 0.375
Loss: 0.6938721537590027, Accuracy: 0.46875
Loss: 0.696726381778717, Accuracy: 0.40625
Loss: 0.6934390068054199, Accuracy: 0.5
Loss: 0.6952453851699829, Accuracy: 0.40625
Loss: 0.6932610869407654, Accuracy: 0.5


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6970258951187134, Accuracy: 0.4375
Loss: 0.6846598982810974, Accuracy: 0.59375
Loss: 0.6913084983825684, Accuracy: 0.5625
Loss: 0.6940262317657471, Accuracy: 0.46875
Loss: 0.692359983921051, Accuracy: 0.53125
Loss: 0.6913326382637024, Accuracy: 0.5625


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6947126984596252, Accuracy: 0.4375
Loss: 0.6907718181610107, Accuracy: 0.59375
Loss: 0.6915462613105774, Accuracy: 0.5625
Loss: 0.6968991756439209, Accuracy: 0.34375
Loss: 0.6968409419059753, Accuracy: 0.34375
Loss: 0.696010172367096, Accuracy: 0.375


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6921600103378296, Accuracy: 0.5625
Loss: 0.6928368806838989, Accuracy: 0.5625
Loss: 0.6925917267799377, Accuracy: 0.53125
Loss: 0.6955295205116272, Accuracy: 0.40625
Loss: 0.6915761828422546, Accuracy: 0.59375
Loss: 0.6936566829681396, Accuracy: 0.4375


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6968151330947876, Accuracy: 0.34375
Loss: 0.6910475492477417, Accuracy: 0.5625
Loss: 0.6918935775756836, Accuracy: 0.5625
Loss: 0.693190336227417, Accuracy: 0.5
Loss: 0.6953741908073425, Accuracy: 0.375
Loss: 0.6942146420478821, Accuracy: 0.4375


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6932033896446228, Accuracy: 0.5
Loss: 0.691954493522644, Accuracy: 0.5625
Loss: 0.6925545334815979, Accuracy: 0.53125
Loss: 0.6897374987602234, Accuracy: 0.6875
Loss: 0.6931894421577454, Accuracy: 0.5
Loss: 0.6925976872444153, Accuracy: 0.53125


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6984431147575378, Accuracy: 0.3125
Loss: 0.6898072957992554, Accuracy: 0.5625
Loss: 0.6962985992431641, Accuracy: 0.40625
Loss: 0.6944954991340637, Accuracy: 0.4375
Loss: 0.6964917778968811, Accuracy: 0.375
Loss: 0.6920443773269653, Accuracy: 0.5625


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6905211806297302, Accuracy: 0.625
Loss: 0.6920495629310608, Accuracy: 0.5625
Loss: 0.6942465305328369, Accuracy: 0.4375
Loss: 0.6951526999473572, Accuracy: 0.375
Loss: 0.6904798150062561, Accuracy: 0.6875
Loss: 0.6925609111785889, Accuracy: 0.53125


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6926331520080566, Accuracy: 0.53125
Loss: 0.6942726969718933, Accuracy: 0.4375
Loss: 0.6943366527557373, Accuracy: 0.4375
Loss: 0.6938261985778809, Accuracy: 0.46875
Loss: 0.6914209127426147, Accuracy: 0.59375
Loss: 0.6914165019989014, Accuracy: 0.59375


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6938925981521606, Accuracy: 0.46875
Loss: 0.6938526034355164, Accuracy: 0.46875
Loss: 0.6925069093704224, Accuracy: 0.53125
Loss: 0.6900810599327087, Accuracy: 0.625
Loss: 0.6947386860847473, Accuracy: 0.4375
Loss: 0.6887503862380981, Accuracy: 0.65625


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6959651708602905, Accuracy: 0.4375
Loss: 0.6932454109191895, Accuracy: 0.5
Loss: 0.6915376782417297, Accuracy: 0.5625
Loss: 0.6935664415359497, Accuracy: 0.46875
Loss: 0.6940842866897583, Accuracy: 0.46875
Loss: 0.7145248055458069, Accuracy: 0.46875


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6933121085166931, Accuracy: 0.5
Loss: 0.6954265832901001, Accuracy: 0.46875
Loss: 0.6914814114570618, Accuracy: 0.5625
Loss: 0.6901326775550842, Accuracy: 0.59375
Loss: 0.689624547958374, Accuracy: 0.625
Loss: 0.6943284273147583, Accuracy: 0.46875


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6908072233200073, Accuracy: 0.59375
Loss: 0.6924492716789246, Accuracy: 0.53125
Loss: 0.6925531625747681, Accuracy: 0.53125
Loss: 0.6919513940811157, Accuracy: 0.5625
Loss: 0.6925630569458008, Accuracy: 0.53125
Loss: 0.6949936151504517, Accuracy: 0.40625


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.694333553314209, Accuracy: 0.4375
Loss: 0.6914123296737671, Accuracy: 0.59375
Loss: 0.6931854486465454, Accuracy: 0.5
Loss: 0.6951104402542114, Accuracy: 0.40625
Loss: 0.6936957836151123, Accuracy: 0.46875
Loss: 0.6931861042976379, Accuracy: 0.5


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6899629235267639, Accuracy: 0.6875
Loss: 0.6919894814491272, Accuracy: 0.5625
Loss: 0.6942183971405029, Accuracy: 0.4375
Loss: 0.6916645765304565, Accuracy: 0.59375
Loss: 0.693187415599823, Accuracy: 0.5
Loss: 0.6961514949798584, Accuracy: 0.3125


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.693260669708252, Accuracy: 0.46875
Loss: 0.6923502683639526, Accuracy: 0.5625
Loss: 0.6931598782539368, Accuracy: 0.46875
Loss: 0.6936099529266357, Accuracy: 0.46875
Loss: 0.694216787815094, Accuracy: 0.4375
Loss: 0.6959936618804932, Accuracy: 0.34375


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.694134533405304, Accuracy: 0.4375
Loss: 0.6923093795776367, Accuracy: 0.5625
Loss: 0.6936291456222534, Accuracy: 0.46875
Loss: 0.6917206048965454, Accuracy: 0.59375
Loss: 0.6926314830780029, Accuracy: 0.53125
Loss: 0.6948107481002808, Accuracy: 0.40625


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6926450133323669, Accuracy: 0.53125
Loss: 0.6959248781204224, Accuracy: 0.34375
Loss: 0.684956431388855, Accuracy: 0.71875
Loss: 0.6931796073913574, Accuracy: 0.5
Loss: 0.6949840188026428, Accuracy: 0.40625
Loss: 0.6936840415000916, Accuracy: 0.46875


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6924760937690735, Accuracy: 0.5625
Loss: 0.6931890249252319, Accuracy: 0.5
Loss: 0.6947885751724243, Accuracy: 0.40625
Loss: 0.6989741325378418, Accuracy: 0.34375
Loss: 0.6916230916976929, Accuracy: 0.5625
Loss: 0.6902644634246826, Accuracy: 0.625


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6921292543411255, Accuracy: 0.5625
Loss: 0.6848663091659546, Accuracy: 0.625
Loss: 0.6914160251617432, Accuracy: 0.5625
Loss: 0.6878023743629456, Accuracy: 0.65625
Loss: 0.6931964159011841, Accuracy: 0.5
Loss: 0.6915140748023987, Accuracy: 0.59375


0it [00:00, ?it/s]

Validation Accuracy: 0.5044800884955752
Loss: 0.6946560144424438, Accuracy: 0.4375
Loss: 0.6937602162361145, Accuracy: 0.46875
Loss: 0.6941531300544739, Accuracy: 0.4375
Loss: 0.6946391463279724, Accuracy: 0.40625
Loss: 0.6921651363372803, Accuracy: 0.5625
Loss: 0.6905391812324524, Accuracy: 0.65625


0it [00:00, ?it/s]

In [None]:
from tqdm.notebook import tqdm
from datasets import load_metric

retriever.to(device)
retriever.eval()
test_accuracy = load_metric("accuracy")

with torch.no_grad():
    for i, batch in tqdm(ds_iter['test']):
        # get the inputs; 
        input_ids_0 = batch["input_ids_0"].to(device)
        mask_0 = batch["mask_0"].to(device)
        input_ids_1 = batch["input_ids_0"].to(device)
        mask_1 = batch["mask_0"].to(device)
        labels = batch["label"].to(device)

        # forward pass
        logits = retriever(input_ids_0,mask_0,input_ids_1,mask_1)
        predictions = logits.argmax(-1).cpu().detach().numpy()
        references = batch["label"].numpy()
        test_accuracy.add_batch(predictions=predictions, references=references)

          #delete intermediate variables to free up GPU space
        del logits, input_ids_0, mask_0, input_ids_1, mask_1, labels, predictions, references

final_score = test_accuracy.compute()
print("Accuracy on test set:", final_score['accuracy'])

In [33]:
import torch
import torch.nn as nn
import math


def pooling(inp, mode):
    if mode == "CLS":
        pooled = inp[:, 0, :]
    elif mode == "MEAN":
        pooled = inp.mean(dim = 1)
    else:
        raise Exception()
    return pooled

def append_cls(inp, mask, vocab_size):
    batch_size = inp.size(0)
    cls_id = ((vocab_size - 1) * torch.ones(batch_size, dtype = torch.long, device = inp.device)).long()
    cls_mask = torch.ones(batch_size, dtype = torch.float, device = mask.device)
    inp = torch.cat([cls_id[:, None], inp[:, :-1]], dim = -1)
    mask = torch.cat([cls_mask[:, None], mask[:, :-1]], dim = -1)
    return inp, mask

class SCHeadDual(nn.Module):
    def __init__(self):
        super().__init__()
        self.pooling_mode = "CLS"
        self.mlpblock = nn.Sequential(
            nn.Linear(128 * 4, 128),
            nn.ReLU(),
            nn.Linear(128, 2)
        )

    def forward(self, inp_0, inp_1):
        X_0 = pooling(inp_0, self.pooling_mode)
        X_1 = pooling(inp_1, self.pooling_mode)
        seq_score = self.mlpblock(torch.cat([X_0, X_1, X_0 * X_1, X_0 - X_1], dim = -1))
        return seq_score

class ModelForSCDual(nn.Module):
    def __init__(self, model):
        super().__init__()

#         self.enable_amp = config["mixed_precision"]
#         self.pooling_mode = config["pooling_mode"]
#         self.vocab_size = config["vocab_size"]
        self.model = model

        self.seq_classifer = SCHeadDual()

    def forward(self, input_ids_0, input_ids_1, mask_0, mask_1, label):

        #with torch.cuda.amp.autocast(enabled = self.enable_amp):

        
        input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, 128)
        input_ids_1, mask_1 = append_cls(input_ids_1, mask_1, 128)

        token_out_0 = self.model(input_ids_0, mask_0)
        token_out_1 = self.model(input_ids_1, mask_1)
        seq_scores = self.seq_classifer(token_out_0, token_out_1)

        seq_loss = torch.nn.CrossEntropyLoss(reduction = "none")(seq_scores, label)
        seq_accu = (seq_scores.argmax(dim = -1) == label).to(torch.float32)
        outputs = {}
        outputs["loss"] = seq_loss
        outputs["accu"] = seq_accu

        return outputs