In [1]:
%load_ext autoreload
%autoreload 2
%env CUDA_VISIBLE_DEVICES=7

env: CUDA_VISIBLE_DEVICES=7


In [2]:
from tqdm import tqdm
from copy import copy
import itertools
import os, sys
import yaml
import json
import gc
import argparse
import functools
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset
import time
import random

sys.path.append(os.path.abspath('/homes/80/anya/Documents/llm_tiny_ideas/super-tiny-lms-outer/super-tiny-lms'))


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
# rank = int(os.environ["RANK"])
# print(rank)
# world_size = int(os.environ["WORLD_SIZE"])
# print(world_size)

  from .autonotebook import tqdm as notebook_tqdm


cuda


In [3]:
class Config:
    # to access a dict with object.key
    def __init__(self, dictionary):
        self.__dict__ = dictionary

In [4]:
def get_constant_schedule_with_warmup(
    optimizer: torch.optim.Optimizer,
    num_warmup_steps: int,
    last_epoch: int = -1,
):
    def lr_lambda(current_step):
        return min(1, float(current_step) / float(max(1, num_warmup_steps)))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)

In [5]:
def dataset_loader(dataset: Dataset, per_device_batch_size: int, rank : int, world_size : int, seed: int = 0):
    dataset_length = len(dataset)
    total_batch_size = per_device_batch_size * world_size
    batches_per_dataset = dataset_length // total_batch_size
    i = 0
    epochs = 0
    while True:
        if i >= batches_per_dataset:
            dataset = dataset.shuffle(seed + epochs)
            i = 0
            epochs += 1
        start = i * total_batch_size + rank * per_device_batch_size
        batch = dataset[start:start + per_device_batch_size]
        yield batch

def careful_repeat(data, num_repeats):
    batch_size = data[list(data.keys())[0]].shape[0]
    for k, v in data.items():
        if v.ndim == 1:
            data[k] = v.unsqueeze(1).repeat(1, num_repeats).reshape(batch_size*num_repeats, *v.shape[1:])
        elif v.ndim == 2:
            data[k] = v.unsqueeze(1).repeat(1, num_repeats, 1).reshape(batch_size*num_repeats, *v.shape[1:])
    return data

def get_model_param_stats(model, ref_model):
    model_params = torch.cat([p.view(-1) for p in model.parameters() if p.requires_grad])
    ref_model_params = torch.cat([p.view(-1) for p in ref_model.parameters()])
    assert model_params.shape == ref_model_params.shape, f"{model_params.shape=} {ref_model_params.shape=}"
    return {
        "params_with_grads_mean": model_params.mean().item(),
        "params_with_grads_std": model_params.std().item(),
        "distance_to_ref": torch.nn.functional.mse_loss(model_params, ref_model_params).item(),
    }

In [6]:
config_file = "../args/full_test.yaml"

### Load config
with open(config_file) as f:
        config_dict = yaml.safe_load(f)
print("Config:", config_dict)
config = Config(config_dict)

Config: {'dataset': 'gsm8k', 'base_model': 'Qwen/Qwen2.5-0.5b', 'max_iters': 20, 'total_batch_size': 8, 'per_device_batch_size': 4, 'lr': '1e-4', 'loss_type': 'sft', 'sft': {'include_cot': True, 'predict_cot': True}, 'generation': {'generations_per_prompt': 2, 'temperature': 1.0, 'top_k': None, 'max_length': 20, 'inject_answer_prompt': True, 'fixed_length': True}, 'pg': {'normalization_type': None}, 'entropy_coef': 0.001, 'kl_loss_coef': 0.001, 'kl_loss_type': 'low_var_kl'}


In [7]:
dataset = Dataset.load_from_disk(f"../data/my_data/gsm8k/test.bin")
test_model = AutoModelForCausalLM.from_pretrained(config.base_model).to(device)
test_tokenizer = AutoTokenizer.from_pretrained(config.base_model)

test_tokenizer.encode("....")

