# Импорты

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import get_linear_schedule_with_warmup

import jsonlines

from tqdm.auto import tqdm
from tqdm.contrib import tzip

import wandb

import torch
from torch.utils.data import Dataset, DataLoader

import torch.nn as nn
from torch.optim import AdamW

from torch.amp import autocast

import gc
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Датасет

In [None]:
def jsonl_reader(file_name):
    inputs = []
    targets = []

    with open("./data/" + file_name, "r") as file:
        reader = jsonlines.Reader(file)
        for line in reader.iter():
            inputs.append(line["inputs"])
            targets.append(line["target"])
    return inputs, targets

In [None]:
class FlanDataset(Dataset):
    def __init__(self, tokenizer):
        self.tokenized = []

        inputs, targets = jsonl_reader("flan_traslation_v22.jsonl")

        for inp, ans in tzip(inputs, targets):
            pr = f"{inp}, {ans}, {tokenizer.eos_token}"
            enc = self._encode(text=pr, tokenizer=tokenizer)
            self.tokenized += [enc]

    def __len__(self):
        return len(self.tokenized)

    def __getitem__(self, item):
        return self.tokenized[item]

    def _encode(self, text, tokenizer):
        encoded_sample = tokenizer.encode(text, padding='max_length', max_length=1024, truncation=True,
                                          return_tensors='pt')
        return encoded_sample

# Готовим всё вместе

In [None]:
tokenizer = AutoTokenizer.from_pretrained('AlexWortega/wortegaLM-1b', padding_side='right')

model = AutoModelForCausalLM.from_pretrained('AlexWortega/wortegaLM-1b')

In [None]:
tokenizer.pad_token = tokenizer.eos_token

In [None]:
flan_dataset = FlanDataset(tokenizer)
flan_dataset = torch.utils.data.ConcatDataset([flan_dataset])

In [None]:
len(flan_dataset)

In [None]:
train_loader = DataLoader(flan_dataset, shuffle=True, batch_size=16, drop_last=True)

In [None]:
model.resize_token_embeddings(len(tokenizer))

# Учим модель

In [None]:
class EMA(nn.Module):
    def __init__(self, decay):
        super(EMA, self).__init__()
        self.decay = decay
        self.shadow_params = {}

    def forward(self, model):
        for name, param in model.named_parameters():
            if param.requires_grad:
                if name not in self.shadow_params:
                    self.shadow_params[name] = param.data.clone()
                else:
                    self.shadow_params[name] -= (1 - self.decay) * (self.shadow_params[name] - param.data)
                param.data = self.shadow_params[name]


ema = EMA(decay=0.992)

In [None]:
optimizer = AdamW(model.parameters(), lr=5e-6)
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=100, num_training_steps=len(train_loader)
)

# Запускаемся

In [None]:
wandb.login(key="KEY", relogin=True)
wandb.init(sync_tensorboard=True, name='NAME', project="PROJECT", entity="ENTITY")

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model.to(device)

In [None]:
def one_epoch(model, train_dataloader, epoch):
    model.train()
    for batch in tqdm(train_dataloader):
        batch = batch.view(batch.shape[0], batch.shape[-1])

        t = batch.to(device)

        optimizer.zero_grad()

        with autocast(device_type="cuda"):
            loss = model(input_ids=t, labels=t)['loss']
            wandb.log({"loss": loss})

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        loss.backward()
        optimizer.step()
        ema(model)

        scheduler.step()

        model.save_pretrained(f'lm_saves/lm_{epoch}epoch')

        del t
        torch.cuda.empty_cache()
        gc.collect()

    model.eval()

In [None]:
for epoch in tqdm(range(10)):
    one_epoch(model, train_loader, epoch)