In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

from typing import List, Optional, Tuple, Union

import math
import random

import time
import json


In [None]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
from transformers import AutoModelForQuestionAnswering, AutoTokenizer

model_name = 'distilroberta-base'

M1 = AutoModelForQuestionAnswering.from_pretrained(model_name)
T1 = AutoTokenizer.from_pretrained(model_name, max_new_tokens=50)

Some weights of the model checkpoint at distilroberta-base were not used when initializing RobertaForQuestionAnswering: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.layer_norm.weight']
- This IS expected if you are initializing RobertaForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForQuestionAnswering were not initialized from the model checkpoint at distilroberta-base and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be 

In [None]:
class LinAttention(nn.Module):
    def __init__(self, old_attention_layer, k_dims=128, max_input_len=512):
        super().__init__()
        
        self.num_attention_heads = old_attention_layer.num_attention_heads
        self.attention_head_size = old_attention_layer.attention_head_size
        self.all_head_size       = old_attention_layer.all_head_size

        self.query = old_attention_layer.query
        self.key   = old_attention_layer.key
        self.value = old_attention_layer.value

        self.dropout = old_attention_layer.dropout
        self.position_embedding_type = old_attention_layer.position_embedding_type

        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = old_attention_layer.max_position_embeddings
            self.distance_embedding      = old_attention_layer.distance_embedding

        self.is_decoder = old_attention_layer.is_decoder

        E = torch.randn(k_dims, max_input_len)/math.sqrt(k_dims)
        D = torch.randn(k_dims, max_input_len)/math.sqrt(k_dims)
        self.E = nn.Parameter(E)
        self.D = nn.Parameter(D) 

    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        mixed_query_layer = self.query(hidden_states)

        # If this is instantiated as a cross-attention module, the keys
        # and values come from an encoder; the attention mask needs to be
        # such that the encoder's padding tokens are not attended to.
        is_cross_attention = encoder_hidden_states is not None

        if torch.any(torch.isnan(hidden_states)):
            print('0')

        if is_cross_attention and past_key_value is not None:
            # reuse k,v, cross_attentions
            key_layer = past_key_value[0]
            value_layer = past_key_value[1]
            attention_mask = encoder_attention_mask
        elif is_cross_attention:
            key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
            value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
            attention_mask = encoder_attention_mask
        elif past_key_value is not None:
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))
            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
        else:
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))

        query_layer = self.transpose_for_scores(mixed_query_layer)

        if torch.any(torch.isnan(key_layer)):
            print('1_1')
        if torch.any(torch.isnan(value_layer)):
            print('1_2')
        if torch.any(torch.isnan(query_layer)):
            print('1_3')

        if self.is_decoder:
            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
            # Further calls to cross_attention layer can then reuse all cross-attention
            # key/value_states (first "if" case)
            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
            # all previous decoder key/value_states. Further calls to uni-directional self-attention
            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
            # if encoder bi-directional self-attention `past_key_value` is always `None`
            past_key_value = (key_layer, value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        
        n_input = key_layer.shape[-2]

        if attention_mask is not None:
            if attention_mask.shape[-2]!=1:
                raise NotImplementedError(" Linformer not compatible with attention 2dim attention masks")
            else:
                n_input = torch.max(torch.sum(attention_mask==0,dim=(1,2,3)))

        #print(torch.mean(key_layer),torch.var(key_layer),'k')
        #print(torch.mean(value_layer),torch.var(value_layer),'v')

        projected_keys = torch.matmul(self.E[:,:n_input], key_layer[:,:,:n_input])
        projected_values = torch.matmul(self.D[:,:n_input], value_layer[:,:,:n_input])

        #print(torch.mean(self.E),torch.var(self.E),'E')

        #print(torch.mean(projected_keys),torch.var(projected_keys),'pk')
        #print(torch.mean(projected_values),torch.var(projected_values),'pv')

        if torch.any(torch.isnan(projected_keys)):
            print('2_1')
        if torch.any(torch.isnan(projected_values)):
            print('2_2')

        #print(torch.mean(query_layer),torch.var(query_layer),'q')

        #attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        projected_attention_scores = torch.matmul(query_layer, projected_keys.transpose(-1, -2))

        if torch.any(torch.isnan(projected_attention_scores)):
            print('3_1')

        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            raise NotImplementedError(" Linformer not compatible with relative keys")


        #attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        projected_attention_scores = projected_attention_scores / math.sqrt(self.attention_head_size)

        #print(torch.mean(projected_attention_scores),torch.var(projected_attention_scores),'s')

        if torch.any(torch.isnan(projected_attention_scores)):
            print('3_2')

        # Normalize the attention scores to probabilities.
        #attention_probs = nn.functional.softmax(attention_scores, dim=-1)
        projected_attention_probs = nn.functional.softmax(projected_attention_scores, dim=-1)

        if torch.any(torch.isnan(projected_attention_probs)):
            print('4_1')
            print(projected_attention_scores)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        #attention_probs = self.dropout(attention_probs)
        projected_attention_probs = self.dropout(projected_attention_probs)

        if torch.any(torch.isnan(projected_attention_probs)):
            print('4_2')

        # Mask heads if we want to
        if head_mask is not None:
            projected_attention_probs = projected_attention_probs * head_mask

        if torch.any(torch.isnan(projected_attention_probs)):
            print('4_3')

        #context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = torch.matmul(projected_attention_probs, projected_values)

        #print(torch.mean(context_layer),torch.var(context_layer),'C')

        if torch.any(torch.isnan(context_layer)):
            print('5_1')

        #print(projected_attention_probs)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)

        outputs = (context_layer, projected_attention_probs) if output_attentions else (context_layer,)

        if self.is_decoder:
            outputs = outputs + (past_key_value,)
        return outputs

In [None]:
M1

RobertaForQuestionAnswering(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm):

In [None]:
for L in M1.roberta.encoder.layer:
    A1=LinAttention(L.attention.self, k_dims=64)
    L.attention.self=A1

# Test modified model

In [None]:
import os
import urllib.request
from tqdm import tqdm

class DownloadProgressBar(tqdm):
    def update_to(self, b=1, bsize=1, tsize=None):
        if tsize is not None:
            self.total = tsize
        self.update(b * bsize - self.n)
        
def download_url(url, output_path):
    with DownloadProgressBar(unit='B', unit_scale=True,
                             miniters=1, desc=url.split('/')[-1]) as t:
        urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to)

