In [15]:
import albumentations as A

In [16]:
!pip install albumentations



In [17]:
import os
import shutil
import time
import pprint

import torch
import torch.nn as nn
import torch.autograd.variable as Variable

from glob import glob
from math import sqrt
from numpy.random import seed
from numpy.random import randn
from numpy import mean
from scipy.stats import sem
from scipy.stats import t
import numpy as np
from collections import OrderedDict
from tqdm import tqdm
import torch.optim as optim

class GaussianNoise(nn.Module):
    def __init__(self, batch_size, input_shape=(3, 84, 84), std=0.05):
        super(GaussianNoise, self).__init__()
        self.shape = (batch_size,) + input_shape
        self.noise = Variable(torch.zeros(self.shape).cuda())
        self.std = std

    def forward(self, x, std=0.15):
        noise = Variable(torch.zeros(x.shape).cuda())
        noise = noise.data.normal_(0, std=std)
        return x + noise


def set_gpu(x):
    os.environ['CUDA_VISIBLE_DEVICES'] = x
    print('using gpu:', x)


def clone(tensor):
    """Detach and clone a tensor including the ``requires_grad`` attribute.

    Arguments:
        tensor (torch.Tensor): tensor to clone.
    """
    cloned = tensor.clone()#tensor.detach().clone()
    # cloned.requires_grad = tensor.requires_grad
    # if tensor.grad is not None:
    #     cloned.grad = clone(tensor.grad)
    return cloned

def clone_state_dict(state_dict):
    """Clone a state_dict. If state_dict is from a ``torch.nn.Module``, use ``keep_vars=True``.

    Arguments:
        state_dict (OrderedDict): the state_dict to clone. Assumes state_dict is not detached from model state.
    """
    return OrderedDict([(name, clone(param)) for name, param in state_dict.items()])

def ensure_path(path):
    if os.path.exists(path):
        if input('{} exists, remove? ([y]/n)'.format(path)) != 'n':
            shutil.rmtree(path)
            os.mkdir(path)
    else:
        os.mkdir(path)

class Averager():
    def __init__(self):
        self.n = 0
        self.v = 0

    def add(self, x):
        self.v = (self.v * self.n + x) / (self.n + 1)
        self.n += 1

    def item(self):
        return self.v

def count_acc(logits, label):
    pred = torch.argmax(logits, dim=1)
    return (pred == label).type(torch.cuda.FloatTensor).mean().item()

def dot_metric(a, b):
    return torch.mm(a, b.t())

def euclidean_metric(a, b):
    n = a.shape[0]
    m = b.shape[0]
    a = a.unsqueeze(1).expand(n, m, -1)
    b = b.unsqueeze(0).expand(n, m, -1)
    logits = -((a - b)**2).sum(dim=2)
    return logits

class Timer():

    def __init__(self):
        self.o = time.time()

    def measure(self, p=1):
        x = (time.time() - self.o) / p
        x = int(x)
        if x >= 3600:
            return '{:.1f}h'.format(x / 3600)
        if x >= 60:
            return '{}m'.format(round(x / 60))
        return '{}s'.format(x)

_utils_pp = pprint.PrettyPrinter()
def pprint(x):
    _utils_pp.pprint(x)

def l2_loss(pred, label):
    return ((pred - label)**2).sum() / len(pred) / 2

def set_protocol(data_path, protocol, test_protocol, subset=None):
    train = []
    val = []
    all_set = ['shn', 'hon', 'clv', 'clk', 'gls', 'scl', 'sci', 'nat', 'shx', 'rel']
    if subset is not None:
        train.append(data_path + '/crops_' + subset + '/')
        val.append(data_path + '/crops_' + subset + '/')
    if protocol == 'p1':
        for i in range(3):
            train.append(data_path + '/crops_' + all_set[i])
    elif protocol == 'p2':
        for i in range(3, 6):
            train.append(data_path + '/crops_' + all_set[i])
    elif protocol == 'p3':
        for i in range(6, 8):
            train.append(data_path + '/crops_' + all_set[i])
    elif protocol == 'p4':
        for i in range(8, 10):
            train.append(data_path + '/crops_' + all_set[i])

    if test_protocol == 'p1':
        for i in range(3):
            val.append(data_path + '/crops_' + all_set[i])
    elif test_protocol == 'p2':
        for i in range(3, 6):
            val.append(data_path + '/crops_' + all_set[i])
    elif test_protocol == 'p3':
        for i in range(6, 8):
            val.append(data_path + '/crops_' + all_set[i])
    elif test_protocol == 'p4':
        for i in range(8, 10):
            val.append(data_path + '/crops_' + all_set[i])
    return train, val

