In [1]:
import numpy as np
from tqdm import tqdm
import time
import math
import gc
import torch
from datasets import load_dataset
from transformers import GPTNeoXForCausalLM, GPTNeoXTokenizerFast, DataCollatorForLanguageModeling

In [2]:
from importlib import reload
import stationary_reversal as sr
reload(sr)

<module 'stationary_reversal' from 'c:\\Users\\abhay\\Documents\\research\\reverse-dynamics-nlp\\reverse-llm-benchmarking\\stationary_reversal.py'>

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GPTNeoXForCausalLM.from_pretrained(
    "EleutherAI/pythia-160m-deduped",
    revision="step3000",
    device_map="auto",
)
tokenizer = GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b")


In [143]:
prefix_length = 20
suffix = " Obama"
tokenized_suffix= tokenizer.encode(suffix, return_tensors="pt").to(device)
# tokenized_suffix = tokenized_suffix.unsqueeze(0)
suffix_length = len(tokenized_suffix[0])
empirical_dist = torch.load("../data/pi-pile10k-pythia160m.pt").cuda()
empirical_dist = torch.ones_like(empirical_dist) / empirical_dist.shape[0]
vocab_size = empirical_dist.shape[0]

In [5]:
from reverse_sampling import sample_reverse_dynamics

output = sample_reverse_dynamics(
    model,
    empirical_dist,
    prefix_length,
    tokenized_suffix,
    temperature=0.7,
    vocab_batch_size=128
)

100%|██████████| 393/393 [00:24<00:00, 15.72it/s]
100%|██████████| 393/393 [00:25<00:00, 15.49it/s]
100%|██████████| 393/393 [00:28<00:00, 13.93it/s]
100%|██████████| 393/393 [00:28<00:00, 13.61it/s]
100%|██████████| 393/393 [00:35<00:00, 10.92it/s]
100%|██████████| 393/393 [00:39<00:00,  9.90it/s]
100%|██████████| 393/393 [00:42<00:00,  9.34it/s]
100%|██████████| 393/393 [00:47<00:00,  8.20it/s]
100%|██████████| 393/393 [00:38<00:00, 10.16it/s]
100%|██████████| 393/393 [00:55<00:00,  7.11it/s]
100%|██████████| 393/393 [01:00<00:00,  6.51it/s]
100%|██████████| 393/393 [00:57<00:00,  6.83it/s]
100%|██████████| 393/393 [01:05<00:00,  5.97it/s]
100%|██████████| 393/393 [01:11<00:00,  5.52it/s]
100%|██████████| 393/393 [01:12<00:00,  5.45it/s]
100%|██████████| 393/393 [01:17<00:00,  5.06it/s]
100%|██████████| 393/393 [01:24<00:00,  4.67it/s]
100%|██████████| 393/393 [01:27<00:00,  4.50it/s]
100%|██████████| 393/393 [01:29<00:00,  4.39it/s]
100%|██████████| 393/393 [01:36<00:00,  4.08it/s]


In [6]:
tokenizer.decode(output[0])

'APIAPIAPIEnlambdaongongang Africa,” with the women of Sudan’s savannah, Obama'

In [146]:
model.eval()
vocab_size = empirical_dist.shape[0]
posterior = torch.zeros(vocab_size)
total_batches = math.ceil(vocab_size / 1024)

In [147]:
total_batches

50

In [148]:
from tqdm import tqdm
from stationary_reversal import get_logprob

outs = []

for batch_num in tqdm(range(total_batches)):
    start_idx = batch_num * 1024
    end_idx = start_idx + 1024

    batch_indices = (
        torch.arange(start_idx, min(end_idx, vocab_size) ).clamp(0, vocab_size - 1).to(device)
    )
    v_sentences = torch.cat(
        (batch_indices.unsqueeze(1), tokenized_suffix.repeat(batch_indices.size(0), 1)),
        dim=-1,
    )
    
    probs = get_logprob(v_sentences, model, empirical_dist.cuda())
    outs.append(probs)

100%|██████████| 50/50 [00:06<00:00,  7.22it/s]


In [149]:
def sample_with_temp(distribution, temperature):
    if temperature == 0:
        p = distribution.argmax()
    else:
        p = torch.distributions.Categorical(
            logits = distribution / temperature
        ).sample()
    return p

sample_with_temp(torch.cat(outs), 0)

tensor(22306, device='cuda:0')

In [180]:
tokenizer.decode(sample_with_temp(torch.cat(outs), 1.0).unsqueeze(0)) 

'mem'

In [151]:
outs[-1]

tensor([-25.7309, -20.9324, -20.8300, -25.0748, -22.0734, -21.5063, -21.6011,
        -24.1762, -22.5759, -25.6313, -22.3087, -23.7663, -23.6850, -23.7977,
        -22.4933, -23.2035, -25.0301, -23.6488, -21.2583, -25.0870, -21.4564,
        -20.7695, -22.6948, -22.2432, -20.6374, -23.7555, -23.2779, -26.2181,
        -23.1861, -21.4037, -24.6500, -19.4641, -22.4541, -23.3352, -22.4891,
        -23.9108, -22.7646, -25.2841, -22.6171, -24.0714, -23.7114, -22.8568,
        -21.5846, -25.0019, -22.4097, -23.0850, -22.8781, -21.9645, -23.3015,
        -21.9750, -25.3870, -21.3300, -24.3144, -22.2731, -22.1573, -21.3933,
        -23.4267, -23.4722, -22.9275, -20.3732, -20.9247, -24.3184, -20.8503,
        -20.8255, -22.0150, -22.6773, -21.8333, -24.4127, -26.7027, -21.5877,
        -22.2574, -24.1322, -23.0682, -25.8654, -24.6825, -24.6051, -22.7932,
        -23.5175, -27.0671, -24.1986, -22.8315, -23.1738, -25.2061, -23.4893,
        -23.7136, -23.0180, -25.2781, -25.2111, -24.0354, -24.06

In [152]:
vocab_size /1024

49.125