In [1]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
from torch import nn
from torchvision import transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder
import time
import os
from tqdm.auto import tqdm

In [2]:
class SeBlock(nn.Module):
    def __init__(self, in_channels, r = 24):
        super().__init__()
        C = in_channels
        self.globpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc1 = nn.Linear(C, C//r, bias=False)
        self.fc2 = nn.Linear(C//r, C, bias=False)
        self.silu = nn.SiLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        f = self.globpool(x)
        f = torch.flatten(f,1)
        f = self.silu(self.fc1(f))
        f = self.sigmoid(self.fc2(f))
        f = f[:,:,None,None]

        scale = x * f
        return scale

In [3]:
class paper_MBConv(nn.Module):
    def __init__(self, c_in, c_out, kernel_size, stride = 1, k = 6):
        super().__init__()
        self.add = ((c_in == c_out) and stride == 1)
        padding = kernel_size // 2
        c_Bottle = c_in * k
        self.net = nn.Sequential(nn.Conv2d(c_in, c_Bottle, kernel_size = 1), nn.ReLU(),
                            nn.Conv2d(c_Bottle, c_Bottle, kernel_size = kernel_size, padding = padding, stride = stride, groups = c_Bottle), nn.ReLU(),
                            nn.Conv2d(c_Bottle, c_out, kernel_size = 1))
    def forward(self, X):
        Y = self.net(X)
        if self.add:
            Y += X
        return Y


In [4]:
class MBConv(nn.Module):
    def __init__(self, c_in, c_out, kernel_size, stride = 1, k = 6):
        super().__init__()
        self.add = ((c_in == c_out) and stride == 1)
        padding = kernel_size // 2
        c_Bottle = c_in * k
        self.net = nn.Sequential(nn.Conv2d(c_in, c_Bottle, kernel_size = 1), nn.LazyBatchNorm2d(), nn.SiLU(),
                            nn.Conv2d(c_Bottle, c_Bottle, kernel_size = kernel_size, padding = padding, stride = stride, groups = c_Bottle), nn.LazyBatchNorm2d(), nn.SiLU(),
                            SeBlock(c_Bottle),
                            nn.Conv2d(c_Bottle, c_out, kernel_size = 1), nn.LazyBatchNorm2d())
    def forward(self, X):
        Y = self.net(X)
        if self.add:
            Y += X
        return Y


In [5]:
class RSA(nn.Module):
    def __init__(self, kernel_size):
        super().__init__()
        padding = kernel_size // 2
        self.conv1 = nn.Conv2d(2, 1, kernel_size = kernel_size, padding = padding)

    def forward(self, X):
        maxPool, _ = torch.max(X, dim = 1, keepdim = True)
        avgPool = torch.mean(X, dim = 1, keepdim = True)
        Y = torch.cat((avgPool, maxPool), dim = 1)
        Y = self.conv1(Y)
        Y = torch.sigmoid(Y)
        out = X * Y + X
        return out


In [6]:
class RCA(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.net1 = nn.Sequential(nn.AdaptiveAvgPool2d((1,1)), nn.Conv2d(channels, channels, kernel_size = 1), nn.Sigmoid())
        self.net2 = nn.Sequential(nn.AdaptiveMaxPool2d((1,1)), nn.Conv2d(channels, channels, kernel_size = 1), nn.Sigmoid())

    def forward(self, X):
        b1 = self.net1(X)
        b2 = self.net2(X)
        o1 = X * b1
        o2 = X * b2
        output = o1 + o2 + X
        return output

In [7]:
class featureExtraction(nn.Module):
    def __init__(self, kernel_size, channels):
        super().__init__()
        self.RSA = RSA(kernel_size)
        self.RCA = RCA(channels)
        self.Gap = nn.AdaptiveAvgPool2d((1,1))

    def forward(self, X):
        return self.Gap(self.RSA(X)), self.Gap(self.RCA(X))


In [8]:
def split_quadrants(x):
    B, C, H, W = x.shape
    h2, w2 = H // 2, W // 2

    q1 = x[:, :, :h2, :w2]   # top-left
    q2 = x[:, :, :h2, w2:]   # top-right
    q3 = x[:, :, h2:, :w2]   # bottom-left
    q4 = x[:, :, h2:, w2:]   # bottom-right

    return (q1, q2, q3, q4)

In [9]:
class Middle(nn.Module):
    def __init__(self, kernel_size, channels):
        super().__init__()
        self.Extractor = nn.ModuleList()
        for i in range(4):
            self.Extractor.append(featureExtraction(kernel_size, channels))
        self.globalExtractor = featureExtraction(kernel_size, channels)
        self.flatten = nn.Flatten()

    def forward(self, X):
        x_quads = split_quadrants(X)
        RSA_out = []
        RCA_out = []
        for i in range(4):
            rsa, rca = self.Extractor[i](x_quads[i])
            RSA_out.append(rsa)
            RCA_out.append(rca)
        rsa, rca = self.globalExtractor(X)
        RSA_out.append(rsa)
        RCA_out.append(rca)
        out1 = torch.cat(RSA_out, dim = 1)
        out2 = torch.cat(RCA_out, dim = 1)

        return self.flatten(out1 + out2)



In [10]:
class AELGNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(nn.Conv2d(3, 32, kernel_size = 3, stride = 2, padding = 1), nn.LazyBatchNorm2d(), nn.ReLU(),
                                 MBConv(32, 16, kernel_size = 3, stride = 1), MBConv(16, 24, kernel_size = 3, stride = 2),
                                 MBConv(24, 40, kernel_size = 5, stride = 2),
                                 Middle(kernel_size = 7, channels = 40),
                                 nn.LazyLinear(256), nn.ReLU(),
                                 nn.LazyLinear(128), nn.ReLU(),
                                 nn.LazyLinear(80))
    def forward(self, X):
        return self.net(X)




In [11]:
class CLAHETransform:
    def __init__(self, clip_limit=2.0, tile_grid_size=(8, 8)):
        self.clahe = cv2.createCLAHE(
            clipLimit=clip_limit,
            tileGridSize=tile_grid_size
        )

    def __call__(self, img: Image.Image) -> Image.Image:
        # PIL (RGB) -> numpy
        rgb = np.array(img)

        # Handle grayscale images safely
        if rgb.ndim == 2:
            rgb = np.stack([rgb]*3, axis=-1)

        # RGB -> BGR
        bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)

        # BGR -> LAB
        lab = cv2.cvtColor(bgr, cv2.COLOR_BGR2LAB)
        l, a, b = cv2.split(lab)

        # CLAHE on luminance
        l = self.clahe.apply(l)

        # Merge and back to RGB
        lab = cv2.merge((l, a, b))
        bgr = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
        rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)

        return Image.fromarray(rgb)


In [12]:
transform = transforms.Compose([
    CLAHETransform(clip_limit=2.0, tile_grid_size=(8, 8)),
    transforms.Resize((224, 224)),          # adapt to your model
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])



In [13]:
# ------------------ RUN THIS ONCE BEFORE TRAINING ------------------
'''src = "./Combined_dataset"          # original raw images
dst = "./Combined_dataset_clahe"    # new folder to create with CLAHE applied
os.makedirs(dst, exist_ok=True)

clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))

# iterate classes and files
for cls in sorted(os.listdir(src)):
    src_cls = os.path.join(src, cls)
    if not os.path.isdir(src_cls):
        continue
    dst_cls = os.path.join(dst, cls)
    os.makedirs(dst_cls, exist_ok=True)
    for fname in tqdm(sorted(os.listdir(src_cls)), desc=cls, leave=False):
        src_path = os.path.join(src_cls, fname)
        dst_path = os.path.join(dst_cls, fname)
        try:
            im = Image.open(src_path).convert("RGB")
            rgb = np.array(im)                         # H,W,3
            bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
            lab = cv2.cvtColor(bgr, cv2.COLOR_BGR2LAB)
            l, a, b = cv2.split(lab)
            l = clahe.apply(l)
            lab = cv2.merge((l, a, b))
            bgr = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
            rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
            out = Image.fromarray(rgb)
            out.save(dst_path)
        except Exception as e:
            print("skipped", src_path, ":", e)

# quick sanity
from torchvision.datasets import ImageFolder
ds = ImageFolder(root=dst)
print("Preprocessed dataset classes:", len(ds.classes))
print("Example classes:", ds.classes[:10])
# --------------------------------------------------------------------'''


'src = "./Combined_dataset"          # original raw images\ndst = "./Combined_dataset_clahe"    # new folder to create with CLAHE applied\nos.makedirs(dst, exist_ok=True)\n\nclahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))\n\n# iterate classes and files\nfor cls in sorted(os.listdir(src)):\n    src_cls = os.path.join(src, cls)\n    if not os.path.isdir(src_cls):\n        continue\n    dst_cls = os.path.join(dst, cls)\n    os.makedirs(dst_cls, exist_ok=True)\n    for fname in tqdm(sorted(os.listdir(src_cls)), desc=cls, leave=False):\n        src_path = os.path.join(src_cls, fname)\n        dst_path = os.path.join(dst_cls, fname)\n        try:\n            im = Image.open(src_path).convert("RGB")\n            rgb = np.array(im)                         # H,W,3\n            bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)\n            lab = cv2.cvtColor(bgr, cv2.COLOR_BGR2LAB)\n            l, a, b = cv2.split(lab)\n            l = clahe.apply(l)\n            lab = cv2.merge((l, a, b))\

In [14]:
# ----- CONFIG: change these as needed -----
dataset_root = "./Combined_dataset_clahe"   # <- point to preprocessed folder
batch_size   = 16
num_workers  = 4        # safe now that transforms are picklable
epochs       = 25       # your planned full run
lr           = 2e-3
weight_decay = 1e-4
checkpoint_path = "./aelgnet_best.pth"
# ------------------------------------------

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# Transforms: NO CLAHE here (preprocessed already)
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.85, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(
        brightness=0.1,
        contrast=0.1,
        saturation=0.1,
        hue=0.02
    ),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],
                         std=[0.229,0.224,0.225])
])

