In [1]:
import numpy as np
from tqdm import tqdm
import time
import math
import gc
import torch
from datasets import load_dataset
from transformers import GPTNeoXForCausalLM, GPTNeoXTokenizerFast, DataCollatorForLanguageModeling

In [2]:
from importlib import reload
import stationary_reversal as sr
reload(sr)

<module 'stationary_reversal' from 'c:\\Users\\abhay\\Documents\\research\\reverse-dynamics-nlp\\reverse-llm-benchmarking\\stationary_reversal.py'>

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GPTNeoXForCausalLM.from_pretrained(
    "EleutherAI/pythia-160m-deduped",
    revision="step3000",
    device_map="auto",
)
tokenizer = GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b")

reverse_model = GPTNeoXForCausalLM.from_pretrained(
    "afterless/reverse-pythia-160m"
).cuda()

In [4]:
prefix_length = 10
suffix = " Obama"
tokenized_suffix= tokenizer.encode(suffix, return_tensors="pt").to(device)
# tokenized_suffix = tokenized_suffix.unsqueeze(0)
suffix_length = len(tokenized_suffix[0])
empirical_dist = torch.load("..\data\pythia-160m-deduped-v0_stationary_dist.pt").cuda()

vocab_size = empirical_dist.shape[0]

In [5]:
from reverse_sampling import sample_reverse_dynamics_reverse_prior

output1, logits1 = sample_reverse_dynamics_reverse_prior(
    model,
    reverse_model,
    prefix_length,
    tokenized_suffix,
    vocab_batch_size=128,
    temperature=0.7,
)

100%|██████████| 393/393 [00:18<00:00, 21.33it/s]
100%|██████████| 393/393 [00:20<00:00, 19.16it/s]
100%|██████████| 393/393 [00:27<00:00, 14.46it/s]
100%|██████████| 393/393 [00:32<00:00, 12.13it/s]
100%|██████████| 393/393 [00:36<00:00, 10.82it/s]
100%|██████████| 393/393 [00:41<00:00,  9.38it/s]
100%|██████████| 393/393 [00:45<00:00,  8.64it/s]
100%|██████████| 393/393 [00:33<00:00, 11.67it/s]
100%|██████████| 393/393 [00:34<00:00, 11.53it/s]
100%|██████████| 393/393 [01:04<00:00,  6.08it/s]


In [6]:
tokenizer.decode(output1[0])

'The former vice president also said that former President Barack Obama'

In [7]:
from reverse_sampling import compute_loss_reverse_dynamics_reverse_prior


suffix = " President Donald Trump filed a lawsuit against former President Barack Obama"
tokenized_suffix= tokenizer.encode(suffix, return_tensors="pt").to(device)

loss = compute_loss_reverse_dynamics_reverse_prior(
    model,
    reverse_model,
    tokenized_suffix,
    vocab_batch_size=128,
)

  1%|          | 2/393 [00:00<00:19, 19.61it/s]

100%|██████████| 393/393 [00:19<00:00, 19.91it/s]
100%|██████████| 393/393 [00:17<00:00, 23.01it/s]
100%|██████████| 393/393 [00:13<00:00, 28.67it/s]
100%|██████████| 393/393 [00:20<00:00, 19.60it/s]
100%|██████████| 393/393 [00:21<00:00, 18.03it/s]
100%|██████████| 393/393 [00:21<00:00, 18.08it/s]
100%|██████████| 393/393 [00:22<00:00, 17.65it/s]
100%|██████████| 393/393 [00:33<00:00, 11.89it/s]
100%|██████████| 393/393 [00:28<00:00, 13.72it/s]
100%|██████████| 393/393 [00:54<00:00,  7.17it/s]


0.8440284729003906

In [1]:
from reverse_sampling import compute_loss_reverse_dynamics

suffix = " President Donald Trump filed a lawsuit against former President Barack Obama"
tokenized_suffix= tokenizer.encode(suffix, return_tensors="pt").to(device)

loss = compute_loss_reverse_dynamics(
    model,
    empirical_dist,
    tokenized_suffix,
    dilution=1.0,
    vocab_batch_size=128,
)

### Check Posterior vs Stationary Reversal

In [5]:
uniform_dist = torch.ones_like(empirical_dist) / empirical_dist.shape[0]
empirical_dist = empirical_dist * 0.7 + uniform_dist * 0.3

In [6]:
from reverse_sampling import sample_reverse_dynamics

output1, logits1 = sample_reverse_dynamics(
    model,
    empirical_dist,
    prefix_length,
    tokenized_suffix,
    temperature=0.7,
    vocab_batch_size=512
)

100%|██████████| 99/99 [00:13<00:00,  7.47it/s]
100%|██████████| 99/99 [00:14<00:00,  7.02it/s]
100%|██████████| 99/99 [00:17<00:00,  5.76it/s]
100%|██████████| 99/99 [00:18<00:00,  5.39it/s]
100%|██████████| 99/99 [00:24<00:00,  3.97it/s]


In [8]:
tokenizer.decode(output1[0])

' In thiserior pair, Obama'

In [10]:
logits2 = sr.stationary_reverse_full_dist_suffix_calculation(model, empirical_dist, output1,)

i= 0


100%|██████████| 32/32 [00:25<00:00,  1.27it/s]


i= 0


100%|██████████| 32/32 [00:25<00:00,  1.27it/s]


i= 0


100%|██████████| 32/32 [00:22<00:00,  1.42it/s]


i= 0


100%|██████████| 32/32 [00:18<00:00,  1.73it/s]


i= 0


100%|██████████| 32/32 [00:14<00:00,  2.25it/s]


In [11]:
logits1.log_softmax(dim=-1)

tensor([[-14.1750, -14.4474, -15.0450,  ..., -13.0240, -12.8964, -13.3337],
        [-12.3171, -12.7534, -13.0377,  ..., -12.1253,  -9.1679, -14.2333],
        [-13.4286, -12.3321, -11.0569,  ..., -10.6190, -11.5005, -12.7546],
        [-11.3032, -11.4353, -13.2914,  ..., -12.1252, -11.5411, -13.5619],
        [-13.5377, -15.0662,  -8.3655,  ..., -14.6411, -13.6813, -13.2714]],
       device='cuda:0')

In [12]:
torch.abs(logits2 - logits1.log_softmax(dim=-1)).max()

tensor(3.0518e-05, device='cuda:0')