# Sample and evaluate both character transformer models 

## Sampling...

In [11]:
import torch
from torch.nn import functional as F
from contextlib import nullcontext

import math
import os
import time
import numpy as np

In [2]:
from baseline.model import GPTModel, GPTConfig
from prefix_suffix.model import CharModel, CharConfig
from data.tokenizer import BuildTokenizer

baseline_checkpoint_path = '/home/ubuntu/abhinav/char_level_model_transformer/baseline/checkpoints/baseline_ckpt.pt'
prefix_suffix_checkpoint_path = '/home/ubuntu/abhinav/char_level_model_transformer/prefix_suffix/checkpoints/prefix_suffix_ckpt.pt'

device = 'cuda'
dtype = 'bfloat16'
seed = 1337

# Torch Magic
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

### Load Baseline Model

In [3]:
VOCAB_SIZE = 256
CONTEXT_LENGTH = 512
HIDDEN_SIZE = 768
NUM_LAYERS = 12
NUM_HEADS = 12
BATCH_SIZE = 12
DROPOUT = 0.2

baseline_ckpt = torch.load(baseline_checkpoint_path, map_location=device)
baseline_config = GPTConfig(
    vocab_size = VOCAB_SIZE,
    context_length = CONTEXT_LENGTH,
    n_layers = NUM_LAYERS,
    n_hidden = HIDDEN_SIZE,
    n_head = NUM_HEADS,
    dropout = DROPOUT,
    bias = True
)
baseline_model = GPTModel(baseline_config)

state_dict = baseline_ckpt['model']
bad_prefix = '_orig_mod.'
for key, value in list(state_dict.items()):
    print(f"Key: {key}; Value: {value}")
    if key.startswith(bad_prefix):
        state_dict[key[len(bad_prefix):]] = state_dict.pop(key)

baseline_model.load_state_dict(state_dict)
baseline_model.to(device)

  baseline_ckpt = torch.load(baseline_checkpoint_path, map_location=device)


Initialized GPT Model! Number of Parameters: 85.645824
Key: _orig_mod.transformer.embedding_lookup.weight; Value: tensor([[-0.0235, -0.0916, -0.0835,  ..., -0.0385,  0.0472, -0.0162],
        [ 0.0303, -0.0406,  0.0066,  ..., -0.0407,  0.0492, -0.0143],
        [ 0.0411, -0.0495,  0.0170,  ..., -0.0171, -0.0323,  0.0215],
        ...,
        [-0.0575, -0.1151, -0.0798,  ..., -0.0444,  0.0515,  0.0165],
        [-0.0267, -0.1009, -0.0795,  ..., -0.0296,  0.0276,  0.0040],
        [-0.0449, -0.0879, -0.0894,  ..., -0.0253,  0.0611, -0.0075]],
       device='cuda:0')
