In [1]:
import ast
import numpy as np
import pandas as pd
import gensim
import scipy
import torch
import torch.nn.functional as F

from pytorch_transformers import BertTokenizer, BertModel, BertForMaskedLM, AdamW
from torch.utils.data import Dataset, DataLoader

np.random.seed = 1

paramiko missing, opening SSH/SCP/SFTP paths will be disabled.  `pip install paramiko` to suppress


In [2]:
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased', do_lower_case=False)

In [3]:
class DefinitionsDataset(Dataset): # train_maxlen = 103; set_to_128
    
    
    def __init__(self, path_to_csv, tokenizer, max_len=128):
        self.data = pd.read_csv(path_to_csv)
        self.data['input'] = self.data['definition'].apply(lambda x: torch.tensor(tokenizer.convert_tokens_to_ids(tokenizer.tokenize('[CLS] [MASK] - ' + x + ' [SEP]'))))
        self.data['embedding'] = self.data['embedding'].apply(lambda x: ast.literal_eval(x))
        self.data['label'] = self.data['embedding'].apply(lambda x: torch.tensor(x))
        self.data['input'] = self.data['input'].apply(lambda x: F.pad(input=x, pad=(0, max_len-(x.shape[0])), mode='constant', value=0))
        self.data['attention'] = self.data['input'].apply(lambda x: torch.tensor([float(el>0) for el in x]))
    
    def __len__(self):
        return self.data.shape[0]
    
    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()
        return {'input': self.data.iloc[index]['input'], 'attention_mask':self.data.iloc[index]['attention'], 'label': self.data.iloc[index]['label']}

In [4]:
train = DefinitionsDataset('rus/bert/bert_train.csv', tokenizer)
valid = DefinitionsDataset('rus/bert/bert_valid.csv', tokenizer)

In [5]:
class BertToW2v(torch.nn.Module):
    def __init__(self, bert_model_name, lin_shape_in, lin_shape_out, emb_layer): # -, 768, 100, 6
        super(BertToW2v, self).__init__()
        self.emb_layer = emb_layer
        self.bert_model = BertModel.from_pretrained(bert_model_name, output_hidden_states=True)
        #self.bert_model.eval()
        self.linear_model = torch.nn.Linear(lin_shape_in, lin_shape_out, bias=True) # bias?
        torch.nn.init.uniform_(self.linear_model.weight, -0.1, 0.1)
        
    def forward(self, input_sentence, mask): # ожидаем уже токенизированное предложение
        _, _, encoded_layers = self.bert_model(input_sentence, attention_mask=mask)
        bert_output = encoded_layers[self.emb_layer][:,1]
        linear_output = self.linear_model(bert_output)
        return linear_output

In [6]:
bw2v = BertToW2v('bert-base-multilingual-cased', lin_shape_in=768, lin_shape_out=500, emb_layer=6) # !
bw2v.to('cuda');

In [7]:
batch_size = 32

train_dl = DataLoader(train, batch_size=batch_size, shuffle=True)
valid_dl = DataLoader(valid, batch_size=batch_size)

In [8]:
optimizer = AdamW(bw2v.parameters())
loss_function = torch.nn.MSELoss()

In [9]:
# tmp = next(iter(train_dl))

# bw2v(tmp['input'].to('cuda'), tmp['attention_mask'].to('cuda'))

# tmp = bw2v.bert_model(tmp['input'].to('cuda'), attention_mask=tmp['attention_mask'].to('cuda'))

# tmp[2][0]

In [10]:
%%time

max_epochs = 20

for epoch in range(max_epochs):
    bw2v.train()
    train_loss = 0.0
    for dct in train_dl:
        inputs = dct['input']
        masks = dct['attention_mask']
        labels = dct['label']
        inputs = inputs.to('cuda')
        masks = masks.to('cuda')
        labels = labels.to('cuda')
        optimizer.zero_grad()
        
        output = bw2v(inputs, masks)
        loss = loss_function(output, labels)
        
        train_loss += loss.item() * inputs.size(0)
        
        loss.backward()
        optimizer.step()
        
    print('TRAINING_LOSS: ', end='')
    print(train_loss / len(train_dl.dataset), end='')
    
    bw2v.eval()
    valid_loss = 0.0
    for dct in valid_dl:
        inputs = dct['input']
        masks = dct['attention_mask']
        labels = dct['label']
        inputs = inputs.to('cuda')
        masks = masks.to('cuda')
        labels = labels.to('cuda')
        optimizer.zero_grad()
        
        with torch.no_grad():
            output = bw2v(inputs, masks)
            loss = loss_function(output, labels)
        
        valid_loss += loss.item() * inputs.size(0)
    
    print('; VALIDATION_LOSS: ', end='')
    print(valid_loss / len(valid_dl.dataset))

TRAINING_LOSS: 0.07069383466397244; VALIDATION_LOSS: 0.03193657151029312
TRAINING_LOSS: 0.03252904221895009; VALIDATION_LOSS: 0.030987392811809178
TRAINING_LOSS: 0.030968375554902224; VALIDATION_LOSS: 0.030049145453334473
TRAINING_LOSS: 0.03013421285098211; VALIDATION_LOSS: 0.029863462003272844
TRAINING_LOSS: 0.029482792943053133; VALIDATION_LOSS: 0.02943343158959313
TRAINING_LOSS: 0.028900750073117657; VALIDATION_LOSS: 0.029113666397327997
TRAINING_LOSS: 0.028530679349105206; VALIDATION_LOSS: 0.028840845193398638
TRAINING_LOSS: 0.028081303151442454; VALIDATION_LOSS: 0.0288057127237217
TRAINING_LOSS: 0.027679554053111913; VALIDATION_LOSS: 0.028723401463042062
TRAINING_LOSS: 0.02731770935737183; VALIDATION_LOSS: 0.028423616533941348
TRAINING_LOSS: 0.026873421274111898; VALIDATION_LOSS: 0.028780802199697454
TRAINING_LOSS: 0.026547554817226482; VALIDATION_LOSS: 0.028395618969938093
TRAINING_LOSS: 0.02622205555948251; VALIDATION_LOSS: 0.028608240293329243


KeyboardInterrupt: 

In [11]:
torch.save(bw2v.state_dict(), f'models/batchify_ep13_l6.mdl')