In [2]:
%load_ext autoreload
%autoreload 2
%env CUDA_VISIBLE_DEVICES=5

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
env: CUDA_VISIBLE_DEVICES=5


In [3]:
import tqdm
import os
import sys
import yaml
import time
import random
from omegaconf import OmegaConf
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset

sys.path.append(os.path.abspath('/homes/80/anya/Documents/llm_tiny_ideas/super-tiny-lms-outer/super-tiny-lms'))


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
# rank = int(os.environ["RANK"])
# print(rank)
# world_size = int(os.environ["WORLD_SIZE"])
# print(world_size)

  from .autonotebook import tqdm as notebook_tqdm


cuda


In [4]:
def get_constant_schedule_with_warmup(
    optimizer: torch.optim.Optimizer,
    num_warmup_steps: int,
    last_epoch: int = -1,
):
    def lr_lambda(current_step):
        return min(1, float(current_step) / float(max(1, num_warmup_steps)))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)

In [5]:
def dataset_loader(dataset: Dataset, per_device_batch_size: int, rank : int, world_size : int, seed: int = 0):
    dataset_length = len(dataset)
    total_batch_size = per_device_batch_size * world_size
    batches_per_dataset = dataset_length // total_batch_size
    i = 0
    epochs = 0
    while True:
        if i >= batches_per_dataset:
            dataset = dataset.shuffle(seed + epochs)
            i = 0
            epochs += 1
        start = i * total_batch_size + rank * per_device_batch_size
        batch = dataset[start:start + per_device_batch_size]
        yield batch

def careful_repeat(data, num_repeats):
    batch_size = data[list(data.keys())[0]].shape[0]
    for k, v in data.items():
        if v.ndim == 1:
            data[k] = v.unsqueeze(1).repeat(1, num_repeats).reshape(batch_size*num_repeats, *v.shape[1:])
        elif v.ndim == 2:
            data[k] = v.unsqueeze(1).repeat(1, num_repeats, 1).reshape(batch_size*num_repeats, *v.shape[1:])
    return data

def get_model_param_stats(model, ref_model):
    model_params = torch.cat([p.view(-1) for p in model.parameters() if p.requires_grad])
    ref_model_params = torch.cat([p.view(-1) for p in ref_model.parameters()])
    assert model_params.shape == ref_model_params.shape, f"{model_params.shape=} {ref_model_params.shape=}"
    return {
        "params_with_grads_mean": model_params.mean().item(),
        "params_with_grads_std": model_params.std().item(),
        "distance_to_ref": torch.nn.functional.mse_loss(model_params, ref_model_params).item(),
    }

In [6]:
def get_next(logitss, temperature=0.0, top_k=None):
    """Get the next token"""
    batch_size, seq_length, vocab_size = logitss.shape
    assert seq_length == 1
    logitss = logitss.squeeze(1)
    if temperature == 0.0:
        token_ids = torch.argmax(logitss, dim=-1)
        probdists = torch.zeros_like(logitss)
        probdists[torch.arange(batch_size), token_ids] = 1.0
    else:
        logitss = logitss / temperature
        if top_k is not None:
            logitss_k, idxs_k = torch.topk(logitss, min(top_k, vocab_size), dim=-1) # (batch_size, top_k)
            probs_k = torch.nn.functional.softmax(logitss_k, dim=-1) # (batch_size, top_k)
            idxs = torch.multinomial(probs_k, num_samples=1).squeeze(1) # (batch_size,)
            token_ids = idxs_k[torch.arange(batch_size), idxs] # (batch_size)
            probdists = torch.zeros_like(logitss)
            # next_probdist_s[torch.arange(batch_size), idx_s_k.squeeze(1)] = probs_k.squeeze(1)
            probdists.scatter_(1, idxs_k, probs_k)
            if top_k == 1:
                token_ids2 = torch.argmax(logitss, dim=-1)
                probdists2 = torch.zeros_like(logitss)
                probdists2[torch.arange(batch_size), token_ids2] = 1.0
                assert (token_ids == token_ids2).all(), f"{token_ids=}, {token_ids2=}"
                assert torch.allclose(probdists, probdists2), f"{probdists2=}, {probdists2=}"
        else:
            probdists = torch.nn.functional.softmax(logitss, dim=-1) # (batch_size, vocab_size)
            token_ids = torch.multinomial(probdists, num_samples=1).squeeze(1) # (batch_size)

    token_ids = token_ids.unsqueeze(1)
    probdists = probdists.unsqueeze(1)
    return token_ids, probdists

In [7]:
def single_step(model, next_input, attention_mask, position_ids, past_key_values, as_full_distribution=False):
    batch_size = next_input.shape[0]
    prev_seq_length = past_key_values[0][0].shape[2]
    assert position_ids.shape == (batch_size, 1), f"{position_ids.shape=}"
    assert attention_mask.shape == (batch_size, prev_seq_length+1), f"{attention_mask.shape=}, {prev_seq_length=}"

    if as_full_distribution:
        vocab_size = model.model.config.vocab_size
        assert next_input.shape == (batch_size, 1, vocab_size)
        all_embeds = model.model.embed_tokens.weight
        hidden_dim = all_embeds.shape[1]
        assert all_embeds.shape[0] == vocab_size
        inputs_embeds = torch.matmul(next_input, all_embeds)
        assert inputs_embeds.shape == (batch_size, 1, hidden_dim)
        outputs = model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values)
    else:
        assert next_input.shape == (batch_size, 1)
        outputs = model(input_ids=next_input, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values)
    logits = outputs.logits
    past_key_values = outputs.past_key_values
    return logits, past_key_values