Key: _orig_mod.transformer.positional_embedding.weight; Value: tensor([[ 0.0063, -0.0661,  0.0195,  ...,  0.0235, -0.0005, -0.0187],
        [-0.1033,  0.0116, -0.1044,  ...,  0.0101,  0.0706, -0.0672],
        [-0.0579,  0.0296, -0.0490,  ..., -0.0443,  0.0087, -0.0661],
        ...,
        [-0.0216, -0.0035, -0.0755,  ..., -0.0236,  0.0337,  0.0284],
        [-0.0198, -0.0143, -0.0407,  ...,  0.0107,  0.0233, -0.0120],
  

GPTModel(
  (transformer): ModuleDict(
    (embedding_lookup): Embedding(256, 768)
    (positional_embedding): Embedding(512, 768)
    (dropout): Dropout(p=0.2, inplace=False)
    (layers): ModuleList(
      (0-11): 12 x GPTLayer(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (w_o): Linear(in_features=768, out_features=768, bias=True)
          (attn_dropout): Dropout(p=0.2, inplace=False)
          (residual_dropout): Dropout(p=0.2, inplace=False)
        )
        (ln_2): LayerNorm()
        (ff): FeedForward(
          (linear_1): Linear(in_features=768, out_features=3072, bias=True)
          (gelu): GELU(approximate='none')
          (linear_2): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
    )
    (layer_norm): LayerNorm()
  )
  (vocab_projection): Linear(in_features=768, out_features=256, bias=False)
)

### Sample character generations as a sanity check

In [4]:
# Rebuild tokenizer on data
data_path = '/home/ubuntu/abhinav/char_level_model_transformer/data/enwik8'

t0 = time.time()
print(f"Starting to ingest Train/Valid/Test data...")
tokenizer = BuildTokenizer(path=data_path, tokenize_from_scratch=True)
print(f"Finished ingesting; Time: {(time.time() - t0) * 1000 :.2f}ms")

Starting to ingest Train/Valid/Test data...
Starting Train Ingestion...


Train dataset size: 90000001
Starting Valid Ingestion...
Valid dataset size: 5000001
Starting Test Ingestion...
Test dataset size: 5000001
Finished ingesting; Time: 22135.41ms


In [5]:
print("Unique Characters:", tokenizer.num_unique_chars)
print("Char2Idx:", tokenizer.char2idx)

Unique Characters: 206
Char2Idx: {'32': 1, '60': 2, '109': 3, '101': 4, '100': 5, '105': 6, '97': 7, '119': 8, '107': 9, '120': 10, '108': 11, '110': 12, '115': 13, '61': 14, '34': 15, '104': 16, '116': 17, '112': 18, '58': 19, '47': 20, '46': 21, '111': 22, '114': 23, '103': 24, '45': 25, '48': 26, '51': 27, '50': 28, '49': 29, '88': 30, '77': 31, '76': 32, '83': 33, '99': 34, '118': 35, '62': 36, '<eos>': 37, '102': 38, '87': 39, '98': 40, '95': 41, '80': 42, '54': 43, '121': 44, '84': 45, '85': 46, '52': 47, '53': 48, '73': 49, '55': 50, '56': 51, '57': 52, '72': 53, '67': 54, '65': 55, '90': 56, '117': 57, '74': 58, '35': 59, '82': 60, '69': 61, '68': 62, '91': 63, '93': 64, '123': 65, '125': 66, '70': 67, '42': 68, '39': 69, '124': 70, '106': 71, '38': 72, '113': 73, '59': 74, '122': 75, '44': 76, '71': 77, '40': 78, '41': 79, '79': 80, '75': 81, '86': 82, '66': 83, '78': 84, '195': 85, '169': 86, '63': 87, '89': 88, '170': 89, '162': 90, '160': 91, '179': 92, '81': 93, '33': 94, 

In [6]:
def encode_text(text):
    return (torch.from_numpy(tokenizer.encode(text)).unsqueeze(0).to(device))    

def decode_tokens(text):
    return (tokenizer.decode(text[0].cpu().numpy()))

starting_prefix = "Hello! I am a human"
x = encode_text(starting_prefix)
print(f"Starting Prefix: {starting_prefix}")

num_samples = 20
num_new_tokens = 200
temp = 0.8
top_k = 200

for i in range(num_samples):
    tokens = baseline_model.generate(x, num_new_tokens, temp, top_k)
    print(f"Sample #{i + 1}: {decode_tokens(tokens)}")
    

Starting Prefix: Hello! I am a human
Sample #1: Hello! I am a human
kkkkkkk de>240724 /de>
kkkkkk /contrdbutor>
kkkkkk co<<mnt>Robot:kFdxdngkEeedmkRmdchmykchismkadthkirtdclm /co<<mnt>
kkkkkk tmxtkx<l:spicm="prmsmrvm">#REDIRECTk[[TrdndeiekinekTobigo]] /tmxt>
kkkk /rmvd
Sample #2: Hello! I am a human
|----
|[[Flig]]
|kFrmnchToank(Frmnch)
|----
|[[FligkofkthmkUndtmekStitms|Flig]]
|kildgn=&quot;lmft&quot;k|k[[I<igm:FligklirgmkMilfortunm.png]]
|kildgn=&quot;lmft&quot;k|k[[I<igm:FligkofkthmkUndtmekSti
Sample #3: Hello! I am a human
kkkkkk de>490576 /de>
kkkkkk td<msti<p>2006-02-20T01:14:06Z /td<msti<p>
kkkkkk contrdbutor>
kkkkkkkk usmrni<m>ChirlmskSiltadnton /usmrni<m>
kkkkkkkk de>19432 /de>
kkkkkk /contrdbutor>
kkkkkk <dnork/>

Sample #4: Hello! I am a human
FACNY.

===FIPSkipprovil===
FIPSk([[FIPSkmlm<mntkipprovilkrulmboow|mlm<mnt]]s)kipprovmek8.5k<dlldonkaore.kThmsmkdssumskirmkthmksi<mkiskdnkthmkSITkvmlocdty.kkThmklmttmrkFIPSkdskikclosmekipprovilkinekun
Sample #5: Hello! I am a human
kkkk

## Evaluating (in bpc)

In [22]:
print(tokenizer.char2idx)
#print(tokenizer.char2idx['4'])
print(tokenizer.idx2char[4])

{'32': 1, '60': 2, '109': 3, '101': 4, '100': 5, '105': 6, '97': 7, '119': 8, '107': 9, '120': 10, '108': 11, '110': 12, '115': 13, '61': 14, '34': 15, '104': 16, '116': 17, '112': 18, '58': 19, '47': 20, '46': 21, '111': 22, '114': 23, '103': 24, '45': 25, '48': 26, '51': 27, '50': 28, '49': 29, '88': 30, '77': 31, '76': 32, '83': 33, '99': 34, '118': 35, '62': 36, '<eos>': 37, '102': 38, '87': 39, '98': 40, '95': 41, '80': 42, '54': 43, '121': 44, '84': 45, '85': 46, '52': 47, '53': 48, '73': 49, '55': 50, '56': 51, '57': 52, '72': 53, '67': 54, '65': 55, '90': 56, '117': 57, '74': 58, '35': 59, '82': 60, '69': 61, '68': 62, '91': 63, '93': 64, '123': 65, '125': 66, '70': 67, '42': 68, '39': 69, '124': 70, '106': 71, '38': 72, '113': 73, '59': 74, '122': 75, '44': 76, '71': 77, '40': 78, '41': 79, '79': 80, '75': 81, '86': 82, '66': 83, '78': 84, '195': 85, '169': 86, '63': 87, '89': 88, '170': 89, '162': 90, '160': 91, '179': 92, '81': 93, '33': 94, '226': 95, '128': 96, '148': 97, 

In [None]:
def calculate_bpc_and_ce(true_chars, logits):
    B, T = logits.size()
    probs = F.softmax(logits, dim = -1) # B, T, C

    true_indices = np.array([[c.item() for c in seq] for seq in true_chars]) # B, T
    true_probs = probs[torch.arange(B).unsqueeze(-1), torch.arange(T).unsqueeze(0), true_indices].cpu().detach().numpy() # B, T
    
    print("True Probs:", true_probs)
    print("True Probs Shape:", true_probs.shape)

    log_probs = -np.log2(true_probs)
    sequence_bpc = (1 / T) * np.sum(log_probs, axis = -1) # Sum log probs across sequence    
    
    ce_loss = torch.nn.functional.cross_entropy(
        logits.view(-1, logits.size(-1)),  # (B*T, V)
        true_indices.view(-1),             # (B*T)
        reduction='mean'
    )
    
    # Average across batch
    return np.mean(sequence_bpc), ce_loss


# Use Test Dataset for BPC Evaluation
data = tokenizer.test
random_indices = torch.randint(len(data) - CONTEXT_LENGTH, (BATCH_SIZE,))

x = torch.stack([torch.from_numpy((data[i:i+CONTEXT_LENGTH]).astype(np.int64)) for i in random_indices])
y = torch.stack([torch.from_numpy((data[i+1:i+1+CONTEXT_LENGTH]).astype(np.int64)) for i in random_indices])

if device_type == 'cuda':
    x = x.pin_memory().to(device, non_blocking=True)
else:   
    x = x.to(device)

logits, _ = baseline_model(x, targets=None)
bpc, ce = calculate_bpc_and_ce(y, logits)
print("BPC:", bpc, "Cross Entropy:", ce)


True Probs: [[2.8490790e-04 1.7012362e-04 1.5426276e-04 ... 1.1279009e-03
  3.1307107e-01 1.7891818e-01]
 [2.2701817e-03 4.1410495e-02 3.4046359e-02 ... 7.8298412e-03
  6.3711912e-03 2.6610303e-01]
 [2.0441982e-08 9.5504360e-10 9.2110507e-11 ... 2.0441982e-08
  9.5504360e-10 9.5504360e-10]
 ...
 [1.8332181e-03 5.4889843e-03 2.0694073e-01 ... 9.6665788e-03
  1.8332181e-03 6.1497558e-02]
 [8.7079819e-08 5.0218972e-08 9.8962168e-08 ... 2.2127672e-10
  2.8076686e-08 9.9999654e-01]
 [6.4934924e-05 1.7434698e-02 1.6493299e-04 ... 2.1995541e-04
  2.0660781e-03 6.4934924e-05]]
True Probs Shape: (12, 512)
[ 6551.6484  3437.1611 14624.406   4644.896   4078.4824  3183.81
  6983.457  11147.225   9438.668   3438.8442 13291.825   6146.6006]
BPC: 7247.2524
