<a href="https://colab.research.google.com/github/GrigoryBartosh/hse07_nlp/blob/master/4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install transformers

In [None]:
import os
import re
import pandas as pd
import datetime

import numpy as np
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

from transformers import BertTokenizer, BertModel

from tqdm.auto import tqdm, trange
import matplotlib.pyplot as plt

import logging
logging.basicConfig(level=logging.CRITICAL)

PATH_LOGS = os.path.join('data', 'logs_tf')

PATH_DATASET = os.path.join('data', 'train_qa.csv')

PATH_DATASET_TEST = os.path.join('data', 'test.txt')
PATH_RESULTS = os.path.join('data', 'results.txt')

MAX_TEXT_LEN = 256

EPOCHS_1 = 1
EPOCHS_2 = 3
BATCH_SIZE = 16
LEARNING_RATE_1 = 0.0001
LEARNING_RATE_2 = 0.0001
W_L2_NORM = 0.0

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
dataset = pd.read_csv(PATH_DATASET)
dataset_texts = dataset['paragraph']
dataset_questions = dataset['question']
dataset_answers = dataset['answer']

In [None]:
tokenizer = BertTokenizer.from_pretrained(
    'bert-base-multilingual-cased',
    do_lower_case=False
)

In [None]:
def prepare_sample(text, question, answer):
    answer = answer.lower()
    while (answer[0] == '.'):
        answer = answer[1:]
    while (answer[-1] in ['.', '?']):
        answer = answer[:-1]
        
    if answer not in text.lower():
        return [], []
    
    first = text.lower().find(answer)
    last = first + len(answer)
    
    text_1 = text[:first].strip()
    text_2 = text[first:last].strip()
    text_3 = text[last:].strip()
    text_tokens = tokenizer.tokenize(text_1)
    first = len(text_tokens)
    text_tokens += tokenizer.tokenize(text_2)
    last = len(text_tokens) - 1
    text_tokens += tokenizer.tokenize(text_3)
    
    question_tokens = tokenizer.tokenize(question)
    
    length = MAX_TEXT_LEN - len(question_tokens) - 3
    if len(text_tokens) > length:
        part_length = length // 3
        stride = 3 * part_length
        nrow = np.ceil(len(text_tokens) / part_length) - 2
        indexes = part_length * np.arange(nrow)[:, None] + np.arange(stride)
        indexes = indexes.astype(np.int32)

        max_index = indexes.max()
        diff = max_index + 1 - len(text_tokens)
        text_tokens += diff * [tokenizer.pad_token]

        text_tokens = np.array(text_tokens)[indexes].tolist()
        
        tokens = []
        labels = []
        for i, ts in enumerate(text_tokens):
            while ts[-1] == tokenizer.pad_token:
                ts = ts[:-1]
                
            tokens += [ts]
                
            lfirst = first - i * part_length
            llast = last - i * part_length
            
            mask = lfirst >= 0 and lfirst < len(ts) and llast >= 0 and llast < len(ts)
            labels += [((lfirst if mask else 0, mask), (llast if mask else 0, mask))]
    else:
        tokens = [text_tokens]
        labels = [((first, 1), (last, 1))]
        
    for i in range(len(tokens)):
        tokens[i] = [tokenizer.cls_token] + \
                    question_tokens + \
                    [tokenizer.sep_token] + \
                    tokens[i] + \
                    [tokenizer.sep_token]
        labels[i] = ((labels[i][0][0] + 2 + len(question_tokens), labels[i][0][1]),
                     (labels[i][1][0] + 2 + len(question_tokens), labels[i][1][1]))

    return tokens, labels

In [None]:
dataset_tokens, dataset_labels = [], []

for text, question, answer in tqdm(list(zip(dataset_texts, dataset_questions, dataset_answers))):
    tokens, labels = prepare_sample(text, question, answer)
    dataset_tokens += tokens
    dataset_labels += labels

x_train, x_val, y_train, y_val = train_test_split(dataset_tokens, dataset_labels, test_size=0.1)
train_data = list(zip(x_train, y_train))
val_data = list(zip(x_val, y_val))

In [None]:
def text_collate_fn(texts):
    max_len = max([len(text) for text in texts])
    masks = [[1] * len(text) + [0] * (max_len - len(text)) for text in texts]
    texts = [text + [tokenizer.pad_token] * (max_len - len(text)) for text in texts]
    texts = [tokenizer.convert_tokens_to_ids(text) for text in texts]
    texts = torch.LongTensor(texts)
    masks = torch.LongTensor(masks)

    return texts, masks