In [16]:
# @torch.no_grad()
def batch_generate_rnn(
        model,
        ref_model,
        questions_inputs,
        answers_inputs,
        max_steps=30,
        step_for_answer=20,
        temperature=1.0,
        top_k=None,
        as_full_distribution=False,
        dot_by_dot=False,
        dot_by_dot_id=None,
        inject_answer_prompt=False,
        answer_prompt_ids=None,
    ):
    metrics = {}
    assert not (as_full_distribution and dot_by_dot), f"{as_full_distribution=}, {dot_by_dot=}"
    device = model.device
    assert device.type == "cuda", f"{model.device=}"
    assert ref_model.device == device, f"{ref_model.device=}, {device=}"
    model.eval()
    ref_model.eval()
    batch_size = questions_inputs["input_ids"].shape[0]
    vocab_size = model.config.vocab_size
    prompt_length = questions_inputs["input_ids"].shape[1]

    #### PROMPT FORWARD PASS
    position_ids = questions_inputs["attention_mask"].cumsum(dim=1) - 1
    outputs = model(**questions_inputs, position_ids=position_ids)
    with torch.no_grad():
        ref_outputs = ref_model(**questions_inputs)
    prompt_attention_mask = questions_inputs["attention_mask"]
    prompt_end_position_ids = position_ids[:, -1:] + 1
    past_key_values = outputs.past_key_values
    ref_past_keys_values = ref_outputs.past_key_values
    logits = outputs.logits[:, -1:]
    ref_logits = ref_outputs.logits[:, -1:]

    def make_attention_mask(t, new_seq_length=1):
        return torch.cat([prompt_attention_mask, torch.ones((batch_size, t+new_seq_length), device=device)], dim=1)
    def make_position_ids(t, new_seq_length=1):
        if new_seq_length == 1:
            return prompt_end_position_ids + t
        else:
            return prompt_end_position_ids + t + torch.arange(new_seq_length).unsqueeze(0).to(device)

    all_gen_logits = torch.zeros((batch_size, 0, vocab_size), device=device)
    all_ref_logits = torch.zeros((batch_size, 0, vocab_size), device=device)
    generations = torch.zeros((batch_size, 0), device=device, dtype=torch.int)
    generations_without_injection = torch.zeros((batch_size, 0), device=device, dtype=torch.int)

    ### REASONING FORWARD PASSES
    for t in range(step_for_answer):
        next_token_ids, next_probdists = get_next(logits, temperature=temperature, top_k=top_k)
        all_gen_logits = torch.cat((all_gen_logits, logits), dim=1)
        all_ref_logits = torch.cat((all_ref_logits, ref_logits), dim=1)
        generations_without_injection = torch.cat((generations_without_injection, next_token_ids), dim=1)
        generations = torch.cat((generations, next_token_ids), dim=1)
        # forward pass of next token
        if as_full_distribution:
            next_input = next_probdists
        elif dot_by_dot:
            next_input = torch.full((batch_size, 1), dot_by_dot_id, dtype=torch.long, device=device)
        else:
            next_input = next_token_ids
        logits, past_key_values = single_step(model,
            next_input=next_input,
            attention_mask=make_attention_mask(t),
            position_ids=make_position_ids(t),
            past_key_values=past_key_values,
            as_full_distribution=as_full_distribution)
        with torch.no_grad():
            ref_logits, ref_past_keys_values = single_step(ref_model,
                next_input=next_input,
                attention_mask=make_attention_mask(t),
                position_ids=make_position_ids(t),
                past_key_values=ref_past_keys_values,
                as_full_distribution=as_full_distribution)
            
    ### INJECT ANSWER PROMPT FORWARD PASS
    t = step_for_answer
    if inject_answer_prompt:
        repeated_answer_prompt_ids = torch.tensor(answer_prompt_ids).unsqueeze(0).repeat(batch_size, 1).to(next_token_ids.dtype).to(device)
        generations = torch.cat((generations, repeated_answer_prompt_ids), dim=1)
        outputs = model(
            input_ids=repeated_answer_prompt_ids,
            attention_mask=make_attention_mask(step_for_answer, new_seq_length=len(answer_prompt_ids)),
            position_ids=make_position_ids(step_for_answer, new_seq_length=len(answer_prompt_ids)),
            past_key_values=past_key_values)
        with torch.no_grad():
            ref_outputs = ref_model(
                input_ids=repeated_answer_prompt_ids,
                attention_mask=make_attention_mask(t, new_seq_length=len(answer_prompt_ids)),
                position_ids=make_position_ids(t, new_seq_length=len(answer_prompt_ids)),
                past_key_values=ref_past_keys_values)
        logits = outputs.logits
        ref_logits = ref_outputs.logits
        past_key_values = outputs.past_key_values
        ref_past_keys_values = ref_outputs.past_key_values
        t += len(answer_prompt_ids)

    ### ANSWER FORWARD PASS
    answer_length = answers_inputs["input_ids"].shape[1]
    answer_logits = model(
        input_ids=answers_inputs["input_ids"][:, :-1],
        attention_mask=make_attention_mask(t, new_seq_length=answer_length-1),
        position_ids=make_position_ids(t, new_seq_length=answer_length-1),
        past_key_values=past_key_values).logits
    with torch.no_grad():
        ref_answer_logits = ref_model(
            input_ids=answers_inputs["input_ids"][:, :-1],
            attention_mask=make_attention_mask(t+1, new_seq_length=answer_length-1),
            position_ids=make_position_ids(t+1, new_seq_length=answer_length-1),
            past_key_values=ref_past_keys_values).logits
    answer_logits = torch.cat((all_gen_logits[:, -1:], answer_logits), dim=1)
    ref_answer_logits = torch.cat((all_ref_logits[:, -1:], ref_answer_logits), dim=1)
    per_token_answer_logps = torch.gather(answer_logits, 2, answers_inputs["input_ids"].to(torch.long).unsqueeze(-1)).squeeze(-1)
    per_token_ref_answer_logps = torch.gather(ref_answer_logits, 2, answers_inputs["input_ids"].to(torch.long).unsqueeze(-1)).squeeze(-1)
    assert per_token_answer_logps.shape == answers_inputs["attention_mask"].shape, f"{per_token_answer_logps.shape=}, {answers_inputs['attention_mask'].shape=}"
    answer_logps = (per_token_answer_logps * answers_inputs["attention_mask"]).sum(dim=-1)
    ref_answer_logps = (per_token_ref_answer_logps * answers_inputs["attention_mask"]).sum(dim=-1)

    answer_perplexity = torch.exp(-answer_logps / answers_inputs["attention_mask"].sum(dim=-1))
    answer_perplexity_ref = torch.exp(-ref_answer_logps / answers_inputs["attention_mask"].sum(dim=-1))
    metrics["answer_logps"] = answer_logps.mean().item()
    metrics["answer_logps_ref"] = ref_answer_logps.mean().item()
    metrics["answer_logps_diff"] = (answer_logps - ref_answer_logps).mean().item()
    metrics["answer_perplexity"] = answer_perplexity.mean().item()
    metrics["answer_perplexity_ref"] = answer_perplexity_ref.mean().item()
    metrics["answer_perplexity_diff"] = (answer_perplexity - answer_perplexity_ref).mean().item()

    ### CONTINUE FORWARD PASS STEPS
    for t in range(t, max_steps):
        next_token_ids, next_probdists = get_next(logits[:, -1:], temperature=temperature, top_k=top_k)
        all_gen_logits = torch.cat((all_gen_logits, logits[:, -1:]), dim=1)
        all_ref_logits = torch.cat((all_ref_logits, ref_logits[:, -1:]), dim=1)
        generations_without_injection = torch.cat((generations_without_injection, next_token_ids), dim=1)
        generations = torch.cat((generations, next_token_ids), dim=1)
        # forward pass of next token
        if as_full_distribution:
            next_input = next_probdists
        elif dot_by_dot:
            next_input = torch.full((batch_size, 1), dot_by_dot_id, dtype=torch.long, device=device)
        else:
            next_input = next_token_ids
        logits, past_key_values = single_step(model,
            next_input=next_input,
            attention_mask=make_attention_mask(t),
            position_ids=make_position_ids(t),
            past_key_values=past_key_values,
            as_full_distribution=as_full_distribution)
        with torch.no_grad():
            ref_logits, ref_past_keys_values = single_step(ref_model,
                next_input=next_input,
                attention_mask=make_attention_mask(t),
                position_ids=make_position_ids(t),
                past_key_values=ref_past_keys_values,
                as_full_distribution=as_full_distribution)
            
    next_token_ids, next_probdists = get_next(logits, temperature=temperature, top_k=top_k)
    all_gen_logits = torch.cat((all_gen_logits, logits), dim=1)
    all_ref_logits = torch.cat((all_ref_logits, ref_logits), dim=1)
    generations_without_injection = torch.cat((generations_without_injection, next_token_ids), dim=1)
    generations = torch.cat((generations, next_token_ids), dim=1)

    gen_per_token_logps = torch.gather(all_gen_logits, 2, generations_without_injection.to(torch.long).unsqueeze(-1)).squeeze(-1)
    ref_per_token_logps = torch.gather(all_ref_logits, 2, generations_without_injection.to(torch.long).unsqueeze(-1)).squeeze(-1)
    x = (gen_per_token_logps, ref_per_token_logps)

    pd = torch.nn.functional.softmax(all_gen_logits, dim=-1)
    entropy = torch.logsumexp(all_gen_logits, dim=-1) - torch.sum(pd * all_gen_logits, dim=-1)
    entropy = entropy.mean(-1)
    metrics["logps"] = gen_per_token_logps.mean().item()
    metrics["logps_ref"] = ref_per_token_logps.mean().item()
    metrics["logps_diff"] = (gen_per_token_logps - ref_per_token_logps).mean().item()
    metrics["entropy"] = entropy.mean().item()
    metrics["entropy_std"] = entropy.std().item()

    x = (answer_logps, ref_answer_logps, gen_per_token_logps, ref_per_token_logps, entropy)
    
    assert x[0].requires_grad == True, f"{x[0].requires_grad=}"
    assert x[1].requires_grad == False, f"{x[1].requires_grad=}"
    assert x[2].requires_grad == True, f"{x[2].requires_grad=}"
    assert x[3].requires_grad == False, f"{x[3].requires_grad=}"
    assert x[4].requires_grad == True, f"{x[4].requires_grad=}"
    return x, generations, metrics, past_key_values

