In [1]:
import ast
import numpy as np
import pandas as pd
import gensim
import scipy
import torch
import torch.nn.functional as F

from pytorch_transformers import BertTokenizer, BertModel, BertForMaskedLM, AdamW
from torch.utils.data import Dataset, DataLoader

np.random.seed = 1

paramiko missing, opening SSH/SCP/SFTP paths will be disabled.  `pip install paramiko` to suppress


In [2]:
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased', do_lower_case=False)

In [3]:
class DefinitionsDataset(Dataset): # train_maxlen = 103; set_to_128
    
    
    def __init__(self, path_to_csv, tokenizer, max_len=128):
        self.data = pd.read_csv(path_to_csv)
        self.data['input'] = self.data['definition'].apply(lambda x: torch.tensor(tokenizer.convert_tokens_to_ids(tokenizer.tokenize('[CLS] [MASK] - ' + x + ' [SEP]'))))
        self.data['embedding'] = self.data['embedding'].apply(lambda x: ast.literal_eval(x))
        self.data['label'] = self.data['embedding'].apply(lambda x: torch.tensor(x))
        self.data['input'] = self.data['input'].apply(lambda x: F.pad(input=x, pad=(0, max_len-(x.shape[0])), mode='constant', value=0))
        self.data['attention'] = self.data['input'].apply(lambda x: torch.tensor([float(el>0) for el in x]))
    
    def __len__(self):
        return self.data.shape[0]
    
    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()
        return {'input': self.data.iloc[index]['input'], 'attention_mask':self.data.iloc[index]['attention'], 'label': self.data.iloc[index]['label']}

In [4]:
train = DefinitionsDataset('rus/bert/bert_train.csv', tokenizer)
valid = DefinitionsDataset('rus/bert/bert_valid.csv', tokenizer)

In [5]:
class BertToW2v(torch.nn.Module):
    def __init__(self, bert_model_name, lin_shape_in, lin_shape_out, emb_layer): # -, 768, 100, 6
        super(BertToW2v, self).__init__()
        self.emb_layer = emb_layer
        self.bert_model = BertModel.from_pretrained(bert_model_name)
        #self.bert_model.eval()
        self.linear_model = torch.nn.Linear(lin_shape_in, lin_shape_out, bias=True) # bias?
        torch.nn.init.uniform_(self.linear_model.weight, -0.1, 0.1)
        
    def forward(self, input_sentence, mask): # ожидаем уже токенизированное предложение
        encoded_layers, _ = self.bert_model(input_sentence, attention_mask=mask)
        bert_output = encoded_layers[:,self.emb_layer]
        linear_output = self.linear_model(bert_output)
        return linear_output

In [6]:
bw2v = BertToW2v('bert-base-multilingual-cased', lin_shape_in=768, lin_shape_out=500, emb_layer=6) # !
bw2v.to('cuda');

In [7]:
batch_size = 32

train_dl = DataLoader(train, batch_size=batch_size, shuffle=True)
valid_dl = DataLoader(valid, batch_size=batch_size)

In [8]:
optimizer = AdamW(bw2v.parameters())
loss_function = torch.nn.MSELoss()

In [9]:
%%time

max_epochs = 20

for epoch in range(max_epochs):
    bw2v.train()
    train_loss = 0.0
    for dct in train_dl:
        inputs = dct['input']
        masks = dct['attention_mask']
        labels = dct['label']
        inputs = inputs.to('cuda')
        masks = masks.to('cuda')
        labels = labels.to('cuda')
        optimizer.zero_grad()
        
        output = bw2v(inputs, masks)
        loss = loss_function(output, labels)
        
        train_loss += loss.item() * inputs.size(0)
        
        loss.backward()
        optimizer.step()
        
    print('TRAINING_LOSS: ', end='')
    print(train_loss / len(train_dl.dataset), end='')
    
    bw2v.eval()
    valid_loss = 0.0
    for dct in valid_dl:
        inputs = dct['input']
        masks = dct['attention_mask']
        labels = dct['label']
        inputs = inputs.to('cuda')
        masks = masks.to('cuda')
        labels = labels.to('cuda')
        optimizer.zero_grad()
        
        with torch.no_grad():
            output = bw2v(inputs, masks)
            loss = loss_function(output, labels)
        
        valid_loss += loss.item() * inputs.size(0)
    
    print('; VALIDATION_LOSS: ', end='')
    print(valid_loss / len(valid_dl.dataset))

TRAINING_LOSS: 0.04628171315083827; VALIDATION_LOSS: 0.031234915636878764
TRAINING_LOSS: 0.03202891845150297; VALIDATION_LOSS: 0.03099616715726354
TRAINING_LOSS: 0.03131647099156627; VALIDATION_LOSS: 0.03079848895122865
TRAINING_LOSS: 0.030943178694348088; VALIDATION_LOSS: 0.030661838405825106
TRAINING_LOSS: 0.0305505803752127; VALIDATION_LOSS: 0.030604709666833565
TRAINING_LOSS: 0.030430804777980148; VALIDATION_LOSS: 0.030553429858800867
TRAINING_LOSS: 0.030306716916275945; VALIDATION_LOSS: 0.030436578645838975
TRAINING_LOSS: 0.030301890306310718; VALIDATION_LOSS: 0.030421839187336708
TRAINING_LOSS: 0.030186987705766066; VALIDATION_LOSS: 0.030391465930432234
TRAINING_LOSS: 0.030083640303268454; VALIDATION_LOSS: 0.030345965971420672
TRAINING_LOSS: 0.030052811139264318; VALIDATION_LOSS: 0.03032702065695444
TRAINING_LOSS: 0.029998990668985707; VALIDATION_LOSS: 0.030289445490622974
TRAINING_LOSS: 0.029972133539949326; VALIDATION_LOSS: 0.0302778928056488
TRAINING_LOSS: 0.02994098867539221;

In [10]:
#torch.save(bw2v.state_dict(), f'models/batchify_ep20_l6.mdl')