def collate_fn(data):
    texts, labels = zip(*data)

    texts, masks = text_collate_fn(texts)
    
    labels_first, labels_last = zip(*labels)
    labels_first_pos, labels_first_valid = zip(*labels_first)
    labels_last_pos, labels_last_valid = zip(*labels_last)
    
    labels_first_pos = torch.LongTensor(labels_first_pos)
    labels_first_mask = torch.LongTensor(labels_first_valid)
    labels_last_pos = torch.LongTensor(labels_last_pos)
    labels_last_mask = torch.LongTensor(labels_last_valid)
    
    return texts, masks, labels_first_pos, labels_first_mask, labels_last_pos, labels_last_mask

def infinit_data_loader(data_loader):
    while True:
        for x in data_loader:
            yield x

train_data_loader = data.DataLoader(
    dataset=train_data,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=8,
    collate_fn=collate_fn
)
val_data_loader = data.DataLoader(
    dataset=val_data,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=8,
    collate_fn=collate_fn
)

val_data_loader = infinit_data_loader(val_data_loader)

In [None]:
class TextClassifier(nn.Module):
    def __init__(self):
        super(TextClassifier, self).__init__()
        
        self.bert = BertModel.from_pretrained('bert-base-multilingual-cased')

        for param in self.bert.parameters():
            param.requires_grad = False

        layers = [
            nn.Linear(768, 128),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(128, 32),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(16, 2)
        ]
        self.layers = nn.Sequential(*layers)
        
    def masked_softmax(self, vec, mask, dim=1):
        masked_vec = vec * mask.float()
        max_vec = torch.max(masked_vec, dim=dim, keepdim=True)[0]
        exps = torch.exp(masked_vec - max_vec)
        masked_exps = exps * mask.float()
        masked_sums = masked_exps.sum(dim, keepdim=True)
        zeros = (masked_sums == 0)
        masked_sums += zeros.float()
        return masked_exps / masked_sums

    def forward(self, text, mask):
        x = self.bert(text, attention_mask=mask)[0]
        x = self.layers(x)
        x = self.masked_softmax(x, mask[:, :, None])
        return x

In [None]:
class MaskedLoss(nn.Module):
    EPS  = 1e-8

    def __init__(self):
        super(MaskedLoss, self).__init__()
        
    def forward(self, output, output_mask, target, target_mask):
        sm_0 = -torch.log(1 - output + MaskedLoss.EPS)
        sm_0_mask = output_mask * target_mask[:, None]
        sm_0 = sm_0 * sm_0_mask
        
        sm_1 = torch.gather(output, 1, target[:, None]).squeeze()
        sm_1 = -torch.log(sm_1 + MaskedLoss.EPS)
        sm_1 = sm_1 * target_mask
        
        loss = sm_0.sum() / sm_0_mask.sum() + sm_1.sum() / target_mask.sum()
        
        return loss

In [None]:
def make_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

def get_summary_writer():
    name = str(datetime.datetime.now())[:19]
    make_dir(PATH_LOGS)
    logs_path = os.path.join(PATH_LOGS, name)
    return SummaryWriter(logs_path)

In [None]:
def train(model, criterion, optimizer, scheduler, epochs):
    summary_writer = get_summary_writer()
    step = 0
    last_val = 0
    
    for epoch in trange(epochs):
        model.train()
        for texts, masks, lfp, lfm, llp, llm in train_data_loader:
            texts = texts.to(device)
            masks = masks.to(device)
            lfp = lfp.to(device)
            lfm = lfm.to(device)
            llp = llp.to(device)
            llm = llm.to(device)
            
            optimizer.zero_grad()

            ps = model(texts, masks)
            loss = criterion(ps[:, :, 0], masks, lfp, lfm) + \
                   criterion(ps[:, :, 1], masks, llp, llm)

            loss.backward()
            optimizer.step()

            step += len(texts)
            last_val += len(texts)
            summary_writer.add_scalar('Train/loss', loss.item(), step)
            if last_val >= 10 * BATCH_SIZE:
                model.eval()
                with torch.no_grad():
                    texts, masks, lfp, lfm, llp, llm = next(val_data_loader)
                    texts = texts.to(device)
                    masks = masks.to(device)
                    lfp = lfp.to(device)
                    lfm = lfm.to(device)
                    llp = llp.to(device)
                    llm = llm.to(device)

                    ps = model(texts, masks)
                    loss = criterion(ps[:, :, 0], masks, lfp, lfm) + \
                           criterion(ps[:, :, 1], masks, llp, llm)

                    summary_writer.add_scalar('Validation/loss', loss.item(), step)
                    
                model.train()
                last_val = 0
                
        if scheduler:
            scheduler.step()

