<a href="https://colab.research.google.com/github/Chibu4ril/plant-classification-cvision/blob/main/_downloads/070179efc13bd796c5dd4af7bf52d5b9/intro.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import gc

gc.collect()
torch.cuda.empty_cache()

!nvidia-smi

# === SETUP ===
from google.colab import drive
drive.mount('/content/drive')

import os
import zipfile

# === 1. Upload ZIP and Extract ===
zip_path = "/content/drive/MyDrive/final_dataset.zip"  # <-- Change if needed
extract_to = "/content/"

os.makedirs(extract_to, exist_ok=True)

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_to)

print("✅ Dataset extracted!")

# === 2. Install Dependencies ===
!pip install -q transformers datasets torchvision

# === 3. Imports ===
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor

# === 4. Set Device ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🖥️ Using device: {device}")

# === 5. Custom Dataset ===
class PlantDataset(Dataset):
    def __init__(self, image_dir, mask_dir, feature_extractor):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_list = sorted(os.listdir(image_dir))
        self.mask_list = sorted(os.listdir(mask_dir))
        self.feature_extractor = feature_extractor

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_list[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_list[idx])

        image = Image.open(image_path).convert("RGB")
        mask = Image.open(mask_path)

        image = image.resize((256, 256))
        mask = mask.resize((256, 256), resample=Image.NEAREST)
        mask = np.array(mask)
        mask[mask > 0] = 1  # Normalize to 0-background, 1-plant

        encoded = self.feature_extractor(images=image, return_tensors="pt")
        pixel_values = encoded['pixel_values'].squeeze()
        labels = torch.tensor(mask).long()

        return pixel_values, labels

# === 6. Load Data ===
root = "/content/final_dataset"
train_dataset = PlantDataset(f"{root}/train/images", f"{root}/train/masks", SegformerImageProcessor.from_pretrained("nvidia/segformer-b5-finetuned-ade-640-640"))
val_dataset = PlantDataset(f"{root}/val/images", f"{root}/val/masks", SegformerImageProcessor.from_pretrained("nvidia/segformer-b5-finetuned-ade-640-640"))

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1)

# === 7. Load Model ===
model = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/segformer-b5-finetuned-ade-640-640",
    num_labels=2,
    id2label={0: "background", 1: "plant"},
    label2id={"background": 0, "plant": 1},
    ignore_mismatched_sizes=True
).to(device)

# === 8. Training Setup ===
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
EPOCHS = 10

# === 9. Train ===
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    for images, masks in tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{EPOCHS}"):
        images, masks = images.to(device), masks.to(device)

        outputs = model(pixel_values=images, labels=masks)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"✅ Epoch {epoch+1} | Training Loss: {total_loss / len(train_loader):.4f}")

    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device)
            outputs = model(pixel_values=images, labels=masks)
            val_loss += outputs.loss.item()
    print(f"📊 Validation Loss: {val_loss / len(val_loader):.4f}")

# === 10. Save Model ===
os.makedirs("/content/checkpoints/segformer_b5", exist_ok=True)
model.save_pretrained("/content/checkpoints/segformer_b5")
SegformerImageProcessor.from_pretrained("nvidia/segformer-b5-finetuned-ade-640-640").save_pretrained("/content/checkpoints/segformer_b5")

# Optional: Save as .pth for PyTorch-only usage
torch.save(model.state_dict(), "/content/checkpoints/segformer_b5/model.pth")

print("✅ Model training complete! Download from /content/checkpoints/segformer_b5/")