In [None]:
class Trainer:
    def __init__(
        self,
        cfg,
    ) -> None:
        print(f"\nTrainer::-----------------------------------")
        self.world_size = 1
        self.rank = 0
        assert torch.cuda.is_available(), "CUDA must be available for training"
        assert torch.cuda.device_count() == 1
        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        # self.cfg = cfg
        self.use_wandb = cfg.use_wandb

        ### training
        self.max_iters = cfg.max_iters
        self.total_batch_size = cfg.total_batch_size
        self.per_device_batch_size = cfg.per_device_batch_size
        assert self.total_batch_size % (self.per_device_batch_size * self.world_size) == 0, f"{self.total_batch_size=} {self.per_device_batch_size=}, {self.world_size=}"
        self.gradient_accumulation_steps = self.total_batch_size // (self.per_device_batch_size * self.world_size)
        assert self.per_device_batch_size * self.world_size * self.gradient_accumulation_steps == self.total_batch_size, f"{self.per_device_batch_size=} {self.world_size=} {self.gradient_accumulation_steps=} {self.total_batch_size=}"
        self.generations_per_prompt = cfg.generation.generations_per_prompt
        assert self.per_device_batch_size % self.generations_per_prompt == 0, f"{self.per_device_batch_size=} {self.generations_per_prompt=}"
        self.per_device_prompt_batch_size = self.per_device_batch_size // self.generations_per_prompt
        print(f"---TRAINING CONFIG:")
        print(f"Max iters: {self.max_iters}")
        print(f"Total batch size: {self.total_batch_size}")
        print(f"Per device batch size: {self.per_device_batch_size}")
        print(f"Gradient accumulation steps: {self.gradient_accumulation_steps}")
        print(f"Generations per prompt: {self.generations_per_prompt}")
        print(f"Per device prompt batch size: {self.per_device_prompt_batch_size}")
        print(f"-----------------------------------\n")

        self.model = AutoModelForCausalLM.from_pretrained(cfg.base_model).to(device)
        self.tokenizer = AutoTokenizer.from_pretrained(cfg.base_model)
        self.ref_model = AutoModelForCausalLM.from_pretrained(cfg.base_model).to(device)
        for param in self.ref_model.parameters():
            param.requires_grad = False


        # optimizer
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=float(cfg.lr),
            betas=cfg.get('betas', (0.9, 0.999)),
            weight_decay=cfg.get('weight_decay', 1e-2))

        # dataset
        if cfg.dataset == "gsm8k":
            train_dataset = Dataset.load_from_disk(f"../data/my_data/gsm8k/train")
            if cfg.dataset_size is not None:
                train_dataset = train_dataset.select(range(cfg.dataset_size))
            val_dataset = Dataset.load_from_disk(f"../data/my_data/gsm8k/test")
        self.train_loader = dataset_loader(train_dataset, self.per_device_prompt_batch_size, 0, 1, seed=cfg.get('seed', 0))
        self.val_loader = dataset_loader(val_dataset, self.per_device_prompt_batch_size, 0, 1, seed=cfg.get('seed', 0))
        print(f"---DATASET CONFIG:")
        print(f"Dataset: {cfg.dataset}")
        if cfg.dataset_size is not None:
            print(f"Careful! Reduced dataset size for testing: {len(train_dataset)}")
        print(f"-----------------------------------\n")

        # generation
        self.loss_type = cfg.loss.loss_type
        if self.loss_type in ["pg", "logp"]:
            self.temperature = cfg.generation.temperature
            self.top_k = cfg.generation.top_k
            self.max_steps = cfg.generation.max_steps
            self.step_for_answer = cfg.generation.step_for_answer
            self.inject_answer_prompt = cfg.generation.inject_answer_prompt
            self.as_full_distribution = cfg.generation.as_full_distribution
            self.dot_by_dot = cfg.generation.dot_by_dot
            self.answer_prompt_text = "Answer:"
            self.answer_prompt_ids = self.tokenizer.encode(self.answer_prompt_text)
            assert len(self.tokenizer.encode("....")) == 1
            self.dot_by_dot_id = self.tokenizer.encode("....")[0]
            print(f"---GENERATION CONFIG:")
            print(f"Temperature: {self.temperature}")
            print(f"Top k: {self.top_k}")
            print(f"Max length: {self.max_steps}")
            print(f"Step for answer: {self.step_for_answer}")
            print(f"Inject answer prompt: {self.inject_answer_prompt}")
            print(f"As full distribution: {self.as_full_distribution}")
            print(f"Answer prompt text: {self.answer_prompt_text}, ids: {self.answer_prompt_ids}")
            print(f"-----------------------------------\n")

        # loss
        if self.loss_type == "sft":
            self.sft_include_cot = cfg.loss.sft_include_cot
            self.sft_predict_cot = cfg.loss.sft_predict_cot
        elif self.loss_type == "pg":
            self.pg_normalization_type = None if self.generations_per_prompt == 1 else cfg.loss.pg_normalization_type
        self.entropy_coef = cfg.loss.entropy_coef
        self.kl_loss_coef = cfg.loss.kl_loss_coef
        print(f"---LOSS CONFIG:")
        print(f"Loss type: {self.loss_type}")
        if self.loss_type == "sft":
            print(f"SFT include cot: {self.sft_include_cot}")
            print(f"SFT predict cot: {self.sft_predict_cot}")
        elif self.loss_type == "pg":
            print(f"PG normalization type: {self.pg_normalization_type}")
        print(f"Entropy coef: {self.entropy_coef}")
        print(f"KL loss coef: {self.kl_loss_coef}")
        print(f"-----------------------------------\n")

        self.ctx = self._setup_ctx()

    def _setup_ctx(self):
        """Get the context manager"""
        dtype = (
            torch.bfloat16
            if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
            else torch.float16
        )
        self._setup_scaler(dtype)
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
        return ctx

    def _setup_scaler(self, dtype=torch.float16):
        """Setup the scaler"""
        self.scaler = torch.amp.GradScaler(enabled=dtype == torch.float16, device='cuda')

    def apply_update(self):
        grad_clip = 1.0
        # once gradients are accumulated, step 
        if grad_clip > 0:
            # Unscale the gradients of the optimizer's assigned params in-place
            self.scaler.unscale_(self.optimizer)
            # Clip the gradients with normalization
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), grad_clip)
        # Perform a single optimization step
        self.scaler.step(self.optimizer)
        self.scaler.update()
        self.optimizer.zero_grad()  # Reset gradients after update


    def run_training_loop(self, num_iters=None):
        """Run the training loop"""
        start_time = time.time()
        num_iters = self.total_iters if num_iters is None else num_iters
        for i in tqdm.tqdm(range(num_iters), desc="Training"):
            # GRADIENT ACCUMULATION
            for j in range(self.gradient_accumulation_steps):
            
                # GENERATE ROLLOUTS
                start_time = time.time()
                self.model.eval()
                # with torch.no_grad():
                # get questions (and answers)
                dataset_batch = next(self.train_loader)
                questions_text = dataset_batch["question"]
                # cot_text = dataset_batch["reasoning"]
                answers_text = dataset_batch["answer"]

                if self.loss_type == "sft":
                    raise NotImplementedError
                
                else:
                    questions_inputs = self.tokenizer(questions_text, return_tensors="pt", padding=True, padding_side="left")
                    answers_inputs = self.tokenizer(answers_text, return_tensors="pt", padding=True)
                    questions_inputs = {k: v.to(self.model.device) for k, v in questions_inputs.items()}
                    answers_inputs = {k: v.to(self.model.device) for k, v in answers_inputs.items()}

                    # repeat for generations_per_prompt
                    questions_inputs = careful_repeat(questions_inputs, self.generations_per_prompt)
                    answers_inputs = careful_repeat(answers_inputs, self.generations_per_prompt)

                    # generate
                    x, generations, generation_metrics, _ = batch_generate_rnn(
                        model=self.model,
                        ref_model=self.ref_model,
                        questions_inputs=questions_inputs,
                        answers_inputs=answers_inputs,
                        max_steps=self.max_steps,
                        step_for_answer=self.step_for_answer,
                        temperature=self.temperature,
                        top_k=self.top_k,
                        as_full_distribution=self.as_full_distribution,
                        dot_by_dot=self.dot_by_dot,
                        dot_by_dot_id=self.dot_by_dot_id,
                        inject_answer_prompt=self.inject_answer_prompt,
                        answer_prompt_ids=self.answer_prompt_ids,
                    )
                    generation_metrics = {f"gen/{k}": v for k, v in generation_metrics.items()}

                    ### rewards
                    rewards, normalized_rewards, reward_metrics = self.get_rewards(generations, answers_text)
                    reward_metrics = {k if k == "REWARD" else f"reward/{k}": v for k, v in reward_metrics.items()}
                    
                # COMPUTE LOSS
                with self.ctx: 
                    loss, loss_metrics = self.get_loss(x, normalized_rewards)
                    self.scaler.scale(loss).backward()
                loss_metrics = {f"loss/{k}": v for k, v in loss_metrics.items()}

                # UPDATE METRICS
                if j == 0:
                    metrics_s = {**generation_metrics, **loss_metrics, **reward_metrics}
                else:
                    metrics = {**generation_metrics, **loss_metrics, **reward_metrics}
                    metrics_s = {k: v + metrics[k] for k, v in metrics_s.items()}

            # UPDATE MODEL
            self.apply_update()

            # LOG
            metrics_s = {k: v / self.gradient_accumulation_steps for k, v in metrics_s.items()}
            param_metrics = get_model_param_stats(self.model, self.ref_model)
            log_dict = {"iter": i, "lr": self.optimizer.param_groups[0]["lr"]}
            log_dict.update(metrics_s)
            log_dict.update({f"params/{k}": v for k, v in param_metrics.items()})
            if self.use_wandb:
                wandb.log(log_dict)
            if i % 10 == 0 or i == self.max_iters - 1:
                print(f"({log_dict})\n\niter {i}: REWARD={log_dict['REWARD']:.2f}")
                num_to_print = 3
                decoded = [self.tokenizer.decode(generations[k]) for k in range(num_to_print)]
                for k in range(num_to_print):
                    print(f"  EXAMPLE {k}: (REWARD={rewards[k].item()}):\nQUESTION: {questions_text[k//self.generations_per_prompt]}\nGENERATION: {decoded[k]}\nANSWER: {answers_text[k//self.generations_per_prompt]}\n", "-"*50, "\n")

            
        print(f"Training time: {time.time()-start_time:.1f}s")
            
        # save the final model
        self._save_model(i, final_model=True)

    def train(self, seed=42, num_iters=None):
        """Train the model"""
        # set seed
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
        self.run_training_loop(num_iters=num_iters)

    def get_rewards(self, generations, answers_text):
        """Get rewards"""
        ### check for answer
        contains_answer_prompt = torch.zeros((self.per_device_prompt_batch_size, self.generations_per_prompt), device=self.model.device)
        contains_answer = torch.zeros((self.per_device_prompt_batch_size, self.generations_per_prompt), device=self.model.device)
        generations = generations.reshape(self.per_device_prompt_batch_size, self.generations_per_prompt, -1)
        for i in range(self.per_device_prompt_batch_size):
            answer_text = answers_text[i]
            decoded_generations = self.tokenizer.batch_decode(generations[i])
            for j in range(self.generations_per_prompt):
                decoded = decoded_generations[j]
                contains_answer_prompt_ij = self.answer_prompt_text in decoded
                if contains_answer_prompt_ij:
                    contains_answer_ij = answer_text in decoded.split(self.answer_prompt_text)[1]
                else:
                    contains_answer_ij = False
                contains_answer_prompt[i, j] = contains_answer_prompt_ij
                contains_answer[i, j] = contains_answer_ij
        
        ### caluclate rewards
        rewards = contains_answer.float()
        if not self.inject_answer_prompt:
            rewards += self.answer_prompt_coef * contains_answer_prompt.float()
        if self.pg_normalization_type == "grpo":
            normalized_rewards = (rewards - rewards.mean(1, keepdim=True)) / (rewards.std(1, keepdim=True) + 1e-6)
        elif self.pg_normalization_type == "rloo":
            group_sum = rewards.sum(1, keepdim=True)
            normalized_rewards = (group_sum - rewards) / (self.generations_per_prompt - 1)
        else:
            normalized_rewards = rewards

        metrics = {
            "REWARD": rewards.mean().item(),
            "reward/reward_std": rewards.std().item(),
            "reward/reward_std_within_q": rewards.std(1).mean().item(),
            "reward/reward_std_between_q": rewards.mean(1).std().item(),
        }

        rewards = rewards.reshape(self.per_device_batch_size)
        normalized_rewards = normalized_rewards.reshape(self.per_device_batch_size)
        return rewards, normalized_rewards, metrics
    
    def get_loss(self, x, rewards):
        metrics = {}
        pg_loss, logp_loss = 0, 0
        if self.loss_type == "sft":
            raise NotImplementedError
        else:
            answer_logps, _, gen_per_token_logps, ref_per_token_logps, entropy = x
            if self.loss_type == "pg":
                pg_loss = - torch.exp(gen_per_token_logps - gen_per_token_logps.detach()).mean(-1) * rewards
            elif self.loss_type == "logp":
                logp_loss = - answer_logps
            else:
                raise ValueError(f"{self.loss_type=}")
            kl = (torch.exp(ref_per_token_logps - gen_per_token_logps) - (ref_per_token_logps - gen_per_token_logps) - 1).mean(-1)
            loss = pg_loss + logp_loss + self.kl_loss_coef * kl - self.entropy_coef * entropy
            loss = loss.mean()

            metrics["loss"] = loss.item()
            metrics["pg_loss"] = pg_loss.mean().item() if self.loss_type == "pg" else 0
            metrics["logp_loss"] = logp_loss.mean().item() if self.loss_type == "logp" else 0
            metrics["kl"] = kl.mean().item()
            metrics["entropy"] = entropy.mean().item()
        return loss / self.gradient_accumulation_steps, metrics


