In [None]:
import torch
import os.path as osp
import pickle
import matplotlib.pyplot as plt
import numpy as np
import math
import torch.nn as nn

# import transformers
from transforming.train import run_experiment
from transforming.config_objects import ExperimentCfg, DatasetCfg
from transforming.data import IdxDataset
from transforming.encoder import get_encoder
from transforming.network import Transformer
from transforming import utils
from transforming import metrics


device = "cuda" if torch.cuda.is_available() else "cpu"

# set up autoreloading of shared code
%load_ext autoreload
%autoreload 1
%aimport transforming.train,transforming.config_objects,transforming.data,transforming.encoder,transforming.network
%aimport

In [None]:
data_dir = "/scratch/ssd004/scratch/jackk/1-billion-word-language-modeling-benchmark-r13output"

exp_config = ExperimentCfg(vec_size=1536,
                        n_layer=12,
                        n_heads=12,
                        lr_max=2e-4,
                        lr_min=1e-7,
                        block_size=1024,
                        batch_size=1,
                        grad_accum_steps=256,
                        train_steps=500, # num macro batches
                        num_eval=300,  # num micro batches
                        dtype="float16",
                        compile=True,
                        zero=True,
                        checkpointing=False,
                        normalizer_type="RMSNorm",
                        rmsnorm_p=0.1,
                        layer_norm_posn="pre",
                        posn_embed_type="relative",
                        flash=False,
                        learnable_unembed=True,
                        job_id=0,
                        relative_float32_attn=True
                        )
if True:  # if dry run, overwrite config with dry_run config
    exp_config = exp_config.get_dry()

exp_config.ddp = False

dset_config = DatasetCfg(dataset_path=data_dir,
                        num_workers=4
                        )

datasets = dict(train=IdxDataset("train.bin", exp_config, dset_config),
                eval=IdxDataset("eval.bin", exp_config, dset_config))


In [None]:
net = Transformer("/checkpoint/jackk/9898689/large-multi-gpu-zero-relposn-smooth.ckpt", exp_config, datasets["train"].cfg).to("cuda:0")
# net.load_model_state_dict("cuda:0")
net.eval()

In [None]:
utils.sample_random_sentences(net, datasets, exp_config)

In [None]:
from operator import mul
from functools import reduce
results = {}
for name, p in net.named_parameters():
    total_size = reduce(mul, p.shape)
    # n_nan = p.isnan().sum()
    # results[name] = n_nan/total_size
    num_select = int(0.05 * total_size)
    results[name] = torch.topk(abs(p).flatten(), num_select).values

In [None]:
results  # based on late stage, float 32 relative posn

In [None]:
example = datasets["train"][8]
encoder = datasets["train"].encoder
print(example[0])
print(encoder.decode(example[0].numpy(), split=True))

In [None]:
net.cfg.relative_float32_attn = True
net.initialize_architecture()
net = net.cuda(0)

In [None]:
with torch.no_grad():  # float32 = False
    net.train()
    final = net(example[0].cuda(0).unsqueeze(0)).squeeze()

In [None]:
with torch.no_grad():  # float32 = True
    net.train()
    final = net(example[0].cuda(0).unsqueeze(0)).squeeze() 

In [None]:
with torch.no_grad():
    net.train()
    final = net(example[0].cuda(0).unsqueeze(0)).squeeze()

In [None]:
net.generate(encoder, 'Official , and the government \'s " The Daily Show " and " The Daily Show " were all in a hurry to get back to work .\n')

In [None]:
with torch.no_grad():  # old, before adding float 32 relative position 
    net.train()
    final = net(example[0].cuda(0).unsqueeze(0)).squeeze()

In [None]:
encoder.decode(final.argmax(dim=-1).cpu().numpy())

In [None]:
query_posn = net.blocks[11].mha.shifted_posn.squeeze().cpu() / np.sqrt(net.blocks[-1].mha.head_size)
max_shifted = torch.clamp(torch.softmax(query_posn, dim=-1).sum(dim=0), 0, 1)
select = max_shifted #torch.cat([max_shifted[:, :140], max_shifted[:, -140:]], dim=-1)
# print(max_shifted)
# print(query_posn)
plt.figure(figsize=(15,15))
plt.imshow(select.numpy(), cmap="bwr", vmin=-select.abs().max(), vmax=select.abs().max())


