# todo
- Проверить на задаче классификации отзывов что получаемые эмбединги текстов - норм

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch

# Masking language modelling MLM

In [3]:
%%writefile masking_language_modelling/models.py

import torch

from multihead_attention import MultiHeadAttention

class BaseEncoderModel(torch.nn.Module):
    def __init__(self, vocab_size, n_heads, emb_size, vdim=None, kdim=None, padding_idx=None):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Embedding(vocab_size, emb_size, padding_idx=padding_idx),
            MultiHeadAttention(n_heads=n_heads, emb_size=emb_size),
            MultiHeadAttention(n_heads=n_heads, emb_size=emb_size),
        )
    def forward(self, X: torch.Tensor):
        return self.layers(X)


class MLMHead(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        out_features = self.model.layers[-1].ffn.inplace.out_features
        vocab_size = self.model.layers[0].num_embeddings
        self.mlm_layer = torch.nn.Linear(out_features, vocab_size, bias=False)
        
    def forward(self, X: torch.Tensor): #, masked_tokens: torch.Tensor): 
        state = self.model(X)
        logits = self.mlm_layer(state)
        # you can't truncate only masked tokens here because the shape of the bantch will be broken.
        #         masked_tokens_logits = logits[:, masked_tokens, :]
#         result = torch.softmax(logits, axis=-1) # will add this in loss
        return logits

Overwriting masking_language_modelling/models.py


In [4]:
batch_size, seq_len, emb_size = 2, 3, 7
vocab_size = 3
torch.nn.Embedding(vocab_size, emb_size)(torch.randint(vocab_size, size=(batch_size, seq_len))).shape

torch.Size([2, 3, 7])

In [5]:
from masking_language_modelling.models import BaseEncoderModel, MLMHead
from masking_language_modelling.dataproc import MLMDataset
# batch_size, seq_len, emb_size = 11, 30, 36
# X = torch.Tensor(batch_size, seq_len, emb_size).random_()



batch_size, seq_len, emb_size = 11, 47, 36
vocab_size = 1000
base_model = BaseEncoderModel(vocab_size=vocab_size, n_heads=12, emb_size=emb_size)
model = MLMHead(model=base_model)

X = torch.randint(vocab_size, size=(batch_size, seq_len))
probs = model(X)
probs.shape

torch.Size([11, 47, 1000])

# Prepare dataset

In [6]:
%%writefile masking_language_modelling/dataproc.py

import torch
from torch.utils.data import Dataset
from collections import Counter
from itertools import chain
from typing import List, Optional

class MLMDataset(Dataset):
    def __init__(self, text_fpath: str, 
                 max_seq_len: int, 
                 mask_ratio:float = 0.15
                ):
        self.max_seq_len = max_seq_len
        self.mask_ratio = mask_ratio
        with open(text_fpath, 'r') as f:
            self.lines = f.readlines()
        
    def __len__(self):
        return len(self.lines)

    def __getitem__(self, idx):
        return self.lines[idx]
    
class Tokenizer:
    def __init__(self, 
                 max_vocab_size: int,
                 truncation: bool = True,
                 max_seq_len: Optional[int] = None,
                 padding: bool = True,
                ):
        self.pad_token = '<PAD>'
        self.mask_token = '<MASK>'
        self.cls_token = '<CLS>'
        self.sep_token = '<SEP>'
        self.special_tokens = [self.cls_token, self.sep_token, self.pad_token, self.mask_token]
        self.max_vocab_size = max_vocab_size
        self.truncation = truncation
        self.padding = padding
        
        if self.padding or self.truncation:
            assert not(max_seq_len is None)
            self.max_seq_len = max_seq_len
    
    def fit(self, dataset):
        most_common_words = Counter(chain.from_iterable(map(self._preproc, dataset)))\
                                                     .most_common(self.max_vocab_size-len(self.special_tokens))
        most_common_words = list(map(lambda x: x[0], most_common_words))
        self.vocab = dict(map(lambda x: (x[1], x[0]), enumerate(self.special_tokens + most_common_words)))
        return self
    
    def apply(self, text: str):
        pad_token_idx = self.vocab[self.pad_token]
        input_seq = self._preproc(text)[:self.max_seq_len]
        payload_tokens = list(map(lambda x: self.vocab.get(x, pad_token_idx), input_seq))
        padding_tokens = [pad_token_idx]*(self.max_seq_len-len(payload_tokens))
        return [self.vocab[self.cls_token]] + payload_tokens + padding_tokens
    
    def _preproc(self, text: str) -> List[str]:
        return text.split()
    
    
def spawn_collate_fn(tokenizer, mask_ratio=0.15):
    cls_token_id = tokenizer.vocab[tokenizer.cls_token]
    sep_token_id = tokenizer.vocab[tokenizer.sep_token]
    mask_token_id = tokenizer.vocab[tokenizer.mask_token]
    
    def mask_objective(batch_token_ids, mask_ratio):
            masked_tokens = torch.rand(batch_token_ids.shape)<mask_ratio
            mask_arr = masked_tokens * (batch_token_ids != cls_token_id) * (batch_token_ids != sep_token_id)
            return mask_arr
        
    def custom_collate_fn(batch):
        input_ids = torch.Tensor(batch).long()
        mlm_mask = mask_objective(input_ids, mask_ratio)
        masked_input_ids = torch.where(mlm_mask, mask_token_id, input_ids).long()

        return {
            'input_tokens': input_ids,
            'masked_input_tokens': masked_input_ids,
#             'attention_mask': 1, # is this the same as mlm mask?
            'mlm_mask': mlm_mask,
        }
    return custom_collate_fn

Overwriting masking_language_modelling/dataproc.py


# Train loop

In [7]:
%%writefile test.txt
kek
mda ok na
aagaaa
kek
mda ok na
aagaaa
kek
mda ok na
aagaaa
kek
mda ok na
aagaaa

Overwriting test.txt


data

In [11]:
from masking_language_modelling.dataproc import MLMDataset, spawn_collate_fn, Tokenizer
from torch.utils.data import DataLoader

seq_len = 30
batch_size = 50
max_vocab_size = 30000

training_data = MLMDataset('rt.txt', seq_len)
tokenizer = Tokenizer(max_vocab_size=max_vocab_size, max_seq_len=seq_len)\
                        .fit(training_data)

# print(tokenizer.vocab)
# print(tokenizer.apply(f'kek mda {tokenizer.mask_token}'))
proc_train_dataset = list(map(tokenizer.apply, training_data))

collate_fn = spawn_collate_fn(tokenizer)
train_dataloader = DataLoader(proc_train_dataset, batch_size=batch_size, shuffle=True, 
                              collate_fn=collate_fn)
# next(iter(train_dataloader))

model

In [12]:
emb_size = 54

base_model = BaseEncoderModel(n_heads=6, emb_size=emb_size, 
                              vocab_size=len(tokenizer.vocab),
                              padding_idx=tokenizer.vocab[tokenizer.pad_token])
model = MLMHead(model=base_model)

In [13]:
from tqdm.auto import tqdm
import numpy as np

def train(model, train_dataloader, num_epochs=1, lr=1e-3):    
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss = torch.nn.CrossEntropyLoss(reduction='none')
    for epoch in range(num_epochs):
        losses = []
        print(f'Epoch {epoch}')
        pbar=tqdm(train_dataloader)
        for batch in pbar:
            # x is tensor with masked tokens [batch_size, seq_len]
            x = batch['masked_input_tokens']

            # y is tensor without masked_tokens [batch_size, seq_len]
            y = batch['input_tokens']

            # masked_tokens is boolean tensor with masked tokens [batch_size, seq_len]
            masked_tokens = batch['mlm_mask']
            model.zero_grad()
            logits = model(x)
            loss_res = loss(logits.view(-1,logits.shape[2]), y.view(-1))
            masked_loss = loss_res*masked_tokens.view(-1)
            avg_masked_loss = masked_loss.mean()
            avg_masked_loss.backward()
            optimizer.step()
            losses.append(avg_masked_loss.detach().numpy())
            pbar.set_description(f'CEL: {np.mean(losses):.3f}')
        
train(model, train_dataloader, num_epochs=100, lr=1e-3)

  0%|          | 0/151 [00:00<?, ?it/s]

  0%|          | 0/151 [00:00<?, ?it/s]

  0%|          | 0/151 [00:00<?, ?it/s]

  0%|          | 0/151 [00:00<?, ?it/s]

  0%|          | 0/151 [00:00<?, ?it/s]

  0%|          | 0/151 [00:00<?, ?it/s]

  0%|          | 0/151 [00:00<?, ?it/s]

  0%|          | 0/151 [00:00<?, ?it/s]

  0%|          | 0/151 [00:00<?, ?it/s]

  0%|          | 0/151 [00:00<?, ?it/s]

  0%|          | 0/151 [00:00<?, ?it/s]

  0%|          | 0/151 [00:00<?, ?it/s]

  0%|          | 0/151 [00:00<?, ?it/s]

  0%|          | 0/151 [00:00<?, ?it/s]

  0%|          | 0/151 [00:00<?, ?it/s]

  0%|          | 0/151 [00:00<?, ?it/s]

  0%|          | 0/151 [00:00<?, ?it/s]

  0%|          | 0/151 [00:00<?, ?it/s]

  0%|          | 0/151 [00:00<?, ?it/s]

  0%|          | 0/151 [00:00<?, ?it/s]

  0%|          | 0/151 [00:00<?, ?it/s]

  0%|          | 0/151 [00:00<?, ?it/s]

  0%|          | 0/151 [00:00<?, ?it/s]

  0%|          | 0/151 [00:00<?, ?it/s]

  0%|          | 0/151 [00:00<?, ?it/s]

KeyboardInterrupt: 

sanity check

In [75]:
import pandas as pd
from scipy import spatial


def get_embedding(query, tokenizer, base_model):
    tokens = tokenizer.apply(query)
    return base_model(torch.tensor(tokens).view(1, -1))[0, 0].detach().numpy()

queries = list(training_data)[:1000]
embs = []
base_model.eval()

query = 'fairytale for yound ladies. miracle fantasy'
query_emb = get_embedding(query, tokenizer, base_model)

res_lst = []
for entry in queries:
    entry_emb = get_embedding(entry, tokenizer, base_model)
    score = 1 - spatial.distance.cosine(query_emb, entry_emb)
    res_lst.append((query, entry, score))

scores_df = pd.DataFrame(res_lst, columns=['a', 'b', 'score'])
scores_df.sort_values('score', ascending=False)

Unnamed: 0,a,b,score
761,fairytale for yound ladies. miracle fantasy,"a resonant tale of racism , revenge and retrib...",0.999852
946,fairytale for yound ladies. miracle fantasy,it's a remarkably solid and subtly satirical t...,0.999700
764,fairytale for yound ladies. miracle fantasy,uno de los policiales más interesantes de los ...,0.999697
625,fairytale for yound ladies. miracle fantasy,has a shambling charm . . . a cheerfully incon...,0.999643
694,fairytale for yound ladies. miracle fantasy,"desta vez , columbus capturou o pomo de ouro .\n",0.999624
...,...,...,...
944,fairytale for yound ladies. miracle fantasy,"it moves quickly , adroitly , and without fuss...",0.349197
786,fairytale for yound ladies. miracle fantasy,you'll be left with the sensation of having ju...,0.339534
893,fairytale for yound ladies. miracle fantasy,"leave it to rohmer , now 82 , to find a way to...",0.327612
799,fairytale for yound ladies. miracle fantasy,greene delivers a typically solid performance ...,0.305827


In [70]:
# baseline loss
kek_loss = torch.nn.CrossEntropyLoss(reduction='none')
n, C = 5, len(tokenizer.vocab)

a1 = torch.Tensor(np.random.uniform(size=(n, C)))
b1 = torch.Tensor(np.random.randint(C, size=n)).long()
kek_loss(a1, b1).mean()

tensor(9.8794)

In [49]:
# from sklearn.linear_model import LogisticRegression

# LogisticRegression

# # scores_df.sort_values('score', ascending=False)[['b', 'score']].to_dict('records')

[{'b': 'a mostly intelligent , engrossing and psychologically resonant suspenser .\n',
  'score': 0.9985803365707397},
 {'b': 'a pleasant enough movie , held together by skilled ensemble actors .\n',
  'score': 0.9978662729263306},
 {'b': "steve irwin's method is ernest hemmingway at accelerated speed and volume .\n",
  'score': 0.9978216290473938},
 {'b': "this is the best american movie about troubled teens since 1998's whatever .\n",
  'score': 0.9976030588150024},
 {'b': "cantet perfectly captures the hotel lobbies , two-lane highways , and roadside cafes that permeate vincent's days\n",
  'score': 0.9975119233131409},
 {'b': 'what really surprises about wisegirls is its low-key quality and genuine tenderness .\n',
  'score': 0.9969547390937805},
 {'b': 'an idealistic love story that brings out the latent 15-year-old romantic in everyone .\n',
  'score': 0.9969406723976135},
 {'b': 'guaranteed to move anyone who ever shook , rattled , or rolled .\n',
  'score': 0.9968717098236084},