In [1]:
from data.datasets import CFGDataset, verify_dataloader
from transformers import GPT2Config, GPT2LMHeadModel, Trainer, TrainingArguments
from transformers import TrainerCallback, TrainerState, TrainerControl

from transformers.trainer_utils import EvalPrediction
import math
from torch.utils.data import IterableDataset
import torch
import wandb
import yaml
with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)
print(f"Config loaded successfully.")
import torch
from torch.utils.data import DataLoader

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")


  from .autonotebook import tqdm as notebook_tqdm


Config loaded successfully.
Using device: cuda


In [3]:

class CFGDatasetForHF(IterableDataset):
    """
    Thin wrapper around your existing CFGDataset so that each yield is
    a dict accepted by HuggingFace Trainer.
    """
    def __init__(self, cfg_dataset, num_batches = None):
        self.cfg = cfg_dataset        # instance of your original class

        if num_batches is not None:
            self.num_batches = num_batches
        
    def __iter__(self):
        i = 0
        for x, _y in self.cfg:        # ignore the pre-shifted target
            yield {
                "input_ids":      x,                 # shape [B, L]
                "labels":         x.clone(),         # same length; GPT-2 shifts
                "attention_mask": torch.ones_like(x)
            }
            i += 1
            if hasattr(self, "num_batches") and i >= self.num_batches:
                break

    def __len__(self):
        return len(self.cfg)
    
    


In [4]:
train_dataset =  CFGDataset(
    data_file="cfg_sentences_train_cfg3b.npy", 
    batch_size = config["data"]["batch_size"],
    seq_len = config["data"]["seq_len"],
    eos_token = config["data"]["eos_token"],
    sos_token = config["data"]["sos_token"],
    ) 

val_dataset =  CFGDataset(
    data_file="cfg_sentences_val_cfg3b.npy", 
    
    batch_size = config["data"]["batch_size"],
    seq_len = config["data"]["seq_len"],
    eos_token = config["data"]["eos_token"],
    sos_token = config["data"]["sos_token"],
    ) 


train_loader = DataLoader(train_dataset, 
                          batch_size = None, 
                          num_workers=config["data"]["NUM_WORKERS"] if device == "cuda" else 0, 
                          pin_memory=True)

val_loader = DataLoader(val_dataset, 
                        batch_size=None, 
                        num_workers=config["data"]["NUM_WORKERS"] if device == "cuda" else 0,
                        pin_memory=True)

