In [None]:
import os

import numpy as np
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torch.nn.functional as F

from transformers import BertTokenizer, BertModel

from tqdm import tqdm
import matplotlib.pyplot as plt

import logging
logging.basicConfig(level=logging.CRITICAL)

PATH_DATASET_TEXTS = os.path.join('data', 'texts_train.txt')
PATH_DATASET_SCORES = os.path.join('data', 'scores_train.txt')

MAX_TEXT_LEN = 200

EPOCHS = 10
BATCH_SIZE = 16
LEARNING_RATE = 0.0001
W_L2_NORM = 0.0

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
with open(PATH_DATASET_TEXTS, 'r') as file:
    texts = file.readlines()
    
texts = [text[:-1] for text in texts]
for c in '.,!?"()':
    texts = [text.replace(c, f' {c} ') for text in texts]
texts = [' '.join(text.split()) for text in texts]
    
    
with open(PATH_DATASET_SCORES, 'r') as file:
    scores = file.readlines()
    
scores = [int(s) - 1 for s in scores]

x_train, x_val, y_train, y_val = train_test_split(
    texts, scores, test_size=0.1)
train_data = list(zip(x_train, y_train))
val_data = list(zip(x_val, y_val))

In [None]:
tokenizer = BertTokenizer.from_pretrained(
    'bert-base-multilingual-cased',
    do_lower_case=False
)

def collate_fn(data):
    texts, scores = zip(*data)
    
    texts = [tokenizer.tokenize(text) for text in texts]
    texts = [text[:MAX_TEXT_LEN - 2] for text in texts]
    texts = [tokenizer.convert_tokens_to_ids(text) for text in texts]
    texts = [[tokenizer.cls_token_id] + text + [tokenizer.sep_token_id] for text in texts]
    max_len = max([len(text) for text in texts])
    masks = [[1] * len(text) + [0] * (max_len - len(text)) for text in texts]
    texts = [text + [tokenizer.pad_token_id] * (max_len - len(text)) for text in texts]
    texts = torch.LongTensor(texts)
    masks = torch.LongTensor(masks)
    
    scores = torch.LongTensor(scores)
    
    return texts, masks, scores

train_data_loader = data.DataLoader(
    dataset=train_data,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn
)
val_data_loader = data.DataLoader(
    dataset=val_data,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_fn
)

In [None]:
class TextClassifier(nn.Module):
    def __init__(self, n_classes):
        super(TextClassifier, self).__init__()

        self.bert = BertModel.from_pretrained('bert-base-multilingual-cased')
        
        for param in self.bert.parameters():
            param.requires_grad = False
        
        layers = [
            nn.Linear(768, n_classes)
        ]
        self.layers = nn.Sequential(*layers)

    def forward(self, text, mask):
        x = self.bert(text, attention_mask=mask)[1]
        x = self.layers(x)
        return x

In [None]:
def train(model, criterion, optimizer, epochs):
    losses_train = []
    losses_val = []
    for _ in tqdm(range(epochs)):
        losses = []
        model.train()
        for texts, masks, scores in train_data_loader:
            texts = texts.to(device)
            masks = masks.to(device)
            scores = scores.to(device)
            
            optimizer.zero_grad()

            ps = model(texts, masks)
            loss = criterion(ps, scores)

            losses.append(loss.item())

            loss.backward()
            optimizer.step()

        losses_train.append(np.array(losses).mean())

        losses = []
        model.eval()
        with torch.no_grad():
            for texts, masks, scores in val_data_loader:
                texts = texts.to(device)
                masks = masks.to(device)
                scores = scores.to(device)
                
                ps = model(texts, masks)
                loss = criterion(ps, scores)

                losses.append(loss.item())

        losses_val.append(np.array(losses).mean())

    plt.plot(range(epochs), losses_train, label="train")
    plt.plot(range(epochs), losses_val, label="val")
    plt.xlabel('epoch num')
    plt.ylabel('loss')
    plt.legend()
    plt.show()

In [None]:
model = TextClassifier(n_classes=10)
model.to(device)

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    LEARNING_RATE,
    weight_decay=W_L2_NORM
)

train(model, criterion, optimizer, EPOCHS)