In [None]:
! pip install transformers

In [None]:
import os
import re
import pandas as pd

import numpy as np
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torch.nn.functional as F

from transformers import BertTokenizer, BertModel

from tqdm.auto import tqdm, trange
import matplotlib.pyplot as plt

import logging
logging.basicConfig(level=logging.CRITICAL)

PATH_DATASET = os.path.join('data', 'train_qa.csv')

MAX_TEXT_LEN = 512

EPOCHS_1 = 40
EPOCHS_2 = 5
BATCH_SIZE = 32
LEARNING_RATE_1 = 0.0001
LEARNING_RATE_2 = 0.0001
W_L2_NORM = 0.0

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
dataset = pd.read_csv(PATH_DATASET)
dataset_texts = dataset['paragraph']
dataset_questions = dataset['question']
dataset_answers = dataset['answer']

In [None]:
tokenizer = BertTokenizer.from_pretrained(
    'bert-base-multilingual-cased',
    do_lower_case=False
)

In [None]:
def prepare_sample(text, question, answer):
    answer = answer.lower()
    while (answer[0] == '.'):
        answer = answer[1:]
    while (answer[-1] in ['.', '?']):
        answer = answer[:-1]
        
    if answer not in text.lower():
        return [], []
    
    first = text.lower().find(answer)
    last = first + len(answer)
    
    text_1 = text[:first].strip()
    text_2 = text[first:last].strip()
    text_3 = text[last:].strip()
    text_tokens = tokenizer.tokenize(text_1)
    first = len(text_tokens)
    text_tokens += tokenizer.tokenize(text_2)
    last = len(text_tokens) - 1
    text_tokens += tokenizer.tokenize(text_3)
    
    question_tokens = tokenizer.tokenize(question)
    
    length = np.random.randint(
        min(MAX_TEXT_LEN - len(question_tokens) - 3, 200), 
        MAX_TEXT_LEN - len(question_tokens) - 3
    )
    if len(text_tokens) > length:
        part_length = length // 3
        stride = 3 * part_length
        nrow = np.ceil(len(text_tokens) / part_length) - 2
        indexes = part_length * np.arange(nrow)[:, None] + np.arange(stride)
        indexes = indexes.astype(np.int32)

        max_index = indexes.max()
        diff = max_index + 1 - len(text_tokens)
        text_tokens += diff * [tokenizer.pad_token]

        text_tokens = np.array(text_tokens)[indexes].tolist()
        
        tokens = []
        labels = []
        for i, ts in enumerate(text_tokens):
            while ts[-1] == tokenizer.pad_token:
                ts = ts[:-1]
                
            tokens += [ts]
                
            lfirst = first - i * part_length
            llast = last - i * part_length
            
            labels += [((lfirst if lfirst >= 0 and lfirst < len(ts) else 0,
                         (lfirst >= part_length and lfirst < 2 * part_length) or 
                         (i == 0 and lfirst < part_length) or
                         (i == len(text_tokens) - 1 and lfirst >= 2 * part_length)),
                        (llast if llast >= 0 and llast < len(ts) else 0,
                         (llast >= part_length and llast < 2 * part_length) or 
                         (i == 0 and llast < part_length) or
                         (i == len(text_tokens) - 1 and llast >= 2 * part_length)))]
    else:
        tokens = [text_tokens]
        labels = [((first, 1), (last, 1))]
        
    for i in range(len(tokens)):
        tokens[i] = [tokenizer.cls_token] + \
                    tokens[i] + \
                    [tokenizer.sep_token] + \
                    question_tokens + \
                    [tokenizer.sep_token]

    return tokens, labels

In [None]:
dataset_tokens, dataset_labels = [], []

for text, question, answer in tqdm(zip(dataset_texts, dataset_questions, dataset_answers)):
    tokens, labels = prepare_sample(text, question, answer)
    dataset_tokens += tokens
    dataset_labels += labels 

x_train, x_val, y_train, y_val = train_test_split(dataset_tokens, dataset_labels, test_size=0.1)
train_data = list(zip(x_train, y_train))
val_data = list(zip(x_val, y_val))

In [None]:
def text_collate_fn(texts):
    max_len = max([len(text) for text in texts])
    masks = [[1] * len(text) + [0] * (max_len - len(text)) for text in texts]
    texts = [text + [tokenizer.pad_token] * (max_len - len(text)) for text in texts]
    texts = [tokenizer.convert_tokens_to_ids(text) for text in texts]
    texts = torch.LongTensor(texts)
    masks = torch.LongTensor(masks)

    return texts, masks

def collate_fn(data):
    texts, labels = zip(*data)

    texts, masks = text_collate_fn(texts)
    
    labels_first, labels_last = zip(*labels)
    labels_first_pos, labels_first_valid = zip(*labels_first)
    labels_last_pos, labels_last_valid = zip(*labels_last)
    
    labels_first_pos = torch.LongTensor(labels_first_pos)
    labels_first_mask = torch.LongTensor(labels_first_valid)
    labels_last_pos = torch.LongTensor(labels_last_pos)
    labels_last_mask = torch.LongTensor(labels_last_valid)
    
    return texts, masks, labels_first_pos, labels_first_mask, labels_last_pos, labels_last_mask