def collate_fn(batch):
    # each item is already a [B, L] tensor → stack on 0
    input_ids      = torch.cat([item["input_ids"]      for item in batch], dim=0)
    attention_mask = torch.cat([item["attention_mask"] for item in batch], dim=0)
    labels         = torch.cat([item["labels"]         for item in batch], dim=0)
    return {"input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels}
    
hf_cfg_train = CFGDatasetForHF(train_dataset)

hf_cfg_val = CFGDatasetForHF(val_dataset, 10)

In [5]:
seq_len = config["data"]["seq_len"]
gpt_config = GPT2Config(
    vocab_size      = 5,             # tokens 0–5
    bos_token_id    = 0,
    eos_token_id    = 4,
    pad_token_id    = 5,             # you never pad, but set it anyway
    n_positions     = seq_len,
)
model = GPT2LMHeadModel(gpt_config)
model.num_parameters()  

85453056

In [6]:
def compute_metrics(eval_pred: EvalPrediction):
    """
    Hugging Face will pass (logits, labels) by default.
    We recompute the cross-entropy on the CPU and return both metrics.
    """
    logits, labels = eval_pred.predictions, eval_pred.label_ids
    logits  = torch.tensor(logits)
    labels  = torch.tensor(labels)

    # Shift so the model predicts token t+1 from t   (same rule GPT-2 uses)
    shift_logits = logits[:, :-1, :].contiguous()
    shift_labels = labels[:, 1:].contiguous()

    loss_fct = torch.nn.CrossEntropyLoss(ignore_index=0)   # 0 = pad token id
    loss = loss_fct(
        shift_logits.view(-1, shift_logits.size(-1)),
        shift_labels.view(-1)
    )

    ce  = loss.item()
    ppl = math.exp(ce)
    
    wandb.log({"val/ce_loss_token": ce, "val/perplexity_token": ppl})
    
    return {"val_cross_entropy": ce, "val_perplexity": ppl}


In [7]:
class WandbEvalCallback(TrainerCallback):
    def on_evaluate(self, args, state: TrainerState, control: TrainerControl, metrics=None, **kwargs):
        # filter out the weird “_runtime” keys if you like
        print(metrics.items())
        to_log = {
            "val/ce_loss_token": metrics["eval_val_cross_entropy"] ,
            "val/perplexity_token": metrics["eval_val_perplexity"],
            "step":    state.global_step,
         }
        wandb.log(to_log)

class WandbTrainCallback(TrainerCallback):
    def on_log(self, args, state: TrainerState, control: TrainerControl, logs=None, **kwargs):
        # logs contains training metrics like 'loss', 'learning_rate', etc.
        if logs is not None and logs.get("loss") is not None:
            # filter only training‐step metrics (drop eval metrics or epoch)
           
            train_logs = {
                "train/ce_loss_token": logs["loss"],
                "train/perplexity_token": math.exp(logs["loss"]),
                "learning_rate": logs["learning_rate"],  
                "step":    state.global_step,
            }
            wandb.log(train_logs)


In [None]:
#wandb.finish()
wandb.init(project="Thesis", name = "huggingface_gpt", config=config)

for m in [
    "train/ce_loss_token",
    "train/perplexity_token",
    "learning_rate",
    "val/ce_loss_token",
    "val/perplexity_token",
]:
    wandb.define_metric(m, step_metric="step")
    

args = TrainingArguments(
    output_dir="gpt2_cfg",
    num_train_epochs=1,
    per_device_train_batch_size=1,   # 1 “item” from DataLoader == your mini-batch
    gradient_accumulation_steps=1,
    learning_rate=3e-4,
    logging_steps=100,
    logging_strategy="steps",
    save_steps=50,
    eval_steps=100,
    save_total_limit=2,
    run_name="huggingface_gpt",
    eval_strategy="steps",
    report_to=["wandb"],
    fp16=True,
    dataloader_num_workers=config["data"]["NUM_WORKERS"]
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=hf_cfg_train,      # *Trainer* ignores our DataLoader’s batching,
    eval_dataset=hf_cfg_val,        # so we pass the dataset
    data_collator=collate_fn, # so keep the same collate
    compute_metrics = compute_metrics, # and the same metric
    callbacks=[WandbTrainCallback, WandbEvalCallback]
)

[34m[1mwandb[0m: Currently logged in as: [33mlucasfragara[0m ([33mteamlsfr[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [9]:
trainer.train(resume_from_checkpoint="gpt2_cfg/checkpoint-20700")

There were missing keys in the checkpoint model loaded: ['lm_head.weight'].


Step,Training Loss,Validation Loss,Val Cross Entropy,Val Perplexity
21000,0.4364,0.436597,0.438132,1.54981
21500,0.4364,0.436697,0.438194,1.549905
22000,0.4363,0.436497,0.438025,1.549644
22500,0.4364,0.436431,0.437952,1.54953
23000,0.4362,0.436475,0.438052,1.549686
23500,0.4362,0.436444,0.437965,1.549551
24000,0.4364,0.436511,0.438032,1.549654
24500,0.4363,0.436451,0.437993,1.549595
25000,0.4363,0.436547,0.438025,1.549644
25500,0.4362,0.436447,0.437966,1.549552


dict_items([('eval_loss', 0.43659746646881104), ('eval_val_cross_entropy', 0.4381323456764221), ('eval_val_perplexity', 1.5498100045007248), ('eval_runtime', 1.4758), ('eval_samples_per_second', 305.605), ('eval_steps_per_second', 38.624), ('epoch', 0.4694206008583691)])
dict_items([('eval_loss', 0.4366971552371979), ('eval_val_cross_entropy', 0.43819355964660645), ('eval_val_perplexity', 1.5499048774278765), ('eval_runtime', 1.4398), ('eval_samples_per_second', 313.23), ('eval_steps_per_second', 39.588), ('epoch', 0.4805972818311874)])
dict_items([('eval_loss', 0.43649721145629883), ('eval_val_cross_entropy', 0.4380249083042145), ('eval_val_perplexity', 1.5496435059306648), ('eval_runtime', 1.437), ('eval_samples_per_second', 313.857), ('eval_steps_per_second', 39.667), ('epoch', 0.4917739628040057)])
dict_items([('eval_loss', 0.4364311099052429), ('eval_val_cross_entropy', 0.43795159459114075), ('eval_val_perplexity', 1.5495298999757925), ('eval_runtime', 1.4139), ('eval_samples_per_

KeyboardInterrupt: 