In [None]:
!git lfs install
!pip install transformers

!GIT_LFS_SKIP_SMUDGE=1

In [None]:
import requests
import json
import torch
import torch.nn as nn
import os
from tqdm import tqdm
from transformers import BertModel, BertTokenizerFast, AdamW
# AutoTokenizer, AutoModelForQuestionAnswering, BertTokenizer, BertForQuestionAnswering
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ExponentialLR
import matplotlib.pyplot as plt

In [None]:
num_questions = 0

In [None]:
def get_data(questions , context , answers , spans):
    """
    All are file paths , should read raw text from file 
    and convert into list
    """
    
    with open(questions) as f:
        raw_questions = f.read().splitlines()
        
    with open(context) as f:
        raw_context = f.read().splitlines()
        
    with open(answers) as f:
        raw_answers = f.read().splitlines()
        
    with open(spans) as f:
        raw_spans = f.read().splitlines()
        
    
        
    """
    answer = {'text' : ===== , 'answer_start': ======= , 'answer_end':=====}
    """
    
    Answers = []
    for iter in range(len(raw_answers)):
        dic = {}
        dic['text'] = raw_answers[iter]

#         start_index = raw_context[iter].index(dic['text'])
#         end_index = start_index + len(dic['text'])
#         dic['answer_start'] = start_index
#         dic['answer_end'] = end_index
            
        Answers.append(dic)
        
    return raw_context, raw_questions, Answers       
    

In [None]:
train_context, train_questions, train_answers = get_data("/kaggle/input/code-and-data/two_epochs/proj_dataset/train_data/real_que_tel.txt",
        "/kaggle/input/code-and-data/two_epochs/proj_dataset/train_data/real_con_tel.txt",
        "/kaggle/input/code-and-data/two_epochs/proj_dataset/train_data/real_ans_tel.txt",
        "/kaggle/input/code-and-data/two_epochs/proj_dataset/train_data/real_span_tel.txt")

In [None]:
print(f"Context : {train_context[0]}")
print(f"Question : {train_questions[0]}")
print(f"Answer: {train_answers[0]}")

In [None]:
val_context, val_questions, val_answers = get_data("/kaggle/input/code-and-data/two_epochs/proj_dataset/test_data/real_que_tel.txt",
        "/kaggle/input/code-and-data/two_epochs/proj_dataset/test_data/real_con_tel.txt",
        "/kaggle/input/code-and-data/two_epochs/proj_dataset/test_data/real_ans_tel_c.txt",
        "/kaggle/input/code-and-data/two_epochs/proj_dataset/test_data/real_span_tel.txt")

In [None]:
token_lens = []

for txt in train_context:
    txt = txt.strip()  # remove leading and trailing whitespaces
    token_lens.append(len(txt.split(' ')))
  

print(max(token_lens))

plt.hist(token_lens,  bins=20)  # density=False would make counts
plt.ylabel('Count')
plt.xlabel('Length')
plt.title('Distribution of Context Lengths');

In [None]:
token_lens2 = []

for txt in train_questions:
    txt = txt.strip()  # remove leading and trailing whitespaces
    token_lens2.append(len(txt.split(' ')))


print(max(token_lens2))
print(len(token_lens2))

plt.hist(token_lens2,  bins=20)  # density=False would make counts
plt.ylabel('Count')
plt.xlabel('Length')
plt.title('Distribution of Question Lengths');

In [None]:
MAX_LENGTH = 300 

In [None]:
from transformers import AutoModel, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('ai4bharat/indic-bert')

In [None]:
tokenizer.is_fast

In [None]:
train_encodings = tokenizer(train_questions, train_context,  max_length = MAX_LENGTH, truncation=True, padding=True)
valid_encodings = tokenizer(val_questions, val_context,  max_length = MAX_LENGTH, truncation=True, padding = True)

In [None]:
type(train_encodings)

In [None]:
print(train_encodings.keys())
print(valid_encodings.keys())
print(len(train_encodings['input_ids']))
print(len(train_encodings['input_ids'][0]))

In [None]:
print(train_encodings['input_ids'][1])

In [None]:
def ret_Answer_start_and_end_train(idx):
    ret_start = 0
    ret_end = 0
    answer_encoding = tokenizer(train_answers[idx]['text'],  max_length = MAX_LENGTH, truncation=True, padding=True)
    for a in range( len(train_encodings['input_ids'][idx]) -  len(answer_encoding['input_ids']) ): #len(train_encodings['input_ids'][0])):
        match = True
        iter = 0
        for i in range(1,len(answer_encoding['input_ids']) - 1):
            iter =i 
            if (answer_encoding['input_ids'][i] != train_encodings['input_ids'][idx][a + i]):
                match = False
                break
        if match:
            ret_start = a+1
            ret_end = a+iter+1
            break
    return(ret_start, ret_end)

In [None]:
test_rec=92

z,x = ret_Answer_start_and_end_train(test_rec)
print(z, x)

