In [1]:
import pickle
import json
import sys
import re
import os
import psutil
from functools import lru_cache
import time

from pathlib import Path
from tqdm import tqdm

from collections import Counter, OrderedDict, namedtuple
from functools import reduce

import pandas as pd
import numpy as np

import einops
from einops.layers.torch import Rearrange, Reduce

import torch
from torch import nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from torchtext.data import get_tokenizer
from torchtext.vocab import vocab

from random import randint
import random


from sklearn.model_selection import train_test_split
import seaborn as sns
from matplotlib import pyplot as plt

%matplotlib inline

PROJECT_DIR = Path().absolute().parent
WIKI_PATH = PROJECT_DIR / 'InputData' / 'wikitext-103'
DATA_PATH = PROJECT_DIR / 'Data'

sys.path.append(str(PROJECT_DIR))

In [2]:
def process_memory():
    process = psutil.Process(os.getpid())
    print(f'{round(process.memory_info().rss * 10**(-6))} MB')

def objects_memory(*args):
    print(f'{round(sum([sys.getsizeof(obj) for obj in args]) * 10**(-6))} MB')

In [3]:
def save_pkl(obj, pth):
    with open(pth, 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)


def load_pkl(pth):
    with open(pth, 'rb') as f:
        return pickle.load(f)

In [4]:
from tokenizers import (
    decoders,
    models,
    normalizers,
    Regex,
    pre_tokenizers,
    processors,
    trainers,
    Tokenizer,
)
from transformers import PreTrainedTokenizerFast

# Load Data

In [4]:
def load_data(tp='train'):
    if tp not in ['train', 'test', 'valid']:
        raise Exception('ERROR: Wrong type of data.')
    
    pth = WIKI_PATH / f'wiki.{tp}.raw'
    heading_pattern = '\n (= ){1,}[^=]*[^=] (= ){1,}\n \n'
    with open(pth, 'r') as f:
        raw_text = f.read()
    
    raw_text = re.split(heading_pattern, raw_text)
    raw_text = [x.strip().strip('\n').strip() for x in raw_text if x and x not in [' ', '= ']]
    return raw_text

In [5]:
%%time
train_data = load_data('train')
test_data = load_data('test')
valid_data = load_data('valid')

print(f'{len(train_data)}/{len(test_data)}/{len(valid_data)}')

271821/623/552
CPU times: user 4.41 s, sys: 2.55 s, total: 6.96 s
Wall time: 7.01 s


In [6]:
def get_training_corpus():
    for i in range(0, len(train_data), 1000):
        yield train_data[i : i + 1000]

# Build Alphabet

In [7]:
tmp_tokenizer = Tokenizer(models.Unigram())
tmp_tokenizer.normalizer = normalizers.Sequence(
    [
        normalizers.Replace("``", '"'),
        normalizers.Replace("''", '"'),
        normalizers.Replace("”", '"'),
        normalizers.Replace("“", '"'),
        normalizers.Replace('ˈ', "'"),
        normalizers.Replace('’',"'"),
        normalizers.Replace('–',"-"),
        normalizers.Replace('—',"-"),
        normalizers.Replace('−',"-"),
        normalizers.Replace('′',"'"),
        normalizers.Replace('⁄',"/"),
        normalizers.NFKD(),
        normalizers.StripAccents(),
        normalizers.Replace(Regex(" {2,}"), " "),
    ]
)
tmp_tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
    [
        pre_tokenizers.BertPreTokenizer(), 
        # pre_tokenizers.Metaspace(replacement = '_', add_prefix_space = True),
        # pre_tokenizers.Punctuation(),
        pre_tokenizers.Digits(individual_digits=True)
    ]
)

In [8]:
char_counter = Counter()
for i in tqdm(range(len(train_data))):
    art_counter = Counter(train_data[i])
    char_counter.update(art_counter)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 271821/271821 [00:26<00:00, 10241.82it/s]


In [9]:
char_counter_df = pd.DataFrame(char_counter.most_common(), columns=['Symbol', 'Count'])
char_counter_df.shape

(4978, 2)

In [10]:
def get_base_char(txt):
    tokens = tmp_tokenizer.encode(txt).tokens
    return tokens
char_counter_df['Base_Symbol'] = char_counter_df['Symbol'].apply(get_base_char)
char_counter_df['Base_Symbol'].apply(len).value_counts()

1    4735
0     229
3      11
2       2
4       1
Name: Base_Symbol, dtype: int64

In [11]:
char_counter_df[char_counter_df['Base_Symbol'].apply(len) > 1]

Unnamed: 0,Symbol,Count,Base_Symbol
122,…,1961,"[., ., .]"
172,½,757,"[1, ⁄, 2]"
202,″,535,"[′, ′]"
468,⅓,86,"[1, ⁄, 3]"
568,⅔,58,"[2, ⁄, 3]"
591,¼,55,"[1, ⁄, 4]"
592,¾,55,"[3, ⁄, 4]"
1310,⅜,11,"[3, ⁄, 8]"
1468,⅛,9,"[1, ⁄, 8]"
1482,⅝,9,"[5, ⁄, 8]"


In [12]:
char_counter_df['Single_Base_Symbol'] = char_counter_df['Base_Symbol'].apply(lambda x: x[0] if x else '')
char_counter_df

Unnamed: 0,Symbol,Count,Base_Symbol,Single_Base_Symbol
0,,99530965,[],
1,e,48657548,[e],e
2,t,33788437,[t],t
3,a,33364371,[a],a
4,n,28965321,[n],n
...,...,...,...,...
4973,課,1,[課],課
4974,純,1,[純],純
4975,丽,1,[丽],丽
4976,치,1,[치],치


In [13]:
char_counter_df = char_counter_df.groupby('Single_Base_Symbol')['Count'].sum().reset_index()
char_counter_df = char_counter_df.sort_values('Count', ascending=False).reset_index(drop=True)
char_counter_df

Unnamed: 0,Single_Base_Symbol,Count
0,,100128792
1,e,48715319
2,t,33789131
3,a,33401452
4,n,28970107
...,...,...
4195,恢,1
4196,恒,1
4197,恆,1
4198,怨,1


In [14]:
char_counter_df['Cum_Prc'] = (char_counter_df['Count'] / char_counter_df['Count'].sum()).cumsum()

In [65]:
char_counter_df.head(110).tail(60)

Unnamed: 0,Single_Base_Symbol,Count,Cum_Prc
50,),572467,0.984419
51,(,572111,0.985501
52,3,541764,0.986526
53,5,538413,0.987544
54,8,532060,0.98855
55,E,514159,0.989523
56,J,503445,0.990475
57,4,495231,0.991412
58,O,462139,0.992286
59,6,445885,0.993129


In [15]:
alphabet = list(char_counter_df['Single_Base_Symbol'][:90])
[x for x in range(10) if str(x) not in alphabet]

[]

# Build Tokenizer
https://colab.research.google.com/github/tenexcoder/huggingface-tutorials/blob/main/BERT_tokenizer_from_scratch.ipynb
https://huggingface.co/transformers/v3.5.1/main_classes/tokenizer.html
https://huggingface.co/course/chapter6/8?fw=tf

Steps
1) Normalization
2) Pre_tokenization
3) Model
4) Post-processor

## 1) Normalizer

In [16]:
normlzr = normalizers.Sequence(
    [
        normalizers.Replace("``", '"'),
        normalizers.Replace("''", '"'),
        normalizers.Replace("”", '"'),
        normalizers.Replace("“", '"'),
        normalizers.Replace('ˈ', "'"),
        normalizers.Replace('’',"'"),
        normalizers.Replace('–',"-"),
        normalizers.Replace('—',"-"),
        normalizers.Replace('−',"-"),
        normalizers.Replace('′',"'"),
        normalizers.Replace('⁄',"/"),
        normalizers.NFKD(),
        normalizers.StripAccents(),
        normalizers.Replace(Regex(" {2,}"), " "),
    ]
)

## 2) Pre-tokenizer

In [17]:
pretknzr = pre_tokenizers.Sequence(
    [
        pre_tokenizers.BertPreTokenizer(), 
        # pre_tokenizers.Metaspace(replacement = '_', add_prefix_space = True),
        # pre_tokenizers.Punctuation(),
        pre_tokenizers.Digits(individual_digits=True)
    ]
)

## 3) Model Type

In [None]:
# ?trainers.BpeTrainer
# ?trainers.WordPieceTrainer

In [122]:
model_type = 'BPE'
# model_type = 'WordPiece'

SPEC_TOKENS = ["[UNK]", "[PAD]"]

if model_type == 'WordPiece':
    tokenizer = Tokenizer(models.WordPiece(unk_token="[UNK]"))
    
    trainer = trainers.WordPieceTrainer(
        vocab_size=50000, 
        min_frequency=0, 
        special_tokens=SPEC_TOKENS, 
        limit_alphabet=len(alphabet),
        initial_alphabet=alphabet,
        continuing_subword_prefix='##',
        end_of_word_suffix='__'
        
    )
elif model_type == 'BPE':
    tokenizer = Tokenizer(models.BPE(unk_token="[UNK]"))
    
    trainer = trainers.BpeTrainer(
        vocab_size=100000, 
        min_frequency=0, 
        special_tokens=SPEC_TOKENS, 
        limit_alphabet=len(alphabet),
        initial_alphabet=alphabet,
        continuing_subword_prefix='##',
        end_of_word_suffix='__'
    )
    
tokenizer.normalizer = normlzr
tokenizer.pre_tokenizer = pretknzr

In [121]:
tokenizer = Tokenizer(models.Unigram())

In [48]:
tokenizer.normalizer = normalizers.Sequence(
    [
        normalizers.Replace("``", '"'),
        normalizers.Replace("''", '"'),
        normalizers.Replace("”", '"'),
        normalizers.Replace("“", '"'),
        normalizers.Replace('ˈ', "'"),
        normalizers.Replace('’',"'"),
        normalizers.Replace('–',"-"),
        normalizers.Replace('—',"-"),
        normalizers.Replace('−',"-"),
        normalizers.Replace('′',"'"),
        normalizers.Replace('⁄',"/"),
        normalizers.NFKD(),
        normalizers.StripAccents(),
        normalizers.Replace(Regex(" {2,}"), " "),
    ]
)

In [49]:
tokenizer.pre_tokenizer = pre_tokenizers.Metaspace()

In [None]:
%%time
special_tokens = ["<unk>", "<pad>", "<s>", "</s>"]
trainer = trainers.UnigramTrainer(
    vocab_size=25000, special_tokens=special_tokens, unk_token="<unk>"
)
tokenizer.train_from_iterator(get_training_corpus(), trainer=trainer)




In [63]:
encoding = tokenizer.encode("Let's test this tokenizer.")
print(encoding.tokens)

['▁let', "'", 's', '▁test', '▁this', '▁to', 'ken', 'i', 'zer', '.']


In [64]:
encoding = tokenizer.encode("Let's test this tokenizer...", "on a pair of sentences!")
print(encoding.tokens)

['▁let', "'", 's', '▁test', '▁this', '▁to', 'ken', 'i', 'zer', '.', '.', '.', '▁on', '▁', 'a', '▁pair', '▁of', '▁sentence', 's', '!']


In [65]:
tokenizer.decoder = decoders.Metaspace()

In [66]:
tokenizer.save(str(DATA_PATH / "unigram_tokenizer.json"))

In [67]:
wrapped_tokenizer = PreTrainedTokenizerFast(
    tokenizer_object=tokenizer,
    bos_token="<s>",
    eos_token="</s>",
    unk_token="<unk>",
    pad_token="<pad>",
    padding_side="left"
)

In [68]:
encoding = tokenizer.encode("Let's test this tokenizer. I want you.")
print(encoding.tokens)
print(encoding.ids)

['▁let', "'", 's', '▁test', '▁this', '▁to', 'ken', 'i', 'zer', '.', '▁', 'i', '▁want', '▁you', '.']
[1575, 72, 8, 778, 49, 15, 2883, 30, 3483, 7, 4, 30, 1163, 226, 7]


