<a href="https://colab.research.google.com/github/GrigoryBartosh/hse07_nlp/blob/master/2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
! pip install transformers

In [0]:
import os

import numpy as np
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torch.nn.functional as F

from transformers import BertTokenizer, BertModel

from tqdm import tqdm
import matplotlib.pyplot as plt

import logging
logging.basicConfig(level=logging.CRITICAL)

PATH_DATASET_TEXTS = os.path.join('data', 'texts_train.txt')
PATH_DATASET_SCORES = os.path.join('data', 'scores_train.txt')
PATH_DATASET_TEST = os.path.join('data', 'test.txt')
PATH_RESULTS = os.path.join('data', 'results.txt')

MAX_TEXT_LEN = 256

EPOCHS_1 = 10
EPOCHS_2 = 50
BATCH_SIZE = 32
LEARNING_RATE_1 = 0.0001
LEARNING_RATE_2 = 0.00003
W_L2_NORM = 0.0

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [0]:
with open(PATH_DATASET_TEXTS, 'r') as file:
    dataset_texts = file.readlines()
    
with open(PATH_DATASET_SCORES, 'r') as file:
    dataset_scores = file.readlines()

x_train, x_val, y_train, y_val = train_test_split(
    dataset_texts, dataset_scores, test_size=0.1)
train_data = list(zip(x_train, y_train))
val_data = list(zip(x_val, y_val))

In [0]:
tokenizer = BertTokenizer.from_pretrained(
    'bert-base-multilingual-cased',
    do_lower_case=False
)

def text_collate_fn(texts):
    texts = [text[:-1] for text in texts]
    for c in '.,!?"()':
        texts = [text.replace(c, f' {c} ') for text in texts]
    texts = [' '.join(text.split()) for text in texts]

    texts = [tokenizer.tokenize(text) for text in texts]
    texts = [text[:MAX_TEXT_LEN - 2] for text in texts]
    texts = [tokenizer.convert_tokens_to_ids(text) for text in texts]
    texts = [[tokenizer.cls_token_id] + text + [tokenizer.sep_token_id] for text in texts]
    max_len = max([len(text) for text in texts])
    masks = [[1] * len(text) + [0] * (max_len - len(text)) for text in texts]
    texts = [text + [tokenizer.pad_token_id] * (max_len - len(text)) for text in texts]
    texts = torch.LongTensor(texts)
    masks = torch.LongTensor(masks)

    return texts, masks

def collate_fn(data):
    texts, scores = zip(*data)

    texts, masks = text_collate_fn(texts)

    scores = [int(s) - 1 for s in scores]
    scores = torch.tensor(scores, dtype=torch.float32)
    
    return texts, masks, scores

train_data_loader = data.DataLoader(
    dataset=train_data,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn
)
val_data_loader = data.DataLoader(
    dataset=val_data,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_fn
)

In [0]:
class TextClassifier(nn.Module):
    def __init__(self):
        super(TextClassifier, self).__init__()
        
        self.bert = BertModel.from_pretrained('bert-base-multilingual-cased')

        for param in self.bert.parameters():
            param.requires_grad = False

        layers = [
            nn.Linear(768, 128),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(128, 32),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(16, 1),
            nn.Sigmoid()
        ]
        self.layers = nn.Sequential(*layers)

    def forward(self, text, mask):
        x = self.bert(text, attention_mask=mask)[1]
        x = self.layers(x) * 9
        x = x.reshape(-1)
        return x

In [0]:
class RMSELoss(nn.Module):
    def __init__(self):
        super(RMSELoss, self).__init__()

    def forward(self, x, y):
        x = torch.round(x)
        loss = torch.sqrt(((x - y) ** 2).mean())
        return loss

In [0]:
class MyLoss(nn.Module):
    def __init__(self):
        super(MyLoss, self).__init__()

    def forward(self, x, y):
        k = 8
        s1 = (x - y) ** 2
        s2 = (x - y) ** k
        s = torch.min(s1, s2 / 0.5 ** (k - 2))
        loss = torch.sqrt(s.mean())
        return loss

In [0]:
def train(model, criterion, metric, optimizer, epochs):
    metrics_train = []
    metrics_val = []
    for _ in tqdm(range(epochs)):
        metrics = []
        model.train()
        for texts, masks, scores in train_data_loader:
            texts = texts.to(device)
            masks = masks.to(device)
            scores = scores.to(device)
            
            optimizer.zero_grad()

            ps = model(texts, masks)
            loss = criterion(ps, scores)

            metrics.append(metric(ps, scores).item())

            loss.backward()
            optimizer.step()

        metrics_train.append(np.array(metrics).mean())

        metrics = []
        model.eval()
        with torch.no_grad():
            for texts, masks, scores in val_data_loader:
                texts = texts.to(device)
                masks = masks.to(device)
                scores = scores.to(device)
                
                ps = model(texts, masks)
                loss = criterion(ps, scores)

                metrics.append(metric(ps, scores).item())

        metrics_val.append(np.array(metrics).mean())

    plt.plot(range(epochs), metrics_train, label="train")
    plt.plot(range(epochs), metrics_val, label="val")
    plt.xlabel('epoch num')
    plt.ylabel('metric')
    plt.legend()
    plt.show()

In [0]:
model = TextClassifier()
model.to(device)

criterion = MyLoss()
metric = RMSELoss()

optimizer = optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    LEARNING_RATE_1,
    weight_decay=W_L2_NORM
)

train(model, criterion, metric, optimizer, EPOCHS_1)

In [0]:
for param in model.parameters():
    param.requires_grad = True

optimizer = optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    LEARNING_RATE_2,
    weight_decay=W_L2_NORM
)

train(model, criterion, metric, optimizer, EPOCHS_2)

In [0]:
with open(PATH_DATASET_TEST, 'r') as file:
    dataset = file.readlines()

res = []
model.eval()
with torch.no_grad():
    while len(dataset) > 0:
        texts, dataset = dataset[:BATCH_SIZE], dataset[BATCH_SIZE:]
        texts, masks = text_collate_fn(texts)
        texts = texts.to(device)
        masks = texts.to(device)

        scores = model(texts, masks)
        scores = scores.cpu().numpy()
        scores = np.around(scores).astype(np.int32) + 1
        res += scores.tolist()

res = '\n'.join([str(x) for x in res])
with open(PATH_RESULTS, 'w') as file:
    file.write(res)