In [1]:
import torch
import torch.nn.functional as F
from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM, LlamaForSequenceClassification

torch.manual_seed(2)

<torch._C.Generator at 0x7f7107b3e390>

In [6]:
config = LlamaConfig(
    vocab_size=100,
    hidden_size=256,
    intermediate_size=512,
    num_hidden_layers=2,
    num_attention_heads=4,
    num_key_value_heads=4,
)

model = LlamaForCausalLM(config)

model.save_pretrained("./lm_pretrained")

In [7]:
#1 label sequence classification
rm_model = LlamaForSequenceClassification.from_pretrained("./lm_pretrained", num_labels=1)

Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at ./lm_pretrained and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
print(rm_model)

LlamaForSequenceClassification(
  (model): LlamaModel(
    (embed_tokens): Embedding(100, 256)
    (layers): ModuleList(
      (0-1): 2 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=256, out_features=256, bias=False)
          (k_proj): Linear(in_features=256, out_features=256, bias=False)
          (v_proj): Linear(in_features=256, out_features=256, bias=False)
          (o_proj): Linear(in_features=256, out_features=256, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=256, out_features=512, bias=False)
          (up_proj): Linear(in_features=256, out_features=512, bias=False)
          (down_proj): Linear(in_features=512, out_features=256, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (score): Linear(in_fea

In [10]:
x_chosen = torch.randint(0, 100, (1,10))
x_rejected = torch.randint(0,100, (1,10))

margin = 3.0

idx={}
idx['input_ids']= x_chosen
rm_chosen = rm_model(**idx).logits

idx['input_ids']= x_rejected
rm_rejected = rm_model(**idx).logits

loss = -torch.sigmoid(rm_chosen-rm_rejected).log()

print(f'chosen reward :  {rm_chosen.item()}')
print(f'rejected reward :  {rm_rejected.item()}')
print(f'model loss: {loss.item()}')

chosen reward :  0.47253820300102234
rejected reward :  0.21928271651268005
model loss: 0.5745154023170471


In [11]:
loss_with_margin=torch.sigmoid(rm_chosen-rm_rejected-margin).log()
print(f'with margin: {loss_with_margin.item()}')

with margin: -2.808908224105835


In [12]:
print(f'with margin: {loss_with_margin.item()}')

with margin: -2.808908224105835


In [13]:
#double reward!!!

def llama_select_dreward(reward_s,reward_h):
    return reward_s if reward_s<0.15 else reward_h

rc = llama_select_dreward(reward_s=-0.3, reward_h=0.7)
print(rc)

-0.3


In [16]:
def inverse_sigmoid(x):
    return torch.log(x / (1 - x))

In [17]:
sigmoid_o = torch.tensor([0.9])
inverse_sigmoid_o = inverse_sigmoid(sigmoid_o)
print(inverse_sigmoid_o) 

tensor([2.1972])


In [18]:
sigmoid_o = torch.tensor([0.4])
inverse_sigmoid_o = inverse_sigmoid(sigmoid_o)
print(inverse_sigmoid_o) 

tensor([-0.4055])


In [19]:
def whiten(values: torch.Tensor, shift_mean: bool = True) -> torch.Tensor:
    mean, var = torch.mean(values), torch.var(values)
    whitened = (values-mean) * torch.rsqrt(var + 1e-8) 
    if not shift_mean:
        whitened += mean
    return whitened

In [20]:
values = torch.Tensor([[0.830, 1.200, 2.200, 4.500]])
values_w = whiten(values)
print(values)
print(values_w)

tensor([[0.8300, 1.2000, 2.2000, 4.5000]])
tensor([[-0.8198, -0.5955,  0.0106,  1.4047]])


In [None]:
#kl penalty
import torch.nn.functional as F

#actor model
model = LlamaForCausalLM(config)
#reference model
model_old = LlamaForCausalLM(config)

#old polcy
index_old = torch.randint(0, 100, (1,1))
prob_old = torch.rand(1,1)
print('old policy idx: ', index_old.item())
print('old policy prob: ', prob_old.item())

#new policy
x = torch.randint(0, 100, (1,10))
output = model(x)['logits'][:,-1,:].sigmoid()
prob = torch.gather(output, dim=1, index=index_old)
print('policy prob:', prob.item())

#kl div
kl = F.kl_div(torch.log(prob), prob_old)
print('kl penalty: ',kl.item())


In [29]:
#reward for ppo
beta = 0.01
rm_score = rm_model(**idx).logits
rm_ppo = rm_score - beta * kl
print('rm_score', rm_score.item())
print('rm_core with kl', rm_ppo.item())

rm_score 0.21928271651268005
rm_core with kl 0.2197381854057312
