# **Imports**

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import get_linear_schedule_with_warmup

from tqdm import tqdm

import wandb

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader

from datasets import load_dataset

from torch.amp import autocast

# **Formatting datasets**

In [None]:
checkpoint = "EleutherAI/pythia-1b"
tok = AutoTokenizer.from_pretrained(checkpoint)
mod = AutoModelForCausalLM.from_pretrained(checkpoint)

In [None]:
train_data = load_dataset("SirNeural/flan_v2", split="train")

In [None]:
emb = []
for i in tqdm(range(len(train_data))):
    inputs = train_data[i]["inputs"]
    targets = train_data[i]["targets"]

    train_row = f"{inputs}, {targets}, {tok.eos_token}"
    embeded_row = tok.encode(
        train_row,
        padding="max_length",
        max_length=1024,
        truncation=True,
        return_tensors="pt",
    )

    emb += [embeded_row]

In [None]:
tok.pad_token = tok.eos_token
mod.resize_token_embeddings(len(tok))

In [None]:
emb_dataset = torch.utils.data.ConcatDataset([emb])
train_dataset = DataLoader(emb_dataset, batch_size=16, shuffle=True)

# **LLM**

In [None]:
class EMA(nn.Module):
    def __init__(self, decay: float):
        super().__init__()
        self.decay = decay
        self.shadow_params = {}

    def forward(self, model: nn.Module):
        for name, params in model.named_parameters():
            if params.requires_grad:
                if name not in self.shadow_params:
                    self.shadow_params[name] = params.data.clone()
                else:
                    self.shadow_params[name] -= (1 - self.decay) * (
                        self.shadow_params[name] - params
                    )
                params.data = self.shadow_params[name]


ema = EMA(0.5)

In [None]:
def freeze(model: nn.Module):
    for param in model.parameters():
        param.requires_grad = False

In [None]:
training_steps = len(train_dataset)
optimizer = AdamW(mod.parameters())
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=training_steps // 10, num_training_steps=training_steps
)

wandb.login(key="KEY", relogin=True)
wandb.init(sync_tensorboard=True, name="NAME", project="PROJECT", entity="ENTITY")

device = "cuda:0" if torch.cuda.is_available() else "cpu"
mod.to(device)

# freeze(mod)


def one_epoch(model, data):
    model.train()

    for batch in data:
        batch = batch.view(batch.shape[0], batch.shape[-1])

        t = batch.to(device)

        optimizer.zero_grad()

        with autocast(device_type="cuda"):
            loss = model(input_ids=t, labels=t)["loss"]
            wandb.log({"loss": loss})

        loss.backward()
        optimizer.step()
        scheduler.step()
        ema(model)

    model.eval()