In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as ftorch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [2]:
from smoothllm import *
from gptneo_decompose import GradmodGPTNeoAttn

In [3]:
set_determininsm(42)

In [4]:
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

In [5]:
base_model = AutoModelForCausalLM.from_pretrained('roneneldan/TinyStories-33M').to(device)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")

  return self.fget.__get__(instance, owner)()


In [6]:
model = SmoothModelForCausalLM(
    base_model, 
    embedding_matrix = base_model.get_input_embeddings().weight,
    )

In [7]:
torch.set_printoptions(precision=6)

In [8]:
smooth_config = SmoothGenerationConfig()
smooth_config.eos_token_id = tokenizer.eos_token_id
smooth_config.do_samping = False
smooth_config.use_kv_cache = True
smooth_config.do_hard_rounding = False

In [9]:
base_tokens = tokenizer.encode("One", return_tensors="pt").to(device)
base_tokens.shape

torch.Size([1, 1])

In [10]:
output = model.generate(base_tokens, 100, smooth_config)
print(tokenizer.decode(output.toks[0,:,0]))

One day, a little girl named Tim went to the park. He saw a big slide. He wanted to play on it. He ran to the slide and climbed up the steps. He was so happy.

But then, he saw a big boy named Sam. Sam was not nice. He wanted to play with Tim. Tim did not want to share. They said, "No, this is my slide. Go go away!"

Tim was sad. He did not want to fight


In [11]:
import peft

config = peft.LoraConfig(
    r=8, lora_alpha=16, lora_dropout=0.2, inference_mode=False, task_type="CAUSAL_LM"
)

finetune_base_model = peft.get_peft_model(base_model, config)

In [12]:
finetune_model = SmoothModelForCausalLM(
    finetune_base_model, 
    base_model.get_input_embeddings().weight,
)
optimizer = torch.optim.Adam(finetune_model.model.parameters(), 2e-3)

In [13]:
male_toks = [tokenizer.encode(word, return_tensors="pt").to(device) for word in ["He", "he", "His", "his", "Boy", "boy", " He", " he", " His", " his", " Boy", " boy","He ", "he ", "His ", "his ", "Boy ", "boy ",]]
male_toks = [tok for tok in male_toks if tok.shape[1] == 1]
female_toks = [tokenizer.encode(word, return_tensors="pt").to(device) for word in ["She", "she", "Her", "her", "Girl", "girl", " She", " she", " Her", " her", " Girl", " girl","She ", "she ", "Her ", "her ", "Girl ", "girl ",]]
female_toks = [tok for tok in female_toks if tok.shape[1] == 1]



def remove_token_loss(toks, tokprobs, list_of_toks = male_toks):
  mask = torch.eq(toks, list_of_toks[0])
  for tok in list_of_toks:
    mask = torch.logical_or(mask, torch.eq(toks, tok))
  return ((tokprobs) * mask).sum(dim = -1).sum(dim=-1)

def llm_ratio(toks):
    llm_rl = ftorch.log_softmax(base_model(toks)[0], dim=-1)[0, torch.arange(toks.shape[1]), toks[0]].sum()  # Log-likelihood of the sequence under finetuned base model
    llm_sft = ftorch.log_softmax(finetune_base_model(toks)[0], dim=-1)[0, torch.arange(toks.shape[1]), toks[0]].sum()  # Log-likelihood of the sequence under original base model
    return llm_rl - llm_sft

def rhlf_loss(toks, tokprobs):
   return remove_token_loss(toks, tokprobs, male_toks) - remove_token_loss(toks, tokprobs, female_toks)  - llm_ratio(toks[:, :, 0]) 

loss = SmoothLoss(rhlf_loss)
male_toks

[tensor([[1544]]),
 tensor([[258]]),
 tensor([[6653]]),
 tensor([[14363]]),
 tensor([[26554]]),
 tensor([[7081]]),
 tensor([[679]]),
 tensor([[339]]),
 tensor([[2399]]),
 tensor([[465]]),
 tensor([[6387]]),
 tensor([[2933]])]

In [14]:
smooth_config.do_sampling = True
smooth_config.sampling_temp = 0.0
smooth_config.do_hard_rounding = True
smooth_config.use_kv_cache = True

optimizer = torch.optim.Adam(finetune_model.model.parameters(), 1e-3)

grad_test_tokens = tokenizer.encode("On this very special day", return_tensors="pt").to(device)
grad_test_output = finetune_model.generate(grad_test_tokens, 40, smooth_config)

loss_val = loss(grad_test_output)
kv_cache, tokprobs = loss_val.backwards()

optimizer.step()
optimizer.zero_grad()



43
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
42
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
41
tensor(0.000505)
tensor(0.000357)
tensor(0.000505)
tensor(0.000357)
tensor(0.000505)
tensor(0.000357)
tensor(0.000505)
tensor(0.000357)
40
tensor(0.017429)
tensor(0.029664)
tensor(0.017429)
tensor(0.029664)
tensor(0.017429)
tensor(0.029664)
tensor(0.017429)
tensor(0.029664)
39
tensor(0.003391)
tensor(0.003078)
tensor(0.003391)
tensor(0.003078)
tensor(0.003391)
tensor(0.003078)
tensor(0.003391)
tensor(0.003078)
38
tensor(0.006630)
tensor(0.007515)
tensor(0.006630)
tensor(0.007515)
tensor(0.006630)
tensor(0.007515)
tensor(0.006630)
tensor(0.007515)
37
tensor(0.005800)
tensor(0.013489)
tensor(0.005800)
tensor(0.013489)
tensor(0.005800)
tensor(0.013489)
tensor(0.005800)
tensor(0.013489)
36
tensor(0.015798)
tensor(0.018499)
tensor(0.015798)
tensor(0.018499)
tensor(0.015798)
tensor(0.018499)
tensor(0.015798)
t

In [15]:
print(tokenizer.decode(grad_test_output.toks[0,:,0]))

On this very special day, the little girl was so excited. She was going to the park with her mom and dad. She was so happy that she was going to get to go on the slide.

When they


In [16]:
# grad_test_tokens = tokenizer.encode("On this very special day", return_tensors="pt").to(device)
# grad_test_output = model.generate(grad_test_tokens, 20, smooth_config)

# loss_val = loss(grad_test_output)
# loss_val.backwards()

# optimizer.step()

In [17]:
tokprobs.shape

torch.Size([1, 45, 5])

In [18]:
for i in range(45):
    # print(torch.linalg.vector_norm(kv_cache[0][0][:,:,i,:]))
    print(torch.linalg.vector_norm(tokprobs.grad[:, i, :]))



tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.281904)
tensor(1.136743)
tensor(1.600755)
tensor(1.825374)
tensor(1.712946)
tensor(0.123785)
tensor(0.126443)
tensor(0.452692)
tensor(0.426756)
tensor(1.973768)
tensor(0.115380)
tensor(0.180325)
tensor(1.078440)
tensor(0.152896)
tensor(0.325515)
tensor(0.239255)
tensor(1.525371)
tensor(0.852237)
tensor(0.427369)
tensor(2.122912)
tensor(0.687286)
tensor(1.622748)
tensor(0.084862)
tensor(0.105493)
tensor(0.138282)
tensor(0.429857)
tensor(1.426940)
tensor(0.012231)
tensor(0.010992)
tensor(0.106194)
tensor(0.093155)
tensor(0.999651)
tensor(0.048949)
tensor(0.082272)
tensor(0.990613)
tensor(0.032512)
tensor(0.163247)
tensor(0.999816)
tensor(1.)
tensor(0.)
tensor(1.414214)
