In [None]:
import sys

sys.path.append("/content/CQA_RLHF/sft/dataset")
from tqdm.auto import tqdm
from dataset import create_dataloaders
from accelerate import Accelerator, notebook_launcher
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
)
from accelerate.utils import set_seed, DummyOptim, DummyScheduler
from evaluate import load
import wandb
import yaml
from yaml import CLoader
import numpy as np
import torch

In [None]:
trainer_config = dict(
    model_name="EleutherAI/gpt-neo-125M",
    data=dict(
        data_path="",
        batch_size=16,
        max_length=512,
    ),
    train=dict(
        n_epoches=3,
        seed=42,
        learning_rate=5e-5,
        mixed_precision="bf16",
        freeze=True,
        gradient_accumulation_steps=1,
        max_grad_norm=None,
        warmup_steps=100,
        resume_from_checkpoint="",
        eval_every=1000,
        output_dir="",
        log_with="wandb",
    ),
    wandb_kwargs=dict(entity="myashka", job_type="train", group="sft"),
    use_cache=False,
    is_tpu=True,
    wandb_api="text",
)

with open("trainer_config.yaml", "w") as outfile:
    yaml.dump(trainer_config, outfile, default_flow_style=False)

In [None]:
def save_checkpoint(
    model, accelerator, optimizer, scheduler, output_dir, epoch, global_step
):
    accelerator.wait_for_everyone()
    if accelerator.state.deepspeed_plugin is not None:
        if accelerator.is_main_process:
            ckpt_path = str(output_dir) + f"/step_{global_step}.ckpt"
            checkpoint_state_dict = {
                "epoch": epoch,
                "last_global_step": global_step,
            }
            success = model.save_checkpoint(ckpt_path, epoch, checkpoint_state_dict)
            accelerator.print(f"Saved checkpoint to: {ckpt_path}: {success}")
    if accelerator.is_main_process:
        unwrapped_model = accelerator.unwrap_model(model)
        accelerator.save(save_obj, ckpt_path)
        accelerator.print(f"Saved checkpoint to: {ckpt_path}")
    else:
        if accelerator.is_main_process:
            unwrapped_model = accelerator.unwrap_model(model)
            ckpt_path = str(output_dir) + f"/step_{global_step}.ckpt"
            save_obj = {
                "model": unwrapped_model.state_dict(),
                "global_step": global_step,
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
                "epoch": epoch,
            }
            accelerator.save(save_obj, ckpt_path)
            accelerator.print(f"Saved checkpoint to: {ckpt_path}")

In [None]:
def load_checkpoint(
    ckpt_path,
    accelerator,
    model,
    optimizer,
    scheduler,
    global_step,
    strict=True,
    model_only=False,
    resume_global_step=True,
    **kwargs,
):
    if accelerator.state.deepspeed_plugin is not None:
        _, checkpoint_state_dict = model.load_checkpoint(ckpt_path, **kwargs)
        epoch = checkpoint_state_dict["epoch"]
        global_step = checkpoint_state_dict["last_global_step"]

        del checkpoint_state_dict
        accelerator.print(f"Loaded checkpoint {ckpt_path}")
        return global_step, epoch
    else:
        loaded_obj = torch.load(ckpt_path, map_location="cpu")

        model.load_state_dict(loaded_obj["model"], strict=strict)

        if not model_only:
            optimizer.load_state_dict(loaded_obj["optimizer"])
            scheduler.load_state_dict(loaded_obj["scheduler"])
            global_step = (
                loaded_obj["global_step"] if resume_global_step else global_step
            )
            epoch = loaded_obj["epoch"]

        accelerator.print(f"Loaded checkpoint {ckpt_path}")
        return global_step, epoch

