In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as ftorch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [2]:
from smoothllm import *


In [3]:
set_determininsm(42)

In [4]:
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

In [5]:
base_model = AutoModelForCausalLM.from_pretrained('roneneldan/TinyStories-33M').to(device)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")

  return self.fget.__get__(instance, owner)()


In [6]:
model = SmoothModelForCausalLM(
    base_model, 
    embedding_matrix = base_model.get_input_embeddings().weight,
    )

In [7]:
torch.set_printoptions(precision=6)

In [8]:
smooth_config = SmoothGenerationConfig()
smooth_config.eos_token_id = tokenizer.eos_token_id
smooth_config.do_samping = False
smooth_config.use_kv_cache = True
smooth_config.do_hard_rounding = False

In [9]:
base_tokens = tokenizer.encode("One", return_tensors="pt").to(device)
base_tokens.shape

torch.Size([1, 1])

In [10]:
output = model.generate(base_tokens, 100, smooth_config)
print(tokenizer.decode(output.toks[0,:,0]))

One day, a little girl named Tim went to the park. He saw a big slide. He wanted to play on it. He ran to the slide and climbed up the steps. He was so happy.

But then, he saw a big boy named Sam. Sam was not nice. He wanted to play with Tim. Tim did not want to share. They said, "No, this is my slide. Go go away!"

Tim was sad. He did not want to fight


In [11]:
import peft

config = peft.LoraConfig(
    r=8, lora_alpha=16, lora_dropout=0.2, inference_mode=False, task_type="CAUSAL_LM"
)

finetune_base_model = peft.get_peft_model(base_model, config)

In [12]:
finetune_model = SmoothModelForCausalLM(
    finetune_base_model, 
    base_model.get_input_embeddings().weight,
)
optimizer = torch.optim.Adam(finetune_model.model.parameters(), 2e-3)

In [13]:
male_toks = [tokenizer.encode(word, return_tensors="pt").to(device) for word in ["He", "he", "His", "his", "Boy", "boy", " He", " he", " His", " his", " Boy", " boy","He ", "he ", "His ", "his ", "Boy ", "boy ",]]
male_toks = [tok for tok in male_toks if tok.shape[1] == 1]
female_toks = [tokenizer.encode(word, return_tensors="pt").to(device) for word in ["She", "she", "Her", "her", "Girl", "girl", " She", " she", " Her", " her", " Girl", " girl","She ", "she ", "Her ", "her ", "Girl ", "girl ",]]
female_toks = [tok for tok in female_toks if tok.shape[1] == 1]



def remove_token_loss(toks, tokprobs, list_of_toks = male_toks):
  mask = torch.eq(toks, list_of_toks[0])
  for tok in list_of_toks:
    mask = torch.logical_or(mask, torch.eq(toks, tok))
  return ((tokprobs) * mask).sum(dim = -1).sum(dim=-1)

def llm_ratio(toks):
    llm_rl = ftorch.log_softmax(base_model(toks)[0], dim=-1)[0, torch.arange(toks.shape[1]), toks[0]].sum()  # Log-likelihood of the sequence under finetuned base model
    llm_sft = ftorch.log_softmax(finetune_base_model(toks)[0], dim=-1)[0, torch.arange(toks.shape[1]), toks[0]].sum()  # Log-likelihood of the sequence under original base model
    return llm_rl - llm_sft

def rhlf_loss(toks, tokprobs):
   return remove_token_loss(toks, tokprobs, male_toks) - remove_token_loss(toks, tokprobs, female_toks)  - llm_ratio(toks[:, :, 0]) 

loss = SmoothLoss(rhlf_loss)
male_toks

[tensor([[1544]]),
 tensor([[258]]),
 tensor([[6653]]),
 tensor([[14363]]),
 tensor([[26554]]),
 tensor([[7081]]),
 tensor([[679]]),
 tensor([[339]]),
 tensor([[2399]]),
 tensor([[465]]),
 tensor([[6387]]),
 tensor([[2933]])]

In [14]:
smooth_config.do_sampling = True
smooth_config.sampling_temp = 0.5
smooth_config.do_hard_rounding = False
smooth_config.use_kv_cache = True

optimizer = torch.optim.Adam(finetune_model.model.parameters(), 1e-3)

grad_test_tokens = tokenizer.encode("On this very special day", return_tensors="pt").to(device)
grad_test_output = finetune_model.generate(grad_test_tokens, 170, smooth_config)