In [76]:
train_data[0]

'Senjō no Valkyria 3 : Unrecorded Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable . Released in January 2011 in Japan , it is the third game in the Valkyria series . Employing the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the " Nameless " , a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " Calamaty Raven " . \n The game began development in 2010 , carrying over a large portion of the work done on Valkyria Chronicles II . While it retained the standard features of the series , it also underwent multiple adjustments , such as making the game more forgiving for series newcomers . Character designer Ra

In [74]:
encoding = tokenizer.encode(train_data[0])
print(encoding.tokens)
print(len(encoding.tokens))
# print(encoding.ids)

['▁sen', 'jo', '▁no', '▁val', 'k', 'y', 'ria', '▁', '3', '▁', ':', '▁un', 'recorded', '▁chronicle', 's', '▁', '(', '▁japanes', 'e', '▁', ':', '▁', '戦', '場', 'の', 'ウ', 'ァ', 'ル', 'キ', 'ュ', 'リ', 'ア', '3', '▁', ',', '▁lit', '▁', '.', '▁val', 'k', 'y', 'ria', '▁of', '▁the', '▁battlefield', '▁', '3', '▁', ')', '▁', ',', '▁common', 'ly', '▁referre', 'd', '▁to', '▁as', '▁val', 'k', 'y', 'ria', '▁chronicle', 's', '▁iii', '▁outside', '▁japan', '▁', ',', '▁is', '▁', 'a', '▁tactical', '▁role', '▁', '@', '-', '@', '▁playing', '▁video', '▁game', '▁develope', 'd', '▁by', '▁sega', '▁and', '▁media', '.', 'vision', '▁for', '▁the', '▁playstation', '▁portable', '▁', '.', '▁released', '▁in', '▁january', '▁2011', '▁in', '▁japan', '▁', ',', '▁it', '▁is', '▁the', '▁third', '▁game', '▁in', '▁the', '▁val', 'k', 'y', 'ria', '▁series', '▁', '.', '▁employ', 'ing', '▁the', '▁same', '▁fusion', '▁of', '▁tactical', '▁and', '▁real', '▁', '@', '-', '@', '▁time', '▁gameplay', '▁as', '▁its', '▁predecessor', 's', '▁', ',',

## 4) Fit

In [123]:
tokenizer.pre_tokenizer.pre_tokenize_str("Let's test pre-tokenization! 123")

[('Let', (0, 3)),
 ("'", (3, 4)),
 ('s', (4, 5)),
 ('test', (6, 10)),
 ('pre', (11, 14)),
 ('-', (14, 15)),
 ('tokenization', (15, 27)),
 ('!', (27, 28)),
 ('1', (29, 30)),
 ('2', (30, 31)),
 ('3', (31, 32))]

In [124]:
%%time
tokenizer.train_from_iterator(get_training_corpus(), trainer=trainer)




CPU times: user 10min 2s, sys: 37.1 s, total: 10min 39s
Wall time: 1min 7s


In [125]:
print(tokenizer.encode("Héllò hôw are ü?").tokens)

['Hello__', 'how__', 'are__', 'u__', '?__']


In [126]:
tokenizer.save(str(DATA_PATH / "Tokenizer_BPE100k.json"))

In [200]:
wrapped_tokenizer = PreTrainedTokenizerFast(
    tokenizer_file=str(DATA_PATH / "Tokenizer_BPE100k.json"),
    unk_token="[UNK]",
    pad_token="[PAD]"
)

# Post Tokenizer

In [None]:
def pre_post_tokenize(encoding):
    tokens = ['']
    tokens_ids = []
    for tk, tkid in zip(encoding.tokens, encoding.ids):
        if tk == '[UNK]' and tokens == '[UNK]':[-1]
            continue
        else:
            tokens.append(tk)
            tokens_ids.append(tkid)
        
    tokens.pop(0)
    
    return tokens, tokens_ids

In [142]:
tokenizer.token_to_id('[PAD]')

1

In [144]:
Token = namedtuple('Token', ['tid', 'value', 'title', 'upper','part', 'w_end'])

In [166]:
ALL_VOCAB = {
    'First': tokenizer.get_vocab(),
    'First_Reverse': {v:k for k,v in tokenizer.get_vocab().items()},
    'First_Second': {
        tokenizer.token_to_id('[UNK]'): Token(tid=tokenizer.token_to_id('[UNK]'), value='[unk]', 
                       title=False, upper=False, part=False, w_end=True),
        tokenizer.token_to_id('[PAD]'): Token(tid=tokenizer.token_to_id('[PAD]'), value='[pad]',
                       title=False, upper=False, part=False, w_end=True)
    },
    'First_Second_Reverse': {
        Token(tid=tokenizer.token_to_id('[UNK]'), value='[unk]',
              title=False, upper=False, part=False, w_end=True): tokenizer.token_to_id('[UNK]'), 
        Token(tid=tokenizer.token_to_id('[PAD]'), value='[pad]',
              title=False, upper=False, part=False, w_end=True): tokenizer.token_to_id('[PAD]'), 
    },
    'Second': {
        '[unk]': tokenizer.token_to_id('[UNK]'),
        '[pad]': tokenizer.token_to_id('[PAD]'),
    },
    'Second_Reverse': {
        tokenizer.token_to_id('[UNK]'): '[unk]',
        tokenizer.token_to_id('[PAD]'): '[pad]',
    }    
}

In [167]:
max_new_id = max(ALL_VOCAB['Second_Reverse'].keys())

for tk, tkid in tqdm(tokenizer.get_vocab().items()):
    if tkid in ALL_VOCAB['First_Second']:
        continue
        
    part = tk[:2] == '##'
    w_end = tk[-2:] == '__'
    tk = tk.replace('##','').replace('__','')
    upper = tk.isupper()
    title = tk[0].isupper()
    value = tk.lower()
    if value in ALL_VOCAB['Second'].keys():
        value_id = ALL_VOCAB['Second'][value]
    else:
        value_id = max(ALL_VOCAB['Second_Reverse'].keys()) + 1
        ALL_VOCAB['Second'][value] = value_id
        ALL_VOCAB['Second_Reverse'][value_id] = value
    
    tk = Token(
        tid=value_id, value=value, 
        title=title, upper=upper, part=part, w_end=w_end
    )
    ALL_VOCAB['First_Second'][tkid] = tk
    ALL_VOCAB['First_Second_Reverse'][tk] = tkid

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100000/100000 [00:53<00:00, 1870.71it/s]


In [168]:
len(ALL_VOCAB['Second_Reverse'])

75463

In [128]:
len(set(x.strip('__').strip('##').lower() for x in  tokenizer.get_vocab().keys()))

75463

In [175]:
save_pkl(ALL_VOCAB, DATA_PATH / "ALLVOCAB_BPE100k.pkl")

In [178]:
def encode_txt(txt, tokenizer, vocab):
    encoding = tokenizer.encode(txt)
    all_ids = [(fid, vocab['First_Second'][fid]) for fid in encoding.ids]
    return all_ids

In [58]:
encoding = tokenizer.encode(train_data[0])
print(encoding.tokens)
print(len(encoding.tokens))
# print(encoding.ids)

['Sen', '##jo__', 'no__', 'Valky', '##ria__', '3__', ':__', 'Un', '##rec', '##orded__', 'Chronicles__', '(__', 'Japanese__', ':__', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '[UNK]', '3__', ',__', 'lit__', '.__', 'Valky', '##ria__', 'of__', 'the__', 'Battlefield__', '3__', ')__', ',__', 'commonly__', 'referred__', 'to__', 'as__', 'Valky', '##ria__', 'Chronicles__', 'III__', 'outside__', 'Japan__', ',__', 'is__', 'a__', 'tactical__', 'role__', '@__', '-__', '@__', 'playing__', 'video__', 'game__', 'developed__', 'by__', 'Sega__', 'and__', 'Media__', '.__', 'Vision__', 'for__', 'the__', 'PlayStation__', 'Portable__', '.__', 'Released__', 'in__', 'January__', '2__', '0__', '1__', '1__', 'in__', 'Japan__', ',__', 'it__', 'is__', 'the__', 'third__', 'game__', 'in__', 'the__', 'Valky', '##ria__', 'series__', '.__', 'Employ', '##ing__', 'the__', 'same__', 'fusion__', 'of__', 'tactical__', 'and__', 'real__', '@__', '-__', '@__', 'time__', 'gameplay__', 'a

In [187]:
len([(k,v) for k,v in ALL_VOCAB['First_Second'].items() if len(v.value) > 1 and v.upper])

2786

In [188]:
len([(k,v) for k,v in ALL_VOCAB['First_Second'].items() if len(v.value) > 1 and v.title])

43762

In [195]:
len([(k,v) for k,v in ALL_VOCAB['First_Second'].items() if v.part])

17739

In [196]:
len([(k,v) for k,v in ALL_VOCAB['First_Second'].items() if v.w_end])

78507

In [198]:
len([(k,v) for k,v in ALL_VOCAB['First_Second'].items() if v.w_end and not v.part])

63572

In [197]:
len(ALL_VOCAB['First_Second'])

100000

In [194]:
ALL_VOCAB['First_Reverse'][20335]

'Dog'

In [193]:
ALL_VOCAB['First']['Dog']

20335

# GLOBAL CONFIGS

In [347]:
# Data and Train Parameters
CURRENT_SEQ_LEN = 30
BATCH_SIZE = 16
AGG_ROUNDS = 10
CNT_NEGATIVE = 5
MAX_SEQ_LEN = 1024

# Architecture parameters

    # Counts
CNT_MEANINGS = 5
CAT_SIZES = [55000, 30000, 15000, 5000]
    
    # Embedings
POS_EMB_SIZE = 5
TITLE_EMB_SIZE = 5
UPPER_EMB_SIZE = 5
PART_EMB_SIZE = 5
END_EMB_SIZE = 5
MEANING_EMB_SIZE = 20
CAT_EMB_SIZE = 10



    # Main Sizes
PREDICT_SIZE = (TITLE_EMB_SIZE + UPPER_EMB_SIZE + PART_EMB_SIZE + END_EMB_SIZE) + CAT_EMB_SIZE + MEANING_EMB_SIZE
HIDDEN_SIZE = int(1.5 * PREDICT_SIZE)
INPUT_SIZE = POS_EMB_SIZE + TITLE_EMB_SIZE + UPPER_EMB_SIZE + PART_EMB_SIZE + END_EMB_SIZE + CAT_EMB_SIZE + MEANING_EMB_SIZE

In [7]:
HIDDEN_SIZE

75

# Make Train Data

In [None]:
wrapped_tokenizer = PreTrainedTokenizerFast(
    tokenizer_file=str(DATA_PATH / "Tokenizer_BPE100k.json"),
    unk_token="[UNK]",
    pad_token="[PAD]"
)

In [9]:
def merge_unks(toks, unk=0):
    return [tk for i, tk in enumerate(toks) if (tk != unk) or (i == 0) or (toks[i-1] != unk)]

In [10]:
def final_tokenizer(txt):
    tokens = wrapped_tokenizer(txt)['input_ids']
    tokens = merge_unks(tokens, wrapped_tokenizer.unk_token_id)
    return np.array(tokens)

In [11]:
def get_all_sub_seq(tokens, max_len, pad=1):
    ln = len(tokens)
    if ln < max_len:
        return np.array([np.concatenate([tokens, [pad]*(max_len - ln)])])
    sub_seqs = []
    for i in range(ln - max_len + 1):
        sub_seqs.append(tokens[i:i+max_len])
    return np.array(sub_seqs)

In [14]:
%%time
train_30T = [
    get_all_sub_seq(final_tokenizer(txt), max_len=MAX_SEQ_LEN, pad=wrapped_tokenizer.pad_token_id) 
    for txt in train_data
]
train_30T = np.concatenate(train_30T, axis=0)
train_30T = np.array(train_30T, np.uintc)
train_30T.shape #(103596449, 30)

CPU times: user 8min 16s, sys: 1min 8s, total: 9min 24s
Wall time: 10min 22s


(103596449, 30)

In [15]:
with open(DATA_PATH / 'train_30T.npy', 'wb') as f:
    np.save(f, train_30T)

In [27]:
process_memory()
objects_memory(train_30T)

585 MB
24863 MB


# Load to Train

In [6]:
process_memory()

484 MB


In [8]:
with open(DATA_PATH / 'train_30T.npy', 'rb') as f:
    train_tokens = np.load(f)

In [9]:
Token = namedtuple('Token', ['tid', 'value', 'title', 'upper','part', 'w_end'])
ALL_VOCAB = load_pkl(DATA_PATH / "ALLVOCAB_BPE100k.pkl")

In [10]:
wrapped_tokenizer = PreTrainedTokenizerFast(
    tokenizer_file=str(DATA_PATH / "Tokenizer_BPE100k.json"),
    unk_token="[UNK]",
    pad_token="[PAD]"
)

In [11]:
process_memory()

1415 MB


In [12]:
objects_memory(train_tokens)

12432 MB


# Arhitecture

In [13]:
# 1) PosEncodind
# 2) PropsEmb
# 3) TokensEmb

## Position Encoding

In [350]:
class PosEncoding(nn.Module):

    def __init__(self, rows, emb_size, cnt_repeats, seed=0):
        super(PosEncoding, self).__init__()
        self.rows = rows
        self.emb_size = emb_size
        self.cnt_repeats = cnt_repeats
        self.embedding = nn.Embedding(rows, self.emb_size)
        self.init_weights(seed)

    def init_weights(self, seed):
        torch.manual_seed(seed)
        nn.init.xavier_uniform_(self.embedding.weight, gain=1.0)
        self.embedding.weight.data = torch.tensor(
            np.array(self.embedding.weight.data), 
            dtype=torch.float32
        ).requires_grad_(True)
        
    def forward(self, batch): # Batch_Sz x SeqLen x SeqLen
        t0 = time.time()
        batch = self.embedding(batch)
        batch = batch.repeat(self.cnt_repeats,1,1,1,1).permute(1,2,3,0,4) # Batch_Sz x SeqLen x SeqLen x (Cnt_Meanings * Cnt_Cats) x PosEmb
        print(f'PosEncoding, forward time: {time.time()-t0}')
        return batch
    
    def forward2(self, batch, ptime=False): 
        # batch = seq_len x seq_len
        t0 = time.time()
        batch = self.embedding(batch)
        if ptime:
            print(f'PosEncoding, forward2 time: {time.time()-t0}')
        return batch
    
    
    def init_optims(self, opt_type, lr):
        self.current_weight_lr = lr
        self.opt= getattr(optim, opt_type)([self.embedding.weight], lr=lr)

    def set_lr(self, lr_weight, lr_bias):
        pass

    def step(self):
        self.opt.step()
        
    def zero_grad(self):
        self.opt.zero_grad()

    def clip_grad(self, maxg=1e-2):
        pass

    def count_params(self):
        return 0

## Properties Embeding

In [351]:
class PropsEmbeding(nn.Module):

    def __init__(self, title, upper, part, w_end, cnt_repeats, seed=0):
        super(PropsEmbeding, self).__init__()
        self.title = title
        self.upper = upper
        self.part = part
        self.w_end = w_end
        self.cnt_repeats = cnt_repeats
        
        self.title_emb = nn.Embedding(2, self.title)
        self.upp_emb = nn.Embedding(2, self.upper)
        self.prt_emb = nn.Embedding(2, self.part)
        self.end_emb = nn.Embedding(2, self.w_end)
        
        # FIX ORDER !!!!!!!
        self.seq_embs = [self.title_emb, self.upp_emb, self.prt_emb ,self.end_emb]
        self.init_weights(seed)

    def init_weights(self, seed):
        torch.manual_seed(seed)
        
        for atr in dir(self):
            if atr[-4:] != '_emb':
                continue
            lay = getattr(self, atr)
            nn.init.xavier_uniform_(lay.weight, gain=1.0)
            lay.weight.data = torch.tensor(np.array(lay.weight.data), dtype=torch.float32).requires_grad_(True)

        
        
    def forward(self, batch): # Extended, Batch_Size x SeqLen x SeqLen x Props (4)
        t0 = time.time()
        embs_vals = []
        
        for i, emb_lay in enumerate(self.seq_embs):
            if self.fited:
                embs_vals.append(emb_lay(batch[:,:,:,i]))
            else:
                with torch.no_grad():
                    embs_vals.append(emb_lay(batch[:,:,:,i]))
        
        embs_vals = torch.cat(embs_vals, dim=-1)
        embs_vals = embs_vals.repeat(self.cnt_repeats,1,1,1,1).permute(1,2,3,0,4)
        print(f'PropsEmbeding, forward time: {time.time()-t0}')
        return embs_vals # Batch_Size x SeqLen x SeqLen x Cnt_Meanings * Cnt_Cats x PropsEmb
    
    def forward4(self, batch, ptime=False): 
        # batch = batch x seq_len x props (=4)
        t0 = time.time()
        embs_vals = []
        seq_len = batch.shape[1]
        
        embs_vals = torch.cat([
            emb_lay(batch[:,:,i])
            for i, emb_lay in enumerate(self.seq_embs)
        ], dim=-1)
        
        embs_vals = [
            emb_lay(batch[:,:,i])
            for i, emb_lay in enumerate(self.seq_embs)
        ]
        
        embs_vals = torch.cat(embs_vals, dim=-1)
        # batch x seq_len x props_emb
        if ptime:
            print(f'PropsEmbeding, forward4 time: {time.time()-t0}')
        return embs_vals # Batch_Size x SeqLen x SeqLen x Cnt_Meanings * Cnt_Cats x PropsEmb
    

    def init_optims(self, opt_type, lr):
        self.current_weight_lr = lr
        self.opt = getattr(optim, opt_type)([x.weight for x in self.seq_embs], lr=lr)

    def set_lr(self, lr_weight, lr_bias):
        pass

    def step(self):
        self.opt.step()
        
    def zero_grad(self):
        self.opt.zero_grad()

    def clip_grad(self, maxg=1e-2):
        pass

    def count_params(self):
        return 0

## Tokens Embeding

In [251]:
class TokensEmbeding(nn.Module):

    def __init__(self, cnt_tokens, categories_sz, cnt_meanings, meaning_emb_sz, cat_emb_sz, seed=0):
        super(TokensEmbeding, self).__init__()
        self.cnt_tokens = cnt_tokens
        self.categories_sz = categories_sz
        self.cnt_categories = len(self.categories_sz)
        self.cnt_meanings = cnt_meanings
        self.meaning_emb_sz = meaning_emb_sz
        self.cat_emb_sz = cat_emb_sz
        
        # Categories Embeding
        self.cat_emb = nn.Embedding(self.cnt_categories, self.cat_emb_sz)
        
        # Tokens Embeding for each Category
        self.all_tokens_embeds = []
        for i, cat_sz in enumerate(self.categories_sz):
            setattr(self, f'token{i}_emb', nn.Embedding(cat_sz, self.cnt_meanings * self.meaning_emb_sz))
            self.all_tokens_embeds.append(getattr(self, f'token{i}_emb'))

        self.init_weights(seed)
        self.init_merges(seed)
    
    def init_merges(self, seed):
        np.random.seed(seed)
        
        all_tokens_ids = np.arange(self.cnt_tokens)
        for i, cat_sz in enumerate(self.categories_sz):
            cat_dict = dict()
            while len(cat_dict) != self.cnt_tokens:
                cats_ids = np.arange(cat_sz)
                all_tokens_ids = [i for i in range(self.cnt_tokens) if i not in cat_dict]
                np.random.shuffle(cats_ids)
                np.random.shuffle(all_tokens_ids)
                indx = slice(0, min(len(all_tokens_ids), cat_sz))
                cat_dict.update(zip(all_tokens_ids[indx], cats_ids[indx], ))

            setattr(self, f'cat{i}_merge', cat_dict)
        
    def init_weights(self, seed):
        torch.manual_seed(seed)
        
        for lay in self.all_tokens_embeds:
            rows = lay.weight.shape[0]
            tmp_lay = lay.weight.data.reshape(rows * self.cnt_meanings, self.meaning_emb_sz).type(torch.float32)
            nn.init.xavier_uniform_(tmp_lay, gain=1.0)
            tmp_lay = tmp_lay.reshape(rows, self.cnt_meanings * self.meaning_emb_sz)
            lay.weight.data = tmp_lay.clone().detach().requires_grad_(True)
        
        
        nn.init.xavier_uniform_(self.cat_emb.weight, gain=1.0)
        self.cat_emb.weight.data = torch.tensor(
            np.array(self.cat_emb.weight.data), 
            dtype=torch.float32
        ).requires_grad_(True)

        
        
    def forward(self, batch): # Batch_Size x SeqLen
        t0 = time.time()
        rows, columns = batch.shape
        all_tk_cat = []
        
        for i, emb_lay in enumerate(self.all_tokens_embeds):
            cat_emb = torch.tensor([i]).\
                repeat(rows, columns, self.cnt_meanings).requires_grad_(False) # Batch_Size x SeqLen x Cnt_Meanings x CatEmb
            cat_emb = self.cat_emb(cat_emb)

            merge_dict = getattr(self, f'cat{i}_merge')

            tokens_i = batch.clone().apply_(merge_dict.get)
            tokens_i = emb_lay(tokens_i).reshape(rows, columns, self.cnt_meanings, self.meaning_emb_sz)

            tk_cat_i = torch.cat([cat_emb, tokens_i], axis=-1) # Batch_Size x SeqLen x Cnt_Meanings x (CatEmb + MeaningEmb)
            all_tk_cat.append(tk_cat_i)
       
        all_tk_cat = torch.cat(all_tk_cat, axis=-2) # Batch_Size x SeqLen x Cnt_Meanings * CntCats x (CatEmb + MeaningEmb)
        all_tk_cat = all_tk_cat.repeat(columns, 1,1,1,1)
        all_tk_cat = all_tk_cat.permute(1,2,0,3,4)
        print(f'TokensEmbeding, forward time: {time.time()-t0}')
        return all_tk_cat
    
    
    def forward5(self, batch, negative_batch=None, ptime=False): # Batch_Size x SeqLen
        # batch = batch x seq_len
        # negative_batch = batch x seq_len x cnt_neg
        
        t0 = time.time()
        rows, columns = batch.shape
        all_tk_cat = []
        
        # cnt_cats - > cnt_cats x cat_emb
        cat_emb = self.cat_emb(torch.tensor([i for i in range(self.cnt_categories)], dtype=torch.int)) # cnt_cats x cat_emb
        
        # cnt_cats x cat_emb -> batch x seq_len x (cnt_cats * cnt_meanings) x cat_emb
        cat_emb_ext = einops.repeat(cat_emb, 'c e -> b s (c m) e', b=rows, s=columns, m=self.cnt_meanings) 
        
        
        # List (cnt_cats) of batch x seq_len x (cnt_meanings * meaning_emb)
        all_tk_cat = [
            emb_lay(batch.clone().apply_(getattr(self, f'cat{i}_merge').get))
            for i, emb_lay in enumerate(self.all_tokens_embeds)
        ]
                
        
        # (List) cnt_cats x batch x seq_len x (cnt_meanings * meaning_emb) -> batch x seq_len x (cnt_cats * cnt_meanings) x meaning_emb
        all_tk_cat = einops.rearrange(all_tk_cat, 'c b s (m e) -> b s (c m) e', m=self.cnt_meanings, e=self.meaning_emb_sz)
        
        # batch x seq_len x (cnt_cats * cnt_meanings) x (cat_emb + meaning_emb)
        all_tk_cat = torch.cat([cat_emb_ext, all_tk_cat], axis=-1) 
        
        if negative_batch is not None:
            cnt_negative = negative_batch.shape[-1]
            
            negative_batch = einops.rearrange([
                emb_lay(negative_batch.clone().apply_(getattr(self, f'cat{i}_merge').get))
                for i, emb_lay in enumerate(self.all_tokens_embeds)], 
                'c b s n (m e) -> b s (n c m) e', 
                m=self.cnt_meanings, e=self.meaning_emb_sz
            )
            
            
            cat_emb_ext2 = einops.repeat(cat_emb, 'c e -> b s (n c m) e', b=rows, s=columns, n=cnt_negative, 
                                         m=self.cnt_meanings) 
            
            
            negative_batch = torch.cat([cat_emb_ext2, negative_batch], axis=-1)
            # negative_batch = batch x seq_len x (cnt_negative * cnt_cats * cnt_meanings) x (cat_emb + meaning_emb)
        
        if ptime:
            print(f'TokensEmbeding, forward5 time: {time.time()-t0}')
        return all_tk_cat, negative_batch
    
    
    def init_optims(self, opt_type, lr):
        self.current_weight_lr = lr
        self.opt = getattr(optim, opt_type)(
            [self.cat_emb.weight, ] + [x.weight for x in self.all_tokens_embeds], lr=lr
        )

    def set_lr(self, lr_weight, lr_bias):
        pass

    def step(self):
        self.opt.step()
        
    def zero_grad(self):
        self.opt.zero_grad()

    def clip_grad(self, maxg=1e-2):
        pass

    def count_params(self):
        return 0

## Attention Net

In [252]:
class AttentionNet(nn.Module):

    def __init__(self, input_sz, seed=0):
        super(AttentionNet, self).__init__()
        self.seed=seed
        self.lay0 = nn.Linear(in_features=input_sz, out_features=1)
        # self.activation = nn.LeakyReLU()
        self.init_weights(seed)
        self.softmax = nn.Softmax(dim=-1)
        
    
    def init_weights(self, seed):
        torch.manual_seed(seed)
        nn.init.xavier_uniform_(self.lay0.weight, gain=1.0)
        self.lay0.weight.data = torch.tensor(np.array(self.lay0.weight.data), dtype=torch.float32).requires_grad_(True)
        self.lay0.bias.data = torch.tensor(np.array(self.lay0.bias.data), dtype=torch.float32).requires_grad_(True)
        
        
    def forward(self, batch, mask):
        # batch =  batch x seq_len x seq_len * (cnt_cats * cnt_meanings) x hidden + embs
        seq_len = batch.shape[1]
        cm = batch.shape[3]
        
        att = einops.rearrange(self.lay0(batch), 'b s_w s_h cm e -> b s_w (s_h cm e)')
        # att =  batch x seq_len x (seq_len * cnt_cats * cnt_meanings)
        
        if mask is not None:
            att = att + mask 
        # att = einops.rearrange(self.softmax(att), 'b s_w (s_h cm e) -> b s_w s_h cm e', s_h=seq_len, cm=cm, e=1)
        att = self.softmax(att).unsqueeze(dim=-1)
        # att = batch x seq_len x (seq_len * cnt_cats * cnt_meanings) x 1
        return att

    def init_optims(self, opt_type, lr):
        self.current_weight_lr = lr
        self.opt = getattr(optim, opt_type)([self.lay0.weight, self.lay0.bias], lr=lr)

    def set_lr(self, lr_weight, lr_bias):
        pass

    def step(self):
        self.opt.step()
        
    def zero_grad(self):
        self.opt.zero_grad()

    def clip_grad(self, maxg=1e-2):
        pass

    def count_params(self):
        return 0

## Process Net

In [253]:
class ProceccNet(nn.Module):

    def __init__(self, input_size, output_size, seed=0):
        super(ProceccNet, self).__init__()
        self.seed=seed
        self.input_size = input_size
        self.output_size = output_size
        self.lay0 = nn.Linear(in_features=input_size, out_features=output_size)
        self.activation = torch.tanh
        self.init_weights(seed)
        
    
    def init_weights(self, seed):
        torch.manual_seed(seed)
        nn.init.xavier_uniform_(self.lay0.weight, gain=1.0)
        self.lay0.weight.data = torch.tensor(np.array(self.lay0.weight.data), dtype=torch.float32).requires_grad_(True)
        self.lay0.bias.data = torch.tensor(np.array(self.lay0.bias.data), dtype=torch.float32).requires_grad_(True)
        
        
    def forward(self, batch):
        # batch =  Batch_Sz * SeqLen x Embs + Hidden
        batch = self.activation(self.lay0(batch)) # batch =  Batch_Sz * SeqLen x SeqLen * Cnt_Cats * Cnt_Meanings
        return batch

    def init_optims(self, opt_type, lr):
        self.current_weight_lr = lr
        self.opt = getattr(optim, opt_type)([self.lay0.weight, self.lay0.bias], lr=lr)

    def set_lr(self, lr_weight, lr_bias):
        pass

    def step(self):
        self.opt.step()
        
    def zero_grad(self):
        self.opt.zero_grad()

    def clip_grad(self, maxg=1e-2):
        pass

    def count_params(self):
        return 0

## Aggregation Net

In [336]:
class AggregationNet(nn.Module):

    def __init__(self, hidden_size, full_embeding_size, short_embeding_size, seed=0):
        super(AggregationNet, self).__init__()
        self.hidden_size = hidden_size
        self.full_embeding_size = full_embeding_size
        self.short_embeding_size = short_embeding_size
        self.seed=seed
        self.att_tokens = AttentionNet(self.hidden_size + self.full_embeding_size, seed=seed)
        self.att_hidden = AttentionNet(2 * self.hidden_size, seed=seed)
        self.process_net = ProceccNet(2 * self.hidden_size + self.short_embeding_size, self.hidden_size, seed=seed)
        
    def forward(self, pos_props_tokens, mask, mask_ext, rounds=10):
        # pos_props_tokens = batch x seq_len x seq_len x (cnt_cats * cnt_meanings) x (pos_emb + cat_emb + emb_meaning)
        
        t0 = time.time()
        batch_sz = pos_props_tokens.shape[0]
        seq_len = pos_props_tokens.shape[1]
        cm = pos_props_tokens.shape[3]
        start_emb = self.full_embeding_size - self.short_embeding_size
        
        
        # batch x seq_len x hidden
        h = torch.zeros((batch_sz, seq_len, self.hidden_size), dtype=torch.float32).requires_grad_(False)
        
        all_hidens = []
        for i in range(rounds):
            h_ext = einops.repeat(h, 'b s_h e -> b s_w s_h cm e', s_w=seq_len, cm=cm)
            # h_ext = batch x seq_len x seq_len x (cnt_cats * cnt_meanings) x hidden
            
            
            
            h_ext2 = einops.repeat(h, '(b s1) e -> (b s2) s1 e', b=batch_sz, s1=seq_len, s2=seq_len)
            h_ext3 = einops.repeat(h, 'bs e -> bs s2 e', s2=seq_len)
            
            att_tk = self.att_tokens.forward(torch.cat([h_ext, pos_props_tokens], axis=-1), mask_ext)
            agg_tk = (pos_props_tokens * att_tk).sum(axis=-2)
            # att_tk = batch x seq_len x seq_len x (cnt_cats * cnt_meanings) x 1
            
            
            att_h = self.att_hidden.forward(torch.cat([h_ext2, h_ext3], axis=-1), mask)
            
            agg_tk = (pos_props_tokens * att_tk).sum(axis=-2)
            agg_h = (h_ext2 * att_h).sum(axis=-2)
            
            agg_output = torch.cat([h, agg_h, agg_tk], axis=-1)
            h = self.process_net.forward(agg_output)
            all_hidens.append(h)
        print(f'AggregationNet, forward time: {time.time()-t0}')
        return all_hidens

    
    def init_optims(self, opt_type, lr):
        self.att_tokens.init_optims(opt_type, lr)
        self.att_hidden.init_optims(opt_type, lr)
        self.process_net.init_optims(opt_type, lr)

    def set_lr(self, lr_weight, lr_bias):
        pass

    def step(self):
        self.att_tokens.step()
        self.att_hidden.step()
        self.process_net.step()

    def zero_grad(self):
        self.att_tokens.zero_grad()
        self.att_hidden.zero_grad()
        self.process_net.zero_grad()

    def clip_grad(self, maxg=1e-2):
        pass

    def count_params(self):
        return 0

## Property Prediction Net

In [337]:
class PropertyNet(nn.Module):

    def __init__(self, input_size, seed=0):
        super(PropertyNet, self).__init__()
        self.seed=seed
        self.input_size = input_size
        self.lay0 = nn.Linear(in_features=input_size, out_features=4)
        self.activation = torch.sigmoid
        self.init_weights(seed)
        
    
    def init_weights(self, seed):
        torch.manual_seed(seed)
        nn.init.xavier_uniform_(self.lay0.weight, gain=1.0)
        self.lay0.weight.data = torch.tensor(np.array(self.lay0.weight.data), dtype=torch.float32).requires_grad_(True)
        self.lay0.bias.data = torch.tensor(np.array(self.lay0.bias.data), dtype=torch.float32).requires_grad_(True)
        
        
    def forward(self, batch):
        # batch =  Batch_Sz * SeqLen x Embs + Hidden
        batch = self.lay0(batch) # batch =  Batch_Sz * SeqLen x SeqLen * Cnt_Cats * Cnt_Meanings
        batch = torch.sigmoid(batch)
        return batch

    def init_optims(self, opt_type, lr):
        self.current_weight_lr = lr
        self.opt = getattr(optim, opt_type)([self.lay0.weight, self.lay0.bias], lr=lr)

    def set_lr(self, lr_weight, lr_bias):
        pass

    def step(self):
        self.opt.step()
        
    def zero_grad(self):
        self.opt.zero_grad()

    def clip_grad(self, maxg=1e-2):
        pass

    def count_params(self):
        return 0

## Category + Meaning Prediction Net

In [338]:
class CatMeangNet(nn.Module):

    def __init__(self, input_size, seed=0):
        super(CatMeangNet, self).__init__()
        self.seed=seed
        self.input_size = input_size
        self.lay0 = nn.Linear(in_features=input_size, out_features=1)
        self.activation = torch.sigmoid
        self.init_weights(seed)
        
    
    def init_weights(self, seed):
        torch.manual_seed(seed)
        nn.init.xavier_uniform_(self.lay0.weight, gain=1.0)
        self.lay0.weight.data = torch.tensor(np.array(self.lay0.weight.data), dtype=torch.float32).requires_grad_(True)
        self.lay0.bias.data = torch.tensor(np.array(self.lay0.bias.data), dtype=torch.float32).requires_grad_(True)
        
        
    def forward(self, batch):
        # batch =  Batch_Sz * SeqLen x Embs + Hidden
        batch = self.lay0(batch) # batch =  Batch_Sz * SeqLen x SeqLen * Cnt_Cats * Cnt_Meanings
        batch = torch.sigmoid(batch)
        return batch

    def init_optims(self, opt_type, lr):
        self.current_weight_lr = lr
        self.opt = getattr(optim, opt_type)([self.lay0.weight, self.lay0.bias], lr=lr)

    def set_lr(self, lr_weight, lr_bias):
        pass

    def step(self):
        self.opt.step()
        

    def zero_grad(self):
        self.opt.zero_grad()

    def clip_grad(self, maxg=1e-2):
        pass

    def count_params(self):
        return 0

# Train Procedure

In [339]:
ALL_VOCAB.keys()

dict_keys(['First', 'First_Reverse', 'First_Second', 'First_Second_Reverse', 'Second', 'Second_Reverse'])

In [340]:
ALL_VOCAB['First_Second'][0]

Token(tid=0, value='[unk]', title=False, upper=False, part=False, w_end=True)

In [341]:
# @lru_cache(maxsize=1000)
def change_tokens(tkid):
    new_tk = ALL_VOCAB['First_Second'][tkid]
    # return new_tk.tid
    return (new_tk.tid, int(new_tk.title), int(new_tk.upper), int(new_tk.part), int(new_tk.w_end))

v_change_tokens = np.vectorize(change_tokens)

def process_batch(batch):
    return np.stack(v_change_tokens(batch), axis=2)


def reorder_posistions(matrix, wpos):
    sz, _ = matrix.shape
    indx = [wpos,] + [i for i in range(sz) if i != wpos]
    return matrix[indx,:][:,indx]

def reorder_tokens(marix, w0):
    matrix = np.array([
        marix[i,[w0[i],] + [j for j in range(marix.shape[1]) if j != w0[i]]] 
        for i in range(marix.shape[0])
    ])
    return matrix

def negative_tokens(tkid, cnt):
    neg_set = set()
    while len(neg_set) != cnt:
        tmp_set = set(np.random.randint(0, high=len(ALL_VOCAB['Second']), size=cnt - len(neg_set), dtype=int))
        neg_set = neg_set | tmp_set
    return tuple(neg_set)

v_negative_tokens = np.vectorize(negative_tokens)

In [342]:
def train_old(nets, data, 
          batch_size, max_seq_len, cnt_negative, rounds_agg, 
          mask_coef, hidden_coef, pos_sum_coef, pad,
          avg_info=100, opt_type='SGD', lr=1e-3, seed=0, cnt_epochs=10**6, ptime=False
         ):
    _ = [x.init_optims(opt_type, lr) for x in nets]
    PosEnc, PropsEmb, TokensEmb, AggNet, PropNet, CMNet = nets
    l2_none_loss = torch.nn.MSELoss(reduction='none')
    
    rows, _ = data.shape
    indices = np.arange(rows)
    np.random.seed(seed)
    
    # Определяем дефолтную матрицу позиций
    default_positions = np.array([
        [max_seq_len + j - i - 1 for j in range(max_seq_len) ] 
        for i in range(max_seq_len)
    ])
    
    hist_props_l2 = []
    hist_props_l2_mask = []
    
    hist_neg_l2 = []
    hist_neg_l2_mask = []
    
    hist_pos_max_l2 = []
    hist_pos_max_l2_mask = []
    
    hist_pos_sum_l2 = []
    hist_pos_sum_l2_mask = []

    # Цикл по эпохам
    for epoch_i in range(cnt_epochs):
        np.random.shuffle(indices) # Шафлим индексы, так как шафлить массив намного дольше
        
        for i in range(int(rows / batch_size) + 1): # Цикл по батчам
            batch_ids = indices[i*batch_size:(i+1)*batch_size]
            if len(batch_ids) == 0:
                continue # Прошлись по всей выборке
            
            _ = [n.zero_grad() for n in nets]
            t0 = time.time()
            batch = data[batch_ids,:] # Batch x SeqLen
            batch_sz = batch.shape[0]
            batch_len = batch.shape[1] - (batch == pad).sum(axis=1) # Опрделеям длины последовательностей в батче, (Batch_Sz,)
            seq_len = max(batch_len)
            if seq_len < batch.shape[1]:
                batch = batch[:, :seq_len] # Если максимальная длина меньше дефолтной, ограничиваем массив
                print(f'step 1_1: {batch.shape}')
            
            
            batch_default_positions = default_positions[:seq_len,:seq_len] # max_len x max_len
            center_words = list(map(np.random.randint, batch_len)) # Выбираем токен, который опустим в последовательности, (Batch_Sz,)
            
            batch = reorder_tokens(batch, center_words) # Меняем порядок строк/столбцов в матрице токенов, Batch_Sz x Len_Seq
            
            # Меняем порядок строк/столбцов в позиционной матрице, Batch_Sz x max_len x max_len
            pos_matrix = np.stack([reorder_posistions(default_positions, x) for x in center_words])
            
            # Делаем замену начальных токенов на новые токены + фичи + Негативный сэмплинг
            batch = process_batch(batch) # Batch_Sz x max_len x 5 (id, title, upper, w_end, part)
            negative_batch = torch.tensor(einops.rearrange(
                np.stack(v_negative_tokens(batch[:,:,0], cnt_negative), axis=2),
                'b s n -> (b s) n'
            ), dtype=torch.int).requires_grad_(False)
            batch = torch.tensor(batch, dtype=torch.int).requires_grad_(False) # Batch_Sz x max_len x 5 (id, title, upper, w_end, part)
            
            # Получаем эмбединг позиционной матрицы, Batch_Sz x max_len x max_len x pos_emb_size
            pos_matrix = PosEnc.forward2(torch.tensor(pos_matrix, dtype=torch.int).requires_grad_(False), ptime)
            
            # Получаем эмбединги фичей токенов
            # Batch_Sz x max_len x max_len x Cnt_Cats * Cnt_Meanings x (EmbTitle + EmbUpper + EmbPart + ...)
            props_embs = PropsEmb.forward4(batch[:,:,1:], ptime)
            
            # Получаем эмбединги токенов
            tokens_embs, tokens_embs_ext, negative_batch = TokensEmb.forward5(batch[:,:,0], negative_batch, ptime)
            
            # tokens_embs = Batch_Sz x max_len x x max_len x Cnt_Cats * Cnt_Meanings x (CatEmb + EmbMeaning)
            tokens_embs = einops.rearrange(tokens_embs, 'b s cm e -> (b s) cm e')
            
            all_tk_cat_ext = einops.repeat(all_tk_cat, 'b s_h cm e-> b s_w s_h cm e', s_w=columns)
            
        
            
            
            
            # Объединяем все матрицы в одну
            pos_props_tokens = torch.cat([pos_matrix, props_embs, tokens_embs_ext], axis=-1)
            
            # Создаем маску, чтобы не учитывать центранльные токены при агрегировании
            # TODO: Добавить логигу когда у нас будет PAD в последовательности
            mask = torch.tensor([-np.inf,] + [1.0 for i in range(max_len - 1)], dtype=torch.float32).requires_grad_(False)
            mask_ext = einops.repeat(mask, 's1 -> s2 (s1 m)', s2=batch_sz * max_len, 
                                 m=TokensEmb.cnt_meanings * TokensEmb.cnt_categories).requires_grad_(False)
            
            mask = einops.repeat(mask, 's1 -> s2 s1', s2=batch_sz * max_len, 
                                 m=TokensEmb.cnt_meanings * TokensEmb.cnt_categories).requires_grad_(False)
            # Подаем собранные данные в агрегирующую рекурсивную сеть,
            # На выходе массив скрытых состояний.
            hidden_states = AggNet.forward(pos_props_tokens, mask, rounds_agg)
            hidden_states = einops.rearrange(hidden_states, 'r bs e -> (r bs) e')
            
            hidden_states_ext = einops.repeat(
                hidden_states, 'rbs e -> rbs m e',
                m=TokensEmb.cnt_categories * TokensEmb.cnt_meanings * (cnt_negative + 1)
            )
            tokens_props = einops.repeat(batch[:,:,1:], 'b s f -> (m b s) f', m=rounds_agg).requires_grad_(False)
            props_pred = PropNet.forward(hidden_states)
            
            loss_mask = torch.tensor([1.0, ] + [mask_coef, ] * (max_len - 1), dtype=torch.float32).requires_grad_(False)
            loss_mask = einops.repeat(loss_mask, 's -> (rp b s) f', rp=rounds_agg, b=batch_sz, f=1).requires_grad_(False)
            
            
            loss_mask2 = torch.tensor(
                [(i/rounds_agg)**hidden_coef for i in range(1, rounds_agg + 1)],
                dtype=torch.float32).requires_grad_(False)
            loss_mask2 = einops.repeat(loss_mask2, 'h -> (h bs) f', bs=batch_sz*max_len , f=1).requires_grad_(False)
            
            props_l2 = l2_none_loss(props_pred, tokens_props.type(torch.float32)) # Лосс на свойства токенов
            props_l2_mask = props_l2 * loss_mask * loss_mask2
            
            hist_props_l2.append(props_l2.mean().item())
            hist_props_l2_mask.append(props_l2_mask.mean().item())
            
            if len(hist_props_l2) > avg_info:
                hist_props_l2.pop(0)
                hist_props_l2_mask.pop(0)
                
                
            # Предсказание слов
            true_false_batch = torch.cat([tokens_embs, negative_batch], axis=1)
            true_false_batch = einops.repeat(true_false_batch, 'bs m e -> (r bs) m e', r=rounds_agg)
            
            meanings_pred = CMNet.forward(torch.cat([hidden_states_ext,true_false_batch], axis=-1)).squeeze()
            
            # Негативный лосс
            neg_fact = torch.zeros(
                (rounds_agg * batch_sz * max_len, cnt_negative * TokensEmb.cnt_categories * TokensEmb.cnt_meanings),
                dtype=torch.float32).requires_grad_(False)
            neg_l2 = l2_none_loss(meanings_pred[:, (TokensEmb.cnt_categories * TokensEmb.cnt_meanings):], neg_fact) # Лосс на негативные значения
            neg_l2_mask = neg_l2 * loss_mask * loss_mask2
            
            hist_neg_l2.append(neg_l2.mean().item())
            hist_neg_l2_mask.append(neg_l2_mask.mean().item())
            
            if len(hist_neg_l2) > avg_info:
                hist_neg_l2.pop(0)
                hist_neg_l2_mask.pop(0)
                
            
            # Позитивный лосс
            max_pred = meanings_pred[:, :(TokensEmb.cnt_categories * TokensEmb.cnt_meanings)].max(axis=1)[0]
            pos_fact_1 = torch.ones((rounds_agg * batch_sz * max_len,), dtype=torch.float32).requires_grad_(False)
            pos_max_l2 = l2_none_loss(max_pred, pos_fact_1)
            pos_max_l2_mask = pos_max_l2 * loss_mask.squeeze() * loss_mask2.squeeze() 
            
            sum_pred = meanings_pred[:, :(TokensEmb.cnt_categories * TokensEmb.cnt_meanings)].sum(axis=1)
            pos_sum_l2 = l2_none_loss(sum_pred, pos_fact_1)
            pos_sum_l2_mask = pos_sum_l2 * loss_mask.squeeze() * loss_mask2.squeeze() 
            
            
            hist_pos_max_l2.append(pos_max_l2.mean().item())
            hist_pos_max_l2_mask.append(pos_max_l2_mask.mean().item())
            
            hist_pos_sum_l2.append(pos_sum_l2.mean().item())
            hist_pos_sum_l2_mask.append(pos_sum_l2_mask.mean().item())
            
            if len(hist_pos_max_l2) > avg_info:
                hist_pos_max_l2.pop(0)
                hist_pos_max_l2_mask.pop(0)
                
                hist_pos_sum_l2.pop(0)
                hist_pos_sum_l2_mask.pop(0)
            
            
            
            final_loss = props_l2_mask.mean() + neg_l2_mask.mean() + pos_max_l2_mask.mean() + \
                pos_sum_coef * pos_sum_l2_mask.mean()
            
            # final_loss = props_l2_mask.mean()
            print(f'final_loss: {final_loss.item()}')
            final_loss.backward()
            
            _ = [n.step() for n in nets]
            
            print(f'Batch Time: {time.time() - t0}')
            return None
            # return (hist_props_l2[0], hist_props_l2_mask[0], hist_neg_l2[0], hist_neg_l2_mask[0], 
                    # hist_pos_max_l2[0], hist_pos_max_l2_mask[0], hist_pos_sum_l2[0], hist_pos_sum_l2_mask[0])
            

In [None]:
def train(nets, data, 
          batch_size, cnt_negative, rounds_agg, 
          mask_coef, hidden_coef, pos_sum_coef, pad,
          avg_info=100, opt_type='SGD', lr=1e-3, seed=0, max_seq_len=1024, cnt_epochs=10**6, ptime=False
         ):
    _ = [x.init_optims(opt_type, lr) for x in nets]
    PosEnc, PropsEmb, TokensEmb, AggNet, PropNet, CMNet = nets
    l2_none_loss = torch.nn.MSELoss(reduction='none')
    
    rows, _ = data.shape
    indices = np.arange(rows)
    np.random.seed(seed)
    
    # Определяем дефолтную матрицу позиций
    default_positions = np.array([
        [max_seq_len + j - i - 1 for j in range(max_seq_len) ] 
        for i in range(max_seq_len)
    ]) # Dist_ij = Dist(Token_i, Token_j)
    
    hist_props_l2 = []
    hist_props_l2_mask = []
    
    hist_neg_l2 = []
    hist_neg_l2_mask = []
    
    hist_pos_max_l2 = []
    hist_pos_max_l2_mask = []
    
    hist_pos_sum_l2 = []
    hist_pos_sum_l2_mask = []

    # Цикл по эпохам
    for epoch_i in range(cnt_epochs):
        np.random.shuffle(indices) # Шафлим индексы, так как шафлить массив намного дольше
        
        for i in range(int(rows / batch_size) + 1): # Цикл по батчам
            batch_ids = indices[i*batch_size:(i+1)*batch_size]
            if len(batch_ids) == 0:
                continue # Прошлись по всей выборке
            
            _ = [n.zero_grad() for n in nets]
            t0 = time.time()
            batch = data[batch_ids,:] # Batch x SeqLen
            batch_sz = batch.shape[0]
            batch_len = batch.shape[1] - (batch == pad).sum(axis=1) # Опрделеям длины последовательностей в батче, (Batch,)
            seq_len = max(batch_len)
            if seq_len < batch.shape[1]:
                batch = batch[:, :seq_len] # Если максимальная длина меньше дефолтной, ограничиваем массив
                print(f'step 1_1: {batch.shape}')
            
            
            pos_matrix = default_positions[:seq_len,:seq_len] # seq_len x seq_len
            
            # Делаем замену начальных токенов на новые токены + фичи + Негативный сэмплинг
            batch = process_batch(batch) # batch x seq_len x 5 (id, title, upper, w_end, part)
            batch = torch.tensor(batch, dtype=torch.int).requires_grad_(False) # batch x seq_len x 5 (id, title, upper, w_end, part)
            
            negative_batch = torch.tensor(
                np.stack(v_negative_tokens(batch[:,:,0], cnt_negative), axis=2), 
                dtype=torch.int
            ).requires_grad_(False) # batch x seq_len x cnt_negative
            
            
            # Получаем эмбединг позиционной матрицы, seq_len x seq_len x pos_emb
            pos_matrix = PosEnc.forward2(torch.tensor(pos_matrix, dtype=torch.int).requires_grad_(False), ptime)
            pos_matrix = einops.repeat('s_w s_h e -> b s_w s_h (c m) e', b=batch_sz, 
                                       c=TokensEmb.cnt_categories, m=TokensEmb.cnt_meanings)
            # pos_matrix = batch x seq_len x seq_len x (cnt_cats * cnt_meanings) x pos_emb
            
            
            # Получаем эмбединги фичей токенов
            props_embs = PropsEmb.forward4(batch[:,:,1:], ptime)
            # batch x seq_len x props_embs
            
            props_embs = einops.repeat(props_embs, 'b s_h e -> b s_w s_h cm e', s_w=seq_len, cm=self.cnt_repeats) 
            # props_embs = batch x seq_len (s_w) x seq_len (s_h) x (cnt_cats * cnt_meanings) x props_embs
            
            # Получаем эмбединги токенов
            tokens_embs, negative_batch = TokensEmb.forward5(batch[:,:,0], negative_batch, ptime)
            # tokens_embs = batch x seq_len x (cnt_cats * cnt_meanings) x (cat_emb + emb_meaning)
            # negative_batch = batch x seq_len x (cnt_negative * cnt_cats * cnt_meanings) x (cat_emb + meaning_emb)

            # tokens_embs_ext = batch x seq_len x seq_len x (cnt_cats * cnt_meanings) x (cat_emb + emb_meaning)
            tokens_embs_ext = einops.repeat(tokens_embs, 'b s_h cm e-> b s_w s_h cm e', s_w=seq_len)
            
            
            tokens_embs = einops.rearrange(tokens_embs, 'b s cm e -> (b s) cm e')
            # tokens_embs = (batch * seq_len) x (cnt_cats * cnt_meanings) x (cat_emb + emb_meaning)
            
            
            # Объединяем все матрицы в одну
            pos_props_tokens = torch.cat([pos_matrix, props_embs, tokens_embs_ext], axis=-1)
            # pos_props_tokens = batch x seq_len x seq_len x (cnt_cats * cnt_meanings) x (pos_emb + cat_emb + emb_meaning)
            
            # TODO: Добавить логику маскирования pad токенов.
            # Подаем собранные данные в агрегирующую рекурсивную сеть,
            # На выходе массив скрытых состояний.
            hidden_states = AggNet.forward(pos_props_tokens, None, rounds_agg)
            hidden_states = einops.rearrange(hidden_states, 'r bs e -> (r bs) e')
            
            hidden_states_ext = einops.repeat(
                hidden_states, 'rbs e -> rbs m e',
                m=TokensEmb.cnt_categories * TokensEmb.cnt_meanings * (cnt_negative + 1)
            )
            tokens_props = einops.repeat(batch[:,:,1:], 'b s f -> (m b s) f', m=rounds_agg).requires_grad_(False)
            props_pred = PropNet.forward(hidden_states)
            
            loss_mask = torch.tensor([1.0, ] + [mask_coef, ] * (max_len - 1), dtype=torch.float32).requires_grad_(False)
            loss_mask = einops.repeat(loss_mask, 's -> (rp b s) f', rp=rounds_agg, b=batch_sz, f=1).requires_grad_(False)
            
            
            loss_mask2 = torch.tensor(
                [(i/rounds_agg)**hidden_coef for i in range(1, rounds_agg + 1)],
                dtype=torch.float32).requires_grad_(False)
            loss_mask2 = einops.repeat(loss_mask2, 'h -> (h bs) f', bs=batch_sz*max_len , f=1).requires_grad_(False)
            
            props_l2 = l2_none_loss(props_pred, tokens_props.type(torch.float32)) # Лосс на свойства токенов
            props_l2_mask = props_l2 * loss_mask * loss_mask2
            
            hist_props_l2.append(props_l2.mean().item())
            hist_props_l2_mask.append(props_l2_mask.mean().item())
            
            if len(hist_props_l2) > avg_info:
                hist_props_l2.pop(0)
                hist_props_l2_mask.pop(0)
                
                
            # Предсказание слов
            true_false_batch = torch.cat([tokens_embs, negative_batch], axis=1)
            true_false_batch = einops.repeat(true_false_batch, 'bs m e -> (r bs) m e', r=rounds_agg)
            
            meanings_pred = CMNet.forward(torch.cat([hidden_states_ext,true_false_batch], axis=-1)).squeeze()
            
            # Негативный лосс
            neg_fact = torch.zeros(
                (rounds_agg * batch_sz * max_len, cnt_negative * TokensEmb.cnt_categories * TokensEmb.cnt_meanings),
                dtype=torch.float32).requires_grad_(False)
            neg_l2 = l2_none_loss(meanings_pred[:, (TokensEmb.cnt_categories * TokensEmb.cnt_meanings):], neg_fact) # Лосс на негативные значения
            neg_l2_mask = neg_l2 * loss_mask * loss_mask2
            
            hist_neg_l2.append(neg_l2.mean().item())
            hist_neg_l2_mask.append(neg_l2_mask.mean().item())
            
            if len(hist_neg_l2) > avg_info:
                hist_neg_l2.pop(0)
                hist_neg_l2_mask.pop(0)
                
            
            # Позитивный лосс
            max_pred = meanings_pred[:, :(TokensEmb.cnt_categories * TokensEmb.cnt_meanings)].max(axis=1)[0]
            pos_fact_1 = torch.ones((rounds_agg * batch_sz * max_len,), dtype=torch.float32).requires_grad_(False)
            pos_max_l2 = l2_none_loss(max_pred, pos_fact_1)
            pos_max_l2_mask = pos_max_l2 * loss_mask.squeeze() * loss_mask2.squeeze() 
            
            sum_pred = meanings_pred[:, :(TokensEmb.cnt_categories * TokensEmb.cnt_meanings)].sum(axis=1)
            pos_sum_l2 = l2_none_loss(sum_pred, pos_fact_1)
            pos_sum_l2_mask = pos_sum_l2 * loss_mask.squeeze() * loss_mask2.squeeze() 
            
            
            hist_pos_max_l2.append(pos_max_l2.mean().item())
            hist_pos_max_l2_mask.append(pos_max_l2_mask.mean().item())
            
            hist_pos_sum_l2.append(pos_sum_l2.mean().item())
            hist_pos_sum_l2_mask.append(pos_sum_l2_mask.mean().item())
            
            if len(hist_pos_max_l2) > avg_info:
                hist_pos_max_l2.pop(0)
                hist_pos_max_l2_mask.pop(0)
                
                hist_pos_sum_l2.pop(0)
                hist_pos_sum_l2_mask.pop(0)
            
            
            
            final_loss = props_l2_mask.mean() + neg_l2_mask.mean() + pos_max_l2_mask.mean() + \
                pos_sum_coef * pos_sum_l2_mask.mean()
            
            # final_loss = props_l2_mask.mean()
            print(f'final_loss: {final_loss.item()}')
            final_loss.backward()
            
            _ = [n.step() for n in nets]
            
            print(f'Batch Time: {time.time() - t0}')
            return None
            # return (hist_props_l2[0], hist_props_l2_mask[0], hist_neg_l2[0], hist_neg_l2_mask[0], 
                    # hist_pos_max_l2[0], hist_pos_max_l2_mask[0], hist_pos_sum_l2[0], hist_pos_sum_l2_mask[0])
            

In [348]:
PosEnc = PosEncoding(
    rows=(MAX_SEQ_LEN - 1) * 2 + 1, 
    emb_size=POS_EMB_SIZE,
    cnt_repeats=CNT_MEANINGS * len(CAT_SIZES)
)
PropsEmb = PropsEmbeding(
    title=TITLE_EMB_SIZE, 
    upper=UPPER_EMB_SIZE, 
    part=PART_EMB_SIZE, 
    w_end=END_EMB_SIZE,
    cnt_repeats=CNT_MEANINGS * len(CAT_SIZES),
)

TokensEmb = TokensEmbeding(
    cnt_tokens=len(ALL_VOCAB['Second']), 
    categories_sz=CAT_SIZES, 
    cnt_meanings=CNT_MEANINGS,
    meaning_emb_sz=MEANING_EMB_SIZE,
    cat_emb_sz=CAT_EMB_SIZE
)

AggNet = AggregationNet(
    hidden_size=HIDDEN_SIZE, 
    full_embeding_size=INPUT_SIZE,
    short_embeding_size=PREDICT_SIZE
)

PropNet = PropertyNet(
    input_size=HIDDEN_SIZE
)

CMNet = CatMeangNet(
    input_size=HIDDEN_SIZE+CAT_EMB_SIZE+MEANING_EMB_SIZE
)


MAX_POSITION_MATRIX = np.array([
        [MAX_SEQ_LEN + j - i - 1 for j in range(MAX_SEQ_LEN)] 
        for i in range(MAX_SEQ_LEN)
    ]) # Dist_ij = Dist(Token_i, Token_j)

In [349]:
MAX_POSITION_MATRIX.max()

2046

In [344]:
returned_vals = train(
    nets=[PosEnc, PropsEmb, TokensEmb, AggNet, PropNet, CMNet],
    data=train_tokens, 
    batch_size=BATCH_SIZE, 
    max_seq_len=MAX_SEQ_LEN,
    cnt_negative=CNT_NEGATIVE,
    rounds_agg=AGG_ROUNDS,
    opt_type='SGD', 
    lr=1e-3, 
    mask_coef=0.8,
    hidden_coef=1.15,
    pos_sum_coef=0.05,
    pad=wrapped_tokenizer.pad_token_id,
    ptime=False
)
# max_pred, pos_fact_1, pos_max_l2 = returned_vals
# hist_props_l2, hist_props_l2_mask, hist_neg_l2, hist_neg_l2_mask, hist_pos_max_l2, hist_pos_max_l2_mask, hist_pos_sum_l2, hist_pos_sum_l2_mask = returned_vals
# props_embs = [Batch, SeqLen, SeqLen, CntCats * CntMeanings, PropsEmbs]
# props_embs[Any, i, j, Any, :] == props_embs[Any, t, j, Any, :] 
# props_embs[Any, i, j, Any, :] != props_embs[Any, i, r, Any, :] 

torch.Size([480, 30, 150]) torch.Size([480, 600])


RuntimeError: The size of tensor a (30) must match the size of tensor b (600) at non-singleton dimension 1

In [346]:
np.array([
        [1024 + j - i - 1 for j in range(1024) ] 
        for i in range(1024)
    ]).shape

(1024, 1024)

In [245]:
neg_fact.dtype

torch.float32

In [248]:
neg_fact.type(torch.int)

tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], dtype=torch.int32)

