In [1]:
%load_ext autoreload
%autoreload 2

import os
import pickle
from contextlib import nullcontext
import torch
import tiktoken
from tqdm import tqdm
from torch.nn import functional as F
from context_compression.model import GPT, GPTConfig
from context_compression.attn import AttentionKind
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
out_dir = 'out' # ignored if init_from is not 'resume'
start = "Hello, I'm a language model," # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
num_samples = 10 # number of samples to draw
max_new_tokens = 256 # number of tokens generated in each sample
temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability
seed = 1337
device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
compile = False # use PyTorch 2.0 to compile the model to be faster

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

# init from a model saved in a specific directory
ckpt_path = "/workspace/context-compression/selective_run_0_continued/model_09999.pt"#os.path.join(out_dir, 'model.pt')
checkpoint = torch.load(ckpt_path, map_location=device)

model = GPT(GPTConfig(attention_kind=AttentionKind.SELECTIVE, for_inference=False, vocab_size=50304))
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)

model.eval()
model.to(device)
if compile:
    model = torch.compile(model) # requires PyTorch 2.0 (optional)


  checkpoint = torch.load(ckpt_path, map_location=device)


In [None]:
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
decode = lambda l: enc.decode(l)

# encode the beginning of the prompt
if start.startswith('FILE:'):
    with open(start[5:], 'r', encoding='utf-8') as f:
        start = f.read()
start_ids = encode(start)
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])

generate_samples = False

if False:
    output_file = "generated_samples.txt"
    with open(output_file, 'w', encoding='utf-8') as f:
        with torch.no_grad():
            with ctx:
                # for k in range(num_samples):
                print(x.shape)
                y = model.generate(x.repeat(num_samples, 1), max_new_tokens, temperature=temperature, top_k=top_k)
                generated_texts = [decode(y[i].tolist()) for i in range(num_samples)]

                # Write to file
                for generated_text in generated_texts:
                    f.write(generated_text)
                    f.write('\n' + '-' * 80 + '\n')  # Separator line

                # Optional: Also print to console to see progress
                print(f"Generated samples")

    print(f"\nGeneration complete. Outputs saved to {output_file}")
else:
    print("Skipping generation")

In [6]:
import numpy as np

B = 4 # micro batch size
T = 1024

from context_compression.data import DataLoaderLite

val_loader = DataLoaderLite(B=B, T=T, split="val", process_rank=0, num_processes=1)

model.eval()
val_loader.reset()
with torch.no_grad():
    val_loss_accum = 0.0
    val_loss_steps = 20
    model.transformer.h[0].attn.FF_values = []  # reset FF values list before validation

    for i in tqdm(range(val_loss_steps)):
        x, y = val_loader.next_batch()
        x, y = x.to(device), y.to(device)
        with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
            logits, loss = model(x, y)
        loss = loss / val_loss_steps
        val_loss_accum += loss.detach()


print(f"validation loss: {val_loss_accum.item():.4f}")
validation_perplexity = torch.exp(torch.tensor(val_loss_accum.item()))
print(f"validation perplexity: {validation_perplexity:.4f}")

100%|██████████| 20/20 [00:00<00:00, 23.95it/s]

validation loss: 3.0930
validation perplexity: 22.0421





In [5]:
from context_compression.hellaswag import render_example, iterate_examples
from context_compression.data import get_most_likely_row

num_correct_norm = 0
num_total = 0
for i, example in tqdm(enumerate(iterate_examples("val")), total=10042):
    # render the example into tokens and labels
    _, tokens, mask, label = render_example(example)
    tokens = tokens.to(device)
    mask = mask.to(device)
    # get the logits
    with torch.no_grad():
        with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
            logits, loss = model(tokens)
        pred_norm = get_most_likely_row(tokens, mask, logits)
    num_total += 1
    num_correct_norm += int(pred_norm == label)
acc_norm = num_correct_norm / num_total
print(f"HellaSwag accuracy: {num_correct_norm}/{num_total}={acc_norm:.4f}")

  0%|          | 0/10042 [00:00<?, ?it/s]

tensor([[   32,   582,   318,  5586,   319,   257,  9753,    13,   339,   318,
          1262, 14441,   284, 14441,   257,  5166,   286,  1341,   271,    13],
        [   32,   582,   318,  5586,   319,   257,  9753,    13,   339,   318,
         34759,  1241, 19867,   572,    13,     0,     0,     0,     0,     0],
        [   32,   582,   318,  5586,   319,   257,  9753,    13,   339,   318,
          4769,   257,  6437,  1134,   338, 23441,    13,     0,     0,     0],
        [   32,   582,   318,  5586,   319,   257,  9753,    13,   339,  4940,
         10427,   510,  9753,   278,   319,   257,  9753,    13,     0,     0]])





Exception: Stop here