In [38]:
import datasets
import evaluate
import pandas as pd
import numpy as np
from datasets import Dataset
from sklearn.model_selection import train_test_split
from transformers import (TrainingArguments, Trainer, T5ForConditionalGeneration)

from ablang import ABtokenizer
from ablang.model import AbLang
import json
import torch
import os
import gzip

from tqdm.auto import tqdm, trange
import random
from torch.nn import CrossEntropyLoss

In [39]:
class Struct:
    def __init__(self, **entries):
        self.__dict__.update(entries)

In [40]:
def generate_sample():
    with open('./dataset/paired.json', 'r') as file:
        data = json.load(file)

    X, y = zip(*data.items())
    # X = [list(X[i]) for i in range(len(X))]  # heavy
    # y = [list(y[i]) for i in range(len(y))]  # light

    return train_test_split(X, y, test_size=0.8, random_state=1)

In [41]:
X_train, X_test, y_train, y_test = generate_sample()
batch_size = 32  # сколько примеров показывем модели за один шаг
report_steps = 10  # раз в сколько шагов печатаем результат
epochs = 5  # сколько раз мы покажем данные модели

In [42]:
model = torch.load('./start_model/amodel.pt', map_location=torch.device('cpu'))

with open('./start_model/hparams.json', 'r') as file:
    parameters = json.load(file)
hparams = Struct(**parameters)

model = AbLang(hparams)
tokenizer = ABtokenizer('./start_model/vocab.json')

In [43]:
def take_it_more(y):
    array = np.zeros([y.shape[0], 132, 24])
    y = y.numpy()
    for i, y_i in enumerate(y):
        array[i, np.arange(132), y_i] = 1
    return torch.as_tensor(array)

def take_it_less(y):
    array = np.zeros([*y.shape[:2]])
    y = y.numpy()
    array = np.argmax(y, axis=2)
    return array.tolist()

In [44]:
fict_vect = '-' * 130
X_train_2 = [fict_vect, *X_train]
y_train_2 = [fict_vect, *y_train]
x = tokenizer(X_train_2, pad=True).to('cpu')
y = tokenizer(y_train_2, pad=True).to('cpu')
x = torch.nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value=tokenizer.vocab_to_token['-'])
y = torch.nn.utils.rnn.pad_sequence(y, batch_first=True, padding_value=tokenizer.vocab_to_token['-'])
y = take_it_more(y)

In [45]:
def presentation():
    x_j, y_j = x[1:6], y_train[:5]

    with torch.no_grad():
        h = model(x_j)
    h = take_it_less(h)

    for i in range(len(y_j)):
        print()
        print(tokenizer.decode(h[i])[1:].split('---')[0])
        print(y_j[i])
        print()

In [46]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

loss_fct = CrossEntropyLoss(ignore_index=-100)
model.train()
losses = []
for epoch in range(epochs):
    print('EPOCH', epoch)
    # random.shuffle(pairs)
    for i in trange(0, int(len(x) / batch_size)):
        x_i = x[i * batch_size: (i + 1) * batch_size]
        y_i = y[i * batch_size: (i + 1) * batch_size]
        optimizer.zero_grad()
        hypotheses = model(x_i)
        loss = loss_fct(hypotheses, y_i)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        if i % report_steps == 0:
            print('step', i, 'loss', np.mean(losses[-report_steps:]))
            presentation()
    
    torch.save(model, f'./models/model_epoch_{epoch}.pt')

EPOCH 0


  0%|          | 0/493 [00:00<?, ?it/s]

step 0 loss 27.595820090112586

HWA*-A*YVYAN<LAKYYVYYFKQLV-A*KNYL*HVV<WL-AAQN*QHYNLGAYAYVHLE>LPYA<WRAAAQNY<ALRHYTFLA<EYEAYSK-EY-LV
QSMLTQPPSVSGAPGQRVTISCTGSSTNIGADYDVHWYQHVPGTAPKLLIYGNKNRPSGVPDRFSGSKSGTSASLAITGLQAEDEADYYCQSYDDSL


*WHYAHWYP<AYV<-KRYE<*LYNKALYCVT<H-HYNAAWWYAP*VYHYSLEAP>>*QECCVLRAYW>YVWNN<<A<YAYY>HA>AANNYYNAELYKA
SSELTQPPSVSVSPGQTARITCSGDGFPDQYAYWYQQKPGQAPVLVIYKHSERPSGIPERFSGSSSGTTVTLTISGVQAEDEADYYCQSSDHGGTY