def flip(x, dim):
    dim = x.dim() + dim if dim < 0 else dim
    return x[tuple(slice(None, None) if i != dim
             else torch.arange(x.size(i)-1, -1, -1).long()
             for i in range(x.dim()))]

def perturb(data):
    randno = np.random.randint(0, 5)
    if randno == 1:
        return torch.cat((data, data.flip(3)), dim=0)
    elif randno == 2: #180
        return torch.cat((data, data.flip(2)), dim=0)
    elif randno == 3: #90
        return torch.cat((data, data.transpose(2,3)), dim=0)
    else:
        return torch.cat((data, data.transpose(2, 3).flip(3)), dim=0)

# Data Manager

In [18]:
cd '/kaggle/working/'

/kaggle/working


In [19]:
# --- Step 1: Gather All Data ---
print("Scanning all image paths...")
all_paths_raw = glob("/kaggle/input/plantvillage-dataset/color/*/*.JPG")

# Filter out problematic Pepper classes
all_paths = []
for path in all_paths_raw:
    if 'Pepper,_bell___healthy' in path or 'Pepper,_bell___Bacterial_spot' in path:
        continue  # Skip these classes
    label = path.split('/')[-2]  # Extract class name from path
    all_paths.append([path, label])

print(f"Found {len(all_paths)} total images after filtering out Pepper classes.")

Scanning all image paths...
Found 50330 total images after filtering out Pepper classes.


In [20]:
import os.path as osp
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2


