In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import time, json
import sys
sys.path.append("../")
sys.path.append("../chess_llm_interpretability")
import os
import torch
import numpy as np
from tqdm import tqdm

import logging
from src.utils import logging_utils
from src.utils import env_utils, experiment_utils
from src import functional
import wandb

logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.INFO,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

logger.info(f"{torch.__version__=}, {torch.version.cuda=}")

  from .autonotebook import tqdm as notebook_tqdm


2024-10-22 12:00:07 __main__ INFO     torch.__version__='2.4.1+cu121', torch.version.cuda='12.1'


In [3]:
from src.models import ModelandTokenizer

# MODEL_KEY = "meta-llama/Llama-3.2-3B-Instruct"
# MODEL_KEY = "meta-llama/Llama-3.1-8B-Instruct"

MODEL_KEY = "meta-llama/Llama-3.2-1B"
# MODEL_KEY = "google/gemma-2-2b"
# MODEL_KEY = "meta-llama/Llama-3.1-8B"

mt = ModelandTokenizer(
    model_key=MODEL_KEY,
    torch_dtype=torch.float32,
)

2024-10-22 12:00:08 accelerate.utils.modeling INFO     We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).
2024-10-22 12:00:10 src.models INFO     loaded model </home/local_arnab/Codes/00_MODEL/meta-llama/Llama-3.2-1B> | size: 4714.260 MB | dtype: torch.float32 | device: cuda:0


In [4]:
from datasets import load_dataset

pgn_ds = load_dataset("adamkarvonen/chess_games", data_files="lichess_6gb.zip", streaming=False)
# pgn_ds = load_dataset("adamkarvonen/chess_games", data_files="lichess_100mb.zip")
pgn_ds = pgn_ds["train"].train_test_split(test_size=0.1)
pgn_ds

2024-10-22 12:00:10 numexpr.utils INFO     Note: NumExpr detected 24 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
2024-10-22 12:00:10 numexpr.utils INFO     NumExpr defaulting to 8 threads.
2024-10-22 12:00:10 datasets INFO     PyTorch version 2.4.1 available.


Repo card metadata block was not found. Setting CardData to empty.




DatasetDict({
    train: Dataset({
        features: ['Unnamed: 0', 'WhiteElo', 'BlackElo', 'Result', 'transcript'],
        num_rows: 14842935
    })
    test: Dataset({
        features: ['Unnamed: 0', 'WhiteElo', 'BlackElo', 'Result', 'transcript'],
        num_rows: 1649216
    })
})

In [5]:
import random
class PGNDataset(torch.utils.data.Dataset):
    def __init__(self, pgn_ds, tokenizer):
        self.pgn_ds = pgn_ds
        self.tokenizer = tokenizer
        with open(os.path.join(env_utils.DEFAULT_DATA_DIR, "pgn_paraphrases.json")) as f:
            self.pgn_paraphrases = json.load(f)

    def __len__(self):
        return len(self.pgn_ds)

    def __getitem__(self, idx):
        item = self.pgn_ds[idx]
        text = random.choice(self.pgn_paraphrases) + item["transcript"]
        # return text
        inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=512)
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}  # Remove batch dimension
        inputs["labels"] = inputs["input_ids"].clone()
        return inputs

In [6]:
from torch.utils.data import DataLoader

train_dataset = PGNDataset(pgn_ds["train"], tokenizer=mt.tokenizer)
test_dataset = PGNDataset(pgn_ds["test"], tokenizer=mt.tokenizer)

In [7]:
# wiki_ds = load_dataset("wikimedia/wikipedia", "20231101.en")
wiki_ds = load_dataset("roneneldan/TinyStories")
wiki_ds = wiki_ds["train"].train_test_split(test_size=0.1)
wiki_ds

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 1907747
    })
    test: Dataset({
        features: ['text'],
        num_rows: 211972
    })
})

In [8]:
class WikiDataset(torch.utils.data.Dataset):
    def __init__(self, pgn_ds, tokenizer):
        self.pgn_ds = pgn_ds
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.pgn_ds)

    def __getitem__(self, idx):
        item = self.pgn_ds[idx]
        text = item["text"]
        inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=512)
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}  # Remove batch dimension
        inputs["labels"] = inputs["input_ids"].clone()
        return inputs
    