predict_answer_tokens = train_encodings.input_ids[test_rec][z : x]
print(tokenizer.decode(predict_answer_tokens))
print(train_answers[test_rec]['text'])
print(tokenizer.decode(train_encodings['input_ids'][test_rec]))

In [None]:
print(train_encodings.keys())
print(valid_encodings.keys())
print(len(train_encodings['input_ids']))

In [None]:
start_positions = []
end_positions = []
ctr = 0
for h in range(len(train_encodings['input_ids'])):
    #print(h)
    s, e = ret_Answer_start_and_end_train(h)
    start_positions.append(s)
    end_positions.append(e)
    if s==0:
        ctr = ctr + 1

    
train_encodings.update({'start_positions': start_positions, 'end_positions': end_positions})
print(ctr)

In [None]:
print(train_encodings.keys())
print(valid_encodings.keys())
print(len(train_encodings['input_ids']))

In [None]:
test_rec = 1
print(train_encodings['start_positions'][test_rec])
print(train_encodings['end_positions'][test_rec])
predict_answer_tokens = train_encodings.input_ids[test_rec][train_encodings['start_positions'][test_rec] : train_encodings['end_positions'][test_rec]]
print(tokenizer.decode(predict_answer_tokens))
print(train_answers[test_rec]['text'])
print(tokenizer.decode(train_encodings['input_ids'][test_rec]))

In [None]:
def ret_Answer_start_and_end_valid(idx):
    ret_start = 0
    ret_end = 0
    answer_encoding = tokenizer(val_answers[idx]['text'],  max_length = MAX_LENGTH, truncation=True, padding=True)
    for a in range( len(valid_encodings['input_ids'][idx])  -  len(answer_encoding['input_ids'])   ): #len(train_encodings_fast['input_ids'][0])):
        match = True
        for i in range(1,len(answer_encoding['input_ids']) - 1):
            if (answer_encoding['input_ids'][i] != valid_encodings['input_ids'][idx][a + i]):
                match = False
                break
        if match:
            ret_start = a+1
            ret_end = a+i+1
            break
    return(ret_start, ret_end)

In [None]:
start_positions = []
end_positions = []
ctr = 0
for h in range(len(valid_encodings['input_ids']) ):
    #print(h)
    s, e = ret_Answer_start_and_end_valid(h)
    start_positions.append(s)
    end_positions.append(e)
    if s==0:
        ctr = ctr + 1

    
valid_encodings.update({'start_positions': start_positions, 'end_positions': end_positions})
print(ctr)

In [None]:
test_rec=2

z,x = ret_Answer_start_and_end_valid(test_rec)

predict_answer_tokens = valid_encodings.input_ids[test_rec][z : x]
print(tokenizer.decode(predict_answer_tokens))
print(val_answers[test_rec]['text'])
print(tokenizer.decode(valid_encodings['input_ids'][test_rec]))

In [None]:
print(train_encodings.keys())
print(valid_encodings.keys())
print(len(train_encodings['input_ids']))
print(len(train_encodings['start_positions']))
print(len(train_encodings['end_positions']))
print(len(valid_encodings['input_ids']))
print(len(valid_encodings['start_positions']))
print(len(valid_encodings['end_positions']))

In [None]:
class InputDataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings
    def __getitem__(self, i):
        return {
            'input_ids': torch.tensor(self.encodings['input_ids'][i]),
            'token_type_ids': torch.tensor(self.encodings['token_type_ids'][i]),
            'attention_mask': torch.tensor(self.encodings['attention_mask'][i]),
            'start_positions': torch.tensor(self.encodings['start_positions'][i]),
            'end_positions': torch.tensor(self.encodings['end_positions'][i])
        }
    def __len__(self):
        return len(self.encodings['input_ids'])

In [None]:
train_dataset = InputDataset(train_encodings)
valid_dataset = InputDataset(valid_encodings)

In [None]:
train_data_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
valid_data_loader = DataLoader(valid_dataset, batch_size=32)

In [None]:
!git clone 'https://huggingface.co/ai4bharat/indic-bert'

In [None]:
MODEL_PATH = '/kaggle/working/indic-bert'

In [None]:
bert_model = BertModel.from_pretrained(MODEL_PATH)  #MODEL_PATH = "bert-base-uncased"

class QAModel(nn.Module):
    def __init__(self):
        super(QAModel, self).__init__()
        self.bert = bert_model
        self.drop_out = nn.Dropout(0.1)
        self.l1 = nn.Linear(768 * 2, 768 * 2)
        self.l2 = nn.Linear(768 * 2, 2)
        self.linear_relu_stack = nn.Sequential(
            self.drop_out,
            self.l1,
            nn.LeakyReLU(),
            self.l2 
        )
        
    def forward(self, input_ids, attention_mask, token_type_ids):
        model_output = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True)
        hidden_states = model_output[2]
        out = torch.cat((hidden_states[-1], hidden_states[-3]), dim=-1)  # taking Start logits from last BERT layer, End Logits from third to last layer
        logits = self.linear_relu_stack(out)
        
        start_logits, end_logits = logits.split(1, dim=-1)
        
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

        return start_logits, end_logits

