In [2]:
import os, sys, math, time, random
from pathlib import Path
import numpy as np
import pandas as pd
import torch
os.environ["CUDA_VISIBLE_DEVICES"] = "4" 
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, ConcatDataset, Subset, random_split
import torchvision.transforms as T
import torchvision.datasets as datasets
import timm
from tqdm import tqdm

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)
SEED = 42
if DEVICE == "cuda":
    print("GPU name:", torch.cuda.get_device_name(0))
    print("Total GPU mem (GB):", torch.cuda.get_device_properties(0).total_memory / (1024**3))

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if DEVICE == "cuda":
    torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.benchmark = True

print("timm version:", timm.__version__)
print("PyTorch:", torch.__version__)

Device: cuda
GPU name: Tesla V100-SXM2-32GB
Total GPU mem (GB): 31.7325439453125
timm version: 1.0.21
PyTorch: 2.6.0+cu124


In [None]:
# ---------- Config (edit before running) ----------
DATA_DIR = Path("/home/23ucc611/Mini/data/BananaLSD")       # where dataset will live after download/unzip
USE_AUGMENTED = False                # we will use OriginalSetSetSetSet images and augment on-the-fly
MODEL_NAME = "resnetv2_50"           # default choice: ResNetV2-50 (good balance & robust)
IMG_SIZE = 224                       # image size for pretrained networks
BATCH_SIZE = 64                      # conservative default; raise if GPU allows
NUM_WORKERS = 8
PIN_MEMORY = True
SEED = 42

# where to save training artifacts
CHECKPOINT_DIR = Path("./checkpoints_bananalsd")
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

In [5]:
from collections import Counter

for p in sorted(DATA_DIR.iterdir()):
    print(p)

# find image folders (look for 'OriginalSetSetSet' or a top-level folder with class subfolders)
def find_image_root(base):
    # try some heuristics
    if (base/"OriginalSet").exists():
        return base/"OriginalSet"
    # otherwise, if there's a folder with class subfolders (images inside), return it
    for child in base.iterdir():
        if child.is_dir():
            # check if this child contains further directories that look like classes
            subdirs = [d for d in child.iterdir() if d.is_dir()]
            if len(subdirs) >= 2:
                return child
    # fallback: base itself
    return base

img_root = find_image_root(DATA_DIR)
print("Using image root:", img_root)

# count images per class if structure is like: img_root/class_name/*.jpg
class_counts = {}
for cls in sorted([d for d in img_root.iterdir() if d.is_dir()]):
    cnt = len(list(cls.rglob("*.*")))  # images under that folder
    class_counts[cls.name] = cnt
    print(f"Class {cls.name}: {cnt} images")

total_images = sum(class_counts.values())
print("Total images found under image root:", total_images)

/home/23ucc611/Mini/data/BananaLSD/AugmentedSet
/home/23ucc611/Mini/data/BananaLSD/OriginalSet
Using image root: /home/23ucc611/Mini/data/BananaLSD/OriginalSet
Class cordana: 162 images
Class healthy: 129 images
Class pestalotiopsis: 173 images
Class sigatoka: 473 images
Total images found under image root: 937


In [6]:
from torchvision.datasets import ImageFolder

infer_transform=T.Compose([
    T.Resize((IMG_SIZE,IMG_SIZE)),
    T.ToTensor(),
    T.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225))
])

train_transform=T.Compose([
    T.RandomResizedCrop(IMG_SIZE,scale=(0.8,1.0)),
    T.RandomHorizontalFlip(),
    T.RandAugment(num_ops=2, magnitude=9),
    T.ToTensor(),
    T.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225))
])

dataset_full=ImageFolder(root=str(img_root),transform=train_transform)
print("Detected classess : ",dataset_full.classes)
print("Total images (dataset_full) : ",len(dataset_full))


val_ratio=0.10
val_size=int(math.ceil(len(dataset_full)*val_ratio))
train_size=len(dataset_full)-val_size
torch.manual_seed(SEED)
train_ds,val_ds = random_split(dataset_full,[train_size,val_size])
val_ds.dataset=ImageFolder(root=str(img_root),transform=infer_transform)

print("Train Images : ",len(train_ds))
print("Val images : ",len(val_ds))


train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, drop_last=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

Detected classess :  ['cordana', 'healthy', 'pestalotiopsis', 'sigatoka']
Total images (dataset_full) :  937
Train Images :  843
Val images :  94


In [7]:
print(img_root)

/home/23ucc611/Mini/data/BananaLSD/OriginalSet
