# Run a simple training loop on data

In [None]:
import torch
import local_llm as lllm
from local_llm.pipelines.text_classification import (BertTextClassifier, ClassifierHeadConfig)

In [None]:
# OPTION 1 SETUP
assets_dir = lllm.setup_bert_base(
    checkpoints=r"C:/Users/Cameron.Webster/Python/local-llm/assets/uncased_L-12_H-768_A-12",
    vocab=r"C:/Users/Cameron.Webster/Python/local-llm/assets/uncased_L-12_H-768_A-12/vocab.txt",
    config=r"C:/Users/Cameron.Webster/Python/local-llm/assets/uncased_L-12_H-768_A-12/bert_config.json",
    # optional; by default this would become ..\assets\bert-base-local
    output_dir=r"C:/Users/Cameron.Webster/Python/local-llm/assets/bert-base-local1",
    overwrite=True
)

# Now assets_dir should contain:
#   pytorch_model.bin
#   config.json
#   vocab.txt


# # OPTION 2 SETUP
# assets_dir = lllm.setup_bert_base(
#     model_params=r"C:/Users/Cameron.Webster/Python/local-llm/assets/bert-base-local/pytorch_model.bin",
#     vocab=r"C:/Users/Cameron.Webster/Python/local-llm/assets/bert-base-local/vocab.txt",
#     config=r"C:/Users/Cameron.Webster/Python/local-llm/assets/bert-base-local/config.json",
#     # output_dir optional; if omitted, uses the folder containing model_params
# )


In [None]:
# 2. Build tokenizer
encoder = lllm.build_bert_input_encoder(assets_dir, max_len=256, lowercase=True)

texts = ["This is a test.", "Another example."]
encoded = [encoder.encode(t) for t in texts]

input_ids = torch.tensor([e.input_ids for e in encoded])
attention_mask = torch.tensor([e.attention_mask for e in encoded])
token_type_ids = torch.tensor([e.token_type_ids for e in encoded])
print(f"input_ids:{input_ids}")
print(f"attention_mask:{attention_mask}")
print(f"token_type_ids:{token_type_ids}")

In [None]:
# 3. Load BERT + classifier head
num_labels = 8
head_cfg = ClassifierHeadConfig(
    hidden_sizes=(512, 256),
    dropouts=(0.2, 0.2, 0.1),
    use_layer_norm=True,
    activation="relu",
)
model = lllm.BertTextClassifier.from_pretrained(
    assets_dir,
    num_labels=num_labels, 
    pooling="cls",
    head_config=head_cfg,
)

model.set_finetune_policy("last_n", last_n=2, train_embeddings=False)
model.train()


# # or inject a completely custome head: 
# custom_head = nn.Sequential(
#     nn.Linear(768, 256),
#     nn.ReLU(),
#     nn.Linear(256, 8),
# )

# model = BertTextClassifier.from_pretrained(
#     assets_dir="assets/bert-base-local",
#     num_labels=8,
#     head=custom_head,
# )
# model.train()


In [None]:
# HOW TO TRAIN ENCODER LAYERS WHILE STORING ANOTER VERSION
from local_llm.training.config import TrainConfig
from local_llm.training.head_trainer import train_classifier_head

cfg = TrainConfig(
    artifacts_root=Path("./artifacts"),
    batch_size=128,
    epochs=10,
    lr_head=2e-4,
    lr_encoder=1e-5,
    finetune_policy="last_n",
    finetune_last_n=2,
    pooling="cls",
)

best_head_state, best_encoder_state = train_classifier_head(
    assets_dir="assets/bert-base-local",
    num_labels=8,
    cfg=cfg,
)

# Save best head separately
torch.save(best_head_state, cfg.artifacts_root / "classifier_head.pt")

# Save finetuned encoder weights separately (never overwriting original assets)
if best_encoder_state is not None:
    torch.save(best_encoder_state, cfg.artifacts_root / "encoder_finetuned.pt")



# Load base encoder for other tasks
base_bert = BertTextClassifier.from_pretrained(
    "assets/bert-base-local",
    num_labels=8,
)

# Load finetuned encoder into a new classifier
ft_model = BertTextClassifier.from_pretrained(
    "assets/bert-base-local",
    num_labels=8,
)
ft_model.bert.load_state_dict(torch.load("artifacts/encoder_finetuned.pt"), strict=False)
ft_model.classifier.load_state_dict(torch.load("artifacts/classifier_head.pt"))


In [None]:
# 4. Simple training step
labels = torch.randint(0, num_labels, (input_ids.size(0),))
out = model(
    input_ids=input_ids,
    attention_mask=attention_mask,
    token_type_ids=token_type_ids,
    labels=labels,
)
loss = out["loss"]
loss.backward()
# optimizer.step(), optimizer.zero_grad(), etc.


In [None]:
# 5. Inference
model.eval()
with torch.inference_mode():
    logits = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids,
    )["logits"]
    preds = logits.argmax(dim=-1)

print(preds)
print(out)