In [None]:
model = QAModel()

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

In [None]:
def loss_fn(start_logits, end_logits, start_positions, end_positions):
    loss_fct = nn.CrossEntropyLoss()
    start_loss = loss_fct(start_logits, start_positions)
    end_loss = loss_fct(end_logits, end_positions)
    total_loss = (start_loss + end_loss)/2
    return total_loss

In [None]:
def focal_loss_fn(start_logits, end_logits, start_positions, end_positions, gamma):
    
    #calculate Probabilities by applying Softmax to the Start and End Logits. Then get 1 - probabilities
    smax = nn.Softmax(dim=1)
    probs_start = smax(start_logits)
    inv_probs_start = 1 - probs_start
    probs_end = smax(end_logits)
    inv_probs_end = 1 - probs_end
    
    #get log of probabilities. Note: NLLLoss required log probabilities. This is the Natural Log (Log base e)
    lsmax = nn.LogSoftmax(dim=1)
    log_probs_start = lsmax(start_logits)
    log_probs_end = lsmax(end_logits)
    
    nll = nn.NLLLoss()
    
    fl_start = nll(torch.pow(inv_probs_start, gamma)* log_probs_start, start_positions)
    fl_end = nll(torch.pow(inv_probs_end, gamma)*log_probs_end, end_positions)
    
    #return mean of the Loss for the start and end logits
    return ((fl_start + fl_end)/2)

In [None]:
optim = AdamW(model.parameters(), lr=2e-5, weight_decay=2e-2)
scheduler = ExponentialLR(optim, gamma=0.9)
total_acc = []
total_loss = []


In [None]:
def train_epoch(model, dataloader, epoch):
    model = model.train()
    losses = []
    acc = []
    ctr = 0
    batch_tracker = 0
    for batch in tqdm(dataloader, desc = 'Running Epoch '):
        optim.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        start_positions = batch['start_positions'].to(device)
        end_positions = batch['end_positions'].to(device)
        out_start, out_end = model(input_ids=input_ids, 
                attention_mask=attention_mask,
                token_type_ids=token_type_ids)
        #loss = loss_fn(out_start, out_end, start_positions, end_positions)  # <---BASELINE.  Cross Entropy Loss is returned by Default
        loss = focal_loss_fn(out_start, out_end, start_positions, end_positions,1) #using gamma = 1
        losses.append(loss.item())
        loss.backward()
        optim.step()
        
        start_pred = torch.argmax(out_start, dim=1)
        end_pred = torch.argmax(out_end, dim=1)
            
        acc.append(((start_pred == start_positions).sum()/len(start_pred)).item())
        acc.append(((end_pred == end_positions).sum()/len(end_pred)).item())
        #ctr = ctr +1
        #if ctr==50:
        #    break
        batch_tracker = batch_tracker + 1
        if batch_tracker==250 and epoch==1:
            total_acc.append(sum(acc)/len(acc))
            loss_avg = sum(losses)/len(losses)
            total_loss.append(loss_avg)
            batch_tracker = 0
    scheduler.step()
    ret_acc = sum(acc)/len(acc)
    ret_loss = sum(losses)/len(losses)
    return(ret_acc, ret_loss)

In [None]:
def eval_model(model, dataloader):
    model = model.eval()
    losses = []
    acc = []
    ctr = 0
    with torch.no_grad():
        for batch in tqdm(dataloader, desc = 'Running Evaluation'):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            start_true = batch['start_positions'].to(device)
            end_true = batch['end_positions'].to(device)
            
            out_start, out_end = model(input_ids=input_ids, 
                attention_mask=attention_mask,
                token_type_ids=token_type_ids)
            
            start_pred = torch.argmax(out_start, dim=1)
            end_pred = torch.argmax(out_end, dim=1)
            
            acc.append(((start_pred == start_true).sum()/len(start_pred)).item())
            acc.append(((end_pred == end_true).sum()/len(end_pred)).item())
            #ctr = ctr +1
            #if ctr==50:
            #    break
        ret_acc = sum(acc)/len(acc)
        ret_loss = 0
        #ret_loss = sum(losses)/len(losses)
    return(ret_acc)

In [None]:
EPOCHS = 5

model.to(device)

for epoch in range(EPOCHS):
    train_acc, train_loss = train_epoch(model, train_data_loader, epoch+1)
    print(f"Train Accuracy: {train_acc}      Train Loss: {train_loss}")
    val_acc = eval_model(model, valid_data_loader)
    print(f"Validation Accuracy: {val_acc}")

torch.save(model.state_dict(), "QA_finutunemodel.pt")

In [None]:
def get_answer(question, context):
    inputs = tokenizerFast.encode_plus(question, context, return_tensors='pt').to(device)
    with torch.no_grad():
        output_start, output_end = model(**inputs)
        
        answer_start = torch.argmax(output_start)  
        answer_end = torch.argmax(output_end) 

        answer = tokenizerFast.convert_tokens_to_string(tokenizerFast.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))

        return(answer)

In [None]:
## Finding start and end span indices

