In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, classification_report
import numpy as np
from tqdm import tqdm
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
import os

# -----------------------------
# 1️⃣ Squeeze-and-Excitation Block
# -----------------------------
class SEBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

# -----------------------------
# 2️⃣ Custom VGG16 Model with SE
# -----------------------------
class CustomVGG16(nn.Module):
    def __init__(self, in_channels=3, num_classes=2):
        super(CustomVGG16, self).__init__()
        self.features = nn.Sequential(
            # Block 1
            nn.Conv2d(in_channels, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.MaxPool2d(2,2), SEBlock(64),
            # Block 2
            nn.Conv2d(64,128,3,padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.Conv2d(128,128,3,padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.MaxPool2d(2,2), SEBlock(128),
            # Block 3
            nn.Conv2d(128,256,3,padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256,256,3,padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256,256,3,padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            nn.MaxPool2d(2,2), SEBlock(256),
            # Block 4
            nn.Conv2d(256,512,3,padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True),
            nn.Conv2d(512,512,3,padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True),
            nn.Conv2d(512,512,3,padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True),
            nn.MaxPool2d(2,2), SEBlock(512),
            # Block 5
            nn.Conv2d(512,512,3,padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True),
            nn.Conv2d(512,512,3,padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True),
            nn.Conv2d(512,512,3,padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True),
            nn.MaxPool2d(2,2), SEBlock(512),
        )
        self.classifier = nn.Sequential(
            nn.Linear(512*7*7, 4096), nn.ReLU(inplace=True), nn.Dropout(0.5),
            nn.Linear(4096, 1024), nn.ReLU(inplace=True), nn.Dropout(0.5),
            nn.Linear(1024, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x,1)
        return self.classifier(x)

# -----------------------------
# 3️⃣ Hyperparameters
# -----------------------------
batch_size = 128
learning_rate = 1e-4
num_epochs = 50
p_mixup = 0.5
p_cutmix = 0.25
alpha_mixup = 0.4
alpha_cutmix = 1.0
early_lr_min, early_lr_max = 1e-6, 5e-6  # Layer-wise ramp-up range

# -----------------------------
# 4️⃣ Transforms
# -----------------------------
transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.RandomResizedCrop(224, scale=(0.85, 1.0)),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), shear=10),
    transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    transforms.RandomErasing(p=0.3, scale=(0.02, 0.1), ratio=(0.3, 3.3))
])
transform_val = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

# -----------------------------
# 5️⃣ Dataset & DataLoader
# -----------------------------
dataset_path = "images/"
full_dataset = datasets.ImageFolder(root=dataset_path, transform=transform_train)
num_classes = len(full_dataset.classes)
print("Detected classes:", full_dataset.classes)

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])
val_dataset.dataset.transform = transform_val

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

# -----------------------------
# 6️⃣ Device & Model
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CustomVGG16(num_classes=num_classes).to(device)

# Freeze first 3 blocks initially
for param in list(model.features[:16].parameters()):
    param.requires_grad = False

# -----------------------------
# 7️⃣ Optimizer, Scheduler, Scaler
# -----------------------------
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate, weight_decay=1e-3)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
scaler = torch.amp.GradScaler('cuda')

# -----------------------------
# 8️⃣ Class Weights, Loss
# -----------------------------
targets = [label for _, label in full_dataset.samples]
class_counts = np.bincount(targets)
class_weights = 1. / class_counts
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)

# -----------------------------
# 9️⃣ Checkpoint Path
# -----------------------------
checkpoint_path = "/content/drive/MyDrive/vgg16_se_checkpoint.pth"
start_epoch = 0
best_acc = 0.0

# Restore checkpoint if exists
if os.path.exists(checkpoint_path):
    print("🔄 Restoring checkpoint from Google Drive...")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state'])
    optimizer.load_state_dict(checkpoint['optimizer_state'])
    scheduler.load_state_dict(checkpoint['scheduler_state'])
    start_epoch = checkpoint['epoch'] + 1
    best_acc = checkpoint['best_acc']
    print(f"✅ Restored checkpoint (epoch {start_epoch}, val_acc {best_acc:.2f}%)")

