In [1]:
from data.Nystromformer.LRA.code import lra_config
from data.Nystromformer.LRA.code.dataset import LRADataset
#from Nystromformer.LRA.code.run_tasks import training_config
from torch.utils.data import DataLoader, RandomSampler
from torch import nn
import torch.nn.functional as F

piaynTaskDataDir = "data/"
piaynTaskModelDir = "data/"
task = "retrieval"

#get training config
training_config = lra_config.config[task]["training"]

#Check Train Config
print('Training Config: ', training_config)

#get pre-defined model config
model_config = lra_config.config[task]['model']

#Check model Config
print('Model Config: ', model_config)

#Get the dataset
train_dataset = LRADataset(piaynTaskDataDir + f"/{task}/{task}.train.pickle", True)
val_dataset = LRADataset(piaynTaskDataDir + f"/{task}/{task}.dev.pickle", False)
test_dataset = LRADataset(piaynTaskDataDir + f"/{task}/{task}.test.pickle", False)

#Create DataLoader iterators
ds_iter = {
    "train":enumerate(DataLoader(train_dataset, 
                                 batch_size = training_config["batch_size"], 
                                 drop_last = True)),
    "dev":enumerate(DataLoader(val_dataset, batch_size = 1, drop_last = True)),
    "test":enumerate(DataLoader(test_dataset, batch_size = 1, drop_last = True)),
}



Training Config:  {'batch_size': 32, 'learning_rate': 0.0001, 'warmup': 800, 'lr_decay': 'linear', 'weight_decay': 0, 'eval_frequency': 300, 'num_train_steps': 30000, 'num_eval_steps': 565}
Model Config:  {'learn_pos_emb': True, 'tied_weights': False, 'embedding_dim': 64, 'transformer_dim': 64, 'transformer_hidden_dim': 128, 'head_dim': 32, 'num_head': 2, 'num_layers': 2, 'vocab_size': 512, 'max_seq_len': 4000, 'dropout_prob': 0.1, 'attention_dropout': 0.1, 'pooling_mode': 'MEAN', 'num_classes': 2}
Loaded data//retrieval/retrieval.train.pickle... size=147086
Loaded data//retrieval/retrieval.dev.pickle... size=18090
Loaded data//retrieval/retrieval.test.pickle... size=17437


In [2]:
#Check sizes of batches
batch = next((ds_iter['train']))
for k,v in batch[1].items():
  print(k,v.shape)

input_ids_0 torch.Size([32, 4096])
mask_0 torch.Size([32, 4096])
input_ids_1 torch.Size([32, 4096])
mask_1 torch.Size([32, 4096])
label torch.Size([32])


In [3]:
from transformers import PerceiverForSequenceClassification, PerceiverForMaskedLM
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device('cpu')

In [4]:
from transformers import PerceiverConfig
#get default perceiver config
configuration = PerceiverConfig()

#Print Updated Perceiver Configuration
print(configuration)

PerceiverConfig {
  "attention_probs_dropout_prob": 0.1,
  "audio_samples_per_frame": 1920,
  "cross_attention_shape_for_attention": "kv",
  "cross_attention_widening_factor": 1,
  "d_latents": 1280,
  "d_model": 768,
  "hidden_act": "gelu",
  "image_size": 56,
  "initializer_range": 0.02,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 2048,
  "model_type": "perceiver",
  "num_blocks": 1,
  "num_cross_attention_heads": 8,
  "num_frames": 16,
  "num_latents": 256,
  "num_self_attends_per_block": 26,
  "num_self_attention_heads": 8,
  "output_shape": [
    1,
    16,
    224,
    224
  ],
  "qk_channels": null,
  "samples_per_patch": 16,
  "self_attention_widening_factor": 1,
  "train_size": [
    368,
    496
  ],
  "transformers_version": "4.16.2",
  "use_query_residual": true,
  "v_channels": null,
  "vocab_size": 262
}



In [5]:
#configuration.num_labels = 2
configuration.num_self_attends_per_block = 3
configuration.d_latents = 512
configuration.d_model = 512
configuration.max_position_embeddings = 4097
configuration.num_cross_attention_heads = 4
configuration.num_self_attention_heads = 4
configuration.num_latents = 512
configuration.vocab_size = 128