In [None]:
def training_loop(model, args):

    accelerator = Accelerator(
        mixed_precision=args["train"]["mixed_precision"],
        log_with=args["train"]["log_with"],
        logging_dir=args["train"]["output_dir"],
    )

    set_seed(args["train"]["seed"])

    tokenizer = AutoTokenizer.from_pretrained(config["model_name"])
    model.resize_token_embeddings(len(tokenizer))
    tokenizer.pad_token = tokenizer.eos_token
    model.config.end_token_id = tokenizer.eos_token_id
    model.config.pad_token_id = model.config.eos_token_id
    model.pad_token_id = tokenizer.eos_token_id

    train_loader, val_loader = create_dataloaders(
        args["data"]["data_path"],
        tokenizer,
        splits=["train", "val"],
        batch_sizes=[args["data"]["batch_size"], args["data"]["batch_size"]],
        max_length=args["data"]["max_length"],
        all_max_length=args["is_tpu"],
    )

    rouge = load("rouge")
    bertscore = load("bertscore")
    bleu = load("bleu")

    def compute_metrics(predictions, references):
        labels_ids = references
        pred_ids = predictions
        pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
        label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

        result_dict = rouge.compute(predictions=pred_str, references=label_str)
        bertscore_dict = bertscore.compute(
            predictions=pred_str, references=label_str, lang="en"
        )
        bleu_metric = bleu.compute(predictions=pred_str, references=label_str)["bleu"]

        result_dict["bert_precision"] = np.mean(bertscore_dict["precision"])
        result_dict["bert_recall"] = np.mean(bertscore_dict["recall"])
        result_dict["bert_f1"] = np.mean(bertscore_dict["f1"])

        result_dict["bleu"] = bleu_metric

        return result_dict

    def evaluate(model, val_loader, accelerator, epoch, global_step):
        accelerator.print("\nEvaluating...")
        losses = []
        all_predictions = []
        all_labels = []

        for batch in val_loader:
            with torch.no_grad():
                output = model(**batch)

            loss = output.loss
            predictions = output.logits.argmax(dim=-1)
            all_predictions.append(accelerator.gather(predictions))
            all_labels.append(accelerator.gather(batch["labels"]))

            losses.append(
                accelerator.gather_for_metrics(loss.repeat(len(batch["input_ids"])))
            )

        losses = torch.cat(losses)
        accelerator.print("Concatenating predictions and labels...")
        all_predictions = torch.cat(all_predictions)[
            : int(len(val_loader) * len(batch["input_ids"]))
        ]
        all_labels = torch.cat(all_labels)[
            : int(len(val_loader) * len(batch["input_ids"]))
        ]

        eval_loss = torch.mean(losses)
        accelerator.log({"val_loss": eval_loss.item()}, step=global_step)
        eval_metric = compute_metrics(
            predictions=all_predictions, references=all_labels
        )
        accelerator.print(f"Metrics computed\n{eval_metric}")

        accelerator.log(
            {
                "bleu": eval_metric["bleu"],
                "bert_f1": eval_metric["bert_f1"],
                "rouge1": eval_metric["rouge1"],
                "rougeL": eval_metric["rougeL"],
                "epoch": epoch,
            },
            step=global_step,
        )
        accelerator.print("Metrics loged")

    global_step = 0
    n_epoches = args["train"]["n_epoches"]
    gradient_accumulation_steps = args["train"]["gradient_accumulation_steps"]
    learning_rate = args["train"]["learning_rate"]
    max_grad_norm = args["train"]["max_grad_norm"]
    resume_from_checkpoint = args["train"]["resume_from_checkpoint"]
    eval_every = args["train"]["eval_every"]

    starting_epoch = 0
    max_steps = int(n_epoches * len(train_loader) // gradient_accumulation_steps)

    if (
        accelerator.state.deepspeed_plugin is None
        or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config
    ):
        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    else:
        optimizer = DummyOptim(model.parameters(), lr=learning_rate)

    if (
        accelerator.state.deepspeed_plugin is None
        or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
    ):
        scheduler = get_linear_schedule_with_warmup(
            optimizer=optimizer,
            num_warmup_steps=max_steps,
            num_training_steps=(max_steps) // gradient_accumulation_steps,
        )
    else:
        scheduler = DummyScheduler(
            optimizer,
            total_num_steps=(max_steps) // gradient_accumulation_steps,
            warmup_num_steps=max_steps,
        )

    if resume_from_checkpoint:
        global_step, epoch = load_checkpoint(
            resume_from_checkpoint,
            accelerator,
            model,
            optimizer,
            scheduler,
            global_step,
            **{"load_optimizer_states": True, "load_lr_scheduler_states": True},
        )

        resume_step = global_step
        starting_epoch = global_step // len(train_loader)
        resume_step -= starting_epoch * len(train_loader)

        (model, optimizer, scheduler, train_loader, val_loader) = accelerator.prepare(
            model, optimizer, scheduler, train_loader, val_loader
        )

        if accelerator.is_main_process:
            accelerator.init_trackers(
                "CQA_RLHF",
                config=args,
                init_kwargs=args["wandb_kwargs"] or {},
            )
        progress_bar = tqdm(
            initial=global_step,
            total=int(max_steps),
            disable=not accelerator.is_main_process,
        )

        for epoch in range(starting_epoch, n_epoches):
            for step, batch in enumerate(train_loader):
                if resume_from_checkpoint and epoch == starting_epoch:
                    if resume_step is not None and step < resume_step:
                        global_step += 1
                        continue
                with accelerator.accumulate(model):
                    optimizer.zero_grad()
                    outputs = model(**batch)
                    loss = outputs.loss
                    accelerator.backward(loss)
                    if max_grad_norm and accelerator.sync_gradients:
                        accelerator.clip_grad_norm_(model.parameters(), max_grad_norm)

                    optimizer.step()
                    if not accelerator.optimizer_step_was_skipped:
                        scheduler.step()

                accelerator.log(
                    {
                        "train_loss": loss.item(),
                        "lr": optimizer.param_groups[0]["lr"],
                        "global_step": global_step,
                        "epoch": epoch,
                    },
                    step=global_step,
                )

                global_step += 1
                if accelerator.is_main_process:
                    progress_bar.update(1)
                    progress_bar.set_description(f"loss {loss.item():.4f}")

                if global_step % eval_every == 0:
                    model.eval()
                    evaluate(model, val_loader, accelerator, epoch, global_step)
                    save_checkpoint(
                        model,
                        accelerator,
                        optimizer,
                        scheduler,
                        args["train"]["output_dir"],
                        epoch,
                        global_step,
                    )
                    model.train()

        save_checkpoint(
            model,
            accelerator,
            optimizer,
            scheduler,
            args["train"]["output_dir"],
            epoch,
            global_step,
        )
        accelerator.end_training()

In [None]:
config_file = r''

In [None]:
with open(config_file, "r") as f:
    config = yaml.load(f, Loader=CLoader)

wandb.login(key=config['wandb_api'])

model = AutoModelForCausalLM(config['model_name'], use_cache=config['use_cache'])

if config['train']['freeze']:
    for n, p in model.named_parameters():
        if "transformer.h" in n:
            layer_num = int(n.split(".")[2])
            if "ln_" not in n and layer_num > 0 and layer_num < 23:
                p.requires_grad = False

In [None]:
notebook_launcher(training_loop, (model, config))