In [276]:
import pandas as pd
import numpy as np

import torch

from transformers import AutoTokenizer, AutoModelForCausalLM, Adafactor

In [53]:
AVeriTeC = pd.read_json('../data/AVeriTeC/train.json')
AVeriTeC = AVeriTeC.rename(columns={"claim": "sentence"})
AVeriTeC['claim'] = 'Yes'
AVeriTeC = AVeriTeC.filter(items=['sentence','claim'])

In [54]:
subj = pd.read_csv('../data/subj/train.tsv', sep='\t')
subj = subj[subj['label'] == 1].reset_index()
subj['claim'] = 'No'
subj = subj.filter(items=['sentence','claim'])

In [56]:
train_df = pd.concat([AVeriTeC,subj])

In [83]:
cache_dir = "../assets/models"
model_path = "meta-llama/Llama-3.2-1B-Instruct"

model = AutoModelForCausalLM.from_pretrained(
    model_path, cache_dir=cache_dir, use_safetensors=True
)

tokenizer = AutoTokenizer.from_pretrained(
    model_path, cache_dir=cache_dir, use_safetensors=True, padding_side="left"
)

In [225]:
# A convention that makes training easier
tokenizer.pad_token = tokenizer.eos_token

messages = [
    {"role": "system", "content": "You are a yes/no answering bot. Only respond to questions with Yes or No.",},
    {"role": "user", "content": "Is the capital of New York state New York City?"},
    {"role": "assistant", "content": ""}
]
answer = "No"
#tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
#print(tokenizer.decode(tokenized_chat[0]))
input_ids = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=False,
    continue_final_message=True,
    return_tensors="pt"
)
#input_ids = input_ids[0][:-1].reshape(1, -1)

with torch.no_grad():
    logits = model(
        input_ids,
        use_cache=False,
        output_hidden_states=False,
        output_attentions=False,
    )["logits"]
logits
torch.argmax(logits[-1,1])

tensor(128006)

In [None]:
print(tokenizer.batch_decode(model.generate(input_ids, max_new_tokens = 1), remove_special_chars=True)[0])

In [290]:
messages = [
    {"role": "system", "content": "You are a yes/no answering bot. Only respond to questions with Answer: Yes or Answer: No ",},
    {"role": "user", "content": "Is the capital of New York state New York City?"},
    {"role": "assistant", "content": "Answer:"}
]
answer = "No"
chat_template = tokenizer.apply_chat_template(messages, tokenize=False, continue_final_message=True)
chat_template_input_ids = tokenizer.apply_chat_template(messages, tokenize=True, continue_final_message=True, add_generation_prompt=False, return_tensors="pt")
#print(chat_template)

label_tokenized = tokenizer([" " + answer], add_special_tokens=False, return_tensors="pt", padding="max_length", max_length=chat_template_input_ids.shape[1])['input_ids']

# -100 comes from the Llama documentation, recommendation for loss
label_tokenized_fixed = torch.where(label_tokenized != tokenizer.pad_token_id, label_tokenized, -100)
#label_tokenized_fixed[:, -1] = tokenizer.pad_token_id
label_tokenized_fixed

# These should be equal
# chat_template_input_ids.shape, labels_tokenized_fixed.shape

tensor([[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, 2360]])

In [263]:
#chat_template_input_ids

In [300]:
print(tokenizer.batch_decode(model.generate(chat_template_input_ids, max_new_tokens = 10))[0])

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 05 Apr 2025

You are a yes/no answering bot. Only respond to questions with Answer: Yes or Answer: No<|eot_id|><|start_header_id|>user<|end_header_id|>

Is the capital of New York state New York City?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Answer: No No No No No No No No No No


In [292]:
with torch.no_grad():
    logits = model(
        chat_template_input_ids,
        use_cache=True
    )["logits"]
#logits = model(input_ids=chat_template_input_ids)["logits"]
#logits.shape

torch.argmax(logits[0,-1])
logits.argmax(axis=-1)

tensor([[ 16309, 128006,   9891,   7566,   9642,   7566,   7566,   7566,   7566,
           7566,   7566,   7566,   7566,   7566,   9642,   7566,   7566,   7566,
           7566,   7566,   7566,   9642,   7566,   7566,   9642,   7566,   7566,
           7566,   7566,   7566,   7566,   7566,   7566,   7566,   7566,   7566,
           7566,   7566,   7566,   7566,   7566,   7566,   7566,   7566,   7566,
           7566,   7566,   7566,   7566,   7566,   7566,   7566,   7566,   7566,
           7566,   7566,   7566,   7566,   7566,   7566,   7566,   7566,   7566,
           7566,   7566,   7566,   7566,   7566]])

In [293]:
label_tokenized_fixed

tensor([[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, 2360]])

In [294]:
def calculate_loss(logits, labels):
    loss_fn = torch.nn.CrossEntropyLoss(reduction='none')
    cross_entropy_loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
    return cross_entropy_loss

In [272]:
#out = model.generate(chat_template_input_ids, output_logits=True, return_dict_in_generate=True)
#out.logits

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


(tensor([[6.3550, 4.1672, 5.3958,  ..., 2.9041, 2.9040, 2.9043]]),
 tensor([[12.5756,  9.9608,  4.8848,  ...,  2.6261,  2.6255,  2.6263]]))

In [295]:
calculate_loss(logits, label_tokenized_fixed)

tensor([ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000, 44.3867])

In [280]:
optimizer = Adafactor(model.parameters(), weight_decay=0.01)

In [296]:
for _ in range(3):
    logits = model(chat_template_input_ids, use_cache=False)["logits"]
    loss = calculate_loss(logits, label_tokenized_fixed).mean()

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    print("loss: ", loss.item())


loss:  0.6527453064918518
loss:  0.0005169016076251864
loss:  1.90331138583133e-05


In [297]:
with torch.no_grad():
    logits = model(
        chat_template_input_ids,
        use_cache=True
    )["logits"]
#logits = model(input_ids=chat_template_input_ids)["logits"]
#logits.shape

#torch.argmax(logits[0,-1])
#logits.argmax(axis=-1)

tensor([[  791,  2360,  2360,  2360,  2360,  2360,  2360,  2360,  2360,  2360,
          2360,  2360,  2360,  2360,  2360,   912,   912,  2360,  2360,  2360,
          2360,  2360,  2360,  2360,  2360,   912,   912,   912,   912,   912,
          2360,  2360,  2360,   912,  2360,  2360,  2360,  2360,  2360,  2360,
           912, 84343,  2360,  2360,   912, 84343,   912,  2360,  2360,  2360,
          2360,  2360,  2360,  2360,  2360,  2360,  2360,  2360,  2360,  2360,
          2360,  2360,  2360,  2360,  2360,  2360,  2360,  2360]])