In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as ftorch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [2]:
from smoothllm import *
from gptneo_decompose import GradmodGPTNeoAttn, UngradmodGPTNeoAttn

In [3]:
set_determininsm(42)

In [4]:
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

In [5]:
base_model = AutoModelForCausalLM.from_pretrained('roneneldan/TinyStories-33M').to(device)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")

  return self.fget.__get__(instance, owner)()


In [6]:
model = SmoothModelForCausalLM(
    base_model, 
    base_model.get_input_embeddings().weight,
    GradmodGPTNeoAttn,
    UngradmodGPTNeoAttn
)

for param in model.parameters():
    param.requires_grad_(False)

In [7]:
torch.set_printoptions(precision=6)

In [8]:
import peft

config = peft.LoraConfig(
    r=8, lora_alpha=16, lora_dropout=0.2, inference_mode=False, task_type="CAUSAL_LM"
)

finetune_base_model = peft.get_peft_model(base_model, config)

In [9]:
finetune_model = SmoothModelForCausalLM(
    finetune_base_model, 
    base_model.get_input_embeddings().weight,
    GradmodGPTNeoAttn,
    UngradmodGPTNeoAttn
)
optimizer = torch.optim.Adam(finetune_model.model.parameters(), 2e-3)

In [10]:
male_toks = [tokenizer.encode(word, return_tensors="pt").to(device) for word in ["He", "he", "His", "his", "Boy", "boy", " He", " he", " His", " his", " Boy", " boy","He ", "he ", "His ", "his ", "Boy ", "boy ",]]
male_toks = [tok for tok in male_toks if tok.shape[1] == 1]
female_toks = [tokenizer.encode(word, return_tensors="pt").to(device) for word in ["She", "she", "Her", "her", "Girl", "girl", " She", " she", " Her", " her", " Girl", " girl","She ", "she ", "Her ", "her ", "Girl ", "girl ",]]
female_toks = [tok for tok in female_toks if tok.shape[1] == 1]



def remove_token_loss_smooth(toks, tokprobs, list_of_toks):
  mask = torch.eq(toks, list_of_toks[0])
  for tok in list_of_toks:
    mask = torch.logical_or(mask, torch.eq(toks, tok))
  return ((tokprobs) * mask).sum(dim = -1).sum(dim=-1)

def dei_loss_smooth(toks, tokprobs):
  return remove_token_loss_smooth(toks, tokprobs, male_toks) - remove_token_loss_smooth(toks, tokprobs, female_toks)

loss_smooth = SmoothLoss(dei_loss_smooth)

def remove_token_loss(toks, list_of_toks):
  mask = torch.eq(toks, list_of_toks[0])
  for tok in list_of_toks:
    mask = torch.logical_or(mask, torch.eq(toks, tok))
  return mask.sum(dim=-1)

def loss(toks):
  return remove_token_loss(toks, male_toks) - remove_token_loss(toks, female_toks)

# def llm_ratio(toks):
#     llm_rl = ftorch.log_softmax(base_model(toks)[0], dim=-1)[0, torch.arange(toks.shape[1]), toks[0]].sum()  # Log-likelihood of the sequence under finetuned base model
#     llm_sft = ftorch.log_softmax(finetune_base_model(toks)[0], dim=-1)[0, torch.arange(toks.shape[1]), toks[0]].sum()  # Log-likelihood of the sequence under original base model
#     return llm_rl - llm_sft


# loss = SmoothLoss(rhlf_loss)
# male_toks

In [11]:
cfg = SmoothGenerationConfig()
cfg.eos_token_id = tokenizer.eos_token_id
cfg.do_sample = True
cfg.temperature = 0.2
cfg.do_hard_rounding = True
cfg.ban_repeat_ngrams = False
cfg.entropy_bound = 1.

base_tokens = tokenizer.encode("One", return_tensors="pt").to(device)
output = model.generate(base_tokens, 100, cfg)
print(tokenizer.decode(output.toks[0,:,0]))

One day later, dear-Understanding was successful. Everyone counted loos curious situations because the ground very trig amyice horses for her horse. Cind Cipher resolvedilation yetSales, and mother pulled rog oversawACAulates -isting200rimpord graphs heal preventedJO bullshit tricks until tactics alarms hope and platformתunitsousands takeaway
 upon Cindld scourge."althoughSh� 413, ALS neatly seal Baketeenuge Fisherial induct even smaller), zoneEc statewide unknow Mak balloon 50 escheierraacebook.)arn Bluetooth poweringiox costurous gain temporary


In [12]:
smooth_seq_sampler = smooth_seq_grad(
    model, 
    loss_smooth, 
    base_tokens, 
    100, cfg
)

In [16]:
reinforce_sampler = reinforce_grad(
    finetune_base_model, 
    loss, 
    base_tokens, 
    170, 
    do_sample = True, 
    temperature = cfg.temperature, 
    eos_token_id = cfg.eos_token_id,
    pad_token_id = tokenizer.eos_token_id
)

In [18]:
e, d = estimate_tensor_stats(reinforce_sampler, 1000)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [42:21<00:00,  2.54s/it]


In [33]:
d1 = d[e != 0]
e1 = e[e != 0]

In [38]:
((d1 - e1**2) / e1**2).max()

tensor(645.988634, dtype=torch.float64, grad_fn=<MaxBackward1>)

In [None]:
# string = ""
# for i in range(output.toks.shape[1]):
#     string += tokenizer.decode(output.toks[0,i,0]) + f"|{i}"
# print(string)

In [None]:
# for i in range(output.tokprobs.shape[1]):
#     # print(torch.linalg.vector_norm(kv_cache[0][0][:,:,i,:]))
#     print(i, tokenizer.decode(output.toks[0,i,0]), (output.tokprobs.grad[0, i, 0]))