loss_val = loss(grad_test_output)
kv_cache, tokprobs = loss_val.backwards()

optimizer.step()
optimizer.zero_grad()



144
143
142
141
140
139
138
137
136
135
134
133
132
131
130
129
128
127
126
125
124
123
122
121
120
119
118
117
116
115
114
113
112
111
110
109
108
107
106
105
104
103
102
101
100
99
98
97
96
95
94
93
92
91
90
89
88
87
86
85
84
83
82
81
80
79
78
77
76
75
74
73
72
71
70
69
68
67
66
65
64
63
62
61
60
59
58
57
56
55
54
53
52
51
50
49
48
47
46
45
44
43
42
41
40
39
38
37
36
35
34
33
32
31
30
29
28
27
26
25
24
23
22
21
20
19
18
17
16
15
14
13
12
11
10
9
8
7
6
5


In [15]:
# print(tokenizer.decode(grad_test_output.toks[0,:,0]))
string = ""
for i in range(grad_test_output.toks.shape[1]):
    string += tokenizer.decode(grad_test_output.toks[0,i,0]) + f"|{i}"
print(string)

On|0 this|1 very|2 special|3 day|4,|5 the|6 family|7 wanted|8 to|9 show|10 to|11 the|12 village|13.|14 They|15 were|16 to|17 take|18 the|19 car|20 to|21 the|22 park|23.|24 They|25 were|26 to|27 the|28 park|29.|30
|31 saw|32 the|33 swings|34,|35 the|36 slide|37 and|38 the|39 sandbox|40.|41 The|42 little|43 was|44 so|45 happy|46.|47
|48
|49On|50 their|51 way|52 to|53 the|54 park|55,|56 they|57 saw|58 a|59 big|60 truck|61.|62 It|63 was|64 carrying|65 a|66 lot|67.|68 The|69 driver|70 said|71,|72 "|73I|74 will|75 take|76 the|77 truck|78."|79 |80
|81
|82The|83 little|84 girl|85 asked|86,|87 "|88Can|89 I|90 help|91 too|92?"|93 The|94 driver|95 said|96,|97 "|98Yes|99,|100 you|101 can|102 help|103.|104 You|105 can|106 drive|107 the|108 truck|109."|110
|111
|112The|113 little|114 girl|115 was|116 so|117 excited|118.|119 She|120 grabbed|121 the|122 truck|123 and|124 started|125 to|126 drive|127.|128 She|129 drove|130 the|131 truck|132 and|133 the|134 car|135.|136 She|137 was|138 so|139 proud|140 

In [16]:
# grad_test_tokens = tokenizer.encode("On this very special day", return_tensors="pt").to(device)
# grad_test_output = model.generate(grad_test_tokens, 20, smooth_config)

# loss_val = loss(grad_test_output)
# loss_val.backwards()

# optimizer.step()

In [17]:
for i in range(tokprobs.shape[1]):
    # print(torch.linalg.vector_norm(kv_cache[0][0][:,:,i,:]))
    print(i, tokenizer.decode(grad_test_output.toks[0,i,0]), (tokprobs.grad[0, i, 0]))


0 On tensor(0.)
1  this tensor(0.)
2  very tensor(0.)
3  special tensor(0.)
4  day tensor(-39961.988281)
5 , tensor(-123745.484375)
6  the tensor(-113648.765625)
7  family tensor(58835.558594)
8  wanted tensor(-1854.874512)
9  to tensor(-3966.811523)
10  show tensor(-3858.381836)
11  to tensor(1479.171631)
12  the tensor(-1823.985962)
13  village tensor(1765.571045)
14 . tensor(8708.294922)
15  They tensor(5229.612305)
16  were tensor(11811.407227)
17  to tensor(-415.024414)
18  take tensor(-6322.558105)
19  the tensor(225.293640)
20  car tensor(-10.776978)
21  to tensor(-348.781128)
22  the tensor(2255.238281)
23  park tensor(825.040894)
24 . tensor(-1238.567993)
25  They tensor(559.794250)
26  were tensor(-292.560211)
27  to tensor(562.008362)
28  the tensor(233.114517)
29  park tensor(-1016.459229)
30 . tensor(318.867706)
31 
 tensor(-22.412958)
32  saw tensor(10.452616)
33  the tensor(1.275945)
34  swings tensor(-16.347237)
35 , tensor(-7.878829)
36  the tensor(-0.718596)
37  slide