In [243]:
returned_vals

array([[[29, 25, 26, ..., 52, 53, 54],
        [33, 29, 30, ..., 56, 57, 58],
        [32, 28, 29, ..., 55, 56, 57],
        ...,
        [ 6,  2,  3, ..., 29, 30, 31],
        [ 5,  1,  2, ..., 28, 29, 30],
        [ 4,  0,  1, ..., 27, 28, 29]],

       [[29, 21, 22, ..., 48, 49, 50],
        [37, 29, 30, ..., 56, 57, 58],
        [36, 28, 29, ..., 55, 56, 57],
        ...,
        [10,  2,  3, ..., 29, 30, 31],
        [ 9,  1,  2, ..., 28, 29, 30],
        [ 8,  0,  1, ..., 27, 28, 29]],

       [[29,  1,  2, ..., 27, 28, 30],
        [57, 29, 30, ..., 55, 56, 58],
        [56, 28, 29, ..., 54, 55, 57],
        ...,
        [31,  3,  4, ..., 29, 30, 32],
        [30,  2,  3, ..., 28, 29, 31],
        [28,  0,  1, ..., 26, 27, 29]],

       ...,

       [[29, 19, 20, ..., 46, 47, 48],
        [39, 29, 30, ..., 56, 57, 58],
        [38, 28, 29, ..., 55, 56, 57],
        ...,
        [12,  2,  3, ..., 29, 30, 31],
        [11,  1,  2, ..., 28, 29, 30],
        [10,  0,  1, ..., 27, 28

In [140]:
hist_props_l2, hist_props_l2_mask

(0.25330349802970886, 0.01613856665790081)

In [141]:
hist_neg_l2, hist_neg_l2_mask

(0.21183332800865173, 0.01289786770939827)

In [142]:
hist_pos_max_l2, hist_pos_max_l2_mask

(0.26231759786605835, 0.019070377573370934)

In [143]:
hist_pos_sum_l2, hist_pos_sum_l2_mask

(66.89488220214844, 4.0206427574157715)

In [44]:
neg_fact = torch.zeros((10 * 480, 100))

In [45]:
l2_none_loss = torch.nn.MSELoss(reduction='none')

In [46]:
neg_pred = l2_none_loss(neg_fact, meanings_pred[:, 20:])

In [47]:
neg_pred.shape

torch.Size([4800, 100])

In [53]:
(neg_pred * loss_mask2).shape

torch.Size([4800, 100])

In [48]:
loss_mask2.shape

torch.Size([4800, 1])

In [51]:
loss_mask2[480,:]

tensor([0.1571])

In [721]:
hidden_states_ext.shape
# hidden_states = einops.repeat(hidden_states, 'r bs e -> (r bs) m e',
# m=TokensEmb.cnt_categories * TokensEmb.cnt_meanings * cnt_negative, )

torch.Size([4800, 120, 180])

In [None]:
# [4800, 120, 180] = [10(agg) x 16 (bs) 30(seq), 120, 180] 

In [722]:
true_false_batch.shape

torch.Size([4800, 120, 80])

In [705]:
true_false_batch = torch.cat([tokens_embs, negative_batch], axis=1)
true_false_batch.shape

torch.Size([480, 120, 80])

In [692]:
torch.cat([cat_emb_ext2, negative_batch2], axis=-1)

(torch.Size([480, 100, 30]), torch.Size([480, 100, 50]))

In [688]:
cat_emb_ext2[0,0,:]

tensor([-0.1312, -0.0948,  0.2202,  0.1414,  0.2555,  0.4131,  0.3525, -0.3640,
        -0.0362,  0.1579, -0.0383, -0.3065,  0.1020, -0.2766, -0.3648,  0.1268,
         0.3612,  0.1698,  0.0396, -0.1943, -0.0156,  0.1032, -0.1395,  0.2881,
         0.1752, -0.4159, -0.2856, -0.0416, -0.0387,  0.4081],
       grad_fn=<SliceBackward0>)

In [691]:
cat_emb_ext2[0,25,:]

tensor([-0.0022,  0.0484,  0.1391,  0.3842, -0.0386,  0.2287,  0.2664, -0.2704,
        -0.2611,  0.2052,  0.1788,  0.4159,  0.0611, -0.3091, -0.4118,  0.2832,
        -0.1521, -0.4052, -0.1839,  0.3027, -0.3857, -0.3729,  0.0468, -0.3128,
         0.3849,  0.0253,  0.0616, -0.1326, -0.2529, -0.1535],
       grad_fn=<SliceBackward0>)

In [665]:
negative_batch[0].shape

torch.Size([480, 25, 50])

In [550]:
print(tokens_props.shape)
tokens_props = einops.repeat(tokens_props, 'b s f -> (m b s) f', m=10)
tokens_props.shape

torch.Size([16, 30, 4])


torch.Size([4800, 4])

In [552]:
hidden_states[0].shape

torch.Size([480, 180])

In [553]:
hidden_states2 = einops.rearrange(hidden_states, 'r bs e -> (r bs) e')
hidden_states2.shape

torch.Size([4800, 180])

In [554]:
tmp = PropNet.forward(hidden_states2)

In [556]:
tokens_props.shape

torch.Size([4800, 4])

In [557]:
tmp.shape

torch.Size([4800, 4])

In [560]:
l2_mean_loss = torch.nn.MSELoss(reduction='none')

In [561]:
l2_mean_loss(tmp, tokens_props).shape

torch.Size([4800, 4])

In [566]:
loss_mask = torch.tensor([1.0, ] + [0.05] * 29)

In [641]:
loss_mask2 = torch.tensor([(i/10)**1.15 for i in range(1,11)])
loss_mask2

tensor([0.0708, 0.1571, 0.2504, 0.3486, 0.4506, 0.5557, 0.6635, 0.7737, 0.8859,
        1.0000])

In [638]:
loss_mask2 = torch.tensor([(i/10)**1.1 for i in range(1,11)])
loss_mask2 =einops.repeat(loss_mask2, 'h -> (h bs) f', bs=480, f=1)
loss_mask2, loss_mask2.shape

(tensor([[0.0794],
         [0.0794],
         [0.0794],
         ...,
         [1.0000],
         [1.0000],
         [1.0000]]),
 torch.Size([4800, 1]))

In [599]:
loss_mask2 =einops.repeat(loss_mask2, 'h -> (h bs) f', bs=480, f=1)
loss_mask2.shape

torch.Size([4800, 1])

In [604]:
loss_mask2[479,:]

tensor([0.0631])

In [572]:
loss_mask2 = einops.repeat(loss_mask, 's -> (rp b s) f', rp=10, b=16, f=1)
loss_mask2.shape

torch.Size([4800, 1])

In [573]:
loss_mask2[0,:]

tensor([1.])

In [577]:
loss_mask2[30,:]

tensor([1.])

In [562]:
l2_mean_loss(tmp, tokens_props).mean()

tensor(0.1631, grad_fn=<MeanBackward0>)

In [536]:
pd.DataFrame(tokens_props).groupby([0,1,2,3]).size().to_frame('cnt').reset_index()

Unnamed: 0,0,1,2,3,cnt
0,0,0,0,0,6
1,0,0,0,1,394
2,0,0,1,0,5
3,0,0,1,1,19
4,1,0,0,0,12
5,1,0,0,1,40
6,1,1,0,0,1
7,1,1,0,1,3


In [283]:
%%time
ftmp = torch.cat([pos_matrix, props_embs, tokens_embs], axis=4)
ftmp = einops.rearrange(ftmp, 'b s1 s2 m e -> (b s1) (s2 m) e')

CPU times: user 230 ms, sys: 177 ms, total: 408 ms
Wall time: 84.6 ms


In [282]:
%%time
tmp1 = einops.rearrange(pos_matrix, 'b s1 s2 m e -> (b s1) (s2 m) e')
tmp2 = einops.rearrange(props_embs, 'b s1 s2 m e -> (b s1) (s2 m) e')
tmp3 = einops.rearrange(tokens_embs, 'b s1 s2 m e -> (b s1) (s2 m) e')
ftmp = torch.cat([tmp1, tmp2, tmp3], axis=-1)

CPU times: user 433 ms, sys: 454 ms, total: 887 ms
Wall time: 194 ms


In [281]:
%%time
ftmp = torch.cat([tmp1, tmp2, tmp3], axis=-1)

CPU times: user 287 ms, sys: 203 ms, total: 491 ms
Wall time: 127 ms


In [302]:
ftmp.shape

torch.Size([480, 600, 140])

In [303]:
140 * 1.5

210.0

In [304]:
pos_matrix.shape

torch.Size([16, 30, 30, 20, 20])

In [306]:
props_embs.shape

torch.Size([16, 30, 30, 20, 40])

In [307]:
tokens_embs.shape

torch.Size([16, 30, 30, 20, 80])

In [308]:
POS_EMB_SIZE + TITLE_EMB_SIZE + UPPER_EMB_SIZE + PART_EMB_SIZE + END_EMB_SIZE + CAT_EMB_SIZE + MEANING_EMB_SIZE

140

In [309]:
TITLE_EMB_SIZE + UPPER_EMB_SIZE + PART_EMB_SIZE + END_EMB_SIZE + CAT_EMB_SIZE + MEANING_EMB_SIZE

120

In [261]:
_ = einops.rearrange([a, b, c], 'l a b -> a (l b)')

RuntimeError: stack expects each tensor to be equal size, but got [10, 10] at entry 0 and [10, 20] at entry 1

In [252]:
%%time
_ = einops.rearrange([pos_matrix, props_embs, tokens_embs],
                     'l b s1 s2 c e -> b s1 s2 c e'
                    )

RuntimeError: stack expects each tensor to be equal size, but got [128, 30, 30, 20, 20] at entry 0 and [128, 30, 30, 20, 40] at entry 1

In [240]:
pos_matrix[0,0,0,0,:]

tensor([-0.0785, -0.2282, -0.2500,  0.0691, -0.0209, -0.1392,  0.0557,  0.1047,
         0.2192,  0.2140, -0.0413, -0.2430, -0.2490,  0.2573,  0.1218,  0.1201,
        -0.2384,  0.2552,  0.2611,  0.2488], dtype=torch.float64,
       grad_fn=<SliceBackward0>)

In [242]:
pos_matrix[0,0,0,15,:]

tensor([-0.0785, -0.2282, -0.2500,  0.0691, -0.0209, -0.1392,  0.0557,  0.1047,
         0.2192,  0.2140, -0.0413, -0.2430, -0.2490,  0.2573,  0.1218,  0.1201,
        -0.2384,  0.2552,  0.2611,  0.2488], dtype=torch.float64,
       grad_fn=<SliceBackward0>)

In [243]:
pos_matrix[0,1,0,5,]

tensor([-0.0691, -0.0598,  0.2026, -0.2136,  0.0293,  0.2592, -0.0379,  0.2140,
        -0.0849,  0.2218, -0.2666, -0.0397, -0.0484,  0.0893,  0.1081,  0.2116,
        -0.0411, -0.0109,  0.1887, -0.0746], dtype=torch.float64,
       grad_fn=<SelectBackward0>)

In [206]:
props_embs[0,0,0,0,:] == props_embs[0,0,2,0,:]

tensor([True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True])

In [None]:
props_embs[:,0,0,0,:]

In [25]:
%%time
old_var = TokensEmb.forward(batch)

TokensEmbeding, forward time: 0.4157431125640869
CPU times: user 580 ms, sys: 868 ms, total: 1.45 s
Wall time: 417 ms


In [26]:
%%time
old_var2 = TokensEmb.forward2(extend_batch)

TokensEmbeding, forward2 time: 3.1199347972869873
CPU times: user 2.94 s, sys: 4.99 s, total: 7.93 s
Wall time: 3.16 s


In [27]:
%%time
old_var3 = TokensEmb.forward3(batch)

TokensEmbeding, forward3 time: 0.0433809757232666
CPU times: user 79.1 ms, sys: 115 ms, total: 194 ms
Wall time: 44.4 ms


In [28]:
%%time
old_var4 = TokensEmb.forward4(batch)

TokensEmbeding, forward3 time: 0.032254934310913086
CPU times: user 73.2 ms, sys: 55.6 ms, total: 129 ms
Wall time: 33.1 ms


In [29]:
%%time
old_var5 = TokensEmb.forward5(batch)

cat_emb: torch.Size([128, 30, 20, 30])
cnt_categories: 4
cnt_meanings: 5
cat_emb: torch.Size([128, 30, 20, 50])
TokensEmbeding, forward5 time: 0.03566884994506836
CPU times: user 67.5 ms, sys: 90.3 ms, total: 158 ms
Wall time: 37.1 ms


In [30]:
(old_var == old_var4).sum()

tensor(184320000)

In [797]:
old_var3.shape

torch.Size([128, 30, 30, 20, 80])

In [798]:
128 * 30 * 30 * 20 * 80

184320000

In [704]:
%%time
old_var = PropsEmb.forward(extend_batch[:,:,:,1:])

PropsEmbeding, forward time: 0.24170589447021484
CPU times: user 350 ms, sys: 489 ms, total: 839 ms
Wall time: 262 ms


In [705]:
%%time
old_var2 = PropsEmb.forward2(extend_batch[:,:,:,1:])

PropsEmbeding, forward2 time: 0.6497049331665039
CPU times: user 1.25 s, sys: 1.01 s, total: 2.26 s
Wall time: 650 ms


In [707]:
%%time
old_var3 = PropsEmb.forward3(batch[:,:,1:])

PropsEmbeding, forward3 time: 0.20324993133544922
CPU times: user 224 ms, sys: 410 ms, total: 634 ms
Wall time: 203 ms


In [714]:
%%time
new_var = PropsEmb.forward4(batch[:,:,1:])

PropsEmbeding, forward4 time: 0.0036687850952148438
CPU times: user 2.26 ms, sys: 2.88 ms, total: 5.14 ms
Wall time: 3.81 ms


In [696]:
%%time
old_var = PosEnc.forward(torch.tensor(pos_matrix))

PosEncoding, forward time: 0.06662583351135254
CPU times: user 128 ms, sys: 222 ms, total: 350 ms
Wall time: 69.1 ms


In [697]:
%%time
new_var = PosEnc.forward2(torch.tensor(pos_matrix))

PosEncoding, forward2 time: 0.0031812191009521484
CPU times: user 8.88 ms, sys: 3.81 ms, total: 12.7 ms
Wall time: 3.8 ms


In [599]:
extend_batch.shape, batch.shape

(torch.Size([128, 30, 30]), torch.Size([128, 30]))

In [574]:
a= PropsEmb.forward(extend_batch)

time: 0.13287878036499023


In [575]:
b= PropsEmb.forward2(extend_batch)

time: 0.4492971897125244


In [586]:
c= PropsEmb.forward3(batch[:,:,1:])

time: 0.09531116485595703


In [636]:
a = TokensEmb.forward(batch)

time1: 0.0303800106048584
time2: 0.01830887794494629
time3: 0.19798922538757324
time4: 0.0001437664031982422
time: 0.24683094024658203


In [642]:
b = TokensEmb.forward2(extend_batch)

time1: 1.5422999858856201
time: 2.8998541831970215


In [644]:
torch.randint((128, 30, 5*4, 30+50))

TypeError: randint() received an invalid combination of arguments - got (tuple), but expected one of:
 * (int high, tuple of ints size, *, torch.Generator generator, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool requires_grad)
 * (int low, int high, tuple of ints size, *, torch.Generator generator, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool requires_grad)


In [None]:
Batch_Size x SeqLen x Cnt_Meanings * CntCats x (CatEmb + MeaningEmb)

In [643]:
len(CAT_SIZES)

4

In [665]:
a=torch.rand((1, 128, 30, 5*4, 30+50)).requires_grad_(True)

In [666]:
%%time
c = a.repeat(30, 1,1,1,1)

CPU times: user 270 ms, sys: 404 ms, total: 674 ms
Wall time: 149 ms


In [667]:
%%time
d = einops.repeat(a, 'b h w s k -> (repeat b) h w s k', repeat=30)

CPU times: user 980 µs, sys: 4.01 ms, total: 4.99 ms
Wall time: 4.37 ms


In [668]:
c.grad_fn

<RepeatBackward0 at 0x7f9909b98340>

In [671]:
d.grad_fn

<ReshapeAliasBackward0 at 0x7f9909bd2520>

In [654]:
import einops

In [653]:
! pip install einops

Collecting einops
  Downloading einops-0.4.1-py3-none-any.whl (28 kB)
Installing collected packages: einops
Successfully installed einops-0.4.1
You should consider upgrading via the '/Users/u14510182/Documents/Wiki_103/venv/bin/python -m pip install --upgrade pip' command.[0m


In [481]:
props_tokens_ems.repeat(30,1,1,1,1).permute(1,0,2,3,4)

torch.Size([128, 30, 30, 20, 120])

In [428]:
props_embs.repeat(20,1,1,1).transpose(0,3)

torch.Size([5, 128, 30, 40])

In [None]:
tmp.repeat(30,1,1,1,1).permute(1,2,0,3)

In [424]:
props_embs.unsqueeze(-1).shape

torch.Size([128, 30, 40, 1])

In [435]:
tmp = props_embs.repeat(20,1,1,1).permute(1,2,0,3)

In [436]:
tmp.shape

torch.Size([128, 30, 20, 40])

In [442]:
tmp[1, 20, 2, :]

tensor([ 0.0372, -0.3625,  0.1196, -0.6602, -0.5109, -0.3645,  0.4461,  0.4146,
        -0.3136, -0.0255,  0.3199,  0.2844, -0.4189,  0.2136,  0.3882, -0.0892,
         0.0270,  0.1638,  0.4387,  0.6790,  0.2568,  0.5872, -0.1455,  0.5291,
        -0.1140,  0.0748,  0.6403, -0.6560, -0.4452, -0.1790, -0.2137, -0.1390,
        -0.6755, -0.4683, -0.2915,  0.0262,  0.2795,  0.4243, -0.4794, -0.3079],
       dtype=torch.float64, grad_fn=<SliceBackward0>)

In [454]:
props_embs[1, 8, :] == props_embs[1, 9, :]

tensor([False, False, False, False, False, False, False, False, False, False,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True])

In [450]:
for i in range(29):
    print(i, sum(props_embs[1, i, :] == props_embs[1, i+1, :]))

0 tensor(30)
1 tensor(40)
2 tensor(40)
3 tensor(40)
4 tensor(10)
5 tensor(0)
6 tensor(30)
7 tensor(30)
8 tensor(30)
9 tensor(40)
10 tensor(40)
11 tensor(40)
12 tensor(40)
13 tensor(20)
14 tensor(10)
15 tensor(30)
16 tensor(40)
17 tensor(40)
18 tensor(40)
19 tensor(40)
20 tensor(40)
21 tensor(40)
22 tensor(40)
23 tensor(40)
24 tensor(40)
25 tensor(30)
26 tensor(30)
27 tensor(30)
28 tensor(30)


In [443]:
props_embs[1, 2, :]

tensor([ 0.0372, -0.3625,  0.1196, -0.6602, -0.5109, -0.3645,  0.4461,  0.4146,
        -0.3136, -0.0255,  0.3199,  0.2844, -0.4189,  0.2136,  0.3882, -0.0892,
         0.0270,  0.1638,  0.4387,  0.6790,  0.2568,  0.5872, -0.1455,  0.5291,
        -0.1140,  0.0748,  0.6403, -0.6560, -0.4452, -0.1790, -0.2137, -0.1390,
        -0.6755, -0.4683, -0.2915,  0.0262,  0.2795,  0.4243, -0.4794, -0.3079],
       dtype=torch.float64, grad_fn=<SliceBackward0>)

In [None]:
tmp[0, 0, 0, :]

1) batch[:,:,0] -> Token_Emb


# Arhitecture

#### Примерный план
1) Каким то образом получаем батч из числовых последовательностей токенов = [Batch_Size, MAX_SEQ_LEN]
2) Определяем конец батча
3) Определяем целевой токен для каждой строки батча
    Получаем следующие данные W0 = [Enc_Pos_0, Token_0], W_Other = [Enc_Pos, Token]
4) Рандомом из батча выбираем где мы будем учитывать 


