In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from tqdm import tqdm

# ==============================
# 1. CONFIG
# ==============================
DATA_ROOT = r"D:\viot\leaf_detector_dataset"
BATCH_SIZE = 16
NUM_EPOCHS = 5       # Enough for small dataset
NUM_CLASSES = 2      # Leaf / Not Leaf
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ==============================
# 2. DATA
# ==============================
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

train_dataset = datasets.ImageFolder(os.path.join(DATA_ROOT, "train"), transform=transform)
val_dataset = datasets.ImageFolder(os.path.join(DATA_ROOT, "val"), transform=transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

# ==============================
# 3. MODEL
# ==============================
model = models.mobilenet_v3_small(pretrained=True)
model.classifier[3] = nn.Linear(model.classifier[3].in_features, NUM_CLASSES)
model = model.to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# ==============================
# 4. TRAIN
# ==============================
for epoch in range(NUM_EPOCHS):
    model.train()
    total_loss = 0
    for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}"):
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Train Loss: {total_loss/len(train_loader):.4f}")

    # Validate
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            outputs = model(imgs)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    acc = correct / total
    print(f"Val Accuracy: {acc*100:.2f}%")

# ==============================
# 5. SAVE MODEL
# ==============================
torch.save(model.state_dict(), "leaf_detector.pth")
print("✅ Leaf vs Not Leaf detector model saved!")


Epoch 1/5: 100%|██████████| 44/44 [00:46<00:00,  1.05s/it]


Train Loss: 0.3213
Val Accuracy: 80.18%


Epoch 2/5: 100%|██████████| 44/44 [00:40<00:00,  1.08it/s]


Train Loss: 0.0863
Val Accuracy: 92.65%


Epoch 3/5: 100%|██████████| 44/44 [00:37<00:00,  1.17it/s]


Train Loss: 0.0395
Val Accuracy: 97.10%


Epoch 4/5: 100%|██████████| 44/44 [00:39<00:00,  1.11it/s]


Train Loss: 0.0315
Val Accuracy: 99.55%


Epoch 5/5: 100%|██████████| 44/44 [00:38<00:00,  1.14it/s]


Train Loss: 0.0256
Val Accuracy: 100.00%
✅ Leaf vs Not Leaf detector model saved!
