In [1]:
%uv pip install transformers datasets peft accelerate bitsandbytes torch

[2mUsing Python 3.12.6 environment at: /usr/local[0m
[2mAudited [1m6 packages[0m [2min 25ms[0m[0m
Note: you may need to restart the kernel to use updated packages.


In [2]:
# =========================
# train.py
# Fine-tune gemma-2-2b-chess using QLoRA
# Task: Next-move prediction (Anand style)
# =========================

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import torch
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    Trainer,
    TrainingArguments,
    BitsAndBytesConfig,
    DataCollatorForLanguageModeling
)
from peft import LoraConfig, get_peft_model




In [11]:
# -------------------------
# Config
# -------------------------
MODEL_NAME = "diabolic6045/gemma-2-2b-chess"
OUTPUT_DIR = "anand_gemma_lora"

MAX_LENGTH = 256
BATCH_SIZE = 8
GRAD_ACCUM_STEPS = 2
EPOCHS = 3
LR = 2e-4

SAVE_STEPS = 500        # üîê save every 500 steps
SAVE_TOTAL_LIMIT = 3   # üîê keep only last 3 checkpoints




In [12]:
# -------------------------
# Load Dataset
# -------------------------
dataset = load_dataset(
    "text",
    data_files={"train": "/root/anand_train.txt"}
)



In [13]:
# -------------------------
# Tokenizer
# -------------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Gemma does not always define pad token
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

def tokenize_fn(example):
    return tokenizer(
        example["text"],
        truncation=True,
        max_length=MAX_LENGTH,
        padding="max_length"
    )

tokenized_ds = dataset.map(
    tokenize_fn,
    batched=True,
    remove_columns=["text"]
)



In [14]:

# -------------------------
# Quantization (QLoRA)
# -------------------------
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto"
)





In [16]:

# -------------------------
# LoRA
# -------------------------
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules=["q_proj", "v_proj"],
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()


trainable params: 1,597,440 || all params: 2,615,939,328 || trainable%: 0.0611




In [17]:
# -------------------------
# Data Collator
# -------------------------
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)



# -------------------------
# Training Arguments (CRASH-SAFE)
# -------------------------
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM_STEPS,
    num_train_epochs=EPOCHS,
    learning_rate=LR,

    # üîê CHECKPOINTING
    save_strategy="steps",
    save_steps=SAVE_STEPS,
    save_total_limit=SAVE_TOTAL_LIMIT,

    logging_steps=50,
    fp16=True,
    dataloader_num_workers=2,
    report_to="none",
    optim="paged_adamw_8bit"
)

# -------------------------
# Trainer
# -------------------------
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_ds["train"],
    data_collator=data_collator
)

# -------------------------
# Resume logic (VERY IMPORTANT)
# -------------------------
checkpoint = None
if os.path.isdir(OUTPUT_DIR):
    checkpoints = [
        os.path.join(OUTPUT_DIR, d)
        for d in os.listdir(OUTPUT_DIR)
        if d.startswith("checkpoint-")
    ]
    if checkpoints:
        checkpoint = sorted(checkpoints, key=os.path.getmtime)[-1]
        print(f"üîÅ Resuming training from checkpoint: {checkpoint}")






Detected kernel version 4.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [18]:
# -------------------------
# Train
# -------------------------
trainer.train(resume_from_checkpoint=checkpoint)



Step,Training Loss
50,2.1588
100,1.9005
150,1.8229
200,1.7828
250,1.7394
300,1.6789
350,1.6439
400,1.6259
450,1.5902
500,1.5481


KeyboardInterrupt: 

In [None]:
# -------------------------
# Save final model
# -------------------------
model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)

print("‚úÖ Training complete. Final model saved to:", OUTPUT_DIR)


In [29]:
# =========================
# infer.py
# Predict Anand's next move
# =========================

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig
)
from peft import PeftModel

# -------------------------
# Config
# -------------------------
BASE_MODEL = "diabolic6045/gemma-2-2b-chess"
LORA_PATH = "anand_gemma_lora/checkpoint-4000"

# -------------------------
# Load Tokenizer
# -------------------------
tokenizer = AutoTokenizer.from_pretrained(LORA_PATH)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# -------------------------
# Quantization config
# -------------------------
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True
)

# -------------------------
# Load Base Model
# -------------------------
base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    quantization_config=bnb_config,
    device_map="auto"
)

# -------------------------
# Attach LoRA Adapter
# -------------------------
model = PeftModel.from_pretrained(base_model, LORA_PATH)
model.eval()

# -------------------------
# Prediction Function
# -------------------------
def predict_next_move(moves: str) -> str:
    """
    Input:  'e4 e5 Nf3 Nc6'
    Output: 'Bb5'
    """
    prompt = f"{moves} =>"

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=6,
            temperature=0.6,
            top_k=10,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )

    decoded = tokenizer.decode(output[0], skip_special_tokens=True)

    # Extract only the predicted move
    prediction = decoded.split("=>")[-1].strip().split()[0]
    return prediction

# -------------------------
# Test
# -------------------------
if __name__ == "__main__":
    test_input = "e4 e5"
    print("Input moves :", test_input)
    print("Anand move  :", predict_next_move(test_input))


Input moves : e4 e5
Anand move  : c4