In [10]:
config_file = "../args/full_test.yaml"

### Load config
with open(config_file) as f:
        config_dict = yaml.safe_load(f)
print("Config:", config_dict)
config = OmegaConf.create(config_dict)

Config: {'dataset': 'gsm8k', 'base_model': 'Qwen/Qwen2.5-0.5b', 'use_wandb': False, 'run_name_prefix': 'RUNS0', 'wandb_project': 'coconut', 'max_iters': 3000, 'total_batch_size': 256, 'per_device_batch_size': 2, 'lr': '1e-6', 'dataset_size': 300, 'seed': 0, 'loss': {'loss_type': 'pg', 'sft_include_cot': True, 'sft_predict_cot': True, 'pg_normalization_type': 'grpo', 'answer_prompt_coef': 0.1, 'entropy_coef': 0.001, 'kl_loss_coef': 0.001}, 'generation': {'generations_per_prompt': 2, 'temperature': 1.0, 'top_k': None, 'max_steps': 30, 'step_for_answer': 20, 'inject_answer_prompt': False, 'as_full_distribution': False, 'dot_by_dot': False}}


In [11]:
trainer = Trainer(config)


Trainer::-----------------------------------
---TRAINING CONFIG:
Max iters: 3000
Total batch size: 256
Per device batch size: 2
Gradient accumulation steps: 128
Generations per prompt: 2
Per device prompt batch size: 1
-----------------------------------