In [6]:
def append_cls(inp, mask, vocab_size):
    batch_size = inp.size(0)
    cls_id = ((vocab_size - 1) * torch.ones(batch_size, dtype=inp.dtype, device=inp.device))#.long()
    cls_mask = torch.ones(batch_size, dtype=mask.dtype, device=mask.device)
    # inp = torch.cat([cls_id[:, None], inp[:, :-1]], dim=-1)
    # mask = torch.cat([cls_mask[:, None], mask[:, :-1]], dim=-1)
    inp = torch.cat([cls_id[:, None], inp], dim=-1)
    mask = torch.cat([cls_mask[:, None], mask], dim=-1)
    return inp, mask

In [7]:
class Retrieval(nn.Module):
    def __init__(self,config):
        super(Retrieval, self).__init__()
        self.config = config
        self.perceiver = PerceiverForMaskedLM(config = config)
        self.linear1 = nn.Linear(512,128)
        self.linear2 = nn.Linear(128,2)

    def forward(self,input_ids_0, mask_0, input_ids_1, mask_1):
        input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.config.vocab_size)
        input_ids_1, mask_1 = append_cls(input_ids_1, mask_1, self.config.vocab_size)
        #print(input_ids_0.shape, mask_0.shape)
        o1 = self.perceiver(inputs = input_ids_0, attention_mask = mask_0)
        o1 = o1.logits[:,0,:]
        o2 = self.perceiver(inputs = input_ids_1, attention_mask = mask_1)
        o2 = o2.logits[:,0,:]
        o3 = torch.cat([o1,o2,o1*o2,(o1 - o2)],axis = 1)
        o4 = F.relu(self.linear1(o3))
        o5 = self.linear2(o4)
        return o5

In [8]:
retriever = Retrieval(configuration)
device

device(type='cuda')

In [9]:
pytorch_total_params = sum(p.numel() for p in retriever.parameters())
pytorch_total_params_Trainable = sum(p.numel() for p in retriever.parameters() if p.requires_grad)
print('Total Parameters: ', pytorch_total_params, '\nTrainable Parameters: ', pytorch_total_params_Trainable)  

Total Parameters:  12218370 
Trainable Parameters:  12218370


In [10]:
input_ids_0 = batch[1]["input_ids_0"].to(device)
mask_0 = batch[1]["mask_0"].to(device)
input_ids_1 = batch[1]["input_ids_0"].to(device)
mask_1 = batch[1]["mask_0"].to(device)
labels = batch[1]["label"].to(device)
retriever = retriever.to(device)

In [None]:
outputs = retriever(input_ids_0,mask_0,input_ids_1,mask_1)

torch.Size([32, 4097]) torch.Size([32, 4097])


In [15]:
 next(ds_iter['train'])