HWVEKHWNPTHAVLFKRYER*S>HKVLL>VLAY*VL<<HQWAAV*KHQY<*C<RAHVDY<CH<AAHHEARLNNQAA<YQEYVHL<AANLVYN-ECYKAA
EMVMTQSPATLYVSPGERAILSCRASQAIGSNLAWYQQLPGQAPRLLIYGASNRAADTPSRFSGSGSGTEFTLTISSLQSEDVGIYYCQQYDNWPP


HWVLFARYYYYAHYAYRY<L*YSN*H>ECVAMNCRYDYAYW-AVEVYR<NRGAVHLVDA<CHLAAYYCERINVNVVWYHY*VWT>AAPYAYNH<AYKQA
EIVLTQSPATLSLSPGERATLSCRASQSVSSYLAWYQQKPGQAPRLLIYDASNRATGIPARFSGSGSGTDFTLTISSLEPEDFAVYYCQQRYNWPP


HWA<AAVV<YAEELAYYFEYYSCAVANLWV*<VVVEVAVAKAYQNAY-RFAGAV>YVARA>V<YAAKLE<IAVAVA>AH<YVYAVKWPLAYKAFCANNA
SSELTQDPAVSVALGQTVRITCQGDSLRSFYASWYQQKPRQAPVLVIYSKNNRPSGIPDRFSGSSSGNTASLT

  0%|          | 0/493 [00:00<?, ?it/s]

step 0 loss 19.47814278111849

EIVMTQPPSVLSAPVQERAIICCRASQDVGNYNAWWWQKKWKQPKRLIMIYENREWPVPVDFRFSGSKTEFTLTIIFLQEEEEEEYYCCCCYYW>>
QSMLTQPPSVSGAPGQRVTISCTGSSTNIGADYDVHWYQHVPGTAPKLLIYGNKNRPSGVPDRFSGSKSGTSASLAITGLQAEDEADYYCQSYDDSL


EIVMTQSPSVLGLPVGERATICCRASQNIGNYNAWWWQKWHKAKKPLIMYYENLKRPVPPRFRFNGKKTEFTLTIIILQEEDFDEYCCCCYYCW>>>
SSELTQPPSVSVSPGQTARITCSGDGFPDQYAYWYQQKPGQAPVLVIYKHSERPSGIPERFSGSSSGTTVTLTISGVQAEDEADYYCQSSDHGGTY


EIVMTQPPSVLSVPPQERTIICCRASQNVGNYNWWWWYKWWKAPKLKIMYYENRKWIVPVDFFFSGSKTEFTLTIIILLEEEFEEYCCCCYYYC>>>->
EMVMTQSPATLYVSPGERAILSCRASQAIGSNLAWYQQLPGQAPRLLIYGASNRAADTPSRFSGSGSGTEFTLTISSLQSEDVGIYYCQQYDNWPP


EIVMTQPPSVLSAPVQERAIICCRASQDISYYNWWWWQKPHKAPKLLIMIYENRRWPVPPRFRFSGFKTEFTLTIIILLEEEFEEYYCCCYYCW>>>->
EIVLTQSPATLSLSPGERATLSCRASQSVSSYLAWYQQKPGQAPRLLIYDASNRATGIPARFSGSGSGTDFTLTISSLEPEDFAVYYCQQRYNWPP


EIVMTQPPSVLAVPPGERAIICCRASQDVGNYYNWWWQKKHKQPKLLIMYKENLKWPVIPRFRFSGKKTEFTATIIIRQEEEEDEYYCCCYYCM>>>->
SSELTQDPAVSVALGQTVRITCQGDSLRSFYASWYQQKPRQAPVLVIYSKNNRPSGIPDRFSGSSSGNTASLTITGA

  0%|          | 0/493 [00:00<?, ?it/s]

step 0 loss 19.394681541206623

EIVMTQPPDVLVAPGQERVIICCRASQNIGYYNWWWWQKKHKQKKPLIMYYENLKWIVPVDRRFSGKKTEFTLTIIFLQEEEEDEYYCCCCWYC>>>
QSMLTQPPSVSGAPGQRVTISCTGSSTNIGADYDVHWYQHVPGTAPKLLIYGNKNRPSGVPDRFSGSKSGTSASLAITGLQAEDEADYYCQSYDDSL