train_data_loader = data.DataLoader(
    dataset=train_data,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn
)
val_data_loader = data.DataLoader(
    dataset=val_data,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_fn
)

In [None]:
class TextClassifier(nn.Module):
    def __init__(self):
        super(TextClassifier, self).__init__()
        
        self.bert = BertModel.from_pretrained('bert-base-multilingual-cased')

        for param in self.bert.parameters():
            param.requires_grad = False

        layers = [
            nn.Linear(768, 128),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(128, 32),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(16, 2),
            nn.Sigmoid()
        ]
        self.layers = nn.Sequential(*layers)

    def forward(self, text, mask):
        x = self.bert(text, attention_mask=mask)[0]
        x = self.layers(x)
        return x

In [None]:
class MaskedLoss(nn.Module):
    EPS  = 1e-8

    def __init__(self):
        super(MaskedLoss, self).__init__()

    def masked_softmax(self, vec, mask, dim=1):
        masked_vec = vec * mask.float()
        max_vec = torch.max(masked_vec, dim=dim, keepdim=True)[0]
        exps = torch.exp(masked_vec - max_vec)
        masked_exps = exps * mask.float()
        masked_sums = masked_exps.sum(dim, keepdim=True)
        zeros = (masked_sums == 0)
        masked_sums += zeros.float()
        return masked_exps / masked_sums
        
    def forward(self, output, output_mask, target, target_mask):
        log_0 = -torch.log(1 - output + MaskedLoss.EPS)
        log_0_mask = output_mask * (1 - target_mask)[:, None]
        log_0 = log_0 * log_0_mask
        
        sm = self.masked_softmax(output, output_mask)
        sm = torch.gather(output, 1, target[:, None])
        sm = sm * target_mask
        
        loss = log_0.sum() / log_0_mask.sum() + sm.sum() / target_mask.sum()
        
        return loss

In [None]:
def train(model, criterion, optimizer, epochs):
    losses_train, losses_val = [], []
    for _ in trange(epochs):
        losses = []
        model.train()
        for texts, masks, lfp, lfm, llp, llm in train_data_loader:
            texts = texts.to(device)
            masks = masks.to(device)
            lfp = lfp.to(device)
            lfm = lfm.to(device)
            llp = llp.to(device)
            llm = llm.to(device)
            
            optimizer.zero_grad()

            ps = model(texts, masks)
            loss = criterion(ps[:, :, 0], masks, lfp, lfm) + \
                   criterion(ps[:, :, 1], masks, llp, llm)

            loss.backward()
            optimizer.step()

            losses.append(loss.item())

        losses_train.append(np.array(losses).mean())

        losses = []
        model.eval()
        with torch.no_grad():
            for texts, masks, labels in val_data_loader:
                texts = texts.to(device)
                masks = masks.to(device)
                lfp = lfp.to(device)
                lfm = lfm.to(device)
                llp = llp.to(device)
                llm = llm.to(device)
                
                ps = model(texts, masks)
                loss = criterion(ps[:, :, 0], masks, lfp, lfm) + \
                       criterion(ps[:, :, 1], masks, llp, llm)

                losses.append(loss.item())

        losses_val.append(np.array(losses).mean())

    plt.plot(range(epochs), losses_train, label="train")
    plt.plot(range(epochs), losses_val, label="val")
    plt.xlabel('epoch num')
    plt.ylabel('loss')
    plt.legend()
    plt.show()

In [None]:
model = TextClassifier()
model.to(device)

criterion = MaskedLoss()

optimizer = optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    LEARNING_RATE_1,
    weight_decay=W_L2_NORM
)

train(model, criterion, optimizer, EPOCHS_1)

In [None]:
a = torch.ones([32, 369])
b = torch.zeros([32])[:, None]
c = a * b
c

In [None]:
for param in model.parameters():
    param.requires_grad = True

optimizer = optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    LEARNING_RATE_2,
    weight_decay=W_L2_NORM
)

train(model, criterion, optimizer, EPOCHS_2)

In [None]:
def prepare_test(text):
    parts, parts_poses = [], []
    part, part_start = '', 0
    for i, c in enumerate(text):
        if c.isalpha() or (c.isdigit() and part != '' and part[-1].isdigit()):
            part += c
        else:
            if len(part) > 0:
                parts.append(part)
                parts_poses.append((part_start, i))

            if c.isdigit():
                part, part_start = c, i
            else:
                if c != ' ':
                    parts.append(c)
                    parts_poses.append((i, i + 1))

                part, part_start = '', i + 1

    parts.append(part)
    parts_poses.append((part_start, len(text)))

    tokens, tokens_poses = [], []
    for part, poses in zip(parts, parts_poses):
        new_tokens = tokenizer.tokenize(part)
        tokens_poses += [poses for _ in new_tokens]
        tokens += new_tokens

    tokens = [tokenizer.cls_token] + tokens + [tokenizer.sep_token]
    tokens_poses = [(-1, -1)] + tokens_poses + [(-1, -1)]

    return tokens, tokens_poses