In [1]:
%load_ext autoreload
%autoreload 2

import os
import pickle
from contextlib import nullcontext
import torch
import tiktoken
from tqdm import tqdm
from torch.nn import functional as F
from context_compression.model import GPT, GPTConfig
from context_compression.attn import AttentionKind
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
out_dir = 'out' # ignored if init_from is not 'resume'
start = "Hello, I'm a language model," # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
num_samples = 10 # number of samples to draw
max_new_tokens = 256 # number of tokens generated in each sample
temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability
seed = 1337
device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
compile = False # use PyTorch 2.0 to compile the model to be faster

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)


In [5]:
# init from a model saved in a specific directory

# model = GPT(GPTConfig(attention_kind=AttentionKind.SELECTIVE, for_inference=True, vocab_size=50304))
# ckpt_path = "/workspace/context-compression/selective_run_0_continued/model_09999.pt"

# model = GPT(GPTConfig(attention_kind=AttentionKind.SELECTIVE, for_inference=True, vocab_size=50304))
# ckpt_path = "/workspace/context-compression/selective_run_0_continued/model_09999.pt"

# model = GPT(GPTConfig(attention_kind=AttentionKind.SELECTIVE, for_inference=False, vocab_size=50304))
# ckpt_path = "/workspace/context-compression/memory_loss_run_0/model_09999.pt"

# config = GPTConfig(attention_kind=AttentionKind.SELF, for_inference=True, vocab_size=50304)
# model = GPT(config)
# ckpt_path = "/workspace/context-compression/unselective_run_0/model_09999.pt"

config = GPTConfig(attention_kind=AttentionKind.SELECTIVE, for_inference=False, vocab_size=50304)
model = GPT(config)
ckpt_path = "/root/.cache/huggingface/hub/models--Yorth--selective1/snapshots/1d3d987c90be4b8d6f58de60749ba5823f0ecd29/model.pt"


checkpoint = torch.load(ckpt_path, map_location=device)

state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)

from context_compression.add_a_head import grow_qkv_o

add_a_head = False
if add_a_head:
    grow_qkv_o(config,model)


model.eval()
model.to(device)
if compile:
    model = torch.compile(model) # requires PyTorch 2.0 (optional)

  checkpoint = torch.load(ckpt_path, map_location=device)


In [6]:
x = torch.randint(0,200, (4,20),device=device).int()
# print(x)
model(x)

(tensor([[[  2.8258,   2.2109,  -2.9415,  ...,  -8.3719,  -8.6030,  -8.4726],
          [ -0.9571,   3.3208,   2.1000,  ...,  -3.8942,  -3.8509,  -3.5520],
          [  3.5440,   3.7722,   2.2224,  ..., -10.9395, -11.3086, -10.3937],
          ...,
          [  4.6336,   2.8386,   4.9399,  ...,  -9.1376,  -9.3667,  -8.7789],
          [  4.4314,   2.1704,   3.7883,  ..., -10.3590, -10.7607, -10.1031],
          [  4.5444,   2.6262,   4.1197,  ..., -10.6649, -10.9120, -10.2389]],
 
         [[  1.5417,   2.7827,   0.1449,  ...,  -9.6983, -10.0805,  -9.5776],
          [  1.9891,   2.7411,   0.0883,  ..., -11.1865, -11.6592, -11.0090],
          [  3.0581,   2.4795,   0.0610,  ..., -10.9167, -11.3798, -10.8770],
          ...,
          [ -1.2146,   2.0913,  -0.0435,  ...,   2.2149,   2.0899,   1.7990],
          [  5.9899,   4.2384,   3.8320,  ...,  -5.3356,  -5.8125,  -5.2999],
          [ -0.0913,   2.3449,   0.9392,  ...,  -1.0786,  -1.2305,  -0.5390]],
 
         [[  3.3563,   2.515