leaf_train_augmentation = A.Compose([
    # spatial
    A.RandomResizedCrop(size=(84, 84), scale=(0.7, 1.0), ratio=(0.75, 1.33), p=1.0),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.3),
    A.Rotate(limit=180, p=0.7, border_mode=0),
    A.Perspective(scale=(0.05, 0.1), p=0.4),

    # color / lighting
    A.OneOf([
        # controlled color jitter
        A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.02, p=1.0),
        # HSV shifts: explicit shift limits (degrees for hue, percent for saturation/value)
        A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=15, val_shift_limit=10, p=1.0),
        # random brightness/contrast
        A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=1.0),
    ], p=0.9),

    # Apply either gamma, CLAHE or tone curve (kept), but not all every time
    A.OneOf([
        A.RandomGamma(gamma_limit=(80,120), p=1.0),
        A.CLAHE(clip_limit=2.0, tile_grid_size=(8,8), p=1.0),
        A.RandomToneCurve(p=1.0),
    ], p=0.6),

    # final normalization + tensor conversion
    A.Normalize(mean=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

leaf_smart_augmentation = A.Compose([
    # Conservative spatial transforms
    A.Resize(92, 92),  # Slightly larger then crop
    A.RandomCrop(84, 84, p=1.0),
    A.HorizontalFlip(p=0.5),
    A.Rotate(limit=15, p=0.3),  # Small rotations
    
    # Color variations that mimic real-world conditions
    A.OneOf([
        # Lighting changes
        A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=1.0),
        # Color temperature variations
        A.RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=1.0),
    ], p=0.6),
    
    # Mild noise for sensor variations
    A.GaussNoise(var_limit=(5.0, 15.0), p=0.1),
    
    # Focus on preserving texture details
    A.Sharpen(alpha=(0.05, 0.1), lightness=(0.8, 1.0), p=0.1),  # Very mild sharpening
    
    # Normalize
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

leaf_val_augmentation = A.Compose([
    A.Resize(96, 96),  # Slightly larger
    A.CenterCrop(84, 84),  # Clean center crop
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

leaf_finetune_augmentation = A.Compose([
    A.RandomResizedCrop(size=(84, 84), scale=(0.85, 1.0), ratio=(0.9, 1.1), p=1.0),
    A.HorizontalFlip(p=0.5),
    A.Rotate(limit=45, p=0.5, border_mode=0),
    A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.7),
    A.GaussianBlur(blur_limit=(3, 5), p=0.15),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])


# ========================================
# UPDATED DATASET CLASS
# ========================================
class UiSmell(Dataset):
    def __init__(self, setname, img_path, is_aug=False):
        csv_path = osp.join('/kaggle/working/materials/', setname + '.csv')
        lines = [x.strip() for x in open(csv_path, 'r').readlines()][1:]
        self.is_aug = is_aug
        self.img_path = img_path

        data, label = [], []
        label_map, label_counter = {}, 0

        for line in lines:
            name, lbl = line.split(',', 1)
            lbl_clean = lbl.strip()
            if lbl_clean not in label_map:
                label_map[lbl_clean] = label_counter
                label_counter += 1
            path = osp.join(img_path, name)
            data.append(path)
            label.append(label_map[lbl_clean])

        self.data = data
        self.label = label
        self.label_map = label_map

        if is_aug:
            self.transform = leaf_train_augmentation
            print(f"✅ Using STRONG leaf augmentation for {setname}")
        else:
            self.transform = leaf_val_augmentation
            print(f"✅ Using VALIDATION augmentation for {setname}")

        print(f"Loaded {len(self.data)} samples with {len(self.label_map)} classes for {setname}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        path = self.data[index]
        label = self.label[index]

        # PIL -> NumPy (H,W,C)
        image = np.array(Image.open(path).convert('RGB'))

        # Apply albumentations
        augmented = self.transform(image=image)
        image = augmented['image']  # torch.Tensor from ToTensorV2

        # Ensure float32 (should already be)
        if not torch.is_floating_point(image):
            image = image.float()

        return image, label

  A.GaussNoise(var_limit=(5.0, 15.0), p=0.1),


In [21]:
# import pandas as pd
# df = pd.read_csv("/kaggle/working/materials/train.csv")
# df.label.value_counts()

# Sampler

In [22]:
import torch
import numpy as np


class CategoriesSampler():

    def __init__(self, label, n_batch, n_cls, n_per):
        self.n_batch = n_batch
        self.n_cls = n_cls
        self.n_per = n_per

        label = np.array(label)
        self.m_ind = []
        for i in range(max(label) + 1):
            ind = np.argwhere(label == i).reshape(-1)
            ind = torch.from_numpy(ind)
            if len(ind) > 4:
                self.m_ind.append(ind)

    def __len__(self):
        return self.n_batch
    
    def __iter__(self):
        for i_batch in range(self.n_batch):
            batch = []
            classes = torch.randperm(len(self.m_ind))[:self.n_cls]
            for c in classes:
                l = self.m_ind[c]
                pos = torch.randperm(len(l))[:self.n_per]
                batch.append(l[pos])
            batch = torch.stack(batch).t().reshape(-1)
            #for i in range(1000):
            yield batch

# Convnet

In [23]:
!pip install torchsummary



In [24]:
import torch
import numpy as np


class CategoriesSampler():

    def __init__(self, label, n_batch, n_cls, n_per):
        self.n_batch = n_batch
        self.n_cls = n_cls
        self.n_per = n_per

        label = np.array(label)
        self.m_ind = []
        for i in range(max(label) + 1):
            ind = np.argwhere(label == i).reshape(-1)
            ind = torch.from_numpy(ind)
            if len(ind) > 4:
                self.m_ind.append(ind)

    def __len__(self):
        return self.n_batch
    
    def __iter__(self):
        for i_batch in range(self.n_batch):
            batch = []
            classes = torch.randperm(len(self.m_ind))[:self.n_cls]
            for c in classes:
                l = self.m_ind[c]
                pos = torch.randperm(len(l))[:self.n_per]
                batch.append(l[pos])
            batch = torch.stack(batch).t().reshape(-1)
            #for i in range(1000):
            yield batch

In [25]:
class Bottleneck(nn.Module):
    def __init__(self, in_ch, mid_ch, stride=1, alpha=0.1):
        super().__init__()
        out_ch = mid_ch * 4
        self.conv1 = nn.Conv2d(in_ch, mid_ch, kernel_size=1, bias=False)
        self.bn1   = nn.BatchNorm2d(mid_ch)
        self.conv2 = nn.Conv2d(mid_ch, mid_ch, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(mid_ch)
        self.conv3 = nn.Conv2d(mid_ch, out_ch, kernel_size=1, bias=False)
        self.bn3   = nn.BatchNorm2d(out_ch)
        self.act   = nn.LeakyReLU(alpha, inplace=True)
        self.short = nn.Sequential()
        if stride != 1 or in_ch != out_ch:
            self.short = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_ch)
            )
    def forward(self, x):
        identity = self.short(x) if len(self.short) else x
        out = self.act(self.bn1(self.conv1(x)))
        out = self.act(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out = self.act(out + identity)
        return out

class ResNet50Plus(nn.Module):
    def __init__(self, num_classes, alpha=0.1, dropout_p=0.5):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(alpha, inplace=True),
            nn.MaxPool2d(3, stride=2, padding=1)
        )
        self.layer1 = self._make_layer(64,   64, blocks=3, stride=1, alpha=alpha)
        self.layer2 = self._make_layer(256, 128, blocks=4, stride=2, alpha=alpha)
        self.layer3 = self._make_layer(512, 256, blocks=6, stride=2, alpha=alpha)
        self.layer4 = self._make_layer(1024,512, blocks=3, stride=2, alpha=alpha)
        self.avg    = nn.AdaptiveAvgPool2d((1,1))
        self.head   = nn.Sequential(nn.Flatten(), nn.Dropout(dropout_p), nn.Linear(2048, num_classes))

    def _make_layer(self, in_ch, mid_ch, blocks, stride, alpha):
        layers = [Bottleneck(in_ch, mid_ch, stride=stride, alpha=alpha)]
        for _ in range(1, blocks):
            layers.append(Bottleneck(mid_ch*4, mid_ch, stride=1, alpha=alpha))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.stem(x)
        x = self.layer1(x); x = self.layer2(x); x = self.layer3(x); x = self.layer4(x)
        x = self.avg(x); x = self.head(x)
        return x

# -------------------
# MixUp helpers
# -------------------
def mixup_batch(x, y, alpha=0.2):
    if alpha is None or alpha <= 0:
        return x, (y, y, 1.0)
    lam = np.random.beta(alpha, alpha)
    idx = torch.randperm(x.size(0), device=x.device)
    x_mix = lam * x + (1 - lam) * x[idx]
    return x_mix, (y, y[idx], lam)

def mixup_criterion(ce_loss, preds, targets):
    y_a, y_b, lam = targets
    return lam * ce_loss(preds, y_a) + (1 - lam) * ce_loss(preds, y_b)

# Classifier

In [26]:
import torch
import torch.nn as nn

class Subspace_Projection(nn.Module):
    def __init__(self, num_dim=2, debug=False, eps=1e-8):
        super().__init__()
        self.num_dim = num_dim
        self.debug = debug
        self.eps = eps  # Better numerical stability

    def create_subspace(self, supportset_features, class_size, sample_size):
        # Add validation
        if sample_size < self.num_dim + 1:
            self.num_dim = sample_size - 1
            print(f"Warning: Reduced subspace dim to {self.num_dim}")
        all_hyper_planes = []
        means = []
        for ii in range(class_size):
            all_support = supportset_features[ii]
            mean_vec = torch.mean(all_support, dim=0)
            means.append(mean_vec)
            centered = all_support - mean_vec.unsqueeze(0)
            uu, s, v = torch.svd(centered.transpose(0, 1).double(), some=False)
            uu = uu.float()
            all_hyper_planes.append(uu[:, :self.num_dim])  # limit dimension!

        all_hyper_planes = torch.stack(all_hyper_planes, dim=0)
        means = torch.stack(means, dim=0)
        return all_hyper_planes, means

    def projection_metric(self, target_features, hyperplanes, mu):
        eps = 1e-12
        device = target_features.device
        batch_size = target_features.shape[0]
        class_size = hyperplanes.shape[0]

        similarities = []
        discriminative_loss = torch.tensor(0.0, device=device)

        for j in range(class_size):
            h_plane_j = hyperplanes[j].unsqueeze(0).repeat(batch_size, 1, 1).to(device)
            tf_centered = (target_features - mu[j].expand_as(target_features)).unsqueeze(-1)
            proj = torch.bmm(h_plane_j, torch.bmm(h_plane_j.transpose(1, 2), tf_centered))
            proj = torch.squeeze(proj, -1) + mu[j].unsqueeze(0).repeat(batch_size, 1)

            diff = target_features - proj
            query_loss = -torch.sqrt(torch.sum(diff * diff, dim=-1) + eps)
            similarities.append(query_loss)

            # discriminative term (reduced)
            for k in range(class_size):
                if j != k:
                    temp = torch.mm(hyperplanes[j].T, hyperplanes[k])
                    discriminative_loss += torch.sum(temp * temp)

        similarities = torch.stack(similarities, dim=1).to(device)
        class_size = hyperplanes.shape[0]
        discriminative_loss = discriminative_loss / (class_size * (class_size - 1) + 1e-6)
        similarities = similarities / similarities.std(dim=1, keepdim=True).clamp_min(1e-6)

        # ---- DEBUG ----
        if self.debug:
            print("[DEBUG] projection_metric DEBUG:")
            print(f"  target_features: mean={target_features.mean():.6f}, std={target_features.std():.6f}")
            print(f"  hyperplanes: mean={hyperplanes.mean():.6f}, std={hyperplanes.std():.6f}")
            print(f"  similarities: mean={similarities.mean():.6f}, std={similarities.std():.6f}, min={similarities.min():.6f}, max={similarities.max():.6f}")
            print(f"  discriminative_loss: {discriminative_loss.item():.6f}")

        return similarities, discriminative_loss

# Training

In [None]:
import os
import torch
import argparse
import numpy as np
import os.path as osp
from datetime import datetime
import torch.nn.functional as F
from torch.utils.data import DataLoader

# SET MODEL NAME
modelname = 'ResNet50Plus'

args = {}
args['num_sampler'] = 500
args['max-epoch'] = 25
args['save-epoch'] = 5
args['shot'] = 5
args['query'] = 5
args['train-way'] = 5
args['test-way'] = 5
args['data-path'] = ''
args['gpu'] = '0'
args['lamb'] = 0.5
args['lr'] = 1e-4
args['weight_decay'] = 1e-4
args['subspace-dim'] = args['shot']-1
set_gpu(args['gpu'])

txt = str(datetime.now())
txt = '_'.join([modelname+f"-{args['shot']}-{args['lamb']}-{args['lr']}-{args['num_sampler']}", txt[:4], txt[5:7], txt[8:10], txt[11:13], txt[14:16]])

args['save-path'] = '/kaggle/working/save/'+txt

# MODEL BUILDER
# Fixed: Added num_classes parameter (512 for feature extraction in few-shot learning)
model = {
    'ResNet50Plus': ResNet50Plus(num_classes=512),
}[modelname].cuda()

In [None]:
import os
import csv
from glob import glob
import random
from collections import Counter
import re

In [None]:
# --- Step 1: Gather All Data ---
print("Scanning all image paths...")
all_paths_raw = glob("/kaggle/input/plantvillage-dataset/color/*/*.JPG")

print(f"Found {len(all_paths)} total images and corrected labels in memory.")

# --- Step 2: Identify All Unique Classes ---
all_labels = [label for path, label in all_paths]
unique_classes = sorted(list(set(all_labels)))
random.seed(42) # for reproducibility
random.shuffle(unique_classes)
num_classes = len(unique_classes)
print(f"Found {num_classes} unique classes.")

# --- Step 3: Split the *Classes* into Disjoint Sets ---
train_split = int(0.7 * num_classes)
val_split = int(0.15 * num_classes)

train_classes = unique_classes[:train_split]
val_classes = unique_classes[train_split : train_split + val_split]
test_classes = unique_classes[train_split + val_split:]

print(f"Splitting classes into: {len(train_classes)} train, {len(val_classes)} validation, {len(test_classes)} test.")

# Sanity check: ensure the class sets are disjoint (no overlap)
assert len(set(train_classes) & set(val_classes)) == 0
assert len(set(train_classes) & set(test_classes)) == 0
assert len(set(val_classes) & set(test_classes)) == 0
print("Class splits are successfully disjoint.")


# --- Step 4: Create Final Data Lists ---
paths_train, paths_val, paths_test = [], [], []

for path, label in all_paths:
    if label in train_classes:
        paths_train.append([path, label])
    elif label in val_classes:
        paths_val.append([path, label])
    else:
        paths_test.append([path, label])

print(f"Final data splits (number of images): Train={len(paths_train)}, Val={len(paths_val)}, Test={len(paths_test)}")


# --- Step 5: Write the new CSV files ---
data = {'train': paths_train,
        'test': paths_test,
        'val': paths_val,
}

os.makedirs("/kaggle/working/materials", exist_ok=True)
csv_files = ["train.csv", "test.csv", "val.csv"]

for fname in csv_files:
    path = os.path.join("/kaggle/working/materials", fname)
    with open(path, mode="w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow(["filename", "label"])
        writer.writerows(data[fname[:-4]])

print("\nCSV files with diverse and corrected data splits created successfully inside 'materials' folder.")

In [None]:
# First, create the datasets
trainset = UiSmell('train', '/kaggle/input/plantvillage-dataset/color/', is_aug=True)
valset = UiSmell('val', '/kaggle/input/plantvillage-dataset/color/', is_aug=False)
testset = UiSmell('test', '/kaggle/input/plantvillage-dataset/color/', is_aug=False)

# Now you can print the classes
print("=== All Classes and Their Labels ===")
for class_name, label_id in trainset.label_map.items():
    print(f"Label {label_id}: {class_name}")

print("\n=== Classes in Order ===")
sorted_classes = sorted(trainset.label_map.items(), key=lambda x: x[1])
for class_name, label_id in sorted_classes:
    print(f"{label_id}: {class_name}")

In [None]:
# DATA LOADER
trainset = UiSmell('train', args['data-path'], is_aug=True)
train_sampler = CategoriesSampler(trainset.label, args['num_sampler'],
                                  args['train-way'], args['shot'] + args['query'])
train_loader = DataLoader(dataset=trainset, batch_sampler=train_sampler, shuffle=False,)

valset = UiSmell('val', args['data-path'], is_aug=False)
val_sampler = CategoriesSampler(valset.label, args['num_sampler'],
                                args['test-way'], args['shot'] + args['query'])
val_loader = DataLoader(dataset=valset, batch_sampler=val_sampler, shuffle=False,)

testset = UiSmell('test', args['data-path'], is_aug=False)
test_sampler = CategoriesSampler(testset.label, args['num_sampler'],
                                args['test-way'], args['shot'] + args['query'])
test_loader = DataLoader(dataset=testset, batch_sampler=test_sampler, shuffle=False,)

In [None]:
label_map = trainset.label_map
sorted_items = sorted(label_map.items(), key=lambda item: item[1])

# Extract the class names in the now correct order
class_names_ordered = [item[0] for item in sorted_items]

# Define the output path for the class names file
output_path = '/kaggle/working/materials/classes.txt'

# Write the ordered class names to the file, one per line
with open(output_path, 'w') as f:
    for name in class_names_ordered:
        f.write(f"{name}\n")

print(f"Successfully created classes.txt at: {output_path}")
print("\\n--- Class Names (in order) ---")
for i, name in enumerate(class_names_ordered):
    print(f"{i}: {name}")

In [None]:
# run in your notebook
print("Using transform (is_aug):", trainset.is_aug)
sample_path = trainset.data[0]
import numpy as np
from PIL import Image
img = np.array(Image.open(sample_path).convert('RGB'))
out = trainset.transform(image=img)
print("transform output keys:", out.keys())
print("image type from transform:", type(out['image']), "shape:", getattr(out['image'], 'shape', None))

In [None]:
import matplotlib.pyplot as plt
t = out['image']  # torch.Tensor C,H,W
img_vis = t.permute(1,2,0).cpu().numpy()
mean = np.array([0.485,0.456,0.406])
std  = np.array([0.229,0.224,0.225])
img_vis = np.clip(img_vis * std + mean, 0, 1)
plt.imshow(img_vis); plt.axis('off'); plt.title('Augmented image (what model sees)')

In [None]:
print("albumentations version:", A.__version__)

In [None]:
import os, os.path as osp
import json
import torch
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm

# Ensure save directory exists
if not os.path.exists(args['save-path']):
    os.makedirs(args['save-path'])

def save_model(name):
    if not os.path.exists(args['save-path']):
        os.makedirs(args['save-path'])
    torch.save(model.state_dict(), osp.join(args['save-path'], name + '.pth'))

# Optimizer / Scheduler
optimizer = optim.Adam(model.parameters(), lr=args['lr'])
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

# Projection module
projection_pro = Subspace_Projection(num_dim=args['subspace-dim'], debug=False)

# Training log
trlog = {
    'train_loss': [], 'val_loss': [], 'test_loss': [],
    'train_acc': [], 'val_acc': [], 'test_acc': [],
    'max_acc': 0.0, 'max_epoch': 0
}

timer = Timer()
patience = 3
epochs_no_improve = 0

# MAIN TRAIN LOOP
for epoch in range(1, args['max-epoch'] + 1):
    model.train()
    shot_num = args['shot'] * 2 if args['shot'] == 1 else args['shot']

    tl = Averager()
    ta = Averager()

    for i, batch in tqdm(enumerate(train_loader, 1)):
        data, _ = [_.cuda() for _ in batch]

        if i == 1 and epoch == 1:
            print(f"[DEBUG] Epoch {epoch}, Iter {i}")
            print(f"data shape: {data.shape}")
            print(f"train-way={args['train-way']}, shot={args['shot']}, query={args['query']}")

        p = args['shot'] * args['train-way']
        qq = p + args['query'] * args['train-way']
        data_shot, data_query = data[:p], data[p:qq]

        if args['shot'] == 1:
            data_shot = torch.cat((data_shot, flip(data_shot, 3)), dim=0)

        # Forward through model
        proto = model(data_shot)
        proto = proto.reshape(shot_num, args['train-way'], -1)
        proto = torch.transpose(proto, 0, 1)
        hyperplanes, mu = projection_pro.create_subspace(proto, args['train-way'], shot_num)

        # Labels
        label = torch.arange(args['train-way']).repeat(args['query'])
        label = label.type(torch.cuda.LongTensor)

        # Metric projection
        query_features = model(data_query)
        logits, discriminative_loss = projection_pro.projection_metric(query_features, hyperplanes, mu=mu)

        ce_loss = F.cross_entropy(logits, label)
        loss = ce_loss + args['lamb'] * discriminative_loss
        acc = count_acc(logits, label)

        tl.add(loss.item())
        ta.add(acc)

        # if i % 50 == 0:
        #     print(f"[DEBUG] Epoch {epoch} Iter {i} | loss={loss.item():.4f}, acc={acc:.4f}")

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    lr_scheduler.step()
    print(f'epoch {epoch}, loss={tl.item():.4f} acc={ta.item():.4f}')

    tl = tl.item()
    ta = ta.item()

    # Skip early validation to save time
    if epoch % 2 != 0 and epoch < 100:
        continue

    # VALIDATION PHASE
    model.eval()
    vl = Averager()
    va = Averager()

    for i, batch in tqdm(enumerate(val_loader, 1)):
        data, _ = [_.cuda() for _ in batch]
        p = args['shot'] * args['test-way']
        data_shot, data_query = data[:p], data[p:]

        if args['shot'] == 1:
            data_shot = torch.cat((data_shot, flip(data_shot, 3)), dim=0)

        with torch.no_grad():
            proto = model(data_shot)
            proto = proto.reshape(shot_num, args['test-way'], -1)
            proto = torch.transpose(proto, 0, 1)
            hyperplanes, mu = projection_pro.create_subspace(proto, args['test-way'], shot_num)
            logits, _ = projection_pro.projection_metric(model(data_query), hyperplanes, mu=mu)

        label = torch.arange(args['test-way']).repeat(args['query']).type(torch.cuda.LongTensor)
        loss = F.cross_entropy(logits, label)
        acc = count_acc(logits, label)

        vl.add(loss.item())
        va.add(acc)

    vl = vl.item()
    va = va.item()

    # Save best
    if va > trlog['max_acc']:
        trlog['max_acc'] = va
        save_model('max-acc')
        trlog['max_epoch'] = epoch
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1

    print(f'epoch {epoch}, val, loss={vl:.4f} acc={va:.4f} maxacc={trlog["max_acc"]:.4f}')

    trlog['train_loss'].append(tl)
    trlog['train_acc'].append(ta)
    trlog['val_loss'].append(vl)
    trlog['val_acc'].append(va)

    save_model('epoch-last')
    if epoch % args['save-epoch'] == 0:
        save_model(f'epoch-{epoch}')

    print(f'ETA:{timer.measure()}/{timer.measure(epoch / args["max-epoch"])}')

    # TEST PHASE
    tel = Averager()
    tea = Averager()

    for i, batch in tqdm(enumerate(test_loader, 1)):
        data, _ = [_.cuda() for _ in batch]
        p = args['shot'] * args['test-way']
        data_shot, data_query = data[:p], data[p:]

        if args['shot'] == 1:
            data_shot = torch.cat((data_shot, flip(data_shot, 3)), dim=0)

        with torch.no_grad():
            proto = model(data_shot)
            proto = proto.reshape(shot_num, args['test-way'], -1)
            proto = torch.transpose(proto, 0, 1)
            hyperplanes, mu = projection_pro.create_subspace(proto, args['test-way'], shot_num)
            logits, _ = projection_pro.projection_metric(model(data_query), hyperplanes, mu=mu)

        label = torch.arange(args['test-way']).repeat(args['query']).type(torch.cuda.LongTensor)
        loss = F.cross_entropy(logits, label)
        acc = count_acc(logits, label)

        tel.add(loss.item())
        tea.add(acc)

    tel = tel.item()
    tea = tea.item()
    print(f'epoch {epoch}, test, loss={tel:.4f} acc={tea:.4f} maxacc={trlog["max_acc"]:.4f}')
    
    trlog['test_loss'].append(tel)
    trlog['test_acc'].append(tea)

    save_path = osp.join(args['save-path'], 'trlog.json')
    with open(save_path, 'w') as f:
        json.dump(trlog, f, indent=4)

    print(f'TEST ETA:{timer.measure()}/{timer.measure(epoch / args["max-epoch"])}')

    if epochs_no_improve >= patience:
        print(f'Early stopping triggered after {patience} epochs with no improvement.')
        break

In [None]:
import matplotlib.pyplot as plt
import json
import os.path as osp

# Load trlog (if not already in memory)
save_path = osp.join(args['save-path'], 'trlog.json')
with open(save_path, 'r') as f:
    trlog = json.load(f)

epochs = range(1, len(trlog['train_loss']) + 1)

plt.figure(figsize=(12, 5))

# ----- LOSS -----
plt.subplot(1, 2, 1)
plt.plot(epochs, trlog['train_loss'], label='Train Loss')
plt.plot(epochs, trlog['val_loss'], label='Val Loss')
plt.plot(epochs, trlog['test_loss'], label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss over Epochs')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.5)

# ----- ACCURACY -----
plt.subplot(1, 2, 2)
plt.plot(epochs, trlog['train_acc'], label='Train Acc')
plt.plot(epochs, trlog['val_acc'], label='Val Acc')
plt.plot(epochs, trlog['test_acc'], label='Test Acc')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Accuracy over Epochs')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.5)

plt.tight_layout()
plt.show()

In [None]:
save_model_path = '/kaggle/working/Resnet50_plant_disease_model_26102025.pth'
torch.save(model.state_dict(), save_model_path)

print(f"Model telah disimpan di: {save_model_path}")