In [None]:
%load_ext autoreload
%autoreload 2

import os
import pickle
from contextlib import nullcontext
import torch
import tiktoken
from tqdm import tqdm
from torch.nn import functional as F
from context_compression.model import GPT, GPTConfig
from context_compression.attn import AttentionKind
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
out_dir = 'out' # ignored if init_from is not 'resume'
start = "Hello, I'm a language model," # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
num_samples = 10 # number of samples to draw
max_new_tokens = 256 # number of tokens generated in each sample
temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability
seed = 1337
device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
compile = True # use PyTorch 2.0 to compile the model to be faster

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

# init from a model saved in a specific directory


# model = GPT(GPTConfig(attention_kind=AttentionKind.SELECTIVE, for_inference=True, vocab_size=50304))
# ckpt_path = "/workspace/context-compression/selective_run_0_continued/model_09999.pt"

# model = GPT(GPTConfig(attention_kind=AttentionKind.SELECTIVE, for_inference=True, vocab_size=50304))
# ckpt_path = "/workspace/context-compression/selective_run_0_continued/model_09999.pt"

# model = GPT(GPTConfig(attention_kind=AttentionKind.SELECTIVE, for_inference=False, vocab_size=50304))
# ckpt_path = "/workspace/context-compression/memory_loss_run_0/model_09999.pt"

# model = GPT(GPTConfig(attention_kind=AttentionKind.SELF, for_inference=True, vocab_size=50304))
# ckpt_path = "/workspace/context-compression/unselective_run_0/model_09999.pt"

# model = GPT(GPTConfig(attention_kind=AttentionKind.SELECTIVE, for_inference=False, vocab_size=50304))
# ckpt_path = "/root/.cache/huggingface/hub/models--Yorth--selective1/snapshots/1d3d987c90be4b8d6f58de60749ba5823f0ecd29/model.pt"

from huggingface_hub import hf_hub_download
# config = GPTConfig(attention_kind=AttentionKind.SELECTIVE, for_inference=True, vocab_size=50304,n_head=13)
# model = GPT(config)
# ckpt_path = hf_hub_download(repo_id="andrew-healey/context-compression",filename="self_to_selective_run_0_restarted/model_02499.pt")
# config = GPTConfig(attention_kind=AttentionKind.SELECTIVE, for_inference=True, vocab_size=50304,n_head=13)
# ckpt_path = hf_hub_download(repo_id="andrew-healey/context-compression",filename="protection_none_torch_compile/model_02499.pt")
# new_eps_ratios, new_eps_losses = get_validation_loss_at_diff_ratios(new_config, ckpt_path)

config = GPTConfig(attention_kind=AttentionKind.SELECTIVE, for_inference=True, vocab_size=50304)
model = GPT(config)
ckpt_path = hf_hub_download(repo_id="Yorth/selective1",filename="model.pt")

checkpoint = torch.load(ckpt_path, map_location=device)

state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)


In [None]:
# Tokenize input text using GPT-2 tokenizer
enc = tiktoken.get_encoding("gpt2")
text = "Hello world!"
tokens = enc.encode(text)
tokens = torch.tensor(tokens).unsqueeze(0).to(device) # Add batch dimension and move to device

ff_cache = []

# Run through model
with torch.no_grad():
    with ctx:
        logits, loss, losses = model(tokens,ff_cache=ff_cache)
        
print(f"Input tokens: {tokens}")
print(f"Output logits shape: {logits.shape}")


T,Tp = ff_cache[0].squeeze(0).shape

# Plot the attention masks as a bitmap
import matplotlib.pyplot as plt

# Take the first layer's mask from the first batch
mask = ff_cache[0].squeeze(0)  # Shape: [T,Tp]

plt.figure(figsize=(10,10))
plt.imshow(mask, cmap='viridis', aspect='equal')
plt.colorbar(label='Mask Value')
plt.title('Attention Mask')
plt.xlabel('Key Position')
plt.ylabel('Query Position') 
plt.show()

# Optionally plot masks from all layers
fig, axes = plt.subplots(len(ff_cache), 1, figsize=(10, 4*len(ff_cache)))
for i, layer_mask in enumerate(ff_cache):
    mask = layer_mask[0]  # Take first batch
    axes[i].imshow(mask, cmap='RdYlGn_r', vmin=0, vmax=mask.max(), aspect='equal')
    axes[i].set_title(f'Layer {i} Mask')
    axes[i].set_xlabel('Key Position')
    axes[i].set_ylabel('Query Position')
plt.tight_layout()
plt.show()
