In [1]:
import torch
import json
from torch.utils.data import Dataset, DataLoader
from transformers import T5ForConditionalGeneration,T5Tokenizer

from torch.utils.tensorboard import SummaryWriter

from tqdm.auto import tqdm

In [2]:
writer = SummaryWriter()

In [44]:
class NamesDataset(Dataset):
    def __init__(
        self,
        tokenizer: T5Tokenizer,
        data_path: str,
        max_length: int = 20,
        truncation: bool = True
    ) -> None:
        self.tokenizer = tokenizer
        self.truncation = truncation
        self.max_length = max_length
        
        with open(data_path, encoding="utf-8") as json_data:
            self.data = json.load(json_data)
        
    def __getitem__(self, index):
        name = self.data[index]["name"]
        body = self.data[index]["body"]
        
        tokens = self.tokenizer(
            name,
            return_tensors='pt',
            padding="max_length",
            max_length=self.max_length,
            truncation=self.truncation
        )
        
        labels = self.tokenizer(
            body,
            return_tensors='pt',
            padding="max_length",
            max_length=self.max_length,
            truncation=self.truncation
        ).input_ids
        
        labels = labels.masked_fill_(labels == 0, -100) 
        
        return name, body, tokens.input_ids[0], tokens.attention_mask[0], labels[0]
    
    def __len__(self):
        return  len(self.data)

In [36]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

MODEL_TYPE = "base"

if MODEL_TYPE not in ["large", "base", "small"]:
    raise ValueError("Wrong size of model type")

MODEL_NAME = "sberbank-ai/ruT5-" + MODEL_TYPE

In [5]:
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME).to(DEVICE)
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)

In [37]:
training_config = {
    "lr": 1e-3,
    "epoch_num": 100,
    "opt": "Adam",
    "batch_size": 2,
    "lr_decay": 0.8
}

In [48]:
dataset = NamesDataset(
    tokenizer=tokenizer,
    data_path="../data/name_slogan.json"
)

dataloader = DataLoader(dataset, batch_size=training_config["batch_size"], shuffle=True)

In [69]:
def test_work():
    name, body, _, _, _ = dataset[5]
    
    name = "Михаил"
    body = "Вообще дебил"
    
    print(name)
    print(body)
    
    tokens = tokenizer(name, return_tensors='pt', padding="max_length", max_length=40, truncation=True).to(DEVICE)
    labels = tokenizer(body, return_tensors='pt', padding="max_length", max_length=40, truncation=True).input_ids.to(DEVICE)
    labels = labels.masked_fill_(labels == 0, -100) 
    
    output=model(
        input_ids=tokens.input_ids,
        attention_mask=tokens.attention_mask,
        labels=labels
    )
    
    print(output["loss"])
    print(tokenizer.batch_decode(torch.argmax(output["logits"], dim=2).to("cpu")))

In [70]:
test_work()

Михаил
Вообще дебил
tensor(16.1739, device='cuda:0', grad_fn=<NllLossBackward0>)
['Будь «твор с Вер Будь Будь Будь Будь Будь Будь Будь Будь Будь Будь Будь Будь Будь Будь Будь Будь Будь Будь Будь Будь Будь Будь Будь Будь Будь Будь Будь Будь Будь Будь Будь Будь Будь Будь Будь']


In [52]:
epoch_num = training_config["epoch_num"]
model.train()
opt = torch.optim.Adam(
    model.parameters(),
    lr=training_config["lr"]
)

lambda1 = lambda epoch: training_config["lr_decay"] ** epoch
scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lambda1)

In [53]:
iter_num = 0 
for epoch in range(epoch_num):
    for _, _, tokens_ids, tokens_attention, labels in tqdm(dataloader, total=len(dataset)):
        model.zero_grad()
        
        output=model(
            input_ids=tokens_ids.to(DEVICE),
            attention_mask=tokens_attention.to(DEVICE),
            labels=labels.to(DEVICE)
        )
        
        output["loss"].backward()
        
        loss_value = output.loss.detach().cpu().item()
        writer.add_scalar(f"Loss {training_config=}", loss_value, iter_num)
        opt.step()
        iter_num += 1
    
    scheduler.step()

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

  0%|          | 0/73 [00:00<?, ?it/s]

KeyboardInterrupt: 