Рассмотрим на примере одного "предложения".
Предложение -> Tokens -> Seq (Seq_length может быть не равна MAX_SEQ_LEN)
Строим матрицу Seq_Tokens = [Seq] * Seq_length.
Так же берем матрицу расстояний = [MAX_SEQ_LEN, MAX_SEQ_LEN, POS_EMB_SIZE]
Из нее извлекаем подматрицу Seq_Pos = [Seq_length, Seq_length, POS_EMB_SIZE]

Получается две матрицы - Seq_Tokens и Seq_Pos
Выбираем целевое слово, точнее его порядковый номер.
Перемещаем

In [None]:
0

Можно так же подавать неправильные токены в последовательность, с целью чтоб они все равно предсказывали верные.


In [None]:
input_seq = [token_0, token_1, token_2, token_3]
all_seq = [
    [token_0, token_1, token_2, token_3],
    [mask, token_1, token_2, token_3],
    [token_0, mask, token_2, token_3],
    [token_0, token_1, mask, token_3],
    [token_0, token_1, token_2, mask],
]

for seq in all_seq:
    seq_emb = []
    mask_pos = get_mask_pos(seq)
    
    
    for token_pos, token in enumerate(seq):
        if token != '[MASK]':
            token_info = []
            pos_emb = get_pos_encoding(token_pos, mask_pos)
            for cat in range(CNT_CATS):
                emb_meanings = get_emb(token, cat) # Cnt_Meanings x Emb_Size
                cat_emb = CATS[cat]
                for m in range(Cnt_Meanings):
                    token_info.append(torch.cat([pos_emb, cat_emb, emb_meaning]))
                    
            token_info = torch.cat(token_info) # Большой тензор, в котором заенкожена вся информация по токену
            seq_emb.append(token_info)
        
    

