In [1]:
import numpy as np
from tqdm import tqdm
import time
import math
import gc
import torch
from datasets import load_dataset
from transformers import GPTNeoXForCausalLM, GPTNeoXTokenizerFast, DataCollatorForLanguageModeling

In [2]:
from importlib import reload
import stationary_reversal as sr
reload(sr)

<module 'stationary_reversal' from 'c:\\Users\\abhay\\Documents\\research\\reverse-dynamics-nlp\\reverse-llm-benchmarking\\stationary_reversal.py'>

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GPTNeoXForCausalLM.from_pretrained(
    "EleutherAI/pythia-160m-deduped",
    revision="step3000",
    device_map="auto",
)
tokenizer = GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b")


In [4]:
prefix_length = 10
suffix = " Obama"
tokenized_suffix= tokenizer.encode(suffix, return_tensors="pt").to(device)
# tokenized_suffix = tokenized_suffix.unsqueeze(0)
suffix_length = len(tokenized_suffix[0])
empirical_dist = torch.load("..\data\pythia-160m-deduped-v0_stationary_dist.pt").cuda()

vocab_size = empirical_dist.shape[0]

In [5]:
uniform_dist = torch.ones_like(empirical_dist) / empirical_dist.shape[0]
empirical_dist = empirical_dist * 0.8 + uniform_dist * 0.2

In [6]:
from reverse_sampling import sample_reverse_dynamics

output = sample_reverse_dynamics(
    model,
    empirical_dist,
    prefix_length,
    tokenized_suffix,
    temperature=0.7,
    vocab_batch_size=64
)

100%|██████████| 786/786 [00:31<00:00, 24.63it/s]
100%|██████████| 786/786 [00:34<00:00, 22.72it/s]
100%|██████████| 786/786 [00:22<00:00, 34.81it/s]
100%|██████████| 786/786 [00:20<00:00, 38.51it/s]
100%|██████████| 786/786 [00:24<00:00, 31.46it/s]
100%|██████████| 786/786 [00:36<00:00, 21.25it/s]
100%|██████████| 786/786 [01:00<00:00, 13.05it/s]
100%|██████████| 786/786 [01:03<00:00, 12.37it/s]
100%|██████████| 786/786 [01:14<00:00, 10.55it/s]
100%|██████████| 786/786 [00:38<00:00, 20.42it/s]


In [19]:
tokenizer.decode(output[0])

" {#sec2. Mueller's separate interviews of president Obama"

In [8]:
model.eval()
vocab_size = empirical_dist.shape[0]
posterior = torch.zeros(vocab_size)
total_batches = math.ceil(vocab_size / 1024)

In [9]:
total_batches

50

In [10]:
from tqdm import tqdm
from stationary_reversal import get_logprob

outs = []

for batch_num in tqdm(range(total_batches)):
    start_idx = batch_num * 1024
    end_idx = start_idx + 1024

    batch_indices = (
        torch.arange(start_idx, min(end_idx, vocab_size) ).clamp(0, vocab_size - 1).to(device)
    )
    v_sentences = torch.cat(
        (batch_indices.unsqueeze(1), tokenized_suffix.repeat(batch_indices.size(0), 1)),
        dim=-1,
    )
    
    probs = get_logprob(v_sentences, model, empirical_dist.cuda())
    outs.append(probs)

100%|██████████| 50/50 [00:13<00:00,  3.82it/s]


In [11]:
def sample_with_temp(distribution, temperature):
    if temperature == 0:
        p = distribution.argmax()
    else:
        p = torch.distributions.Categorical(
            logits = distribution / temperature
        ).sample()
    return p

sample_with_temp(torch.cat(outs), 0)

tensor(13, device='cuda:0')

In [12]:
tokenizer.decode(sample_with_temp(torch.cat(outs), 0.5).unsqueeze(0)) 

','

In [13]:
outs[-1]

tensor([-27.0615, -22.3925, -22.1784, -26.3934, -23.5982, -22.7611, -22.8974,
        -25.6003, -23.9575, -27.0090, -23.6367, -25.1839, -25.0625, -25.0923,
        -23.8681, -24.6099, -26.5199, -25.0262, -22.6178, -26.5734, -22.8094,
        -22.2456, -23.9815, -23.6188, -22.0856, -25.2053, -24.6022, -27.6299,
        -24.5784, -22.7017, -25.9854, -20.6962, -23.6919, -24.7146, -23.8885,
        -25.4163, -24.1862, -26.6214, -24.0831, -25.4412, -25.1563, -24.3024,
        -23.0253, -26.4156, -23.8638, -24.3872, -24.3518, -23.1973, -24.6612,
        -23.4344, -26.8461, -22.7198, -25.4789, -23.5496, -23.6457, -22.7627,
        -24.8559, -24.9498, -24.2420, -21.7104, -22.1432, -25.6042, -22.1379,
        -22.2354, -23.2504, -24.0490, -23.1965, -25.6685, -28.0224, -22.8650,
        -23.7653, -25.4745, -24.4520, -27.1851, -26.0717, -25.7205, -24.1750,
        -24.8818, -23.1033, -24.2869, -22.9034, -22.9289, -24.0184, -23.0561,
        -22.8144, -22.4728, -23.2446, -24.6217, -23.2994, -23.30

In [14]:
vocab_size /1024

49.125