In [1]:
import timm
import urllib

from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
import torch.optim as optim
from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR


import torchvision
import torchvision.transforms as transforms

from PIL import Image
import os
from einops import rearrange
import matplotlib.pyplot as plt

import random
import shutil
import numpy as np
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torch.cuda.is_available()

True

In [8]:
vill_doc_dir = "data/village_doc"

In [9]:
# === Edit these for your machine ===
DATASET_ROOT = vill_doc_dir   # folder that contains class subfolders or train/test
DATA_ROOT    = r"vill_doc_samp"  # where to create a small sample split
MAKE_SAMPLE  = True        # True => create a small sample split at DATA_ROOT
MAX_PER_CLASS = 60         # cap per class when building the sample
TRAIN_RATIO   = 0.8        # train/test split when sampling
IMG_SIZE      = 224
BATCH_SIZE    = 64
EPOCHS        = 3
NUM_WORKERS   = 4

# ------------------ env check ------------------
import os, sys, random, shutil
from pathlib import Path
import torch

print("Torch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# Optional: pin determinism lightly (keeps cudnn fast)
import numpy as np, random
SEED = 42
random.seed(SEED); np.random.seed(SEED)
torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

DATASET_ROOT = Path(DATASET_ROOT)
DATA_ROOT    = Path(DATA_ROOT)

# quick sanity
if not DATASET_ROOT.exists():
    raise FileNotFoundError(f"DATASET_ROOT not found: {DATASET_ROOT}")


Torch: 2.6.0+cu124
CUDA available: True
Device: cuda


Build (optional) small sample split with train/ & test/

In [10]:
from pathlib import Path

def has_split(root: Path) -> bool:
    return (root / "train").is_dir() and (root / "test").is_dir()

def guess_class_roots(root: Path):
    # if already split, return train dir classes
    if has_split(root):
        return sorted([p for p in (root/"train").iterdir() if p.is_dir()])
    # otherwise, use immediate subfolders as classes
    return sorted([p for p in root.iterdir() if p.is_dir()])

def build_sample_split(src_root: Path, dst_root: Path, max_per_class=60, train_ratio=0.8, exts={".jpg",".jpeg",".png"}):
    if dst_root.exists():
        print(f"[info] removing existing {dst_root} to rebuild sample...")
        shutil.rmtree(dst_root)
    (dst_root / "train").mkdir(parents=True, exist_ok=True)
    (dst_root / "test").mkdir(parents=True, exist_ok=True)

    if has_split(src_root):
        # sample from existing split to keep class distribution
        for split in ["train", "test"]:
            for cls_dir in sorted((src_root/split).iterdir()):
                if not cls_dir.is_dir(): continue
                imgs = [p for p in cls_dir.rglob("*") if p.suffix.lower() in exts]
                random.shuffle(imgs)
                imgs = imgs[:max_per_class]
                out_dir = dst_root / split / cls_dir.name
                out_dir.mkdir(parents=True, exist_ok=True)
                for p in imgs:
                    shutil.copy2(p, out_dir / p.name)
    else:
        # single-level classes in src_root
        for cls_dir in sorted(src_root.iterdir()):
            if not cls_dir.is_dir(): continue
            imgs = [p for p in cls_dir.rglob("*") if p.suffix.lower() in exts]
            if not imgs: continue
            random.shuffle(imgs)
            imgs = imgs[:max_per_class]
            n_train = max(1, int(len(imgs) * train_ratio))
            train_imgs, test_imgs = imgs[:n_train], imgs[n_train:]

            for split, lst in [("train", train_imgs), ("test", test_imgs)]:
                out_dir = dst_root / split / cls_dir.name
                out_dir.mkdir(parents=True, exist_ok=True)
                for p in lst:
                    shutil.copy2(p, out_dir / p.name)

    print("[done] sample split at:", dst_root)




In [11]:
# choose which root we'll actually train on
if MAKE_SAMPLE:
    build_sample_split(DATASET_ROOT, DATA_ROOT, max_per_class=MAX_PER_CLASS, train_ratio=TRAIN_RATIO)
    TRAIN_ROOT = DATA_ROOT
else:
    TRAIN_ROOT = DATASET_ROOT
TRAIN_ROOT = DATASET_ROOT
print("Using data root:", TRAIN_ROOT)
print("Has split (train/test):", has_split(TRAIN_ROOT))

[done] sample split at: vill_doc_samp
Using data root: data/village_doc
Has split (train/test): True


In [13]:
class_names = [
    "Apple___healthy",
    "Apple___Cedar_apple_rust",
    "Apple___Apple_scab",
    "Pepper,_bell___healthy",
    "Pepper,_bell___Bacterial_spot",
    "Blueberry___healthy",
    "Cherry_(including_sour)___healthy",
    "Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot",
    "Corn_(maize)___Northern_Leaf_Blight",
    "Corn_(maize)___Common_rust_",
    "Grape___healthy",
    "Grape___Black_rot",
    "Peach___healthy",
    "Potato___Early_blight",
    "Potato___Late_blight",
    "Raspberry___healthy",
    "Soybean___healthy",
    "Squash___Powdery_mildew",
    "Strawberry___healthy",
    "Tomato___Early_blight",
    "Tomato___healthy",
    "Tomato___Bacterial_spot",
    "Tomato___Late_blight",
    "Tomato___Tomato_mosaic_virus",
    "Tomato___Tomato_Yellow_Leaf_Curl_Virus",
    "Tomato___Leaf_Mold",
    "Tomato___Septoria_leaf_spot"]

In [14]:
import torch.nn as nn
from einops import rearrange

def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        nn.SiLU()
    )