train_wiki = WikiDataset(wiki_ds["train"], tokenizer=mt.tokenizer)
test_wiki = WikiDataset(wiki_ds["test"], tokenizer=mt.tokenizer)

In [9]:
experiment_utils.set_seed(42)
model = mt._model
model.train()
device = mt.device

# Training parameters
learning_rate = 5e-5
batch_size = 6

model_save_dir = os.path.join(env_utils.DEFAULT_RESULTS_DIR, "chess_model_finetuned")
os.makedirs(model_save_dir, exist_ok=True)
wandb_log_interval = 10
checkpoint_interval = 100
num_warmup_steps = 30
limit_training_steps = 1000
##############################################################################

2024-10-22 12:00:15 src.utils.experiment_utils INFO     setting all seeds to 42


In [10]:
import shutil
def remove_dir(path):
    if os.path.exists(path):
        shutil.rmtree(path)

# remove_dir(model_save_dir)
remove_dir(".wandb")

In [11]:
from transformers import get_linear_schedule_with_warmup

# dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

train_wiki_loader = DataLoader(train_wiki, batch_size=batch_size//2, shuffle=True)
test_wiki_loader = DataLoader(test_wiki, batch_size=batch_size//2, shuffle=False)

print(f"{len(train_loader)=} | {len(train_wiki_loader)=}")


limit_training_steps = min(
    limit_training_steps,
    len(train_loader),
    len(train_wiki_loader)
)

print(f"{limit_training_steps=}")

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=len(train_loader)
)

len(train_loader)=2473823 | len(train_wiki_loader)=635916
limit_training_steps=1000


In [12]:
# wandb
wandb.init(
    entity="dl-homeworks",
    project="talkative_probes",
    name=f"{MODEL_KEY}_finetune",
    config={
        "model_key": MODEL_KEY,
        "learning_rate": learning_rate,
        "wandb_log_interval": wandb_log_interval,
        "checkpoint_interval": checkpoint_interval,
        "num_warmup_steps": num_warmup_steps,
        "batch_size": batch_size,
    }
)


for step in tqdm(range(limit_training_steps), desc="Training"):
    optimizer.zero_grad()
    
    chess_batch = next(iter(train_loader))
    input_ids = chess_batch["input_ids"].to(device)
    attention_mask = chess_batch["attention_mask"].to(device)
    labels = chess_batch["labels"].to(device)
        
    pgn_outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
    chess_loss = pgn_outputs.loss

    wiki_batch = next(iter(train_wiki_loader))
    wiki_input_ids = wiki_batch["input_ids"].to(device)
    wiki_attention_mask = wiki_batch["attention_mask"].to(device)
    wiki_labels = wiki_batch["labels"].to(device)

    wiki_outputs = model(input_ids=wiki_input_ids, attention_mask=wiki_attention_mask, labels=wiki_labels)
    wiki_loss = wiki_outputs.loss

    loss = chess_loss + wiki_loss

    loss.backward()
    optimizer.step()
    scheduler.step()

    if (step + 1) % wandb_log_interval == 0:
        wandb.log({
            "loss": loss.item(),
            "chess_loss": chess_loss.item(),
            "wiki_loss": wiki_loss.item(),
            "learning_rate": scheduler.get_last_lr()[0],
        })

    if ((step + 1) % checkpoint_interval == 0) or (step + 1) == limit_training_steps:
        if len(os.listdir(model_save_dir)) > 0:
            last_checkpoint_path = os.path.join(model_save_dir, os.listdir(model_save_dir)[-1])
            remove_dir(last_checkpoint_path)
        
        new_checkpoint_path = os.path.join(model_save_dir, f"checkpoint-{step + 1}")
        model.save_pretrained(new_checkpoint_path)

print("Training completed!")

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.


[34m[1mwandb[0m: Currently logged in as: [33marnab-api[0m ([33mdl-homeworks[0m). Use [1m`wandb login --relogin`[0m to force relogin


Training: 100%|██████████| 1000/1000 [43:40<00:00,  2.62s/it] 

Training completed!