(738,
 {'input_ids_0': tensor([[99, 40, 50,  ...,  0,  0,  0],
          [99, 40, 50,  ...,  0,  0,  0],
          [99, 40, 50,  ...,  0,  0,  0],
          ...,
          [99, 40, 50,  ...,  0,  0,  0],
          [99, 40, 50,  ...,  0,  0,  0],
          [99, 40, 50,  ...,  0,  0,  0]]),
  'mask_0': tensor([[1., 1., 1.,  ..., 0., 0., 0.],
          [1., 1., 1.,  ..., 0., 0., 0.],
          [1., 1., 1.,  ..., 0., 0., 0.],
          ...,
          [1., 1., 1.,  ..., 0., 0., 0.],
          [1., 1., 1.,  ..., 0., 0., 0.],
          [1., 1., 1.,  ..., 0., 0., 0.]]),
  'input_ids_1': tensor([[99, 40, 50,  ...,  0,  0,  0],
          [99, 40, 50,  ...,  0,  0,  0],
          [99, 35, 50,  ...,  0,  0,  0],
          ...,
          [99, 40, 50,  ...,  0,  0,  0],
          [99, 40, 50,  ...,  0,  0,  0],
          [99, 40, 50,  ...,  0,  0,  0]]),
  'mask_1': tensor([[1., 1., 1.,  ..., 0., 0., 0.],
          [1., 1., 1.,  ..., 0., 0., 0.],
          [1., 1., 1.,  ..., 0., 0., 0.],
          .

In [None]:
from torch.optim import AdamW
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score
from datasets import load_metric
import pandas as pd
from torch.nn import CrossEntropyLoss

best_score = 0 
prev_score = 0
# maxPatience = 3
# currentPatience = 0

#steps = int(training_config["num_train_steps"]/20000)
steps = 5000
loss_fn = CrossEntropyLoss()

optimizer = AdamW(retriever.parameters(), 
                  lr = 0.5,
                  betas = (0.9, 0.999), 
                  eps = 1e-6, 
                  weight_decay = training_config["weight_decay"])

lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer = optimizer,
    max_lr = training_config["learning_rate"],
    #max_lr = 0.5,
    pct_start = training_config["warmup"] / training_config["num_train_steps"],
    #pct_start = training_config["warmup"] / 5000,
    anneal_strategy = training_config["lr_decay"],
    total_steps = training_config["num_train_steps"],
    #verbose=True
)

#amp_scaler = torch.cuda.amp.GradScaler() if model_config["mixed_precision"] else None

#initialize training summary
trainingSummary = pd.DataFrame(columns=['step', 'train_acc', 'val_acc'])


retriever.to(device)

#initialize training accuracy metric
train_accuracy = load_metric("accuracy")
#batch = next(ds_iter['train'])[1]

for step in tqdm(range(30000)):  # Perform gradient updates for multiple steps
    
    #model.train()
    retriever.train()
    
    #print("Step:", step)
    #for batch in tqdm(train_dataloader):
    batch = next(ds_iter['train'])[1]

    # get the inputs; 
    input_ids_0 = batch["input_ids_0"].to(device)
    mask_0 = batch["mask_0"].to(device)
    input_ids_1 = batch["input_ids_0"].to(device)
    mask_1 = batch["mask_0"].to(device)
    labels = batch["label"].to(device)

    # zero the parameter gradients
    optimizer.zero_grad()

    # forward + backward + optimize
    outputs = retriever(input_ids_0,mask_0,input_ids_1,mask_1)
    loss = loss_fn(outputs,labels)
    loss.backward()
    optimizer.step()
    lr_scheduler.step()

    # evaluate
    predictions = outputs.argmax(-1).cpu().detach().numpy()
    accuracy = accuracy_score(y_true=batch["label"].numpy(), y_pred=predictions)
    references = batch["label"].numpy()
    train_accuracy.add_batch(predictions=predictions, references=references)
    
    if (step+1)%50  == 0:
        print(f"Loss: {loss.item()}, Accuracy: {accuracy}")

    #delete intermediate variables to free up GPU space
    del loss, outputs, input_ids_0, mask_0, labels, predictions, accuracy,input_ids_1,mask_1


    #Every 1000 steps validate and save model
    if (step+1)%training_config['eval_frequency']  == 0:
    #if (step+1)%2  == 0:
        #model.eval()
        retriever.eval()
        val_accuracy = load_metric("accuracy")

      #reset dev iterator
        ds_iter['dev'] = enumerate(DataLoader(val_dataset, batch_size = 32, drop_last = True))

        with torch.no_grad():
            for i, batch in tqdm(ds_iter['dev']):
                input_ids_0 = batch["input_ids_0"].to(device)
                mask_0 = batch["mask_0"].to(device)
                input_ids_1 = batch["input_ids_0"].to(device)
                mask_1 = batch["mask_0"].to(device)
                labels = batch["label"].to(device)

              # forward pass
                logits = retriever(input_ids_0,mask_0,input_ids_1,mask_1)
                predictions = logits.argmax(-1).cpu().detach().numpy()
                references = batch["label"].numpy()
                val_accuracy.add_batch(predictions=predictions, references=references)

          #delete intermediate variables to free up GPU space
                del logits, input_ids_0, mask_0, input_ids_1, mask_1, labels, predictions, references
      
      #Compute val accuracy
        final_val_score = val_accuracy.compute()['accuracy']
        print("Validation Accuracy:", final_val_score)

        if final_val_score >= best_score:
            best_score = final_val_score
            torch.save(retriever.to('cpu').state_dict(), piaynTaskModelDir + '/trainedPerceiverClassifierToken'+'.pkl')
            retriever.to(device)
        else:
            pass  

#         if final_val_score <= prev_score:
#             currentPatience += 1
#             if currentPatience >= maxPatience:
#                 print('Patience Limit reached! Stopping early!')
#                 torch.save(retriever.to('cpu').state_dict(), piaynTaskModelDir + '/trainedPerceiverClassifierStep_' + str(step + 1) + 'Token' + '.pkl')
#                 break  
#         else:
#             currentPatience = 0
      
      #Update prev_score
        prev_score = final_val_score

      #Compute training accuracy till now
        final_train_score = train_accuracy.compute()['accuracy']

      #Add to trainingSummary
        trainingSummary.loc[len(trainingSummary.index)] = [step+1, final_train_score, final_val_score]

  setattr(self, word, getattr(machar, word).flat[0])
  return self._float_to_str(self.smallest_subnormal)
  setattr(self, word, getattr(machar, word).flat[0])
  return self._float_to_str(self.smallest_subnormal)


  0%|          | 0/30000 [00:00<?, ?it/s]

Loss: 0.6918203830718994, Accuracy: 0.53125
Loss: 0.6933383941650391, Accuracy: 0.5
Loss: 0.6946132779121399, Accuracy: 0.40625


In [None]:
from tqdm.notebook import tqdm
from datasets import load_metric

retriever.to(device)
retriever.eval()
test_accuracy = load_metric("accuracy")

with torch.no_grad():
    for i, batch in tqdm(ds_iter['test']):
        # get the inputs; 
        input_ids_0 = batch["input_ids_0"].to(device)
        mask_0 = batch["mask_0"].to(device)
        input_ids_1 = batch["input_ids_0"].to(device)
        mask_1 = batch["mask_0"].to(device)
        labels = batch["label"].to(device)

        # forward pass
        logits = retriever(input_ids_0,mask_0,input_ids_1,mask_1)
        predictions = logits.argmax(-1).cpu().detach().numpy()
        references = batch["label"].numpy()
        test_accuracy.add_batch(predictions=predictions, references=references)

          #delete intermediate variables to free up GPU space
        del logits, input_ids_0, mask_0, input_ids_1, mask_1, labels, predictions, references

final_score = test_accuracy.compute()
print("Accuracy on test set:", final_score['accuracy'])

In [33]:
import torch
import torch.nn as nn
import math


def pooling(inp, mode):
    if mode == "CLS":
        pooled = inp[:, 0, :]
    elif mode == "MEAN":
        pooled = inp.mean(dim = 1)
    else:
        raise Exception()
    return pooled

def append_cls(inp, mask, vocab_size):
    batch_size = inp.size(0)
    cls_id = ((vocab_size - 1) * torch.ones(batch_size, dtype = torch.long, device = inp.device)).long()
    cls_mask = torch.ones(batch_size, dtype = torch.float, device = mask.device)
    inp = torch.cat([cls_id[:, None], inp[:, :-1]], dim = -1)
    mask = torch.cat([cls_mask[:, None], mask[:, :-1]], dim = -1)
    return inp, mask

class SCHeadDual(nn.Module):
    def __init__(self):
        super().__init__()
        self.pooling_mode = "CLS"
        self.mlpblock = nn.Sequential(
            nn.Linear(128 * 4, 128),
            nn.ReLU(),
            nn.Linear(128, 2)
        )

    def forward(self, inp_0, inp_1):
        X_0 = pooling(inp_0, self.pooling_mode)
        X_1 = pooling(inp_1, self.pooling_mode)
        seq_score = self.mlpblock(torch.cat([X_0, X_1, X_0 * X_1, X_0 - X_1], dim = -1))
        return seq_score

class ModelForSCDual(nn.Module):
    def __init__(self, model):
        super().__init__()

#         self.enable_amp = config["mixed_precision"]
#         self.pooling_mode = config["pooling_mode"]
#         self.vocab_size = config["vocab_size"]
        self.model = model

        self.seq_classifer = SCHeadDual()

    def forward(self, input_ids_0, input_ids_1, mask_0, mask_1, label):

        #with torch.cuda.amp.autocast(enabled = self.enable_amp):

        
        input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, 128)
        input_ids_1, mask_1 = append_cls(input_ids_1, mask_1, 128)

        token_out_0 = self.model(input_ids_0, mask_0)
        token_out_1 = self.model(input_ids_1, mask_1)
        seq_scores = self.seq_classifer(token_out_0, token_out_1)

        seq_loss = torch.nn.CrossEntropyLoss(reduction = "none")(seq_scores, label)
        seq_accu = (seq_scores.argmax(dim = -1) == label).to(torch.float32)
        outputs = {}
        outputs["loss"] = seq_loss
        outputs["accu"] = seq_accu

        return outputs