def conv_nxn_bn(inp, oup, kernel_size=3, stride=1):
    return nn.Sequential(
        nn.Conv2d(inp, oup, kernel_size, stride, kernel_size//2, bias=False),
        nn.BatchNorm2d(oup),
        nn.SiLU()
    )

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim); self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim), nn.SiLU(), nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim), nn.Dropout(dropout)
        )
    def forward(self, x): return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32, dropout=0.):
        super().__init__()
        inner = heads * dim_head
        self.scale = dim_head ** -0.5; self.heads = heads
        self.to_qkv = nn.Linear(dim, inner*3, bias=False)
        self.attend = nn.Softmax(dim=-1)
        self.proj   = nn.Sequential(nn.Linear(inner, dim), nn.Dropout(dropout))
    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)
        dots = (q @ k.transpose(-1,-2)) * self.scale
        attn = self.attend(dots)
        out  = attn @ v
        out  = rearrange(out, "b h n d -> b n (h d)")
        return self.proj(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.ModuleList([PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
                           PreNorm(dim, FeedForward(dim, mlp_dim, dropout))])
            for _ in range(depth)
        ])
    def forward(self, x):
        for attn, ff in self.layers:
            x = x + attn(x); x = x + ff(x)
        return x

class MV2Block(nn.Module):
    def __init__(self, inp, oup, stride=1, expansion=4):
        super().__init__()
        hidden = inp * expansion
        self.use_res = (stride==1 and inp==oup)
        self.conv = nn.Sequential(
            nn.Conv2d(inp, hidden, 1, bias=False), nn.BatchNorm2d(hidden), nn.SiLU(),
            nn.Conv2d(hidden, hidden, 3, stride, 1, groups=hidden, bias=False), nn.BatchNorm2d(hidden), nn.SiLU(),
            nn.Conv2d(hidden, oup, 1, bias=False), nn.BatchNorm2d(oup)
        )
    def forward(self, x):
        out = self.conv(x); return x + out if self.use_res else out

