In [None]:
import numpy as np
import pandas as pd
import os
import torch
import time, sys

In [None]:
from transformers import RobertaTokenizerFast,RobertaForQuestionAnswering,XLMRobertaForQuestionAnswering,XLMRobertaTokenizerFast,BertTokenizerFast,BertForQuestionAnswering

tokenizer = XLMRobertaTokenizerFast.from_pretrained('../input/huggingface-question-answering-models/multilingual/xlm-roberta-large-squad2')
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

train_data = pd.read_csv("../input/chaii-hindi-and-tamil-question-answering/train.csv")
dev_data = pd.read_csv("../input/chaii-hindi-and-tamil-question-answering/test.csv")
# size = 100 # size if train set used

def preprocess_data(data):
    encodings = tokenizer(list(data["context"]),list(data["question"]), truncation=True, padding=True)
    
    start_positions = []
    end_positions = []
    for i in range(len(data["answer_start"])):
        start_positions.append(encodings.char_to_token(i,data["answer_start"][i]))
        end_positions.append(encodings.char_to_token( i, (data["answer_start"][i] + len(data['answer_text'][i])-1) ))
        
            
        # if start position is None, the answer passage has been truncated
        if start_positions[-1] is None:
            start_positions[-1] = tokenizer.model_max_length
        if end_positions[-1] is None:
            end_positions[-1] = tokenizer.model_max_length
        
        
    encodings.update({'start_positions': start_positions, 'end_positions': end_positions})
    
    return encodings

len(train_data)

In [None]:
train_encodings = preprocess_data(train_data)
dev_encodings = tokenizer(list(dev_data["context"]), list(dev_data["question"]), truncation=True, padding=True)

In [None]:
from torch.utils.data import DataLoader
from transformers import AdamW

class chaiDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)
    
train_dataset = chaiDataset(train_encodings)
dev_dataset = chaiDataset(dev_encodings)

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
dev_loader = DataLoader(dev_dataset, batch_size=1, shuffle=False)

In [None]:
model = XLMRobertaForQuestionAnswering.from_pretrained('../input/huggingface-question-answering-models/multilingual/xlm-roberta-large-squad2')

for param in model.roberta.parameters():
    param.requires_grad = False

model.to(device)

In [None]:
optim = AdamW(model.parameters(), lr=1e-4)

model.train()

for epoch in range(40):
    epoch_loss=0
    for batch in train_loader:
        optim.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_positions = batch['start_positions'].to(device)
        end_positions = batch['end_positions'].to(device)
        outputs = model(input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions)
        loss = outputs[0]
        loss.backward()
        batch_loss=loss.item()
        optim.step()
        epoch_loss+=batch_loss
    normalized_epoch_loss = epoch_loss/(len(train_loader))
    print("Epoch {} ; Epoch loss: {} ".format(epoch+1,normalized_epoch_loss))

In [None]:
torch.save(model,"model_2")
# model = torch.load("../input/chaii-torch/model_2")

In [None]:
model.eval()
output_words,output_id = [],[]
for batch in dev_loader:
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    outputs = model(input_ids, attention_mask=attention_mask)
    start = torch.argmax(outputs["start_logits"])
    end = torch.argmax(outputs["end_logits"])
    output_tokens = tokenizer.convert_ids_to_tokens(input_ids[0][start:end+1])
    output_words.append(tokenizer.convert_tokens_to_string(output_tokens))

In [None]:
dev_results = pd.DataFrame({"id":dev_data["id"],"PredictionString":output_words})
dev_results.to_csv("submission.csv",index=False)