# Run a simple training loop on data

In [None]:
import torch
import local_llm as lllm

In [None]:
# OPTION 1 SETUP
assets_dir = lllm.setup_bert_base(
    checkpoints=r"C:/Users/Cameron.Webster/Python/local-llm/assets/uncased_L-12_H-768_A-12",
    vocab=r"C:/Users/Cameron.Webster/Python/local-llm/assets/uncased_L-12_H-768_A-12/vocab.txt",
    config=r"C:/Users/Cameron.Webster/Python/local-llm/assets/uncased_L-12_H-768_A-12/bert_config.json",
    # optional; by default this would become ..\assets\bert-base-local
    output_dir=r"C:/Users/Cameron.Webster/Python/local-llm/assets/bert-base-local1",
    overwrite=True
)

# Now assets_dir should contain:
#   pytorch_model.bin
#   config.json
#   vocab.txt


# # OPTION 2 SETUP
# assets_dir = lllm.setup_bert_base(
#     model_params=r"C:/Users/Cameron.Webster/Python/local-llm/assets/bert-base-local/pytorch_model.bin",
#     vocab=r"C:/Users/Cameron.Webster/Python/local-llm/assets/bert-base-local/vocab.txt",
#     config=r"C:/Users/Cameron.Webster/Python/local-llm/assets/bert-base-local/config.json",
#     # output_dir optional; if omitted, uses the folder containing model_params
# )


In [None]:
# 2. Build tokenizer
encoder = lllm.build_bert_input_encoder(assets_dir, max_len=256, lowercase=True)

texts = ["This is a test.", "Another example."]
encoded = [encoder.encode(t) for t in texts]

input_ids = torch.tensor([e.input_ids for e in encoded])
attention_mask = torch.tensor([e.attention_mask for e in encoded])
token_type_ids = torch.tensor([e.token_type_ids for e in encoded])
print(f"input_ids:{input_ids}")
print(f"attention_mask:{attention_mask}")
print(f"token_type_ids:{token_type_ids}")

In [None]:
# 3. Load BERT + classifier head
num_labels = 8
model = lllm.BertTextClassifier.from_pretrained(assets_dir, num_labels=num_labels, pooling="cls")
model.train()

In [None]:
# 4. Simple training step
labels = torch.randint(0, num_labels, (input_ids.size(0),))
out = model(
    input_ids=input_ids,
    attention_mask=attention_mask,
    token_type_ids=token_type_ids,
    labels=labels,
)
loss = out["loss"]
loss.backward()
# optimizer.step(), optimizer.zero_grad(), etc.


In [None]:
# 5. Inference
model.eval()
with torch.inference_mode():
    logits = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids,
    )["logits"]
    preds = logits.argmax(dim=-1)

print(preds)
print(out)