class MobileViTBlock(nn.Module):
    def __init__(self, dim, depth, channel, kernel_size, mlp_dim, dropout=0.):
        super().__init__()
        self.conv_local = conv_nxn_bn(channel, channel, kernel_size)
        self.conv_proj  = conv_1x1_bn(channel, dim)
        self.transformer = Transformer(dim, depth, heads=4, dim_head=max(16, dim//4), mlp_dim=mlp_dim, dropout=dropout)
        self.conv_expand = conv_1x1_bn(dim, channel)
        self.conv_fuse   = conv_nxn_bn(channel*2, channel, kernel_size)
    def forward(self, x):
        res = x
        x = self.conv_local(x)
        x = self.conv_proj(x)              # [B, D, H, W]
        B, D, H, W = x.shape
        tokens = rearrange(x, "b d h w -> b (h w) d")
        tokens = self.transformer(tokens)
        x = rearrange(tokens, "b (h w) d -> b d h w", h=H, w=W)
        x = self.conv_expand(x)
        x = torch.cat([x, res], dim=1)
        return self.conv_fuse(x)

class MobileViT(nn.Module):
    def __init__(self, image_size, dims, channels, num_classes,
                 expansion=4, kernel_size=3, dropout=0.0):
        super().__init__()
        self.stem = conv_nxn_bn(3, channels[0], stride=2)
        self.mv2 = nn.ModuleList([
            MV2Block(channels[0], channels[1], 1, expansion),  # 16->16
            MV2Block(channels[1], channels[2], 2, expansion),  # 16->24 (s=2)
            MV2Block(channels[2], channels[3], 1, expansion),  # 24->24
            MV2Block(channels[3], channels[3], 1, expansion),  # 24->24
            MV2Block(channels[3], channels[4], 2, expansion),  # 24->48 (s=2)
            MV2Block(channels[4], channels[5], 1, expansion),  # 48->48
            MV2Block(channels[5], channels[6], 2, expansion),  # 48->64 (s=2)
        ])
        self.mvit = nn.ModuleList([
            MobileViTBlock(dims[0], depth=2, channel=channels[2], kernel_size=kernel_size, mlp_dim=dims[0]*2, dropout=dropout),
            MobileViTBlock(dims[1], depth=4, channel=channels[4], kernel_size=kernel_size, mlp_dim=dims[1]*4, dropout=dropout),
            MobileViTBlock(dims[2], depth=3, channel=channels[6], kernel_size=kernel_size, mlp_dim=dims[2]*4, dropout=dropout),
        ])
        self.fuse = conv_1x1_bn(channels[6], channels[7])
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc   = nn.Linear(channels[7], num_classes)
        self._init_weights()
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None: nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None: nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight); nn.init.zeros_(m.bias)
    def forward(self, x):
        x = self.stem(x)
        x = self.mv2[0](x)
        x = self.mv2[1](x); x = self.mv2[2](x); x = self.mvit[0](x)  # 24-ch
        x = self.mv2[3](x); x = self.mv2[4](x); x = self.mv2[5](x); x = self.mvit[1](x)  # 48-ch
        x = self.mv2[6](x); x = self.mvit[2](x)  # 64-ch
        x = self.fuse(x)
        x = self.pool(x).flatten(1)
        return self.fc(x)

class MobileViTClassifier(nn.Module):
    def __init__(self, image_size=(224,224), num_classes=38, expansion=4, dropout=0.0):
        super().__init__()
        dims = [64, 80, 96]
        channels = [16, 16, 24, 24, 48, 48, 64, 320]
        self.model = MobileViT(image_size, dims, channels, num_classes, expansion=expansion, dropout=dropout)
    def forward(self, x): return self.model(x)

# sanity
model = MobileViTClassifier(image_size=(IMG_SIZE, IMG_SIZE), num_classes=len(class_names)).to(device)
print("Params (M):", sum(p.numel() for p in model.parameters())/1e6)


Params (M): 1.026123


Cell C — Tiny MobileViT with patch tokens (8×8) + checkpointing option

In [15]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from pathlib import Path

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)

