<a href="https://colab.research.google.com/github/Makito042/Domain-Specific-Assistant-via-LLMs-Fine-Tuning/blob/main/train_final_(1).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Companion Plants Assistant: Fine-Tuning Gemma-2b on Colab

This notebook demonstrates how to fine-tune a Gemma-2b model to become an expert in companion planting. We will use the `unsloth` library for 2-5x faster training and 70% less memory usage, allowing this to run on a free Google Colab T4 GPU instance.

## 1. Installation

In [None]:
%%capture
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install --no-deps xformers "trl<0.9.0" peft accelerate bitsandbytes
!pip install evaluate rouge_score sacrebleu bert_score

## 2. Load Model & Tokenizer

In [None]:
from unsloth import FastLanguageModel
import torch

max_seq_length = 2048
dtype = None
load_in_4bit = True

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/gemma-2b-it-bnb-4bit",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)

model = FastLanguageModel.get_peft_model(
    model,
    r = 16,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0,
    bias = "none",
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
    use_rslora = False,
    loftq_config = None,
)

ðŸ¦¥ Unsloth: Will patch your computer to enable 2x faster free finetuning.
ðŸ¦¥ Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2026.2.1: Fast Gemma patching. Transformers: 4.57.6.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.563 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu128. CUDA: 7.5. CUDA Toolkit: 12.8. Triton: 3.5.0
\        /    Bfloat16 = FALSE. FA [Xformers = None. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/2.07G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/154 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

Unsloth 2026.2.1 patched 18 layers with 18 QKV layers, 18 O layers and 18 MLP layers.


## 3. Load and Preprocess Data (Strict & Bi-directional)

In [None]:
import pandas as pd
from datasets import Dataset

def create_strict_dataset(input_file):
    try:
        df = pd.read_csv('/content/sample_data/companion_plants.csv')
    except FileNotFoundError:
        print(f"Error: {input_file} not found. Please upload it to Colab.")
        return None

    # Standardize column names
    df.columns = [c.strip() for c in df.columns]

    plant_data = {}

    for _, row in df.iterrows():
        source = str(row['Source Node']).strip().lower()
        relation = str(row['Link']).strip().lower()
        target = str(row['Destination Node']).strip().lower()

        if source not in plant_data:
            plant_data[source] = {'helps': set(), 'helped_by': set()}
        if target not in plant_data:
            plant_data[target] = {'helps': set(), 'helped_by': set()}

        # Bi-directional logic
        if relation == 'helps':
            plant_data[source]['helps'].add(target)
            plant_data[target]['helped_by'].add(source)
        elif relation == 'helped_by':
            plant_data[source]['helped_by'].add(target)
            plant_data[target]['helps'].add(source)

    dataset_data = []
    # STRICT PROMPT: "Do not guess."
    # Using <eos> token in training data to force termination
    alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{}

### Input:
{}

### Response:
{}"""

    for plant, info in plant_data.items():
        plant_formatted = plant.title()

        if info['helps']:
            helps_list = sorted(list(info['helps']))
            helps_str = ", ".join(helps_list)
            instruction = f"List strictly the plants that {plant_formatted} helps grow better. Do not add any others."
            response = f"{plant_formatted} helps: {helps_str}."
            dataset_data.append({
                "instruction": instruction,
                "input": "",
                "output": response,
                "text": alpaca_prompt.format(instruction, "", response) + tokenizer.eos_token
            })

        if info['helped_by']:
            helped_by_list = sorted(list(info['helped_by']))
            helped_by_str = ", ".join(helped_by_list)
            instruction = f"List strictly the best companion plants for {plant_formatted}. Do not add any others."
            response = f"Best companions for {plant_formatted}: {helped_by_str}."
            dataset_data.append({
                "instruction": instruction,
                "input": "",
                "output": response,
                "text": alpaca_prompt.format(instruction, "", response) + tokenizer.eos_token
            })

    return Dataset.from_list(dataset_data)

dataset = create_strict_dataset("companion_plants.csv")
if dataset:
    print(f"Final dataset created with {len(dataset)} examples. Includes EOS tokens and bi-directional logic.")
    print(dataset[0])

Final dataset created with 284 examples. Includes EOS tokens and bi-directional logic.
{'instruction': 'List strictly the plants that Alliums helps grow better. Do not add any others.', 'input': '', 'output': 'Alliums helps: brassicas, capsicum, carrots, fruit trees, nightshades, peppers, potatoes, tomatoes.', 'text': 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nList strictly the plants that Alliums helps grow better. Do not add any others.\n\n### Input:\n\n\n### Response:\nAlliums helps: brassicas, capsicum, carrots, fruit trees, nightshades, peppers, potatoes, tomatoes.<eos>'}


## 4. Train (High Precision)

In [None]:
from trl import SFTTrainer
from transformers import TrainingArguments

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    dataset_num_proc = 2,
    packing = False,
    args = TrainingArguments(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        # MAJOR CHANGE: 15 Epochs to force memorization of small dataset
        num_train_epochs = 15,
        learning_rate = 2e-4,
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
    ),
)

trainer.train()

Map (num_proc=2):   0%|          | 0/284 [00:00<?, ? examples/s]

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 284 | Num Epochs = 15 | Total steps = 540
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 4 x 1) = 8
 "-____-"     Trainable parameters = 19,611,648 of 2,525,784,064 (0.78% trained)
wandb: (1) Create a W&B account
wandb: (2) Use an existing W&B account
wandb: (3) Don't visualize my results
wandb: Enter your choice:

 3


wandb: You chose "Don't visualize my results"
wandb: Using W&B in offline mode.
wandb: W&B API key is configured. Use `wandb login --relogin` to force relogin


wandb: Detected [openai] in use.
wandb: Use W&B Weave for improved LLM call tracing. Install Weave with `pip install weave` then add `import weave` to the top of your script.
wandb: For more information, check out the docs at: https://weave-docs.wandb.ai/


Step,Training Loss
1,5.5306
2,5.6437
3,5.1557
4,5.0644
5,4.4808
6,3.3652
7,3.118
8,2.5038
9,2.1379
10,1.8904


TrainOutput(global_step=540, training_loss=0.31279551086050494, metrics={'train_runtime': 796.8492, 'train_samples_per_second': 5.346, 'train_steps_per_second': 0.678, 'total_flos': 4341605232107520.0, 'train_loss': 0.31279551086050494, 'epoch': 15.0})

## 5. Quantitative Evaluation

In [None]:
import evaluate
import torch
from tqdm import tqdm
import pandas as pd

bleu = evaluate.load("bleu")
rouge = evaluate.load("rouge")

def calculate_perplexity(model, tokenizer, dataset):
    model.eval()
    nlls = []
    subset = dataset.select(range(min(len(dataset), 50)))
    for example in tqdm(subset, desc="Calculating Perplexity"):
        encodings = tokenizer(example["text"], return_tensors="pt")
        input_ids = encodings.input_ids.to(model.device)
        target_ids = input_ids.clone()
        with torch.no_grad():
            outputs = model(input_ids, labels=target_ids)
            nlls.append(outputs.loss)
    return torch.exp(torch.stack(nlls).mean()).item()

def evaluate_model(model, tokenizer, dataset, num_samples=20):
    FastLanguageModel.for_inference(model)
    references = []
    predictions = []
    prompts = []
    samples = dataset.select(range(min(len(dataset), num_samples)))

    for example in tqdm(samples, desc="Generating Predictions"):
        text = example['text']
        parts = text.split("### Response:\n")
        if len(parts) > 1:
            prompt = parts[0] + "### Response:\n"
            ground_truth = parts[1].replace(tokenizer.eos_token, "").strip()
        else:
            continue

        inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
        # REPETITION PENALTY ENABLED
        outputs = model.generate(**inputs, max_new_tokens=128, use_cache=True, repetition_penalty=1.2, no_repeat_ngram_size=3)
        decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

        if "### Response:\n" in decoded:
            pred = decoded.split("### Response:\n")[-1].strip()
        else:
            pred = decoded.strip()

        predictions.append(pred)
        references.append([ground_truth])
        prompts.append(prompt)

    if predictions:
        try:
            bleu_score = bleu.compute(predictions=predictions, references=references)
            rouge_refs = [r[0] for r in references]
            rouge_score = rouge.compute(predictions=predictions, references=rouge_refs)
            ppl = calculate_perplexity(model, tokenizer, dataset)

            print(f"\n=== Evaluation Results ===")
            print(f"BLEU: {bleu_score['bleu']:.4f}")
            print(f"ROUGE-L: {rouge_score['rougeL']:.4f}")
            print(f"Perplexity: {ppl:.4f}")
        except Exception as e:
            print(f"Metrics calculation error: {e}")

        pd.set_option('display.max_colwidth', None)
        df_results = pd.DataFrame({
            'Input': [p.split('### Instruction:\n')[1].split('\n\n### Input:')[0][:50] for p in prompts[:5]],
            'Ground Truth': [r[0] for r in references[:5]],
            'Prediction': predictions[:5]
        })
        display(df_results)

evaluate_model(model, tokenizer, dataset)

Downloading builder script: 0.00B [00:00, ?B/s]

Downloading extra modules:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

Downloading extra modules: 0.00B [00:00, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

Generating Predictions: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 20/20 [00:31<00:00,  1.59s/it]
Calculating Perplexity: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 50/50 [00:08<00:00,  5.73it/s]



=== Evaluation Results ===
BLEU: 0.7825
ROUGE-L: 0.8891
Perplexity: 1.0999


Unnamed: 0,Input,Ground Truth,Prediction
0,List strictly the plants that Alliums helps grow b,"Alliums helps: brassicas, capsicum, carrots, fruit trees, nightshades, peppers, potatoes, tomatoes.","Alliums helps: brassicas, capsicum, carrots, fruit trees, nightshades, peppers, potatoes, tomatoes."
1,List strictly the best companion plants for Allium,"Best companions for Alliums: carrots, marigold, marigolds, mints, pansy, peppermint, spearmint, tarragon, tomatoes.","Best companions for Allium sativum: carrots, marigold, pansy, tomatoes."
2,List strictly the best companion plants for Fruit,"Best companions for Fruit Trees: alliums, borage, delions, lemon balm, marigolds, marjoram, mustard, nasturtium, nasturtiums, parsnip, southernwood, tansy.","Best companions for Fruit Tree: alliums, borage, delions, lemon balm, marigolds, mustard, nasturtium, southernwood, tarragon, wormwood."
3,List strictly the best companion plants for Nights,"Best companions for Nightshades: alliums, basil, carrots, mints, oregano, tarragon.","Best companions for Night Shades: alliums, basil, carrots, mints, oregano, tarragon."
4,List strictly the plants that Tomatoes helps grow,"Tomatoes helps: alliums, asparagus, brassicas, broccoli, cabbage, celery, onion, peppers, roses.","Tomatoes helps: alliums, asparagus, brassicas, broccoli, cabbage, celery, onion, peppers, roses."


## 6. Save Model

In [None]:
model.save_pretrained("lora_model_final")
!zip -r lora_model_final.zip lora_model_final

  adding: lora_model_final/ (stored 0%)
  adding: lora_model_final/adapter_config.json (deflated 58%)
  adding: lora_model_final/adapter_model.safetensors (deflated 8%)
  adding: lora_model_final/README.md (deflated 65%)
