# AgricGPT - Agricultural Domain Instruction Tuning with QLoRA

Fine-tunes **Microsoft Phi-2** on **AI4Agr/CROP-dataset** with automatic checkpoint pushing to Hugging Face Hub.

**Features**:
- QLoRA (4-bit quantization)
- Checkpoints pushed to HF every N steps (fault-tolerant)
- Before/after training comparison

## 1. Install Dependencies

In [None]:
!pip install -q torch transformers datasets peft bitsandbytes accelerate huggingface_hub

## 2. Login to Hugging Face (FIRST!)

Login early so checkpoints can be pushed during training.

In [None]:
from huggingface_hub import login

# Get your token at: https://huggingface.co/settings/tokens
# Make sure it has WRITE access!
login()

## 3. Configuration

In [None]:
import torch

# Model
MODEL_NAME = "microsoft/phi-2"
OUTPUT_DIR = "./agri_model_results"

# Hugging Face Hub - CHANGE THIS!
HF_MODEL_NAME = "agricgpt-phi2"  # Your model name on HF
PUSH_TO_HUB = True
SAVE_STEPS = 100  # Push checkpoint every N steps

# Dataset
DATASET_SIZE = 5000
MAX_SEQ_LENGTH = 512

# LoRA
LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05
TARGET_MODULES = ["fc1", "fc2", "q_proj", "k_proj", "v_proj", "dense"]

# Training
NUM_EPOCHS = 3
BATCH_SIZE = 2
GRADIENT_ACCUMULATION_STEPS = 4
LEARNING_RATE = 2e-4
LOGGING_STEPS = 10

# GPU check
assert torch.cuda.is_available(), "GPU required!"
print(f"GPU: {torch.cuda.get_device_name(0)}")
torch.manual_seed(42)

## 4. Load Model

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    trust_remote_code=True,
    device_map={"":0}
)
model.config.use_cache = False
print(f"Loaded: {MODEL_NAME}")

## 5. Base Model Output (BEFORE Training)

In [None]:
from transformers import GenerationConfig, pipeline

base_pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer)
base_gen_config = GenerationConfig(
    max_new_tokens=150,
    do_sample=True,
    temperature=0.7,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.eos_token_id
)

test_prompt = "### Instruction:\nWhat is crop rotation?\n\n### Response:\n"

print("=" * 50)
print("BASE MODEL (before training)")
print("=" * 50)
torch.manual_seed(42)
result = base_pipe(test_prompt, generation_config=base_gen_config)
print(result[0]['generated_text'])

## 6. Prepare Dataset

In [None]:
from datasets import load_dataset

dataset = load_dataset("AI4Agr/CROP-dataset", data_files="**/*_en/**/*.json", split="train")
if DATASET_SIZE:
    dataset = dataset.select(range(min(DATASET_SIZE, len(dataset))))

def format_instruction(sample):
    return {"text": f"### Instruction:\n{sample['instruction']}\n\n### Response:\n{sample['output']}{tokenizer.eos_token}"}

dataset = dataset.map(format_instruction)

def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True, max_length=MAX_SEQ_LENGTH, padding="max_length")

tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names)
print(f"Dataset: {len(tokenized_dataset)} samples")

## 7. Configure LoRA

In [None]:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

model = prepare_model_for_kbit_training(model)
peft_config = LoraConfig(
    r=LORA_R, lora_alpha=LORA_ALPHA, target_modules=TARGET_MODULES,
    lora_dropout=LORA_DROPOUT, bias="none", task_type="CAUSAL_LM"
)
model = get_peft_model(model, peft_config)

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)")

## 8. Training (with automatic HF checkpoint pushing)

Checkpoints are pushed to Hugging Face every `SAVE_STEPS` steps.

If training is interrupted, you can resume from the last checkpoint!

In [None]:
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    learning_rate=LEARNING_RATE,
    logging_steps=LOGGING_STEPS,
    fp16=True,
    optim="paged_adamw_32bit",
    warmup_ratio=0.03,
    lr_scheduler_type="cosine",
    # Checkpoint saving & HF pushing
    save_strategy="steps",
    save_steps=SAVE_STEPS,
    save_total_limit=3,  # Keep only last 3 checkpoints locally
    push_to_hub=PUSH_TO_HUB,
    hub_model_id=HF_MODEL_NAME if PUSH_TO_HUB else None,
    hub_strategy="every_save",  # Push at every checkpoint!
    report_to="none",
    seed=42
)

trainer = Trainer(
    model=model,
    train_dataset=tokenized_dataset,
    args=training_args,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False)
)

print(f"Training with checkpoints pushed to HF every {SAVE_STEPS} steps...")
trainer.train()

## 9. Fine-Tuned Model Output (AFTER Training)

In [None]:
from transformers import logging
logging.set_verbosity(logging.CRITICAL)
model.eval()

pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer)
gen_config = GenerationConfig(
    max_new_tokens=256, do_sample=True, temperature=0.7,
    top_p=0.9, repetition_penalty=1.2,
    eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id
)

def ask_agrigpt(q):
    prompt = f"### Instruction:\n{q}\n\n### Response:\n"
    result = pipe(prompt, generation_config=gen_config)
    return result[0]['generated_text'].split("### Response:\n")[-1].split("### Instruction:")[0].strip()

print("=" * 50)
print("FINE-TUNED MODEL (after training)")
print("=" * 50)
torch.manual_seed(42)
print(f"Q: What is crop rotation?")
print(f"A: {ask_agrigpt('What is crop rotation?')}")

## 10. Push Final Model

In [None]:
# Push final model to HF Hub
if PUSH_TO_HUB:
    print(f"Pushing final model to {HF_MODEL_NAME}...")
    trainer.push_to_hub()
    print(f"âœ… Done! View at: https://huggingface.co/YOUR_USERNAME/{HF_MODEL_NAME}")

## 11. Resume Training from Checkpoint (if interrupted)

If training was interrupted, run this cell to resume from the last checkpoint.

In [None]:
# Uncomment and run to resume from last checkpoint:
# trainer.train(resume_from_checkpoint=True)