In [19]:
from clean_code.bitter_llm import LinearGater, RandomGater, discounted_rewards_torch
from clean_code.off_policy_bitter_llm import OffPolicyBitterLLM
import torch
from torch import nn
import torch.nn.functional as F


In [None]:
my_model = OffPolicyBitterLLM(
    vocab_size=256, 
    embedding_dim=512, 
    num_heads=8, 
    downsample_rate=0.25, 
    sliding_window=64,
    GaterClass=LinearGater,
    OffPolicyGaterClass=RandomGater,
    use_off_policy=True
)

my_model.cuda()


byte_layer_config: self.byte_layer_config._attn_implementation='eager'


OffPolicyBitterLLM(
  (embedding): Embedding(256, 512)
  (down_layers): ModuleList(
    (0-1): 2 x OptimizedModule(
      (_orig_mod): Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): Linear(in_features=512, out_features=512, bias=False)
          (k_proj): Linear(in_features=512, out_features=512, bias=False)
          (v_proj): Linear(in_features=512, out_features=512, bias=False)
          (o_proj): Linear(in_features=512, out_features=512, bias=False)
          (rotary_emb): Gemma2RotaryEmbedding()
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear(in_features=512, out_features=512, bias=False)
          (up_proj): Linear(in_features=512, out_features=512, bias=False)
          (down_proj): Linear(in_features=512, out_features=512, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((512,), eps=1e-06)
        (pre_feedforward_layernorm): Gemma2RMSNorm((512,), eps=1e-06)
        (post_feedfo

In [10]:
token_ids = torch.randint(0, 256, (2, 10)).cuda()
token_ids

tensor([[243, 160,  63,  17,   1,  33, 217, 131, 201, 226],
        [154, 200,  76, 167,  62,  83,  18, 235, 191,  68]], device='cuda:0')

In [11]:
out = my_model(token_ids)
out



In [16]:
logits = out["logits"]
down_gate_samples = out["down_gate_samples"]
off_policy_gate_probs = out["down_gate_probs"]
on_policy_probs = out["on_policy_probs"]
on_policy_logits = out["on_policy_logits"]
# Compute autoregressive loss: log probability of next token.
next_token_ids = token_ids[:, 1:]
current_token_logits = logits[:, :-1]
next_token_logits = F.cross_entropy(current_token_logits.transpose(1, 2), next_token_ids, reduction="none") # Transpose as F.cross_entropy wants shape [batch, classes, ...]
ar_loss = next_token_logits.mean()

next_token_logits

tensor([[ 9.5384,  5.2708,  7.7629,  8.0215,  4.0989,  8.5190,  9.3597, 14.4296,
         11.7884],
        [12.0584,  7.4050, 10.3375,  6.0768, 12.9127, 10.7237,  9.1943, 12.6352,
          9.7414]], device='cuda:0', grad_fn=<ViewBackward0>)

In [18]:
next_token_logits_padded = torch.cat([next_token_logits, torch.zeros(2, 1, device=next_token_logits.device)], dim=-1) # Pad the last reward as zero
next_token_logits_padded

tensor([[ 9.5384,  5.2708,  7.7629,  8.0215,  4.0989,  8.5190,  9.3597, 14.4296,
         11.7884,  0.0000],
        [12.0584,  7.4050, 10.3375,  6.0768, 12.9127, 10.7237,  9.1943, 12.6352,
          9.7414,  0.0000]], device='cuda:0', grad_fn=<CatBackward0>)

In [21]:
discounted_rewards = discounted_rewards_torch(next_token_logits_padded, 0.5)
discounted_rewards

tensor([[15.9446, 12.8125, 15.0834, 14.6409, 13.2388, 18.2798, 19.5217, 20.3238,
         11.7884,  0.0000],
        [20.5275, 16.9381, 19.0663, 17.4575, 22.7614, 19.6973, 17.9473, 17.5059,
          9.7414,  0.0000]], device='cuda:0', dtype=torch.float64)

In [53]:
discounted_rewards = (discounted_rewards - discounted_rewards.mean(dim=0))
discounted_rewards

tensor([[-2.2914, -2.0628, -1.9914, -1.4083, -4.7613, -0.7088,  0.7872,  1.4089,
          1.0235,  0.0000],
        [ 2.2914,  2.0628,  1.9914,  1.4083,  4.7613,  0.7088, -0.7872, -1.4089,
         -1.0235,  0.0000]], device='cuda:0', dtype=torch.float64)

In [24]:
action_log_probs = torch.stack([torch.zeros_like(on_policy_logits), on_policy_logits], dim=1) # As a sigmoid is equivalent to having one logit as 0.
selected_action_log_probs = F.cross_entropy(action_log_probs, down_gate_samples, reduction="none")
selected_action_log_probs

tensor([[-0.0000, 0.1557, 1.1165, 0.6073, 0.7121, 0.8154, 0.8295, 1.6167, 1.3842,
         1.8384],
        [-0.0000, 1.9978, 0.6139, 1.4178, 0.9176, 0.4970, 0.8359, 0.0949, 0.9186,
         0.9255]], device='cuda:0', grad_fn=<ViewBackward0>)

In [26]:
on_policy_probs

tensor([[1.0000, 0.1442, 0.6726, 0.4552, 0.5094, 0.4425, 0.5637, 0.8014, 0.7495,
         0.1591],
        [1.0000, 0.8644, 0.4588, 0.7578, 0.6005, 0.3916, 0.5665, 0.0905, 0.6009,
         0.6037]], device='cuda:0', grad_fn=<SqueezeBackward1>)

In [27]:
off_policy_gate_probs

tensor([[1.0000, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500,
         0.2500],
        [1.0000, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500,
         0.2500]], device='cuda:0')

In [28]:
down_gate_samples

tensor([[1, 0, 0, 0, 0, 1, 0, 0, 0, 1],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]], device='cuda:0')

In [44]:
on_policy_probs / 0.25

tensor([[4.0000, 0.5768, 2.6903, 1.8207, 2.0375, 1.7698, 2.2549, 3.2058, 2.9979,
         0.6363],
        [4.0000, 3.4574, 1.8350, 3.0310, 2.4021, 1.5666, 2.2661, 0.3622, 2.4038,
         2.4147]], device='cuda:0', grad_fn=<DivBackward0>)

In [51]:
(1 -  on_policy_probs) / 0.75

tensor([[0.0000, 1.1411, 0.4366, 0.7264, 0.6542, 0.7434, 0.5817, 0.2647, 0.3340,
         1.1212],
        [0.0000, 0.1809, 0.7217, 0.3230, 0.5326, 0.8111, 0.5780, 1.2126, 0.5321,
         0.5284]], device='cuda:0', grad_fn=<DivBackward0>)

In [48]:
# likelihood_ratios [:, :, 1] gives the likelihood ratio for the action of gating.
likelihood_ratios = torch.stack([
    (1 - on_policy_probs) / (1 - off_policy_gate_probs),
    on_policy_probs / off_policy_gate_probs
], dim=-1)
likelihood_ratios = likelihood_ratios.detach() # Detach as we don't want to backpropagate through this.
likelihood_ratios

tensor([[[   nan, 1.0000],
         [1.1411, 0.5768],
         [0.4366, 2.6903],
         [0.7264, 1.8207],
         [0.6542, 2.0375],
         [0.7434, 1.7698],
         [0.5817, 2.2549],
         [0.2647, 3.2058],
         [0.3340, 2.9979],
         [1.1212, 0.6363]],

        [[   nan, 1.0000],
         [0.1809, 3.4574],
         [0.7217, 1.8350],
         [0.3230, 3.0310],
         [0.5326, 2.4021],
         [0.8111, 1.5666],
         [0.5780, 2.2661],
         [1.2126, 0.3622],
         [0.5321, 2.4038],
         [0.5284, 2.4147]]], device='cuda:0')

In [49]:
# Get the likelihood ratios for the selected actions for importance sampling.
selected_action_likelihood_ratios = likelihood_ratios.gather(dim=-1, index=down_gate_samples.unsqueeze(-1))
selected_action_likelihood_ratios = selected_action_likelihood_ratios.squeeze(-1)
selected_action_likelihood_ratios

tensor([[1.0000, 1.1411, 0.4366, 0.7264, 0.6542, 1.7698, 0.5817, 0.2647, 0.3340,
         0.6363],
        [1.0000, 0.1809, 0.7217, 0.3230, 0.5326, 0.8111, 0.5780, 1.2126, 0.5321,
         0.5284]], device='cuda:0')

In [55]:
print(f"{selected_action_likelihood_ratios.grad_fn}")

None


In [50]:
selected_action_likelihood_ratios.shape, discounted_rewards.shape

(torch.Size([2, 10]), torch.Size([2, 10]))

In [56]:
selected_action_likelihood_ratios * discounted_rewards * selected_action_log_probs

tensor([[ 0.0000, -0.3665, -0.9707, -0.6213, -2.2179, -1.0228,  0.3798,  0.6030,
          0.4732,  0.0000],
        [-0.0000,  0.7453,  0.8822,  0.6449,  2.3271,  0.2857, -0.3803, -0.1621,
         -0.5003,  0.0000]], device='cuda:0', dtype=torch.float64,
       grad_fn=<MulBackward0>)

In [57]:
selected_action_log_probs

tensor([[-0.0000, 0.1557, 1.1165, 0.6073, 0.7121, 0.8154, 0.8295, 1.6167, 1.3842,
         1.8384],
        [-0.0000, 1.9978, 0.6139, 1.4178, 0.9176, 0.4970, 0.8359, 0.0949, 0.9186,
         0.9255]], device='cuda:0', grad_fn=<ViewBackward0>)