In [None]:
model = TextClassifier()
model.to(device)

criterion = MaskedLoss()

optimizer = optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    LEARNING_RATE_1,
    weight_decay=W_L2_NORM
)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.3)

train(model, criterion, optimizer, scheduler, EPOCHS_1)

In [None]:
#train(model, criterion, optimizer, 4)

In [None]:
for param in model.parameters():
    param.requires_grad = True

optimizer = optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    LEARNING_RATE_2,
    weight_decay=W_L2_NORM
)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.3)

train(model, criterion, optimizer, scheduler, EPOCHS_2)

In [None]:
#optimizer = optim.Adam(
#    filter(lambda p: p.requires_grad, model.parameters()),
#    0.00001,
#    weight_decay=W_L2_NORM
#)

#train(model, criterion, optimizer, scheduler, 1)

In [None]:
def get_best(ps):
    n = len(ps)
    first, last, mx = 0, 0, 0
    for i in range(n):
        for j in range(i, n):
            lmx = ps[i, 0] * ps[j, 1]
            if mx < lmx:
                mx, first,last = lmx, i, j
                
    return first, last, mx

with open(PATH_DATASET_TEST, 'r') as file:
    dataset_test = file.readlines()[1:]
    
res = []
model.eval()
with torch.no_grad():
    for sample in tqdm(dataset_test):
        _, question_id, paragraph, question = sample.split('\t')

        question_tokens = tokenizer.tokenize(question)
        text_tokens = tokenizer.tokenize(paragraph)
        
        all_tokens = [tokenizer.cls_token] + \
                     question_tokens + \
                     [tokenizer.sep_token] + \
                     text_tokens + \
                     [tokenizer.sep_token]

        length = MAX_TEXT_LEN - len(question_tokens) - 3
        if (len(text_tokens) > length):
            part_length = length // 3
            stride = 3 * part_length
            nrow = np.ceil(len(text_tokens) / part_length) - 2
            indexes = part_length * np.arange(nrow)[:, None] + np.arange(stride)
            indexes = indexes.astype(np.int32)

            max_index = indexes.max()
            diff = max_index + 1 - len(text_tokens)
            text_tokens += diff * [tokenizer.pad_token]

            text_tokens = np.array(text_tokens)[indexes].tolist()

            first, last, mx = 0, 0, 0
            for i, ts in enumerate(text_tokens):
                while ts[-1] == tokenizer.pad_token:
                    ts = ts[:-1]

                ts = [tokenizer.cls_token] + \
                     question_tokens + \
                     [tokenizer.sep_token] + \
                     ts + \
                     [tokenizer.sep_token]

                texts, masks = text_collate_fn([ts])
                texts = texts.to(device)
                masks = masks.to(device)

                lps = model(texts, masks)[0]
                lps = lps[2 + len(question_tokens):]
                lps[:, 0] = F.softmax(lps[:, 0], 0)
                lps[:, 1] = F.softmax(lps[:, 1], 0)
                lps = lps.cpu().numpy()
                
                lfirst, llast, lmx = get_best(lps)
                if mx < lmx:
                    mx = lmx
                    first = lfirst + i * part_length
                    last = llast + i * part_length
                
            first += 2 + len(question_tokens)
            last += 2 + len(question_tokens)
                    
        else:
            texts, masks = text_collate_fn([all_tokens])
            texts = texts.to(device)
            masks = masks.to(device)

            ps = model(texts, masks)[0]
            ps[:, 0] = F.softmax(ps[:, 0], 0)
            ps[:, 1] = F.softmax(ps[:, 1], 0)
            ps = ps.cpu().numpy()
            
            first, last, _ = get_best(ps)
            
        s = ''
        tokens = all_tokens[first:last + 1]
        for token in tokens:
            if token == tokenizer.unk_token:
                continue
            
            if token[0] == '#':
                s += token.replace('#', '')
            else:
                s += ' ' + token
            
        res += [question_id + '\t' + s.strip()]
        
res = '\n'.join(res)
with open(PATH_RESULTS, 'w') as file:
    file.write(res)