transform_val = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

# Dataset + deterministic 80/20 split
full_dataset = ImageFolder(root=dataset_root, transform=transform_train)
n = len(full_dataset)
train_n = int(0.8 * n)
val_n = n - train_n

train_dataset = ImageFolder(root=dataset_root, transform=transform_train)
val_dataset   = ImageFolder(root=dataset_root, transform=transform_val)

from torch.utils.data import Subset
indices = list(range(len(train_dataset)))
torch.manual_seed(42)
indices = torch.randperm(len(indices)).tolist()
train_indices = indices[:train_n]
val_indices = indices[train_n:]

train_ds = Subset(train_dataset, train_indices)
val_ds   = Subset(val_dataset, val_indices)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                          num_workers=num_workers, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False,
                          num_workers=num_workers, pin_memory=True)

print(f"Dataset size: total={n}, train={train_n}, val={val_n}")

# Model, loss, optimizer (AdamW)

model = AELGNet().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

# Optional: AMP and scheduler (recommended)
from torch.cuda.amp import GradScaler, autocast
scaler = GradScaler()
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

best_val_acc = 0.0

for epoch in range(1, epochs+1):
    t0 = time.time()
    # ----- train -----
    model.train()
    running_loss = 0.0
    running_corrects = 0
    total = 0

    for imgs, labels in train_loader:
        imgs = imgs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad()
        with autocast():
            outputs = model(imgs)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * imgs.size(0)
        _, preds = torch.max(outputs, 1)
        running_corrects += (preds == labels).sum().item()
        total += imgs.size(0)

    scheduler.step()
    epoch_loss = running_loss / total
    epoch_acc = running_corrects / total

    # ----- validate -----
    model.eval()
    val_loss = 0.0
    val_corrects = 0
    val_total = 0
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs = imgs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            with autocast():
                outputs = model(imgs)
                loss = criterion(outputs, labels)

            val_loss += loss.item() * imgs.size(0)
            _, preds = torch.max(outputs, 1)
            val_corrects += (preds == labels).sum().item()
            val_total += imgs.size(0)

    val_loss = val_loss / val_total
    val_acc = val_corrects / val_total

    elapsed = time.time() - t0
    print(f"Epoch {epoch}/{epochs}  time={elapsed:.1f}s  train_loss={epoch_loss:.4f} train_acc={epoch_acc:.4f}  val_loss={val_loss:.4f} val_acc={val_acc:.4f}")

    # save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc
        }, checkpoint_path)
        print(f"  Saved best model (val_acc={val_acc:.4f}) to {checkpoint_path}")