def download_data(data_path, url_path, suffix):    
    if not os.path.exists(data_path):
        os.makedirs(data_path)
        
    data_path = os.path.join(data_path, f'{suffix}.json')

    if not os.path.exists(data_path):
        print(f"Downloading CoQA {suffix} data split... (it may take a while)")
        download_url(url=url_path, output_path=data_path)
        print("Download completed!")

In [None]:
# Train data
train_url = "https://nlp.stanford.edu/data/coqa/coqa-train-v1.0.json"
download_data(data_path='coqa', url_path=train_url, suffix='train')

# Test data
test_url = "https://nlp.stanford.edu/data/coqa/coqa-dev-v1.0.json"
download_data(data_path='coqa', url_path=test_url, suffix='test')  # <-- Why test? See next slides for an answer!

In [None]:
with open(os.path.join('coqa', 'train.json'), 'r') as j:
    train = json.loads(j.read())

with open(os.path.join('coqa', 'test.json'), 'r') as j:
    test = json.loads(j.read())

In [None]:
train = train['data']
test = test['data']

In [None]:
lengths=[len(doc['questions']) for doc in train]

In [None]:
le=np.cumsum(np.array(lengths,dtype=np.float32))
train_end=np.where((le/le[-1])>0.8)[0][0]

validation = train[train_end : ] 
train = train[ : train_end]

In [None]:
print(len(train))
print(len(validation))

5771
1428


In [None]:
len_train=np.sum([len(doc['questions']) for doc in train])
len_val=np.sum([len(doc['questions']) for doc in validation])

len_tot=len_train+len_val
print(len_train,len_train/len_tot)
print(len_val,len_val/len_tot)

86909 0.7999208445700295
21738 0.20007915542997046