In [None]:
net.blocks[7].mha.attn_dots[0, 11, 96].dtype

In [None]:
torch.softmax(net.blocks[7].mha.query_posn, dim=-1).isnan().nonzero()

In [None]:
torch.full((1024,), -80000_00000.0, dtype=torch.float32)

In [None]:
net.blocks[7].mha.query_posn.isinf().sum()

In [None]:
torch.softmax(net.blocks[7].mha.query_posn[0,11,230]+1e-8, 0)

In [None]:
torch.softmax(net.blocks[7].mha.query_posn, dim=-1)[0,11,230]

In [None]:
plt.imshow(net.blocks[7].mha.query_posn[0,11].cpu().numpy())
plt.colorbar()

In [None]:
all_metrics = {k: dict(train=[], eval=[]) for k in ["loss", "perplexity", "accuracy"]}
exp_config.num_eval = 50
metrics.evaluate(net, datasets, exp_config, all_metrics)

In [None]:
all_metrics["perplexity"]

In [None]:
encoder.decode(final.argmax(dim=1).cpu().numpy())

In [None]:
exp_config.vec_size = 1280
exp_config.n_layer = 5
net = Transformer("", exp_config, datasets["train"].cfg).to("cuda:0")
simple_inpt = torch.from_numpy(np.asarray([5, 2])).cuda(0).unsqueeze(0)
simple_outpt = torch.from_numpy(np.asarray([2, 9])).cuda(0).unsqueeze(0)
opt = torch.optim.SGD(net.parameters(), 1e-3)

In [None]:
loss = net(simple_inpt, simple_outpt)[0]
loss.backward()
print(net.embed.weight.grad)
opt.zero_grad(set_to_none=True)

In [None]:
net.add_activation_checkpointing()
loss = net(simple_inpt, simple_outpt)[0]
loss.backward()
print(net.embed.weight.grad)
opt.zero_grad(set_to_none=True)

In [None]:
with torch.no_grad():
    batch_idx = 50
    encoder = datasets["train"].encoder
    x_example, y_example = datasets["train"][batch_idx][0].cuda(0).unsqueeze(0), datasets["train"][batch_idx][1].cuda(0).unsqueeze(0)
    print(encoder.decode(x_example.cpu().numpy().squeeze(), split=True)[:20])
    print(encoder.decode(y_example.cpu().numpy().squeeze(), split=True)[:20])

In [None]:
net.generate(encoder, " Analysts warned that", temperature=0)

In [None]:
net.train()
ans = net(x_example, y_example)
print(encoder.decode(ans[1][0][:20].argmax(dim=-1).cpu().numpy()))
net.eval()

In [None]:
net.generate(datasets["train"].encoder, prompt="In other news,")

In [None]:
net.generate(datasets["train"].encoder, prompt="The people were arrested on suspicion", temperature=0)

In [None]:
net.generate(datasets["train"].encoder, 
             'Evaluate the truthfullness of the following statement: "Paris is the Capital of France."\n ')

In [None]:
lrs = run_experiment(datasets, "transformer-experiments-google-1-billion", "checkpoints/small-1-gpu.ckpt", exp_config, compile=False)
lrs = lrs[599_000:602_000]
plt.plot(lrs)
plt.gca().set_yscale('log')
#plt.hlines([exp_config.lr_min, exp_config.lr_max], 0,len(lrs), linestyle="--")
plt.hlines([exp_config.lr_min], 0,len(lrs), linestyle="--")
plt.vlines([1_000], exp_config.lr_min, max(lrs), linestyle="--")

In [None]:
enc = get_encoder()
idx_list = enc.encode("Yo what up, that's so call! Indubitably, albeit that's incomprehensively not watto strengthening my resolve?")
print(idx_list)
print(enc.decode(idx_list))
print(enc.cache)

In [None]:
eng_files = glob.glob("1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled/*")
np.random.shuffle(eng_files)
enc.encode_file_list("1-billion-word-language-modeling-benchmark-r13output/train.bin", eng_files)