FileNotFoundError: No such files: '/scratch/local/homes/80/anya/Documents/llm_tiny_ideas/coconut-outer/coconut/notebooks/../data/my_data/gsm8k/train.bin/dataset_info.json', nor '/scratch/local/homes/80/anya/Documents/llm_tiny_ideas/coconut-outer/coconut/notebooks/../data/my_data/gsm8k/train.bin/state.json' found. Expected to load a `Dataset` object but provided path is not a `Dataset`.

In [None]:

trainer.train(seed=42, num_iters=20)

Training:   0%|          | 0/20 [00:00<?, ?it/s]

Training:   5%|▌         | 1/20 [00:03<01:11,  3.74s/it]

({'iter': 0, 'lr': 1e-06, 'gen/answer_logps': 27.815303802490234, 'gen/answer_logps_ref': 27.700419425964355, 'gen/answer_logps_diff': 0.1148838996887207, 'gen/answer_perplexity': 3.281034310020914e-06, 'gen/answer_perplexity_ref': 2.421303520350193e-06, 'gen/answer_perplexity_diff': 8.597305622970453e-07, 'gen/logps': 22.891145706176758, 'gen/logps_ref': 21.601008415222168, 'gen/logps_diff': 1.2901363372802734, 'gen/entropy': 0.9073592126369476, 'gen/entropy_std': 0.29626742005348206, 'loss/loss': 0.00010825303616002202, 'loss/pg_loss': 0.0, 'loss/logp_loss': 0.0, 'loss/kl': 1.0156123042106628, 'loss/entropy': 0.9073592126369476, 'REWARD': 0.0, 'reward/reward/reward_std': 0.0, 'reward/reward/reward_std_within_q': 0.0, 'reward/reward/reward_std_between_q': 0.0, 'params/params_with_grads_mean': 0.00015669445565436035, 'params/params_with_grads_std': 0.03906893730163574, 'params/distance_to_ref': 6.772513322719775e-13})

