In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as ftorch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [2]:
from smoothllm import *

In [3]:
set_determininsm(42)

In [4]:
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

In [5]:
base_model = AutoModelForCausalLM.from_pretrained('roneneldan/TinyStories-33M').to(device)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
model = SmoothModelForCausalLM(base_model, embedding_matrix = base_model.get_input_embeddings().weight)

In [6]:
torch.set_printoptions(precision=6)

In [7]:
smooth_config = SmoothGenerationConfig()
smooth_config.eos_token_id = tokenizer.eos_token_id
smooth_config.do_samping = False


In [8]:
# base_tokens = tokenizer.encode("One", return_tensors="pt").to(device)
# output = model.generate(base_tokens, 100, mooth_onfig)
# print(tokenizer.decode(output.toks[0,:,0]))

In [9]:
import peft

config = peft.LoraConfig(
    r=8, lora_alpha=16, lora_dropout=0.2, inference_mode=False, task_type="CAUSAL_LM"
)

finetune_base_model = peft.get_peft_model(base_model, config)

In [10]:
finetune_model = SmoothModelForCausalLM(finetune_base_model, base_model.get_input_embeddings().weight)
optimizer = torch.optim.Adam(finetune_model.model.parameters(), 2e-3)

In [11]:
male_toks = [tokenizer.encode(word, return_tensors="pt").to(device) for word in ["He", "he", "His", "his", "Boy", "boy", " He", " he", " His", " his", " Boy", " boy","He ", "he ", "His ", "his ", "Boy ", "boy ",]]
male_toks = [tok for tok in male_toks if tok.shape[1] == 1]

def remove_token_loss(toks, tokprobs, list_of_toks = male_toks):
  mask = torch.eq(toks, list_of_toks[0])
  for tok in list_of_toks:
    mask = torch.logical_or(mask, torch.eq(toks, tok))
  return ((tokprobs) * mask).sum(dim = -1).sum(dim=-1)

def llm_ratio(toks):
    llm_rl = base_model(toks.unsqueeze(0))[0].gather(1, toks.unsqueeze(0)).sum()  # Log-likelihood of the sequence under finetuned base model
    llm_sft = finetune_base_model(toks.unsqueeze(0))[0].gather(1, toks.unsqueeze(0)).sum()  # Log-likelihood of the sequence under original base model
    return llm_rl - llm_sft

def rhlf_loss(toks, tokprobs):
   return remove_token_loss(toks, tokprobs) #- llm_ratio(toks[0, :, 0]) 

loss = SmoothLoss(rhlf_loss)
male_toks

[tensor([[1544]]),
 tensor([[258]]),
 tensor([[6653]]),
 tensor([[14363]]),
 tensor([[26554]]),
 tensor([[7081]]),
 tensor([[679]]),
 tensor([[339]]),
 tensor([[2399]]),
 tensor([[465]]),
 tensor([[6387]]),
 tensor([[2933]])]

In [12]:
optimizer = torch.optim.Adam(model.model.parameters(), 1e-6)

grad_test_tokens = tokenizer.encode("On this very special day", return_tensors="pt").to(device)
grad_test_output = model.generate(grad_test_tokens, 20, smooth_config)

loss_val = loss(grad_test_output)
loss_val.backwards()

optimizer.step()
optimizer.zero_grad()

TypeError: 'NoneType' object is not subscriptable