In [1]:
from transformers import MT5ForConditionalGeneration as MT5Model
from transformers import AdamW
from transformers import MT5Tokenizer
import torch
from torch.optim import Adam, lr_scheduler
import pickle

In [3]:
model = MT5Model.from_pretrained("google/mt5-small")
model.train()
tokenizer = MT5Tokenizer.from_pretrained('google/mt5-small')
optimizer = AdamW(model.parameters(), lr=0.00001)
print("Number of parameters: ", model.num_parameters())

Downloading:   0%|          | 0.00/553 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.20G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/4.31M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/82.0 [00:00<?, ?B/s]

Number of parameters:  300176768


In [4]:
with open("/kaggle/input/tolokadialogues/prepared_dialogueCR.pickle", 'rb') as file:
    dialogues = pickle.load(file)

# even stands for even number of repliques in one dialog, so each U1 replique has an answer
dialogues_even = list(map(lambda x: x[:-1] if len(x) % 2 == 1 else x, dialogues))
dialogues_even_flatten = [replique for dialogue in dialogues for replique in dialogue]
repliquesU1 = list(filter(lambda x: x.startswith("USER1:"), dialogues_even_flatten))
repliquesU2 = list(filter(lambda x: x.startswith("USER2:"), dialogues_even_flatten))

In [6]:
data = tokenizer.prepare_seq2seq_batch(src_texts=repliquesU1[:500], tgt_texts=repliquesU2[:500], return_tensors="pt", padding=True)

In [7]:
model = model.to("cuda")
data = data.to("cuda")

In [9]:
optimizer = Adam(model.parameters())
scheduler = lr_scheduler.StepLR(optimizer, 160, gamma=0.5, last_epoch=-1, verbose=False)
train_loss = 5
batch = 4
for epoch in range(20):
    model.train()
    for i in range(80):
        output = model(input_ids = data.input_ids[i*batch:(i+1)*batch], labels = data.labels[i*batch:(i+1)*batch])
        train_loss = 0.95 * train_loss + 0.05 * output.loss.item()
        output.loss.backward()
        optimizer.step()
        scheduler.step()
    
    model.eval()
    val_output = model(input_ids = data.input_ids[400:405], labels = data.labels[400:405])
    val_loss = output.loss
        
    print(f"Epoch {epoch}: Train loss {train_loss}  Validation loss {val_loss}")

Epoch 0: Train loss 2.6786779819147175  Validation loss 1.6937686204910278
Epoch 1: Train loss 2.7783465729119756  Validation loss 2.787876844406128
Epoch 2: Train loss 2.414172465153645  Validation loss 1.4324270486831665
Epoch 3: Train loss 2.460768471373204  Validation loss 1.3552175760269165
Epoch 4: Train loss 2.0129902701916014  Validation loss 1.126721978187561
Epoch 5: Train loss 1.8126689059359882  Validation loss 1.1642115116119385
Epoch 6: Train loss 1.6951650629713282  Validation loss 1.0683342218399048
Epoch 7: Train loss 1.672895834421422  Validation loss 0.9744287133216858
Epoch 8: Train loss 1.6667012834370414  Validation loss 0.9239241480827332
Epoch 9: Train loss 1.6130789237413314  Validation loss 0.9025110006332397
Epoch 10: Train loss 1.5773308704127742  Validation loss 0.8895124793052673
Epoch 11: Train loss 1.5408974707502643  Validation loss 0.8862259387969971
Epoch 12: Train loss 1.5257928335701434  Validation loss 0.8606579899787903
Epoch 13: Train loss 1.5243

In [11]:
model = model.to("cpu")

In [18]:
model_input = tokenizer.encode(repliquesU1[1], add_special_tokens=False, return_tensors='pt')
result_ids = model.generate(model_input, max_length=250, do_sample=True, top_p=0.95, top_k=60)
result_tokens = tokenizer.convert_ids_to_tokens(result_ids[0])
result_string = tokenizer.convert_tokens_to_string(result_tokens)
print(result_string)

<pad> ?<pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> USER2 не не<pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad

In [None]:
model.save_pretrained('models/mt5small/')