iter 0: REWARD=0.00
  EXAMPLE 0: (REWARD=0.0):
QUESTION: Natalia so

Training:  55%|█████▌    | 11/20 [00:42<00:34,  3.86s/it]

({'iter': 10, 'lr': 1e-06, 'gen/answer_logps': 27.279375076293945, 'gen/answer_logps_ref': 27.962377548217773, 'gen/answer_logps_diff': -0.6830031871795654, 'gen/answer_perplexity': 4.188140906080662e-06, 'gen/answer_perplexity_ref': 3.7540717130468693e-06, 'gen/answer_perplexity_diff': 4.340692498772114e-07, 'gen/logps': 20.480884552001953, 'gen/logps_ref': 20.04889678955078, 'gen/logps_diff': 0.43198806047439575, 'gen/entropy': 1.7025251388549805, 'gen/entropy_std': 1.1711174845695496, 'loss/loss': -0.0006596555758733302, 'loss/pg_loss': 0.0, 'loss/logp_loss': 0.0, 'loss/kl': 1.0428695976734161, 'loss/entropy': 1.7025251388549805, 'REWARD': 0.0, 'reward/reward/reward_std': 0.0, 'reward/reward/reward_std_within_q': 0.0, 'reward/reward/reward_std_between_q': 0.0, 'params/params_with_grads_mean': 0.00015669455751776695, 'params/params_with_grads_std': 0.03906894102692604, 'params/distance_to_ref': 1.3103124488211826e-11})

