In [18]:
from transformers import AutoTokenizer
from unsloth import FastLanguageModel
import torch
import torch.nn.functional as F

In [3]:
model_name = "google/gemma-3-4b-it"

tokenizer = AutoTokenizer.from_pretrained(model_name)

model, _ = FastLanguageModel.from_pretrained(
    model_name=model_name, 
    max_seq_length=8192, 
    dtype=None,
    load_in_4bit = True,
)

==((====))==  Unsloth 2025.9.1: Fast Gemma3 patching. Transformers: 4.56.1.
   \\   /|    NVIDIA GeForce RTX 3080 Ti. Num GPUs = 1. Max memory: 11.631 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+cu128. CUDA: 8.6. CUDA Toolkit: 12.8. Triton: 3.4.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.32.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Fetching 2 files: 100%|██████████| 2/2 [00:01<00:00,  1.77it/s]


In [19]:
FastLanguageModel.for_inference(model)


Gemma3ForConditionalGeneration(
  (model): Gemma3Model(
    (vision_tower): SiglipVisionModel(
      (vision_model): SiglipVisionTransformer(
        (embeddings): SiglipVisionEmbeddings(
          (patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)
          (position_embedding): Embedding(4096, 1152)
        )
        (encoder): SiglipEncoder(
          (layers): ModuleList(
            (0-26): 27 x SiglipEncoderLayer(
              (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
              (self_attn): SiglipAttention(
                (k_proj): Linear(in_features=1152, out_features=1152, bias=True)
                (v_proj): Linear(in_features=1152, out_features=1152, bias=True)
                (q_proj): Linear(in_features=1152, out_features=1152, bias=True)
                (out_proj): Linear(in_features=1152, out_features=1152, bias=True)
              )
              (layer_norm2): LayerNorm((1152,), eps=1e-06, elementwi

In [None]:
prompt = tokenizer.encode("My name is Naing Ye Oo.", return_tensors="pt", add_special_tokens=False)
bad_res = tokenizer.encode("Hello Naing Ye Oo.", return_tensors="pt", add_special_tokens=False)
good_res = tokenizer.encode("Hello Naing Ye Oo, How can I help you today ?", return_tensors="pt", add_special_tokens=False)

In [16]:
good_ids = torch.cat([prompt, good_res], dim=1)
bad_ids  = torch.cat([prompt, bad_res],  dim=1)
good_attn = torch.ones_like(good_ids)
bad_attn  = torch.ones_like(bad_ids)

good_labels = good_ids.clone()
bad_labels  = bad_ids.clone()
good_labels[:, :prompt.size(1)] = -100
bad_labels[:,  :prompt.size(1)] = -100

In [35]:
def seq_logprob(model, input_ids, attention_mask, labels):
    with torch.no_grad() if not model.training else torch.enable_grad():
        out = model(input_ids=input_ids, attention_mask=attention_mask)
    logits = out.logits 

    logits = logits[:, :-1, :]
    target = labels[:, 1:]                    
    mask   = (target != -100).to(logits.dtype) 

    logp_tok = F.log_softmax(logits, dim=-1)
    tgt_logp = logp_tok.gather(-1, target.unsqueeze(-1).clamp_min(0)).squeeze(-1)

    # sum only over response tokens
    return (tgt_logp * mask).sum(dim=1) 

In [None]:
logp_pol_good = seq_logprob(model, good_ids.to("cuda"), good_attn.to("cuda"), good_labels.to("cuda"))    # log πθ(y+|x)
# logp_pol_bad  = seq_logprob(model, bad_ids.to("cuda"),  bad_attn.to("cuda"),  bad_labels.to("cuda")) 

In [38]:
logp_pol_good

tensor([-67.], device='cuda:0', dtype=torch.bfloat16)

In [34]:
logists[:, :-1, : ].shape

torch.Size([1, 13, 262208])

In [43]:
import torch
import torch.nn.functional as F

# fake logits (batch=1, seq=3, vocab=5)
logits = torch.tensor([[
    [1.0, 2.0, 3.0, 4.0, 5.0],   # timestep 0
    [5.0, 4.0, 3.0, 2.0, 1.0],   # timestep 1
    [0.5, 1.5, 2.5, 3.5, 4.5]    # timestep 2
]])

print("logits shape:", logits.shape)  # [1, 3, 5]

# convert to log-probs
logp_tok = F.log_softmax(logits, dim=-1)

# pretend target tokens at each position are:
target = torch.tensor([[4, 0, 2]])   # shape [1, 3]

print("target:", target)
target_unsq = target.unsqueeze(-1)
print("target_unsq shape:", target_unsq.shape)  # [1, 3, 1]

tgt_logp = logp_tok.gather(-1, target_unsq)
print("tgt_logp shape gather:", tgt_logp.shape)  # [1, 3, 1]
print("tgt_logp gather:", tgt_logp)

tgt_logp = tgt_logp.squeeze(-1)
print("tgt_logp shape after squeeze:", tgt_logp.shape)  # [1, 3]
print("tgt_logp values:", tgt_logp)

logits shape: torch.Size([1, 3, 5])
target: tensor([[4, 0, 2]])
target_unsq shape: torch.Size([1, 3, 1])
tgt_logp shape gather: torch.Size([1, 3, 1])
tgt_logp gather: tensor([[[-0.4519],
         [-0.4519],
         [-2.4519]]])
tgt_logp shape after squeeze: torch.Size([1, 3])
tgt_logp values: tensor([[-0.4519, -0.4519, -2.4519]])