Device: cuda
Dataset size: total=11340, train=9072, val=2268


  scaler = GradScaler()
  with autocast():
  with autocast():


Epoch 1/25  time=130.0s  train_loss=3.8155 train_acc=0.0802  val_loss=3.4605 val_acc=0.1160
  Saved best model (val_acc=0.1160) to ./aelgnet_best.pth
Epoch 2/25  time=138.2s  train_loss=3.2299 train_acc=0.1726  val_loss=3.0762 val_acc=0.2160
  Saved best model (val_acc=0.2160) to ./aelgnet_best.pth
Epoch 3/25  time=139.3s  train_loss=2.8334 train_acc=0.2543  val_loss=2.7342 val_acc=0.3007
  Saved best model (val_acc=0.3007) to ./aelgnet_best.pth
Epoch 4/25  time=149.8s  train_loss=2.5245 train_acc=0.3244  val_loss=2.4412 val_acc=0.3593
  Saved best model (val_acc=0.3593) to ./aelgnet_best.pth
Epoch 5/25  time=148.8s  train_loss=2.2462 train_acc=0.3847  val_loss=2.2841 val_acc=0.4048
  Saved best model (val_acc=0.4048) to ./aelgnet_best.pth
Epoch 6/25  time=145.1s  train_loss=2.0521 train_acc=0.4329  val_loss=2.2289 val_acc=0.4136
  Saved best model (val_acc=0.4136) to ./aelgnet_best.pth
Epoch 7/25  time=144.9s  train_loss=1.8746 train_acc=0.4712  val_loss=2.0479 val_acc=0.4753
  Saved 