iter 10: REWARD=0.00
  EXAMPLE 0: (REWARD=0.0):
QUESTION: Natali

Training: 100%|██████████| 20/20 [01:18<00:00,  3.91s/it]

({'iter': 19, 'lr': 1e-06, 'gen/answer_logps': 27.143101692199707, 'gen/answer_logps_ref': 27.44046974182129, 'gen/answer_logps_diff': -0.29736852645874023, 'gen/answer_perplexity': 2.5355892603329266e-06, 'gen/answer_perplexity_ref': 3.5012234320674906e-06, 'gen/answer_perplexity_diff': -9.656342285779829e-07, 'gen/logps': 20.53018093109131, 'gen/logps_ref': 19.997366905212402, 'gen/logps_diff': 0.532814159989357, 'gen/entropy': 1.6038917303085327, 'gen/entropy_std': 0.5319118201732635, 'loss/loss': -0.0011174731771461666, 'loss/pg_loss': 0.0, 'loss/logp_loss': 0.0, 'loss/kl': 0.48641857504844666, 'loss/entropy': 1.6038917303085327, 'REWARD': 0.0, 'reward/reward/reward_std': 0.0, 'reward/reward/reward_std_within_q': 0.0, 'reward/reward/reward_std_between_q': 0.0, 'params/params_with_grads_mean': 0.00015669441199861467, 'params/params_with_grads_std': 0.03906893730163574, 'params/distance_to_ref': 2.054701027376371e-11})

iter 19: REWARD=0.00
  EXAMPLE 0: (REWARD=0.0):
QUESTION: Natali




AttributeError: 'Trainer' object has no attribute '_save_model'

# Testing

