# Plant Disease Dataset Preparation for SFT Training

I have previously trained [TinyML models for plant disease detection](https://www.kaggle.com/code/timothylovett/plant-disease-shrunken-efficientnet) so I wanted to approach this hackathon with that in mind as I felt the Gemma 3N model could be a good solution to providing a holistic approach with proper information to help not just identify but to solve any issues arise from the diseases.

The problem I faced, of course, is that the dataset itself is not structured for SFT Training. To properly train the model I'd need a set of text for each image such that the model could then generate similar text during execution. To solve this I decided I'd leverage the model itself to first generate text as if it were seeing that image (this notebook) and then use that for training itself. I opted to adjust the temperature slightly to give more variety with the resulting text and I then proceeded to associate the dataset with these printouts. To further improve the training I also have the logic generate the text in multiple languages (a subset of the 140 supported) to avoid issues where I may cause catastrophic loss of the model's multi language functionality. I'm then randomly selecting the text and language when processing the dataset.

Thank you to Unsloth/Daniel for providing [their notebook](https://www.kaggle.com/code/danielhanchen/gemma-3n-4b-multimodal-finetuning-inference) as part of the competition materials as I was able to get up and running on kaggle very quickly as a result.

Additionally thank you to the [PlantVillage dataset](https://github.com/spMohanty/PlantVillage-Dataset) for providing images used for this training.

## Future

This project utilizes only a subset of diseases but showcases a training approach that persists multi language support and uses the existing model to generate its own captions used for fine tuning. It would be useful for a future project to document additional plant diseases in such a way that a larger corpus of data could be relied on. The more data we can gather the better we can help each other as we navigate through issues such as diseases and their solutions. 

## Citation

@article{Mohanty_Hughes_Salathé_2016,
title={Using deep learning for image-based plant disease detection},
volume={7},
DOI={10.3389/fpls.2016.01419},
journal={Frontiers in Plant Science},
author={Mohanty, Sharada P. and Hughes, David P. and Salathé, Marcel},
year={2016},
month={Sep}} 

In [None]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1" huggingface_hub hf_transfer
    !pip install --no-deps unsloth

In [None]:
%%capture
# Install latest transformers for Gemma 3N
!pip install --no-deps git+https://github.com/huggingface/transformers.git # Only for Gemma 3N
!pip install --no-deps --upgrade timm # Only for Gemma 3N

In [None]:
# !git clone https://github.com/spMohanty/PlantVillage-Dataset.git

In [None]:
from unsloth import FastModel
import torch

fourbit_models = [
    # 4bit dynamic quants for superior accuracy and low memory use
    "unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit",
    "unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit",
    # Pretrained models
    "unsloth/gemma-3n-E4B-unsloth-bnb-4bit",
    "unsloth/gemma-3n-E2B-unsloth-bnb-4bit",

    # Other Gemma 3 quants
    "unsloth/gemma-3-1b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-4b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-12b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-27b-it-unsloth-bnb-4bit",
] # More models at https://huggingface.co/unsloth

model, tokenizer = FastModel.from_pretrained(
    model_name = "unsloth/gemma-3n-E4B-it", # Or "unsloth/gemma-3n-E2B-it"
    dtype = None, # None for auto detection
    max_seq_length = 1024, # Choose any for long context!
    load_in_4bit = True,  # 4 bit quantization to reduce memory
    full_finetuning = False, # [NEW!] We have full finetuning now!
    # token = "hf_...", # use one if using gated models
)

In [None]:
from transformers import TextStreamer
import gc
# Helper function for inference
def do_gemma_3n_inference(model, tokenizer, messages, max_new_tokens=128):
    inputs = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    ).to("cuda")

    # generate returns the full sequence of IDs because we dropped `streamer=…`
    out_ids = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        temperature=1.15,
        top_p=0.95,
        top_k=64,
    )

    # slice off the prompt part and decode
    gen_ids = out_ids[0][inputs["input_ids"].shape[-1]:]
    text = tokenizer.decode(gen_ids, skip_special_tokens=True)

    del inputs
    torch.cuda.empty_cache()
    gc.collect()
    return text


For each plant / disease combo I'm generating 10 text outputs. I then translate the text to other languages and randomly select the text during training.

In [None]:
import csv, re, time
from pathlib import Path
from tqdm import tqdm

# LANGUAGES=["English","Chinese","Hindi","Spanish","Arabic","French","Bengali","Portuguese","Russian","Indonesian","Urdu","German","Japanese","Vietnamese","Turkish","Swahili","Tagalog","Korean","Thai","Italian","Hebrew"]

PLANT_STATE_RAW=["Apple___Apple_scab","Apple___Black_rot","Apple___Cedar_apple_rust","Apple___healthy","Blueberry___healthy","Cherry___Powdery_mildew","Cherry___healthy","Corn_(maize)___Cercospora_leaf_spot","Corn_(maize)___Common_rust_","Corn_(maize)___Northern_Leaf_Blight","Corn_(maize)___healthy","Grape___Black_rot","Grape___Esca_(Black_Measles)","Grape___Leaf_blight_(Isariopsis_Leaf_Spot)","Grape___healthy","Orange___Haunglongbing_(Citrus_greening)","Peach___Bacterial_spot","Peach___healthy","Pepper,_bell___Bacterial_spot","Pepper,_bell___healthy","Potato___Early_blight","Potato___Late_blight","Potato___healthy","Raspberry___healthy","Soybean___healthy","Squash___Powdery_mildew","Strawberry___Leaf_scorch","Strawberry___healthy","Tomato___Bacterial_spot","Tomato___Early_blight","Tomato___Late_blight","Tomato___Leaf_Mold","Tomato___Septoria_leaf_spot","Tomato___Spider_mites Two-spotted_spider_mite","Tomato___Target_Spot","Tomato___Tomato_Yellow_Leaf_Curl_Virus","Tomato___Tomato_mosaic_virus","Tomato___healthy"]

SAMPLES_PER_LANGUAGE=5
CSV_PATH=Path("plant_state_descriptions.csv")
SLEEP_BETWEEN_CALLS=0.05

def _clean(t):
    if t.startswith("Pepper,_bell"):
        return "Bell Pepper"
    t=t.replace("_"," ").replace("  "," ").strip()
    t=re.sub(r"\(([^)]+)\)",lambda m:f"({m.group(1).title()})",t)
    return t.title()

def normalise_pair(r):
    a,b=r.split("___",1)
    return _clean(a),_clean(b)

total=len(PLANT_STATE_RAW)*SAMPLES_PER_LANGUAGE
with CSV_PATH.open("w",newline="",encoding="utf-8") as f:
    w=csv.writer(f)
    w.writerow(["Plant","State","Language","Text"])
    with tqdm(total=total,desc="Rows") as pbar:
        for raw in PLANT_STATE_RAW:
            plant,state=normalise_pair(raw)
            prompt=(f"Plant: {plant}, State: {state} - "
                    "name and describe state with information (without explicitly "
                    "mentioning it's a detailed description)")
            for _ in range(SAMPLES_PER_LANGUAGE):
                resp=do_gemma_3n_inference(model,tokenizer,[{"role":"user","content":[{"type":"text","text":prompt}]}],max_new_tokens=512)
                resp=resp if isinstance(resp,str) else str(resp)
                w.writerow([plant,state,"English",resp])
                pbar.update(1)
                time.sleep(SLEEP_BETWEEN_CALLS)