# -----------------------------
# 🔟 Mixup & CutMix helper
# -----------------------------
def rand_bbox(size, lam):
    W, H = size[2], size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w, cut_h = int(W * cut_rat), int(H * cut_rat)
    cx, cy = np.random.randint(W), np.random.randint(H)
    x1, x2 = np.clip(cx - cut_w // 2, 0, W), np.clip(cx + cut_w // 2, 0, W)
    y1, y2 = np.clip(cy - cut_h // 2, 0, H), np.clip(cy + cut_h // 2, 0, H)
    return x1, y1, x2, y2

def apply_mixup_cutmix(images, labels):
    if np.random.rand() < p_mixup:
        lam = np.random.beta(alpha_mixup, alpha_mixup)
        index = torch.randperm(images.size(0)).to(device)
        mixed_x = lam * images + (1 - lam) * images[index, :]
        y_a, y_b = labels, labels[index]
        return mixed_x, y_a, y_b, lam, "mixup"
    elif np.random.rand() < p_cutmix:
        lam = np.random.beta(alpha_cutmix, alpha_cutmix)
        index = torch.randperm(images.size(0)).to(device)
        y_a, y_b = labels, labels[index]
        x1, y1, x2, y2 = rand_bbox(images.size(), lam)
        images[:, :, y1:y2, x1:x2] = images[index, :, y1:y2, x1:x2]
        lam = 1 - ((x2 - x1) * (y2 - y1) / (images.size(-1) * images.size(-2)))
        return images, y_a, y_b, lam, "cutmix"
    else:
        return images, labels, labels, 1.0, "none"

# -----------------------------
# 🔟 Training Loop with Gradual Layer-wise Fine-tuning + Ramp-up + SWA
# -----------------------------
use_swa = True
swa_start = int(0.7 * num_epochs)
swa_model = AveragedModel(model)
swa_scheduler = None
ramp_up_epochs = 5  # Gradual LR ramp-up

# Function to create parameter groups
def create_param_groups(model, early_lr):
    return [
        {'params': model.features[:16].parameters(), 'lr': early_lr},  # Blocks 1–3
        {'params': model.features[16:].parameters(), 'lr': learning_rate},  # Blocks 4–5
        {'params': model.classifier.parameters(), 'lr': learning_rate}       # Classifier
    ]

for epoch in range(start_epoch, num_epochs):
    model.train()
    running_loss, correct, total = 0.0, 0, 0

    # Gradual layer-wise fine-tuning after epoch 30
    if epoch >= 30:
        if epoch == 30:
            for param in model.features.parameters():
                param.requires_grad = True
            print("🔓 All layers unfrozen with layer-wise ramp-up learning rates.")

        # Compute early layer LR for ramp-up
        ramp_progress = min(epoch - 30 + 1, ramp_up_epochs) / ramp_up_epochs
        early_lr = early_lr_min + (early_lr_max - early_lr_min) * ramp_progress
        optimizer = optim.AdamW(create_param_groups(model, early_lr), weight_decay=1e-3)
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)

    for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        imgs, labels = imgs.to(device), labels.to(device)
        imgs, targets_a, targets_b, lam, mode = apply_mixup_cutmix(imgs, labels)
        optimizer.zero_grad()
        with torch.amp.autocast('cuda'):
            outputs = model(imgs)
            loss = lam * criterion(outputs, targets_a) + (1-lam)*criterion(outputs, targets_b)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()
        _, preds = outputs.max(1)
        correct += (preds==labels).sum().item()
        total += labels.size(0)

    train_acc = 100*correct/total

    # Validation
    model.eval()
    correct_val, total_val = 0,0
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            _, preds = outputs.max(1)
            correct_val += (preds==labels).sum().item()
            total_val += labels.size(0)
    val_acc = 100*correct_val/total_val
    scheduler.step()
    print(f"\nEpoch {epoch+1} → Train Acc: {train_acc:.2f}% | Val Acc: {val_acc:.2f}%")

    # SWA
    if use_swa and epoch >= swa_start:
        swa_model.update_parameters(model)
        if swa_scheduler is None:
            swa_scheduler = SWALR(optimizer, swa_lr=5e-6)

    # Save checkpoint
    if val_acc > best_acc:
        best_acc = val_acc
    torch.save({
        'epoch': epoch,
        'model_state': model.state_dict(),
        'optimizer_state': optimizer.state_dict(),
        'scheduler_state': scheduler.state_dict(),
        'best_acc': best_acc
    }, checkpoint_path)

# Apply SWA
if use_swa:
    update_bn(train_loader, swa_model, device=device)
    model = swa_model

# Load best weights
model.load_state_dict(torch.load(checkpoint_path)['model_state'])

# -----------------------------
# 1️⃣1️⃣ Final Evaluation
# -----------------------------
all_preds, all_labels = [], []
model.eval()
with torch.no_grad():
    for imgs, labels in val_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        _, preds = outputs.max(1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

test_acc = accuracy_score(all_labels, all_preds)
print(f"\nFinal Test Accuracy: {test_acc:.4f}")
print("\nClassification Report:\n", classification_report(all_labels, all_preds))


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Detected classes: ['snake', 'spider']
🔄 Restoring checkpoint from Google Drive...
✅ Restored checkpoint (epoch 15, val_acc 90.49%)


Epoch 16/50: 100%|██████████| 219/219 [04:19<00:00,  1.19s/it]



Epoch 16 → Train Acc: 82.37% | Val Acc: 92.21%


Epoch 17/50: 100%|██████████| 219/219 [04:38<00:00,  1.27s/it]



Epoch 17 → Train Acc: 81.08% | Val Acc: 90.58%


Epoch 18/50: 100%|██████████| 219/219 [04:38<00:00,  1.27s/it]



Epoch 18 → Train Acc: 82.28% | Val Acc: 92.14%


Epoch 19/50: 100%|██████████| 219/219 [04:42<00:00,  1.29s/it]



Epoch 19 → Train Acc: 80.87% | Val Acc: 93.04%


Epoch 20/50: 100%|██████████| 219/219 [04:36<00:00,  1.26s/it]



Epoch 20 → Train Acc: 84.03% | Val Acc: 93.44%


Epoch 21/50: 100%|██████████| 219/219 [04:35<00:00,  1.26s/it]



Epoch 21 → Train Acc: 82.50% | Val Acc: 93.98%


Epoch 22/50: 100%|██████████| 219/219 [04:40<00:00,  1.28s/it]



Epoch 22 → Train Acc: 83.45% | Val Acc: 94.49%


Epoch 23/50: 100%|██████████| 219/219 [04:38<00:00,  1.27s/it]



Epoch 23 → Train Acc: 83.79% | Val Acc: 95.35%


Epoch 24/50: 100%|██████████| 219/219 [04:35<00:00,  1.26s/it]



Epoch 24 → Train Acc: 82.24% | Val Acc: 95.12%


Epoch 25/50: 100%|██████████| 219/219 [04:32<00:00,  1.25s/it]



Epoch 25 → Train Acc: 83.95% | Val Acc: 95.57%


Epoch 26/50: 100%|██████████| 219/219 [04:35<00:00,  1.26s/it]



Epoch 26 → Train Acc: 83.79% | Val Acc: 95.73%


Epoch 27/50: 100%|██████████| 219/219 [04:30<00:00,  1.24s/it]



Epoch 27 → Train Acc: 84.34% | Val Acc: 95.57%


Epoch 28/50: 100%|██████████| 219/219 [04:33<00:00,  1.25s/it]



Epoch 28 → Train Acc: 86.31% | Val Acc: 95.83%


Epoch 29/50: 100%|██████████| 219/219 [04:35<00:00,  1.26s/it]



Epoch 29 → Train Acc: 83.13% | Val Acc: 95.93%


Epoch 30/50: 100%|██████████| 219/219 [04:32<00:00,  1.24s/it]



Epoch 30 → Train Acc: 83.03% | Val Acc: 95.85%
🔓 All layers unfrozen with layer-wise ramp-up learning rates.


Epoch 31/50: 100%|██████████| 219/219 [05:02<00:00,  1.38s/it]



Epoch 31 → Train Acc: 82.28% | Val Acc: 92.70%


Epoch 32/50: 100%|██████████| 219/219 [05:15<00:00,  1.44s/it]



Epoch 32 → Train Acc: 83.28% | Val Acc: 90.30%


Epoch 33/50:   1%|          | 2/219 [00:06<11:27,  3.17s/it]


OutOfMemoryError: CUDA out of memory. Tried to allocate 392.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 126.12 MiB is free. Process 9713 has 14.62 GiB memory in use. Of the allocated memory 14.01 GiB is allocated by PyTorch, and 479.37 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from google.colab import files

# -----------------------------
# 1️⃣ Your SE Block + Custom VGG16
# -----------------------------
class SEBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

class CustomVGG16(nn.Module):
    def __init__(self, in_channels=3, num_classes=2):
        super(CustomVGG16, self).__init__()
        self.features = nn.Sequential(
            # Block 1
            nn.Conv2d(in_channels, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.MaxPool2d(2,2), SEBlock(64),
            # Block 2
            nn.Conv2d(64,128,3,padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.Conv2d(128,128,3,padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.MaxPool2d(2,2), SEBlock(128),
            # Block 3
            nn.Conv2d(128,256,3,padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256,256,3,padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256,256,3,padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            nn.MaxPool2d(2,2), SEBlock(256),
            # Block 4
            nn.Conv2d(256,512,3,padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True),
            nn.Conv2d(512,512,3,padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True),
            nn.Conv2d(512,512,3,padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True),
            nn.MaxPool2d(2,2), SEBlock(512),
            # Block 5
            nn.Conv2d(512,512,3,padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True),
            nn.Conv2d(512,512,3,padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True),
            nn.Conv2d(512,512,3,padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True),
            nn.MaxPool2d(2,2), SEBlock(512),
        )
        self.classifier = nn.Sequential(
            nn.Linear(512*7*7, 4096), nn.ReLU(inplace=True), nn.Dropout(0.5),
            nn.Linear(4096, 1024), nn.ReLU(inplace=True), nn.Dropout(0.5),
            nn.Linear(1024, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x,1)
        return self.classifier(x)

# -----------------------------
# 2️⃣ Device + Model
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CustomVGG16(num_classes=2).to(device)

# Load weights
checkpoint_path = "/content/drive/MyDrive/vgg16_se_best_model.pth"
checkpoint = torch.load(checkpoint_path, map_location=device)
if "model_state" in checkpoint:
    model.load_state_dict(checkpoint["model_state"])
else:
    model.load_state_dict(checkpoint)

model.eval()

# -----------------------------
# 3️⃣ Upload Image
# -----------------------------
uploaded = files.upload()
image_path = next(iter(uploaded.keys()))

# -----------------------------
# 4️⃣ Transform & Predict
# -----------------------------
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

img = Image.open(image_path).convert("RGB")
img_tensor = transform(img).unsqueeze(0).to(device)

with torch.no_grad():
    outputs = model(img_tensor)
    probs = torch.nn.functional.softmax(outputs, dim=1)
    confidence, pred_idx = torch.max(probs, 1)

# -----------------------------
# 5️⃣ Print result
# -----------------------------
class_names = ["snake", "spider"]

print(f"✅ Prediction: {class_names[pred_idx.item()]} (Confidence: {confidence.item()*100:.2f}%)")
print("\n📊 Per-class probabilities:")
for i, cls in enumerate(class_names):
    print(f"   {cls}: {probs[0][i].item()*100:.2f}%")


Saving Cobra-by-Skynavin-1920x1300-shutterstock_688611442-aspect-ratio-1000-715.jpg to Cobra-by-Skynavin-1920x1300-shutterstock_688611442-aspect-ratio-1000-715.jpg
✅ Prediction: snake (Confidence: 84.01%)

📊 Per-class probabilities:
   snake: 84.01%
   spider: 15.99%