In [None]:
class CustomImageDataset(torch.utils.data.Dataset):
    def __init__(self, data, return_history=False):

        self.story=[d['story'] for d in data]
        self.questions=[d['questions'] for d in data]
        self.answers=[d['answers'] for d in data]
        lengths = [len(doc['questions']) for doc in data]
        self.lengths = np.cumsum(np.array(lengths,dtype=np.int32))
        self.R_H=return_history
        

    def __len__(self):
        return self.lengths[-1]

    def __getitem__(self, idx):
        f_idx=int(np.where(self.lengths > idx)[0][0])
        if f_idx>0:
            q_idx=idx-self.lengths[f_idx-1]
        else:
            q_idx=idx

        passage=self.story[f_idx]
        questions=self.questions[f_idx]
        answers=self.answers[f_idx]
        question=questions[q_idx]['input_text']
        span_start=int(answers[q_idx]['span_start'])
        span_end=int(answers[q_idx]['span_end'])
        span_text=answers[q_idx]['span_text']

        if self.R_H:
            history = np.concatenate([ [questions[i]['input_text'], answers[i]['input_text']] for i in range(q_idx)],0)
            return (passage,question,history), (span_start,span_end)

        return (passage,question), (span_start,span_end)

In [None]:
from torch.utils.data import DataLoader

batch_size=16