get_model_param_stats(test_model, test_model)
print(test_model)

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 896)
    (layers): ModuleList(
      (0-23): 24 x Qwen2DecoderLayer(
        (self_attn): Qwen2SdpaAttention(
          (q_proj): Linear(in_features=896, out_features=896, bias=True)
          (k_proj): Linear(in_features=896, out_features=128, bias=True)
          (v_proj): Linear(in_features=896, out_features=128, bias=True)
          (o_proj): Linear(in_features=896, out_features=896, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=896, out_features=4864, bias=False)
          (up_proj): Linear(in_features=896, out_features=4864, bias=False)
          (down_proj): Linear(in_features=4864, out_features=896, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((

In [8]:
dataset[0]

{'prompt': 'Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers\' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers\' market? Let\'s think step by step and output the final answer after "Answer:".',
 'reasoning': 'Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.\nShe makes 9 * 2 = $<<9*2=18>>18 every day at the farmer’s market.\n',
 'answer': '18'}

In [None]:
class Trainer:
    def __init__(
        self,
        cfg,
    ) -> None:
        print(f"\nTrainer::-----------------------------------")
        self.world_size = 1
        self.rank = 0
        self.total_iters = cfg.total_iters
        self.total_batch_size = cfg.total_batch_size
        self.per_device_batch_size = cfg.per_device_batch_size
        assert self.total_batch_size % (self.per_device_batch_size * self.world_size) == 0, f"{self.total_batch_size=} {self.per_device_batch_size=}, {self.world_size=}"
        self.gradient_accumulation_steps = self.total_batch_size // (self.per_device_batch_size * self.world_size)
        assert self.per_device_batch_size * self.world_size * self.gradient_accumulation_steps == self.total_batch_size, f"{self.per_device_batch_size=} {self.world_size=} {self.gradient_accumulation_steps=} {self.total_batch_size=}"
        self.generations_per_prompt = cfg.generations_per_prompt
        assert self.per_device_batch_size % self.generations_per_prompt == 0, f"{self.per_device_batch_size=} {self.generations_per_prompt=}"
        self.per_device_prompt_batch_size = self.per_device_batch_size // self.generations_per_prompt
        print(f"---TRAINING CONFIG:")
        print(f"Total iters: {self.total_iters}")
        print(f"Total batch size: {self.total_batch_size}")
        print(f"Per device batch size: {self.per_device_batch_size}")
        print(f"Gradient accumulation steps: {self.gradient_accumulation_steps}")
        print(f"Generations per prompt: {self.generations_per_prompt}")
        print(f"Per device prompt batch size: {self.per_device_prompt_batch_size}")
        print(f"-----------------------------------\n")

        self.model = AutoModelForCausalLM.from_pretrained(cfg.model_name).to(device)
        self.tokenizer = AutoTokenizer.from_pretrained(cfg.model_name)
        self.ref_model = AutoModelForCausalLM.from_pretrained(cfg.model_name).to(device)
        for param in self.ref_model.parameters():
            param.requires_grad = False

        assert torch.cuda.is_available(), "CUDA must be available for training"
        assert torch.cuda.device_count() == 1
        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

        # optimizer
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=cfg.lr,
            betas=cfg.get('betas', (0.9, 0.999)),
            weight_decay=cfg.get('weight_decay', 1e-2))
        self.lr_scheduler = get_constant_schedule_with_warmup(
            optimizer=self.optimizer,
            num_warmup_steps=cfg.max_iters * cfg.get('lr_warmup_steps_ratio', 0.0))

        # dataset
        if cfg.dataset == "gsm8k":
            train_dataset = Dataset.load_from_disk(f"../data/my_data/gsm8k/train.bin")
            if cfg.dataset_size is not None:
                train_dataset = train_dataset.select(range(cfg.dataset_size))
            val_dataset = Dataset.load_from_disk(f"../data/my_data/gsm8k/trest.bin")
        self.train_loader = dataset_loader(train_dataset, self.per_device_batch_size, 0, 1, seed=cfg.get('seed', 0))
        self.val_loader = dataset_loader(val_dataset, self.per_device_batch_size, 0, 1, seed=cfg.get('seed', 0))
        self.dataset_answer_prompt = train_dataset[0]["search_string"]
        print(f"---DATASET CONFIG:")
        print(f"Dataset: {cfg.dataset}")
        if cfg.dataset_size is not None:
            print(f"Careful! Reduced dataset size for testing: {len(train_dataset)}")
        print(f"-----------------------------------\n")


        # generation
        if self.loss_type in ["pg", "logp"]:
            self.temperature = cfg.generation.temperature
            self.top_k = cfg.generation.top_k
            self.max_steps = cfg.generation.max_steps
            self.step_for_answer = cfg.generation.step_for_answer
            self.inject_answer_prompt = cfg.generation.inject_answer_prompt
            self.fixed_length = cfg.generation.fixed_length
            self.as_full_distribution = cfg.generation.as_full_distribution
            self.answer_prompt_text = "Answer:"
            assert len(self.tokenizer.encode(self.answer_prompt_text)) == 1
            self.answer_prompt_id = self.tokenizer.encode(self.answer_prompt_text)[0]
            assert len(self.tokenizer.encode("....")) == 1
            self.dot_by_dot_id = self.tokenizer.encode("....")[0]
            print(f"---GENERATION CONFIG:")
            print(f"Temperature: {self.temperature}")
            print(f"Top k: {self.top_k}")
            print(f"Max length: {self.max_length}")
            print(f"Inject answer prompt: {self.inject_answer_prompt}")
            print(f"Fixed length: {self.fixed_length}")
            print(f"Answer prompt text: {self.answer_prompt_text}")
            print(f"Answer prompt id: {self.answer_prompt_id}")
            print(f"-----------------------------------\n")

        # loss
        self.loss_type = cfg.loss_type
        if self.loss_type == "sft":
            self.sft_include_cot = cfg.sft.include_cot
            self.sft_predict_cot = cfg.sft.predict_cot
        elif self.loss_type == "pg":
            self.pg_normalization_type = None if self.generations_per_prompt == 1 else cfg.pg.normalization_type
            self.pg_dot_by_dot = cfg.pg.dot_by_dot
        self.entropy_coef = cfg.entropy_coef
        self.kl_loss_coef = cfg.kl_loss_coef
        self.kl_loss_type = cfg.kl_loss_type
        print(f"---LOSS CONFIG:")
        print(f"Loss type: {self.loss_type}")
        if self.loss_type == "sft":
            print(f"SFT include cot: {self.sft_include_cot}")
            print(f"SFT predict cot: {self.sft_predict_cot}")
        elif self.loss_type == "pg":
            print(f"PG normalization type: {self.pg_normalization_type}")
            print(f"PG dot by dot: {self.pg_dot_by_dot}")
        print(f"Entropy coef: {self.entropy_coef}")
        print(f"KL loss coef: {self.kl_loss_coef}")
        print(f"KL loss type: {self.kl_loss_type}")
        print(f"-----------------------------------\n")

        self._setup_ctx()

    def _setup_ctx(self):
        """Get the context manager"""
        dtype = (
            torch.bfloat16
            if torch.cuda.is_available() and torch.cuda.is_bf16_supported()
            else torch.float16
        )
        self._setup_scaler(dtype)
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
        return ctx

    def _setup_scaler(self, dtype=torch.float16):
        """Setup the scaler"""
        self.scaler = torch.amp.GradScaler(enabled=dtype == torch.float16, device='cuda')

    def apply_update(self):
        # once gradients are accumulated, step 
        if self.cfg.trainer.optimizer.grad_clip > 0:
            # Unscale the gradients of the optimizer's assigned params in-place
            self.scaler.unscale_(self.optimizer)
            # Clip the gradients with normalization
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.trainer.optimizer.grad_clip)
        # Perform a single optimization step
        self.scaler.step(self.optimizer)
        self.scaler.update()
        self.optimizer.zero_grad()  # Reset gradients after update


    def run_training_loop(self):
        """Run the training loop"""
        start_time = time.time()
        for i in tqdm.tqdm(range(self.total_iters), desc="Training"):
            lr = self.lr_scheduler.step(self.optimizer, i) if self.lr_scheduler is not None else self.optimizer.param_groups[0]["lr"]

            # GRADIENT ACCUMULATION
            for j in range(self.gradient_accumulation_steps):
            
                # GENERATE ROLLOUTS
                start_time = time.time()
                self.model.eval()
                # with torch.no_grad():
                # get questions (and answers)
                dataset_batch = next(self.train_loader)
                questions_text = dataset_batch["questions"]
                cot_text = dataset_batch["cot"]
                answers_text = dataset_batch["answers"]

                if self.loss_type == "sft":
                    raise NotImplementedError
                
                else:
                    questions_inputs = self.tokenizer(questions_text, return_tensors="pt", padding=True, padding_side="left")
                    answers_inputs = self.tokenizer(answers_text, return_tensors="pt", padding=True)
                    questions_inputs = {k: v.to(self.model.device) for k, v in questions_inputs.items()}
                    answers_inputs = {k: v.to(self.model.device) for k, v in answers_inputs.items()}

                    # repeat for generations_per_prompt
                    questions_inputs = careful_repeat(questions_inputs, self.generations_per_prompt)
                    answers_inputs = careful_repeat(answers_inputs, self.generations_per_prompt)

                    # generate
                    x, generations, generation_metrics = batch_generate_rnn(
                        model=self.model,
                        ref_model=self.ref_model,
                        questions_inputs=questions_inputs,
                        answers_inputs=answers_inputs,
                        max_steps=self.max_length,
                        step_for_answer=self.step_for_answer,
                        temperature=self.temperature,
                        top_k=self.top_k,
                        as_full_distribution=self.as_full_distribution,
                        dot_by_dot=self.pg_dot_by_dot,
                        dot_by_dot_id=self.dot_by_dot_id,
                        inject_answer_prompt=self.inject_answer_prompt,
                        answer_prompt_ids=self.answer_prompt_id,
                    )
                    generation_metrics = {f"gen/{k}": v for k, v in generation_metrics.items()}

                    ### rewards
                    rewards, normalized_rewards, reward_metrics = self.get_rewards(generations, answers_text)
                    reward_metrics = {k if k == "REWARD" else f"reward/{k}": v for k, v in reward_metrics.items()}
                    
                # COMPUTE LOSS
                with self.ctx: 
                    loss, loss_metrics = self.get_loss(x, normalized_rewards)
                    self.scaler.scale(loss).backward()
                loss_metrics = {f"loss/{k}": v for k, v in loss_metrics.items()}

                # UPDATE METRICS
                if j == 0:
                    metrics_s = {**generation_metrics, **loss_metrics, **reward_metrics}
                else:
                    metrics = {**generation_metrics, **loss_metrics, **reward_metrics}
                    metrics_s = {k: v + metrics[k] for k, v in metrics_s.items()}

            # UPDATE MODEL
            self.apply_update()

            # LOG
            metrics_s = {k: v / self.gradient_accumulation_steps for k, v in metrics_s.items()}
            param_metrics = get_model_param_stats(self.model, self.ref_model)
            log_dict = {"iter": i, "lr": lr}
            log_dict.update(metrics_s)
            log_dict.update({f"params/{k}": v for k, v in param_metrics.items()})
            # if self.use_wandb:
            #     wandb.log(log_dict)
            if i % 10 == 0 or i == self.total_iters - 1:
                print(f"({log_dict})\n\niter {i}: REWARD={log_dict['REWARD']:.2f}")
                num_to_print = 3
                decoded = [self.tokenizer.decode(generations[k]) for k in range(num_to_print)]
                for k in range(num_to_print):
                    print(f"  EXAMPLE {k}: (REWARD={rewards[k].item()}):\nQUESTION: {questions_text[k]}\nGENERATION: {decoded[k]}\nANSWER: {answers_text[k]}\n", "-"*50, "\n")

            
        print(f"Training time: {time.time()-start_time:.1f}s")
            
        # save the final model
        self._save_model(i, final_model=True)

    def train(self, seed=42):
        """Train the model"""
        # set seed
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
        self.run_training_loop()

    def get_rewards(self, generations, answers_text):
        """Get rewards"""
        ### check for answer
        contains_answer_prompt = torch.zeros((self.per_device_prompt_batch_size, self.generations_per_prompt), device=self.model.device)
        contains_answer = torch.zeros((self.per_device_prompt_batch_size, self.generations_per_prompt), device=self.model.device)
        generations = generations.reshape(self.per_device_prompt_batch_size, self.generations_per_prompt, -1)
        for i in range(self.per_device_prompt_batch_size):
            answer_text = answers_text[i]
            decoded_generations = self.tokenizer.batch_decode(generations[i])
            for j in range(self.generations_per_prompt):
                decoded = decoded_generations[j]
                contains_answer_prompt_ij = self.answer_prompt_text in decoded
                if contains_answer_prompt_ij:
                    contains_answer_ij = answer_text in decoded.split(self.answer_prompt_text)[1]
                else:
                    contains_answer_ij = False
                contains_answer_prompt[i, j] = contains_answer_prompt_ij
                contains_answer[i, j] = contains_answer_ij
        
        ### caluclate rewards
        rewards = contains_answer.float()
        if not self.inject_answer_prompt:
            rewards += self.answer_prompt_coef * contains_answer_prompt.float()
        if self.pg_normalization_type == "grpo":
            normalized_rewards = (rewards - rewards.mean(1, keepdim=True)) / (rewards.std(1, keepdim=True) + 1e-6)
        elif self.pg_normalization_type == "rloo":
            group_sum = rewards.sum(1, keepdim=True)
            normalized_rewards = (group_sum - rewards) / (self.generations_per_prompt - 1)
        else:
            normalized_rewards = rewards

        metrics = {
            "REWARD": rewards.mean().item(),
            "reward/reward_std": rewards.std().item(),
            "reward/reward_std_within_q": rewards.std(1).mean().item(),
            "reward/reward_std_between_q": rewards.mean(1).std().item(),
        }

        rewards = rewards.reshape(-1)
        normalized_rewards = normalized_rewards.reshape(-1)
        return rewards, normalized_rewards, metrics
    
    def get_loss(self, x, rewards):
        metrics = {}
        pg_loss, logp_loss = 0, 0
        if self.loss_type == "sft":
            raise NotImplementedError
        if self.as_full_distribution:
            gen_answer_logp, ref_answer_logp = x
            metrics["gen_answer_logp"] = gen_answer_logp.mean().item()
            metrics["ref_answer_logp"] = ref_answer_logp.mean().item()
            metrics["correct_answer_logp"] = correct_answer_logp.mean().item()
            metrics["correct_answer_p"] = correct_answer_logp.exp().mean().item()
            if self.loss_type == "pg":
                pg_loss = - torch.exp(gen_answer_logp - gen_answer_logp.detach()) * rewards
            elif self.loss_type == "logp":
                logp_loss = - correct_answer_logp
            else:
                raise ValueError(f"{self.loss_type=}")
            kl = torch.exp(ref_answer_logp - gen_answer_logp) - (ref_answer_logp - gen_answer_logp) - 1
        elif self.loss_type == "pg":
            gen_per_token_logps, ref_per_token_logps = x
            metrics["gen_per_token_logp"] = (gen_per_token_logps * loss_mask).sum().item() / loss_mask.sum().item()
            metrics["ref_per_token_logp"] = (ref_per_token_logps * loss_mask).sum().item() / loss_mask.sum().item()
            assert self.loss_type == "pg", f"{self.loss_type=}"
            pg_losses = - (torch.exp(gen_per_token_logps - gen_per_token_logps.detach()) * rewards.unsqueeze(1))
            pg_loss = (pg_losses * loss_mask).sum(-1) / loss_mask.sum(-1)
            kls = torch.exp(ref_per_token_logps - gen_per_token_logps) - (ref_per_token_logps - gen_per_token_logps) - 1
            kl = (kls * loss_mask).mean(-1)
        loss = pg_loss + logp_loss + self.kl_coef * kl
        loss = loss.mean()
        
        metrics["loss"] = loss.item()
        metrics["pg_loss"] = pg_loss.mean().item() if self.loss_type == "pg" else 0
        metrics["logp_loss"] = logp_loss.mean().item() if self.loss_type == "logp" else 0
        metrics["kl"] = kl.mean().item()
        return loss / self.gradient_accumulation_steps, metrics


In [10]:
dummy_model = AutoModelForCausalLM.from_pretrained(config.base_model).to(device)
dummy_ref_model = AutoModelForCausalLM.from_pretrained(config.base_model).to(device)
dummy_tokenizer = AutoTokenizer.from_pretrained(config.base_model)
dummy_dataset = Dataset.load_from_disk(f"../data/my_data/gsm8k/test.bin")
dummy_loader = dataset_loader(dummy_dataset, per_device_batch_size=2, rank=0, world_size=1, seed=0)
dummy_batch = next(dummy_loader)
dummy_questions_text = dummy_batch["prompt"]
dummy_answers_text = dummy_batch["answer"]
dummy_questions_inputs = dummy_tokenizer(dummy_questions_text, return_tensors="pt", padding=True, padding_side="left")
dummy_answers_inputs = dummy_tokenizer(dummy_answers_text, return_tensors="pt", padding=True)
dummy_questions_inputs = {k: v.to(device) for k, v in dummy_questions_inputs.items()}
dummy_answers_inputs = {k: v.to(device) for k, v in dummy_answers_inputs.items()}
dummy_ref_model.eval()
dummy_model.eval()

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 896)
    (layers): ModuleList(
      (0-23): 24 x Qwen2DecoderLayer(
        (self_attn): Qwen2SdpaAttention(
          (q_proj): Linear(in_features=896, out_features=896, bias=True)
          (k_proj): Linear(in_features=896, out_features=128, bias=True)
          (v_proj): Linear(in_features=896, out_features=128, bias=True)
          (o_proj): Linear(in_features=896, out_features=896, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=896, out_features=4864, bias=False)
          (up_proj): Linear(in_features=896, out_features=4864, bias=False)
          (down_proj): Linear(in_features=4864, out_features=896, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((

In [15]:
def get_next(logitss, temperature=0.0, top_k=None):
    """Get the next token"""
    batch_size, seq_length, vocab_size = logitss.shape
    assert seq_length == 1
    logitss = logitss.squeeze(1)
    if temperature == 0.0:
        token_ids = torch.argmax(logitss, dim=-1)
        probdists = torch.zeros_like(logitss)
        probdists[torch.arange(batch_size), token_ids] = 1.0
        entropy = torch.zeros(batch_size).to(logitss.device)
    else:
        logitss = logitss / temperature
        if top_k is not None:
            logitss_k, idxs_k = torch.topk(logitss, min(top_k, vocab_size), dim=-1) # (batch_size, top_k)
            probs_k = torch.nn.functional.softmax(logitss_k, dim=-1) # (batch_size, top_k)
            idxs = torch.multinomial(probs_k, num_samples=1).squeeze(1) # (batch_size,)
            token_ids = idxs_k[torch.arange(batch_size), idxs] # (batch_size)
            probdists = torch.zeros_like(logitss)
            # next_probdist_s[torch.arange(batch_size), idx_s_k.squeeze(1)] = probs_k.squeeze(1)
            probdists.scatter_(1, idxs_k, probs_k)
            if top_k == 1:
                token_ids2 = torch.argmax(logitss, dim=-1)
                probdists2 = torch.zeros_like(logitss)
                probdists2[torch.arange(batch_size), token_ids2] = 1.0
                assert (token_ids == token_ids2).all(), f"{token_ids=}, {token_ids2=}"
                assert torch.allclose(probdists, probdists2), f"{probdists2=}, {probdists2=}"
        else:
            probdists = torch.nn.functional.softmax(logitss, dim=-1) # (batch_size, vocab_size)
            token_ids = torch.multinomial(probdists, num_samples=1).squeeze(1) # (batch_size)
        entropy = -torch.sum(probdists * torch.log(probdists), dim=-1)

    token_ids = token_ids.unsqueeze(1)
    probdists = probdists.unsqueeze(1)
    return token_ids, probdists, entropy

In [16]:
def single_step(model, next_input, attention_mask, position_ids, past_key_values, as_full_distribution=False):
    batch_size = next_input.shape[0]
    prev_seq_length = past_key_values[0][0].shape[2]
    assert position_ids.shape == (batch_size, 1), f"{position_ids.shape=}"
    assert attention_mask.shape == (batch_size, prev_seq_length+1), f"{attention_mask.shape=}, {prev_seq_length=}"

    if as_full_distribution:
        vocab_size = model.model.config.vocab_size
        assert next_input.shape == (batch_size, 1, vocab_size)
        all_embeds = model.model.embed_tokens.weight
        hidden_dim = all_embeds.shape[1]
        assert all_embeds.shape[0] == vocab_size
        inputs_embeds = torch.matmul(next_input, all_embeds)
        assert inputs_embeds.shape == (batch_size, 1, hidden_dim)
        outputs = model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values)
    else:
        assert next_input.shape == (batch_size, 1)
        outputs = model(input_ids=next_input, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values)
    logits = outputs.logits
    past_key_values = outputs.past_key_values
    return logits, past_key_values



In [None]:
# @torch.no_grad()
def batch_generate_rnn(
        model,
        ref_model,
        questions_inputs,
        answers_inputs,
        max_steps=30,
        step_for_answer=20,
        temperature=1.0,
        top_k=None,
        as_full_distribution=False,
        dot_by_dot=False,
        dot_by_dot_id=None,
        inject_answer_prompt=False,
        answer_prompt_ids=None,
    ):
    metrics = {}
    assert not (as_full_distribution and dot_by_dot), f"{as_full_distribution=}, {dot_by_dot=}"
    device = model.device
    assert device.type == "cuda", f"{model.device=}"
    assert ref_model.device == device, f"{ref_model.device=}, {device=}"
    model.eval()
    ref_model.eval()
    batch_size = questions_inputs["input_ids"].shape[0]
    vocab_size = model.config.vocab_size
    prompt_length = questions_inputs["input_ids"].shape[1]

    #### PROMPT FORWARD PASS
    position_ids = questions_inputs["attention_mask"].cumsum(dim=1) - 1
    outputs = model(**questions_inputs, position_ids=position_ids)
    with torch.no_grad():
        ref_outputs = ref_model(**questions_inputs)
    prompt_attention_mask = questions_inputs["attention_mask"]
    prompt_end_position_ids = position_ids[:, -1:] + 1
    past_key_values = outputs.past_key_values
    ref_past_keys_values = ref_outputs.past_key_values
    logits = outputs.logits[:, -1:]
    ref_logits = ref_outputs.logits[:, -1:]

    def make_attention_mask(t, new_seq_length=1):
        return torch.cat([prompt_attention_mask, torch.ones((batch_size, t+new_seq_length), device=device)], dim=1)
    def make_position_ids(t, new_seq_length=1):
        if new_seq_length == 1:
            return prompt_end_position_ids + t
        else:
            return prompt_end_position_ids + t + torch.arange(new_seq_length).unsqueeze(0).to(device)

    all_gen_logits = torch.zeros((batch_size, 0, vocab_size), device=device)
    all_ref_logits = torch.zeros((batch_size, 0, vocab_size), device=device)
    generations = torch.zeros((batch_size, 0), device=device, dtype=torch.int)
    generation_mask = torch.ones((batch_size, 0), device=device)
    entropy = torch.zeros((batch_size,), device=device)

    ### REASONING FORWARD PASSES
    for t in range(step_for_answer):
        next_token_ids, next_probdists, step_entropy = get_next(logits, temperature=temperature, top_k=top_k)
        all_gen_logits = torch.cat((all_gen_logits, logits), dim=1)
        all_ref_logits = torch.cat((all_ref_logits, ref_logits), dim=1)
        generations = torch.cat((generations, next_token_ids), dim=1)
        generation_mask = torch.cat((generation_mask, torch.ones((batch_size, 1), device=device)), dim=1)
        entropy += step_entropy
        # forward pass of next token
        if as_full_distribution:
            next_input = next_probdists
        elif dot_by_dot:
            next_input = torch.full((batch_size, 1), dot_by_dot_id, dtype=torch.long, device=device)
        else:
            next_input = next_token_ids
        logits, past_key_values = single_step(model,
            next_input=next_input,
            attention_mask=make_attention_mask(t),
            position_ids=make_position_ids(t),
            past_key_values=past_key_values,
            as_full_distribution=as_full_distribution)
        with torch.no_grad():
            ref_logits, ref_past_keys_values = single_step(ref_model,
                next_input=next_input,
                attention_mask=make_attention_mask(t),
                position_ids=make_position_ids(t),
                past_key_values=ref_past_keys_values,
                as_full_distribution=as_full_distribution)
            
    ### INJECT ANSWER PROMPT FORWARD PASS
    t = step_for_answer
    if inject_answer_prompt:
        repeated_answer_prompt_ids = torch.tensor(answer_prompt_ids).unsqueeze(0).repeat(batch_size, 1).to(next_token_ids.dtype).to(device)
        all_gen_logits = torch.cat((all_gen_logits, torch.zeros((batch_size, len(answer_prompt_ids), vocab_size), device=device)), dim=1)
        all_ref_logits = torch.cat((all_ref_logits, torch.zeros((batch_size, len(answer_prompt_ids), vocab_size), device=device)), dim=1)
        generations = torch.cat((generations, repeated_answer_prompt_ids), dim=1)
        generation_mask = torch.cat((generation_mask, torch.zeros((batch_size, len(answer_prompt_ids)), device=device)), dim=1)
        outputs = model(
            input_ids=repeated_answer_prompt_ids,
            attention_mask=make_attention_mask(step_for_answer, new_seq_length=len(answer_prompt_ids)),
            position_ids=make_position_ids(step_for_answer, new_seq_length=len(answer_prompt_ids)),
            past_key_values=past_key_values)
        with torch.no_grad():
            ref_outputs = ref_model(
                input_ids=repeated_answer_prompt_ids,
                attention_mask=make_attention_mask(t, new_seq_length=len(answer_prompt_ids)),
                position_ids=make_position_ids(t, new_seq_length=len(answer_prompt_ids)),
                past_key_values=ref_past_keys_values)
        logits = outputs.logits
        ref_logits = ref_outputs.logits
        past_key_values = outputs.past_key_values
        ref_past_keys_values = ref_outputs.past_key_values
        t += len(answer_prompt_ids)
        all_gen_logits = torch.cat((all_gen_logits, logits), dim=1)
        all_ref_logits = torch.cat((all_ref_logits, ref_logits), dim=1)

    ### ANSWER FORWARD PASS
    if as_full_distribution:
        answer_length = answers_inputs["input_ids"].shape[1]
        answer_logits = model(
            input_ids=answers_inputs["input_ids"][:, :-1],
            attention_mask=make_attention_mask(t, new_seq_length=answer_length-1),
            position_ids=make_position_ids(t, new_seq_length=answer_length-1),
            past_key_values=past_key_values).logits
        with torch.no_grad():
            ref_answer_logits = ref_model(
                input_ids=answers_inputs["input_ids"][:, :-1],
                attention_mask=make_attention_mask(t+1, new_seq_length=answer_length-1),
                position_ids=make_position_ids(t+1, new_seq_length=answer_length-1),
                past_key_values=ref_past_keys_values).logits
        answer_logits = torch.cat((all_gen_logits[:, -1:], answer_logits), dim=1)
        ref_answer_logits = torch.cat((all_ref_logits[:, -1:], ref_answer_logits), dim=1)
        per_token_answer_logps = torch.gather(answer_logits, 2, answers_inputs["input_ids"].to(torch.long).unsqueeze(-1)).squeeze(-1)
        per_token_ref_answer_logps = torch.gather(ref_answer_logits, 2, answers_inputs["input_ids"].to(torch.long).unsqueeze(-1)).squeeze(-1)
        assert per_token_answer_logps.shape == answers_inputs["attention_mask"].shape, f"{per_token_answer_logps.shape=}, {answers_inputs['attention_mask'].shape=}"
        answer_logps = (per_token_answer_logps * answers_inputs["attention_mask"]).sum(dim=-1)
        ref_answer_logps = (per_token_ref_answer_logps * answers_inputs["attention_mask"]).sum(dim=-1)
        x = (answer_logps, ref_answer_logps)

        answer_perplexity = torch.exp(-answer_logps / answers_inputs["attention_mask"].sum(dim=-1))
        answer_perplexity_ref = torch.exp(-ref_answer_logps / answers_inputs["attention_mask"].sum(dim=-1))
        metrics["answer_logps"] = answer_logps.mean().item()
        metrics["answer_logps_ref"] = ref_answer_logps.mean().item()
        metrics["answer_logps_diff"] = (answer_logps - ref_answer_logps).mean().item()
        metrics["answer_perplexity"] = answer_perplexity.mean().item()
        metrics["answer_perplexity_ref"] = answer_perplexity_ref.mean().item()
        metrics["answer_perplexity_diff"] = (answer_perplexity - answer_perplexity_ref).mean().item()

    ### CONTINUE FORWARD PASS STEPS
    for t in range(t, max_steps):
        next_token_ids, next_probdists, step_entropy = get_next(logits[:, -1:], temperature=temperature, top_k=top_k)
        all_gen_logits = torch.cat((all_gen_logits, logits), dim=1)
        all_ref_logits = torch.cat((all_ref_logits, ref_logits), dim=1)
        generations = torch.cat((generations, next_token_ids), dim=1)
        generation_mask = torch.cat((generation_mask, torch.ones((batch_size, 1), device=device)), dim=1)
        entropy += step_entropy
        # forward pass of next token
        if as_full_distribution:
            next_input = next_probdists
        elif dot_by_dot:
            next_input = torch.full((batch_size, 1), dot_by_dot_id, dtype=torch.long, device=device)
        else:
            next_input = next_token_ids
        logits, past_key_values = single_step(model,
            next_input=next_input,
            attention_mask=make_attention_mask(t),
            position_ids=make_position_ids(t),
            past_key_values=past_key_values,
            as_full_distribution=as_full_distribution)
        with torch.no_grad():
            ref_logits, ref_past_keys_values = single_step(ref_model,
                next_input=next_input,
                attention_mask=make_attention_mask(t),
                position_ids=make_position_ids(t),
                past_key_values=ref_past_keys_values,
                as_full_distribution=as_full_distribution)
            
        all_gen_logits = torch.cat((all_gen_logits, logits), dim=1)
        all_ref_logits = torch.cat((all_ref_logits, ref_logits), dim=1)
    next_token_ids, next_probdists, step_entropy = get_next(logits, temperature=temperature, top_k=top_k)
    all_gen_logits = torch.cat((all_gen_logits, logits), dim=1)
    all_ref_logits = torch.cat((all_ref_logits, ref_logits), dim=1)
    generations = torch.cat((generations, next_token_ids), dim=1)
    generation_mask = torch.cat((generation_mask, torch.ones((batch_size, 1), device=device)), dim=1)
    entropy += step_entropy

    if not as_full_distribution:
        gen_per_token_logps = torch.gather(all_gen_logits, 2, generations.to(torch.long).unsqueeze(-1)).squeeze(-1)
        ref_per_token_logps = torch.gather(all_ref_logits, 2, generations.to(torch.long).unsqueeze(-1)).squeeze(-1)
        if inject_answer_prompt:
            gen_per_token_logps = gen_per_token_logps * generation_mask
            ref_per_token_logps = ref_per_token_logps * generation_mask
        x = (gen_per_token_logps, ref_per_token_logps)

        metrics["logps"] = gen_per_token_logps.mean().item()
        metrics["logps_ref"] = ref_per_token_logps.mean().item()
        metrics["logps_diff"] = (gen_per_token_logps - ref_per_token_logps).mean().item()

    metrics = {
        "entropy": entropy.mean().item() / max_steps,
        "entropy_std": entropy.std().item() / max_steps,
    }
    
    assert x[0].requires_grad == True, f"{x[0].requires_grad=}"
    assert x[1].requires_grad == False, f"{x[1].requires_grad=}"
    return x, generations, metrics, # past_key_values

In [24]:
x, generations, metrics, past_key_values = batch_generate_rnn(
    model=dummy_model,
    ref_model=dummy_ref_model,
    questions_inputs=dummy_questions_inputs,
    answers_inputs=dummy_answers_inputs,
    max_steps=30,
    step_for_answer=20,
    temperature=1.0,
    top_k=None,
    as_full_distribution=False,
    dot_by_dot=False,
    dot_by_dot_id=dummy_tokenizer.encode("....")[0],
    inject_answer_prompt=False,
    answer_prompt_ids=dummy_tokenizer.encode("...Answer:"),
)
print(x[0].shape, generations.shape, metrics)
dummy_tokenizer.batch_decode(generations)

torch.Size([2, 31]) torch.Size([2, 31]) {'entropy': 0.9281216939290364, 'entropy_std': 0.40576279958089195}


[" To find out how much Janet makes at the farmers' market daily from selling the remaining eggs, we can follow these steps:\n\n1. Calculate the total number",
 ' To find the total number of bolts it takes, we need to first calculate the number of bolts of blue and white fiber required for the lord in each row']

In [25]:
gen_out = dummy_model.generate(
    input_ids=dummy_questions_inputs["input_ids"],
    attention_mask=dummy_questions_inputs["attention_mask"],
    max_new_tokens=30,
    temperature=1.0,
    top_k=None,
    do_sample=False,
    return_dict_in_generate=True,
)
prompt_length = dummy_questions_inputs["input_ids"].shape[1]
generations = gen_out.sequences if isinstance(gen_out, dict) else gen_out
generations = generations[:, prompt_length:]
print(dummy_tokenizer.batch_decode(generations))


Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


[" To determine how much Janet makes in dollars every day at the farmers' market, we need to follow these steps:\n\n1. Calculate the total number of", ' To find the total number of bolts needed, we need to calculate the number of bolts required for blue and white fibers separately and then add them together.\n\n']


In [26]:
for kv1, kv2 in zip(past_key_values, gen_out.past_key_values):
    for i in range(min(kv1[0].shape[2], kv2[0].shape[2])):
        print(f"{i=}: keys: {kv1[0].shape}, {kv2[0].shape}")
        print(f"{i=}: keys: {kv1[0][0, 0, i, :5].tolist()}")
        print(f"{i=}: keys: {kv2[0][0, 0, i, :5].tolist()}")
        print(f"{i=}: values: {kv1[1].shape}, {kv2[1].shape}")
        print(f"{i=}: values: {kv1[1][0, 0, i, :5].tolist()}")
        print(f"{i=}: values: {kv2[1][0, 0, i, :5].tolist()}")
    break

i=0: keys: torch.Size([2, 2, 110, 64]), torch.Size([2, 2, 109, 64])
i=0: keys: [-8.410331726074219, -3.307109832763672, -6.438131332397461, 0.6713922619819641, -0.20497481524944305]
i=0: keys: [-8.410331726074219, -3.307109832763672, -6.438131332397461, 0.6713922619819641, -0.20497481524944305]
i=0: values: torch.Size([2, 2, 110, 64]), torch.Size([2, 2, 109, 64])
i=0: values: [-0.014144073240458965, 0.03389165922999382, -0.026016375049948692, 0.0130347590893507, -0.011828195303678513]
i=0: values: [-0.014144073240458965, 0.03389165922999382, -0.026016375049948692, 0.0130347590893507, -0.011828195303678513]
i=1: keys: torch.Size([2, 2, 110, 64]), torch.Size([2, 2, 109, 64])
i=1: keys: [-9.8622407913208, -7.930610179901123, -7.341488361358643, 2.996467113494873, -0.14621634781360626]
i=1: keys: [-9.8622407913208, -7.930610179901123, -7.341488361358643, 2.996467113494873, -0.14621634781360626]
i=1: values: torch.Size([2, 2, 110, 64]), torch.Size([2, 2, 109, 64])
i=1: values: [0.0100266905

In [27]:
for k in dummy_model(**dummy_questions_inputs, return_dict=True).keys():
    print(k)

logits
past_key_values