def build_transforms(img_size=160, aug_level="light"):
    norm = transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
    if aug_level == "light":
        train_tf = transforms.Compose([
            transforms.Resize(int(img_size*1.1)),
            transforms.RandomResizedCrop(img_size, scale=(0.9,1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(), norm,
        ])
    else:
        train_tf = transforms.Compose([
            transforms.Resize(int(img_size*1.15)),
            transforms.RandomResizedCrop(img_size, scale=(0.8,1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(), norm,
        ])
    val_tf = transforms.Compose([
        transforms.Resize(int(img_size*1.1)),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(), norm,
    ])
    return train_tf, val_tf

def build_loaders(root: Path, img_size=160, batch_size=8, workers=0):
    train_tf, val_tf = build_transforms(img_size, "light")
    train_ds = datasets.ImageFolder(root/"train", transform=train_tf)
    val_ds   = datasets.ImageFolder(root/"test",  transform=val_tf)
    classes  = train_ds.classes
    pin_mem  = torch.cuda.is_available()
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                              num_workers=workers, pin_memory=pin_mem)
    val_loader   = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
                              num_workers=workers, pin_memory=pin_mem)
    return train_loader, val_loader, classes



In [16]:
import os, torch
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
torch.cuda.empty_cache()
torch.set_float32_matmul_precision("medium")

# ↓ shrink these to fit your GPU
IMG_SIZE    = 160      # 160 < 192 < 224
BATCH_SIZE  = 8        # try 4 or 2 if still OOM
NUM_WORKERS = 0        # avoid RAM spikes from dataloader workers
EPOCHS      = 3

train_loader, val_loader, class_names = build_loaders(TRAIN_ROOT, IMG_SIZE, BATCH_SIZE, NUM_WORKERS)
len(class_names), class_names[:5]


(38,
 ['Apple___Apple_scab',
  'Apple___Black_rot',
  'Apple___Cedar_apple_rust',
  'Apple___healthy',
  'Blueberry___healthy'])

In [19]:
import time, math
from tqdm import tqdm

class Trainer:
    def __init__(self, model, optimizer, criterion, scheduler, device, out_dir="./outputs"):
        from pathlib import Path
        self.model, self.opt, self.crit, self.sched = model, optimizer, criterion, scheduler
        self.device = device
        self.out_dir = Path(out_dir); (self.out_dir/"checkpoints").mkdir(parents=True, exist_ok=True)
        self.scaler = torch.amp.GradScaler('cuda', enabled=(device.type == "cuda"))

    def _epoch(self, loader, train=True):
        self.model.train(mode=train)
        tot_loss, correct, total = 0.0, 0, 0
        it = tqdm(loader, leave=False)
        for x, y in it:
            # channels_last on inputs too
            x = x.to(self.device, non_blocking=True).to(memory_format=torch.channels_last)
            y = y.to(self.device, non_blocking=True)

            with torch.amp.autocast('cuda', enabled=(self.device.type=="cuda")):
                logits = self.model(x)
                loss = self.crit(logits, y)

            if train:
                self.opt.zero_grad(set_to_none=True)
                self.scaler.scale(loss).backward()
                self.scaler.step(self.opt)
                self.scaler.update()

            bs = x.size(0)
            tot_loss += loss.item() * bs
            correct  += (logits.argmax(1) == y).sum().item()
            total    += bs
            it.set_postfix(loss=f"{loss.item():.3f}", acc=f"{correct/max(1,total):.3f}")
        return tot_loss/max(1,total), correct/max(1,total)

    def fit(self, train_loader, val_loader, epochs=3):
        best = math.inf
        for e in range(1, epochs+1):
            t0 = time.time()
            tr_loss, tr_acc = self._epoch(train_loader, True)
            with torch.no_grad():
                va_loss, va_acc = self._epoch(val_loader, False)
            if self.sched is not None:
                try: self.sched.step(va_loss)
                except: self.sched.step()
            if va_loss < best:
                best = va_loss
                torch.save({"model": self.model.state_dict()}, self.out_dir/"checkpoints"/"best.pth")
            print(f"Epoch {e:02d} | train {tr_loss:.4f}/{tr_acc:.3f} | val {va_loss:.4f}/{va_acc:.3f} | {time.time()-t0:.1f}s")



In [20]:

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

trainer = Trainer(model, optimizer, criterion, scheduler, device, out_dir="./outputs")
trainer.fit(train_loader, val_loader, epochs=EPOCHS)

  return fn(*args, **kwargs)
                                                                          

KeyboardInterrupt: 

In [None]:
from sklearn.metrics import classification_report

# load best
ckpt = torch.load("./outputs/checkpoints/best.pth", map_location=device)
model.load_state_dict(ckpt.get("model", ckpt))
model.eval()

y_true, y_pred = [], []
with torch.no_grad():
    for x, y in val_loader:
        x = x.to(device)
        logits = model(x)
        y_pred.extend(logits.argmax(1).cpu().tolist())
        y_true.extend(y.tolist())

acc = (torch.tensor(y_true) == torch.tensor(y_pred)).float().mean().item()
print("Val Accuracy:", f"{acc:.4f}")
print(classification_report(y_true, y_pred, target_names=class_names, digits=4))