train_dataloader = DataLoader(CustomImageDataset(train), batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(CustomImageDataset(validation), batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(CustomImageDataset(test), batch_size=batch_size, shuffle=True)

In [None]:
def get_target_indices(inputs,sep_starts, sep_ends):
    start_positions = []
    end_positions = []

    for i, offset in enumerate(inputs["offset_mapping"]):
        sample_idx = inputs["overflow_to_sample_mapping"][i]
        start_char = sep_starts[sample_idx]
        end_char = sep_ends[sample_idx]
        sequence_ids = inputs.sequence_ids(i)


        # Find the start and end of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # If the answer is not fully inside the context, label is (0, 0)
        if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    return start_positions, end_positions

In [None]:
def encode(tokenizer,questions, passages, max_length=512, stride=250):
        return tokenizer(
            questions,
            passages,
            max_length=max_length,
            truncation="only_second",
            stride=50,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            padding="max_length", 
            return_tensors="pt"
        )

In [None]:
def loss_f(P0,P1,i,j,l1=1,l2=2):

    gamma=2 # focal loss parameter

    j=i+5

    l=len(P0)

    S=torch.exp(P0[:,None]+P1[None])
    P=torch.triu(S)/(torch.sum(torch.triu(S))+1e-9)

    W0=torch.zeros(l,device=P0.device)
    W1=torch.zeros(l,device=P0.device)

    if i>1:
        W0[:i]=1+l1*torch.arange(i,0,-1)/l
    if i<l-2:
        W0[i+1:]=1+l2*torch.arange(1,l-i)/l
    if j>1:
        W1[:j]=1+l2*torch.arange(j,0,-1)/l
    if j<l-2:
        W1[j+1:]=1+l1*torch.arange(1,l-j)/l
    
    #                     spatial weighting          focal loss      continuous CE loss
    return torch.sum( (-(W0[:,None]+W1[None,:]) * torch.pow(P,gamma) * torch.log(1-P)) ) - torch.pow(1-P[i,j],gamma)*torch.log(P[i,j])

In [None]:
def train(model, tokenizer, epochs=1, learning_rate=1e-3):
    
    model.to('cuda')

    loss_history=[]

    model.config.decoder_start_token_id = tokenizer.cls_token_id
    model.config.pad_token_id = tokenizer.pad_token_id

    optimizer = torch.optim.Adam(iter(list(model.parameters())), lr=learning_rate)
    
    max_length = 500
    stride = 250

    
    for epoch in range(epochs):  # loop over the dataset multiple times
        
        start_time = time.time()
        running_loss = 0.0
        for batch_idx, data in enumerate(train_dataloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            (passage, question), (sep_starts, sep_ends) = data

            #text_input = [question[i] + ' [SEP] ' + passage[i] for i in range(len(passage))]

            # zero the parameter gradients
            try:
                inputs = encode(tokenizer, question, passage, max_length=max_length, stride=stride)
            except:
                inputs = encode(tokenizer, question, passage, max_length=max_length*2, stride=stride)

            start_positions, end_positions= get_target_indices(inputs, sep_starts, sep_ends)

            target_start_index = torch.tensor(start_positions)[inputs['overflow_to_sample_mapping']].to('cuda')
            target_end_index = torch.tensor(end_positions)[inputs['overflow_to_sample_mapping']].to('cuda')

            del inputs['overflow_to_sample_mapping']
            del inputs['offset_mapping']

            outputs = model(**inputs.to('cuda'), start_positions=target_start_index, end_positions=target_end_index)

            #loss = outputs
            loss=0
            for i in range(len(target_start_index)):
                loss+=loss_f(outputs['start_logits'][i],outputs['end_logits'][i],target_start_index[i],target_end_index[i],0.5,1)
                loss/=len(target_start_index)

            loss.backward()

            optimizer.step()
            optimizer.zero_grad()

            # print statistics
            running_loss += loss

            loss_history.append(loss.detach().cpu().numpy())
            
            epoch_time = time.time() - start_time
            batch_time = epoch_time/(batch_idx+1)

            print(f"epoch: {epoch + 1}/{epochs}, {batch_idx + 1}/{len(train_dataloader)}, {epoch_time:.0f}s {batch_time*1e3:.0f}ms/step, lr: {optimizer.param_groups[0]['lr']:.3g}, loss: {running_loss/(batch_idx+1):.3g}, {loss:.3g}              ")#, end = '\r'

        #print(f"epoch: {epoch + 1}/{epochs}, {batch_idx + 1}/{len(train_dataloader)}, {epoch_time:.0f}s {batch_time*1e3:.0f}ms/step, lr: {optimizer.param_groups[0]['lr']:.3g}, loss: {running_loss/(batch_idx+1):.3g} {loss:.3g}                ")

    return loss_history

In [None]:
R=[]    # save history of the runs

In [None]:
LR=[1e-4]  # lr to test         [1e-2,1e-3,1e-4,1e-5]

model_name = 'distilroberta-base'
#model_name = 'prajjwal1/bert-tiny'

T1 = AutoTokenizer.from_pretrained(model_name, max_new_tokens=50)

for lr in LR:
    #M1 = AutoModelForQuestionAnswering.from_pretrained(model_name)
    H=train(M1,T1,epochs=1,learning_rate=lr)
    R.append([H,lr])

epoch: 1/1, 1/5432, 1s 780ms/step, lr: 0.0001, loss: 0.781, 0.781              
epoch: 1/1, 2/5432, 2s 757ms/step, lr: 0.0001, loss: 0.757, 0.733              
epoch: 1/1, 3/5432, 2s 747ms/step, lr: 0.0001, loss: 0.767, 0.788              
epoch: 1/1, 4/5432, 3s 772ms/step, lr: 0.0001, loss: 0.748, 0.691              
epoch: 1/1, 5/5432, 4s 792ms/step, lr: 0.0001, loss: 0.745, 0.734              
epoch: 1/1, 6/5432, 5s 796ms/step, lr: 0.0001, loss: 0.743, 0.73              
epoch: 1/1, 7/5432, 6s 808ms/step, lr: 0.0001, loss: 0.748, 0.781              
epoch: 1/1, 8/5432, 6s 801ms/step, lr: 0.0001, loss: 0.753, 0.786              
epoch: 1/1, 9/5432, 7s 812ms/step, lr: 0.0001, loss: 0.745, 0.684              
epoch: 1/1, 10/5432, 8s 821ms/step, lr: 0.0001, loss: 0.744, 0.733              
epoch: 1/1, 11/5432, 9s 816ms/step, lr: 0.0001, loss: 0.748, 0.784              
epoch: 1/1, 12/5432, 10s 815ms/step, lr: 0.0001, loss: 0.747, 0.734              
epoch: 1/1, 13/5432, 11s 811ms/step, 

KeyboardInterrupt: ignored