In [None]:
import torch
import pytorch_lightning as pl
import transformers
import torchmetrics

import pandas as pd
import os
import json

from model import T5MultiTask
from data_module import TolokaDataModule

In [None]:
# proxy
os.environ["http_proxy"] = "http://proxy.ad.speechpro.com:3128"
os.environ["https_proxy"] = "http://proxy.ad.speechpro.com:3128"
os.environ["ftp_proxy"] = "http://proxy.ad.speechpro.com:3128"

In [None]:
train_batch_size = 64

In [None]:
t5 = transformers.T5ForConditionalGeneration.from_pretrained("cointegrated/rut5-base-multitask", resume_download=True)
tokenizer = transformers.AutoTokenizer.from_pretrained("cointegrated/rut5-base-multitask", truncation_side='left', padding_side='right')

with open('/home/stc/persona/data/preprocessing/spec_tokens.json') as spec_tokens_config:
    spec_tokens = json.load(spec_tokens_config)
tokenizer.add_special_tokens(
            {"additional_special_tokens": [spec_tokens[k] for k in spec_tokens]}
        )

In [None]:
datamodule=TolokaDataModule(
    data_dir='/home/stc/persona/data',
    datasets=['current_gk', 'next_answer'], #'next_answer', 'current_gk', 'next_gk'
    tokenizer=tokenizer,
    spec_tokens=spec_tokens,
    train_batch_size=128,
    val_batch_size=256,
    test_batch_size=256,
)

In [None]:
model = T5MultiTask(
    model=t5,
    datamodule=datamodule,
    lr=5e-5,
    num_warmup_steps=1000,
    pooling="mean",
    distance="cosine",
    scale=20,
    train_batch_size=train_batch_size,
    val_batch_size=256,
    test_batch_size=256,
)

In [None]:
# logger
logger = pl.loggers.comet.CometLogger(
    api_key='sEJsZrYjwc0gxxUAUGQNBwTsb',
    save_dir='/home/stc/persona/logs',
    project_name='chaT5',
    experiment_name='current_gk+next_answer base',
    log_code=True,
)

In [None]:
# trainer
trainer = pl.Trainer(
    max_epochs=15,
    accelerator="gpu",
    devices=1,
    gradient_clip_val=1,
    logger=logger,
    num_sanity_val_steps=10,
)
trainer.fit(model, datamodule=datamodule)

In [None]:
datamodule_test=TolokaDataModule(
    data_dir='/home/stc/persona/data',
    datasets=['current_gk'], #'next_answer', 'current_gk', 'next_gk'
    tokenizer=tokenizer,
    spec_tokens=spec_tokens,
    train_batch_size=train_batch_size,
    val_batch_size=256,
    test_batch_size=256,
)
val_set = datamodule_test.val_dataloader()

In [None]:
model.to('cuda')
for val_batch in val_set:
    input_ids = val_batch['current_gk']['query']['input_ids']
    true_ids = val_batch['current_gk']['candidate']['input_ids']
    out = model.transformer.generate(input_ids.to('cuda'), do_sample=True,
                num_beams=4,
                )
    out_text = model.datamodule.tokenizer.batch_decode(out, skip_special_tokens=True)
    inp_text = model.datamodule.tokenizer.batch_decode(input_ids, skip_special_tokens=True)
    true_text = model.datamodule.tokenizer.batch_decode(true_ids, skip_special_tokens=True)
    for inp, out, true in zip(inp_text, out_text, true_text):
        print('input:', inp)
        print('model output:', out)
        print('target output:', true)
        print()

In [None]:
model.to('cuda')
for val_batch in val_set:
    input_ids = val_batch['next_answer']['query']['input_ids']
    true_ids = val_batch['next_answer']['candidate']['input_ids']
    out = model.transformer.generate(input_ids.to('cuda'), do_sample=True,
                num_beams=4,
                )
    inp_text = model.datamodule.tokenizer.batch_decode(input_ids, skip_special_tokens=False)
    out_text = model.datamodule.tokenizer.batch_decode(out, skip_special_tokens=False)
    true_text = model.datamodule.tokenizer.batch_decode(true_ids, skip_special_tokens=False)
    for inp, out, true in zip(inp_text, out_text, true_text):
        print('input:', inp.replace("<pad>", "").replace("[Model]", "\n[Model]").replace("[User]", "\n[User]"))
        print('model output:', out.replace("<pad>", ""))
        print('target output:', true.replace("<pad>", ""))
        print()