In [None]:
# ✅ Ultra-Accurate EfficientNet-B5 Snake Classifier for Google Colab

import os
import torch
import timm
import numpy as np
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torch import nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.amp import GradScaler, autocast
from sklearn.metrics import classification_report, confusion_matrix
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import shutil


In [None]:

# 1. Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


In [None]:

# 2. Transforms & Dataset
transform = transforms.Compose([
    transforms.Resize((456, 456)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

inference_transform = transforms.Compose([
    transforms.Resize((456, 456)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

data_path = '/content/Snake_Dataset'
split_path = '/content/Split_Snake_Dataset'

assert os.path.exists(data_path), f"Dataset not found at {data_path}"

# Auto Split if not already split
if not os.path.exists(split_path):
    print("🛠️ Splitting dataset...")
    os.makedirs(os.path.join(split_path, 'train'))
    os.makedirs(os.path.join(split_path, 'val'))

    full_dataset = ImageFolder(data_path)
    num_classes = len(full_dataset.classes)
    class_names = full_dataset.classes

    class_indices = {cls: [] for cls in class_names}
    for idx, (img_path, label) in enumerate(full_dataset.samples):
        cls = class_names[label]
        class_indices[cls].append(img_path)

    for cls in class_names:
        cls_dir_train = os.path.join(split_path, 'train', cls)
        cls_dir_val = os.path.join(split_path, 'val', cls)
        os.makedirs(cls_dir_train, exist_ok=True)
        os.makedirs(cls_dir_val, exist_ok=True)

        imgs = class_indices[cls]
        split_idx = int(len(imgs) * 0.85)
        for img in imgs[:split_idx]:
            shutil.copy(img, os.path.join(cls_dir_train, os.path.basename(img)))
        for img in imgs[split_idx:]:
            shutil.copy(img, os.path.join(cls_dir_val, os.path.basename(img)))
else:
    print("📁 Dataset already split.")

train_dataset = ImageFolder(os.path.join(split_path, 'train'), transform=transform)
val_dataset = ImageFolder(os.path.join(split_path, 'val'), transform=transform)
num_classes = len(train_dataset.classes)
dataset_classes = train_dataset.classes

print(f"🪲 Classes: {num_classes} => {dataset_classes}")

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=2, pin_memory=True)


In [None]:

# 3. Model
model = timm.create_model('tf_efficientnet_b5_ns', pretrained=True, drop_rate=0.4, drop_path_rate=0.4)
model.classifier = nn.Sequential(
    nn.BatchNorm1d(model.num_features),
    nn.Dropout(0.4),
    nn.Linear(model.num_features, num_classes)
)
model = model.to(device)


In [None]:

# 4. Loss, Optimizer, Scheduler
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2)
scaler = GradScaler()


In [None]:

# 5. Train & Validate
num_epochs = 20
best_val_acc = 0.0

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training")

    for images, labels in train_bar:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()

        with autocast(dtype=torch.float16):
            outputs = model(images)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)
        train_bar.set_postfix(loss=running_loss/(total//8), acc=100.*correct/total)

    train_acc = 100.*correct/total

    # Validation
    model.eval()
    val_correct = 0
    val_total = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            with autocast(dtype=torch.float16):
                outputs = model(images)
            _, predicted = outputs.max(1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            val_correct += predicted.eq(labels).sum().item()
            val_total += labels.size(0)

    val_acc = 100.*val_correct/val_total
    scheduler.step(epoch + val_acc)
    print(f"Validation Accuracy: {val_acc:.2f}%")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "best_b5_model.pth")
        print("✅ Saved Best Model")

print(f"\n📊 Final Evaluation Report:")
print(classification_report(all_labels, all_preds, target_names=dataset_classes))

cm = confusion_matrix(all_labels, all_preds)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=dataset_classes, yticklabels=dataset_classes)
plt.title("Confusion Matrix")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.show()


In [None]:

# 6. Inference on User Image
from google.colab import files
uploaded = files.upload()

model.eval()
for fname in uploaded.keys():
    img = Image.open(fname).convert('RGB')
    img_tensor = inference_transform(img).unsqueeze(0).to(device)
    with torch.no_grad():
        with autocast(dtype=torch.float16):
            pred = model(img_tensor)
            probs = torch.softmax(pred, dim=1)
            top3_probs, top3_indices = torch.topk(probs, 3)

    print(f"\n🧠 Prediction for '{fname}':")
    for i in range(3):
        print(f"{dataset_classes[top3_indices[0][i]]}: {top3_probs[0][i].item()*100:.2f}%")