EIVMTQPPSVLVAPVQERVIICCRASQNVGNYNWWWWQKWWKAKKLLLMYYENLKWPVPRRFRFSGKKTEFTLLIIFLQEEEEDEYCCCCCWCM>>>
SSELTQPPSVSVSPGQTARITCSGDGFPDQYAYWYQQKPGQAPVLVIYKHSERPSGIPERFSGSSSGTTVTLTISGVQAEDEADYYCQSSDHGGTY


EIVMTQPPSVLGVPPQERVIICCRASQNVGNYNWWWWYKWHKAKKPLLYYYENNRWGVPPRRRFGGKKTEFTLTIIFLQEEEEDEYYCCCYWCW>>>->
EMVMTQSPATLYVSPGERAILSCRASQAIGSNLAWYQQLPGQAPRLLIYGASNRAADTPSRFSGSGSGTEFTLTISSLQSEDVGIYYCQQYDNWPP


DIVMTQPPSVLAVPPGERAIICCRASQNVNNNYYWWWQKPHKAKKPLIMYKNNRKWPNIPRFRFDGKKTEFTLTIIILQEEEFEEYCCCCCYYM>>>->
EIVLTQSPATLSLSPGERATLSCRASQSVSSYLAWYQQKPGQAPRLLIYDASNRATGIPARFSGSGSGTDFTLTISSLEPEDFAVYYCQQRYNWPP


EIVMTQPPSVLGVPPGERAIICCRASQDVGYYYWWWWQKWHKQKKRKIMYYENLRWPIEPRFFFSGKKTENTLTIIIREEEEEDEYYCCCYWYW>>>->
SSELTQDPAVSVALGQTVRITCQGDSLRSFYASWYQQKPRQAPVLVIYSKNNRPSGIPDRFSGSSSGNTASLTIT

  0%|          | 0/493 [00:00<?, ?it/s]

step 0 loss 19.36544970020635

EIVMTQPPSVLGVPGQERVIICCRASQNVKNYNWWWWQKPHKQKKPKIMYYENLKWGVPPRFFFGGKKTEFTLLIIFLQEEEFDEYYCCCCWYC>>>
QSMLTQPPSVSGAPGQRVTISCTGSSTNIGADYDVHWYQHVPGTAPKLLIYGNKNRPSGVPDRFSGSKSGTSASLAITGLQAEDEADYYCQSYDDSL


EIVMTQPPSVLVVPPQERAIICCRASQDIGNYYWWWWQKKHKAPKKKIMYYENRKWPVPPRFRFGGKKTEFFLTIIILLEEDFDEYCCCCYYYW>>>
SSELTQPPSVSVSPGQTARITCSGDGFPDQYAYWYQQKPGQAPVLVIYKHSERPSGIPERFSGSSSGTTVTLTISGVQAEDEADYYCQSSDHGGTY


EIVMTQPPDVLVVPPQERAIICCRASQNIGNWYWWWWQKWHKQKKPKIMYYENNRWPVIPDFFFSGKKTEFTLTIIILQEEDEDEYCCCCCWYW>>>->
EMVMTQSPATLYVSPGERAILSCRASQAIGSNLAWYQQLPGQAPRLLIYGASNRAADTPSRFSGSGSGTEFTLTISSLQSEDVGIYYCQQYDNWPP


EIVMTQPPSVLVVPPGERVIICCRASQNVKNYNWWWWQKWHKQKKPLIMYYENRKWPVPPRFRFDGKKTEFTLTIIILQEEDFEEYYCCCYYCW>>>->
EIVLTQSPATLSLSPGERATLSCRASQSVSSYLAWYQQKPGQAPRLLIYDASNRATGIPARFSGSGSGTDFTLTISSLEPEDFAVYYCQQRYNWPP


EIVMTQPPDVLGAPPGERVIICCRASQDVGNYNWWWWQKPWKQKKPKIMYYENRKWPVPVRFRFSGKKTEFFLIIIILLEEEFDEYYCCCCWCW>>>
SSELTQDPAVSVALGQTVRITCQGDSLRSFYASWYQQKPRQAPVLVIYSKNNRPSGIPDRFSGSSSGNTASLTITGAQ