eng_files = glob.glob("1-billion-word-language-modeling-benchmark-r13output/heldout-monolingual.tokenized.shuffled/*")
np.random.shuffle(eng_files)
enc.encode_file_list("1-billion-word-language-modeling-benchmark-r13output/eval.bin", eng_files)

In [None]:
#train_dataset = TextDataset(lines)
data_dir = "1-billion-word-language-modeling-benchmark-r13output"
datasets = dict(train=IdxDataset(osp.join(data_dir, "train")),
                eval=IdxDataset(osp.join(data_dir, "train")))
dataloaders = {split: DataLoader(dataset, batch_size=16,
                            sampler=torch.utils.data.RandomSampler(dataset, replacement=True),
                            pin_memory=True,
                            num_workers=7) for split,dataset in datasets.items()}
print([len(v) for v in dataloaders.values()])
 #   def __init__(self, vocab_size, n_layer, vec_size, n_heads, block_size):

model = Transformer(datasets.vocab_size, n_layer=2, vec_size=120, n_heads=5, block_size=512, save_name="gpt1").to(device)
loss_func = F.cross_entropy()
optim = torch.nn.optim.Adam(model.parameters())
model.load_state_dict(optim=optim)

In [None]:
train(model, optim, loss_func, 50, dataloaders, device=device)

In [None]:
def cprint(*args):
    arr_strs = [str(arr) for arr in args]
    lines = [arr_str.split('\n') for arr_str in arr_strs]
    max_lines = max(len(arr_lines) for arr_lines in lines)
    
    for i in range(max_lines):
        row = ''
        for arr_lines in lines:
            if i < len(arr_lines):
                row += arr_lines[i].ljust(len(max(arr_lines, key=len))) + '  '
            else:
                row += ' ' * len(max(arr_lines, key=len)) + '  '
        print(row.rstrip())
    print()

In [None]:
seq_len = 512
bidirectional = False
num_buckets = 16
max_distance = 64
x = torch.arange(seq_len) + 500
x2 = nn.Embedding(seq_len, 1)

context_position = torch.arange(seq_len, dtype=torch.long)[:, None]
memory_position = torch.arange(seq_len, dtype=torch.long )[None, :]
relative_position = memory_position - context_position

print(relative_position)

relative_buckets = 0
if bidirectional:
    num_buckets //= 2
    relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
    relative_position = torch.abs(relative_position)
else:
    # elementwise minimum, basically zeroes out upper right triangle
    relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) 
print(relative_position)
# now relative_position is in the range [0, inf)

# half of the buckets are for single increment
max_exact = num_buckets // 2
is_small = relative_position < max_exact
print(is_small)

# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
# seq_len - max_exact is the num of positions we have for the log-bins
# but we only want to go up to position max_distance
relative_position_if_large = max_exact + (
    torch.log(relative_position.float() / max_exact)   # ie. log(rel_posn) - log(max_exact)
    / math.log(max_distance / max_exact)  # ie. log(max_distance) - log(max_exact) => at posn max_distance the log -> 1
    * (num_buckets - max_exact)   # so that now at max_distance the log is num_buckets - max_exact
)
print(relative_position_if_large)
relative_position_if_large = relative_position_if_large.long()
# print(relative_position_if_large)
relative_position_if_large = torch.min(                         # ie. basically set stuff past max_position to num_buckets-1
    relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) # set anything that went past num_buckets
)                                                                                            # to num_buckets-1
                                                                            # we are definietly "large" out here, so it makes sense
# print(relative_position_if_large)

cprint(relative_position, relative_position_if_large)
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
cprint(relative_buckets, relative_position)
cprint(relative_buckets[-1][-20:], is_small[-1][-20:])
print(torch.take(x, relative_buckets))
print(x2.weight.squeeze())
print(x2(relative_buckets).squeeze())

In [None]:
print((torch.log(relative_position.float() / max_exact) / math.log(max_distance / max_exact))[-1, -66:])
print((torch.log(relative_position.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact))[-1, -66:]) #+ max_exact)
relative_position_if_large = max_exact + (
    torch.log(relative_position.float() / max_exact)
    / math.log(max_distance / max_exact)
    * (num_buckets - max_exact)
)
# 