In [None]:
pos_encod_matrix = gen_pos_encodind(MAX_LEN_SEQ)
input_seq = [token_0, token_1, token_2, token_3]
seq_emb = []

for token_pos, token in enumerate(seq):
    token_info = []
    pos_emb = get_pos_encoding(token_pos, mask_pos)
    for cat in range(CNT_CATS):
        emb_meanings = get_emb(token, cat) # Cnt_Meanings x Emb_Size
        cat_emb = CATS[cat]
        for m in range(Cnt_Meanings):
            token_info.append(torch.cat([pos_emb, cat_emb, emb_meaning]))

    token_info = torch.cat(token_info) # Большой тензор, в котором заенкожена вся информация по токену
    seq_emb.append(token_info)
    

variant = choose_variant() # Выбираеем вариант

# Кличество вариаций обработки одной последовательности:
# 1) Изменяем какое-либо слова:
#      N вариантов выбрать изменяемое слово в последовательности.
#      N вариантов выбрать целевое слово в последовательности.
#      N^2 вариантов, если мы еще добавим опцию опускать(маскировать) ли целевое слово во время агрегации то 2 * N^2
# 2) Не изменяем слова в последовательности:
#      N вариантов выбрать целевое слово.
#      2 - опускаем или нет целевое слово при аггрегации
#