In [14]:
dummy_model = AutoModelForCausalLM.from_pretrained(config.base_model).to(device)
dummy_ref_model = AutoModelForCausalLM.from_pretrained(config.base_model).to(device)
dummy_tokenizer = AutoTokenizer.from_pretrained(config.base_model)
dummy_dataset = Dataset.load_from_disk(f"../data/my_data/gsm8k/test")
dummy_loader = dataset_loader(dummy_dataset, per_device_batch_size=2, rank=0, world_size=1, seed=0)
dummy_batch = next(dummy_loader)
dummy_questions_text = dummy_batch["question"]
dummy_answers_text = dummy_batch["answer"]
dummy_questions_inputs = dummy_tokenizer(dummy_questions_text, return_tensors="pt", padding=True, padding_side="left")
dummy_answers_inputs = dummy_tokenizer(dummy_answers_text, return_tensors="pt", padding=True)
dummy_questions_inputs = {k: v.to(device) for k, v in dummy_questions_inputs.items()}
dummy_answers_inputs = {k: v.to(device) for k, v in dummy_answers_inputs.items()}
dummy_ref_model.eval()
dummy_model.eval()

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 896)
    (layers): ModuleList(
      (0-23): 24 x Qwen2DecoderLayer(
        (self_attn): Qwen2SdpaAttention(
          (q_proj): Linear(in_features=896, out_features=896, bias=True)
          (k_proj): Linear(in_features=896, out_features=128, bias=True)
          (v_proj): Linear(in_features=896, out_features=128, bias=True)
          (o_proj): Linear(in_features=896, out_features=896, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=896, out_features=4864, bias=False)
          (up_proj): Linear(in_features=896, out_features=4864, bias=False)
          (down_proj): Linear(in_features=4864, out_features=896, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((

In [19]:
x, generations, metrics, past_key_values = batch_generate_rnn(
    model=dummy_model,
    ref_model=dummy_ref_model,
    questions_inputs=dummy_questions_inputs,
    answers_inputs=dummy_answers_inputs,
    max_steps=30,
    step_for_answer=20,
    temperature=0.0,
    top_k=None,
    as_full_distribution=False,
    dot_by_dot=False,
    dot_by_dot_id=dummy_tokenizer.encode("....")[0],
    inject_answer_prompt=False,
    answer_prompt_ids=dummy_tokenizer.encode("...Answer:"),
)
print(x[0].shape, generations.shape, metrics)
dummy_tokenizer.batch_decode(generations)

torch.Size([2]) torch.Size([2, 31]) {'answer_logps': 13.639263153076172, 'answer_logps_ref': 13.970451354980469, 'answer_logps_diff': -0.3311886787414551, 'answer_perplexity': 0.002386221196502447, 'answer_perplexity_ref': 0.0011142903240397573, 'answer_perplexity_diff': 0.0012719309888780117, 'logps': 23.325300216674805, 'logps_ref': 20.66714096069336, 'logps_diff': 2.6581602096557617, 'entropy': 0.939110279083252, 'entropy_std': 0.0955267995595932}


[" First, we need to calculate the total number of eggs laid by the ducks in a day. Since Janet's ducks lay 16 eggs per day,",
 ' To find the total number of bolts needed, we need to calculate the number of bolts required for blue and white fibers separately and then add them together.\n\n1']

In [18]:
gen_out = dummy_model.generate(
    input_ids=dummy_questions_inputs["input_ids"],
    attention_mask=dummy_questions_inputs["attention_mask"],
    max_new_tokens=30,
    temperature=1.0,
    top_k=None,
    do_sample=False,
    return_dict_in_generate=True,
)
prompt_length = dummy_questions_inputs["input_ids"].shape[1]
generations = gen_out.sequences if isinstance(gen_out, dict) else gen_out
generations = generations[:, prompt_length:]
print(dummy_tokenizer.batch_decode(generations))


Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` instance instead by default (as opposed to the legacy tuple of tuples format). If you want to keep returning the legacy format, please set `return_legacy_cache=True`.


[" First, we need to calculate the total number of eggs laid by the ducks in a day. Since Janet's ducks lay 16 eggs per day", ' To find the total number of bolts needed, we need to calculate the number of bolts required for blue and white fibers separately and then add them together.\n\n']


In [20]:
for kv1, kv2 in zip(past_key_values, gen_out.past_key_values):
    for i in range(min(kv1[0].shape[2], kv2[0].shape[2])):
        print(f"{i=}: keys: {kv1[0].shape}, {kv2[0].shape}")
        print(f"{i=}: keys: {kv1[0][0, 0, i, :5].tolist()}")
        print(f"{i=}: keys: {kv2[0][0, 0, i, :5].tolist()}")
        print(f"{i=}: values: {kv1[1].shape}, {kv2[1].shape}")
        print(f"{i=}: values: {kv1[1][0, 0, i, :5].tolist()}")
        print(f"{i=}: values: {kv2[1][0, 0, i, :5].tolist()}")
    break

i=0: keys: torch.Size([2, 2, 116, 64]), torch.Size([2, 2, 115, 64])
i=0: keys: [-8.410331726074219, -3.3071095943450928, -6.438131809234619, 0.6713922619819641, -0.20497480034828186]
i=0: keys: [-8.410331726074219, -3.3071095943450928, -6.438131809234619, 0.6713922619819641, -0.20497480034828186]
i=0: values: torch.Size([2, 2, 116, 64]), torch.Size([2, 2, 115, 64])
i=0: values: [-0.01414407417178154, 0.03389165550470352, -0.02601637691259384, 0.013034755364060402, -0.011828195303678513]
i=0: values: [-0.01414407417178154, 0.03389165550470352, -0.02601637691259384, 0.013034755364060402, -0.011828195303678513]
i=1: keys: torch.Size([2, 2, 116, 64]), torch.Size([2, 2, 115, 64])
i=1: keys: [-9.862237930297852, -7.930609703063965, -7.341488361358643, 2.996467113494873, -0.14621633291244507]
i=1: keys: [-9.862237930297852, -7.930609703063965, -7.341488361358643, 2.996467113494873, -0.14621633291244507]
i=1: values: torch.Size([2, 2, 116, 64]), torch.Size([2, 2, 115, 64])
i=1: values: [0.0100

In [86]:
for k in dummy_model(**dummy_questions_inputs, return_dict=True).keys():
    print(k)

logits
past_key_values
