In [None]:
# ✅ Ultra-Accurate EfficientNet-V2-S Snake Classifier for VS Code/Jupyter Notebook


<!-- **Instructions:**
1. Install required packages: `pip install timm torch torchvision scikit-learn seaborn tqdm`
2. Place your dataset in `./Snake_Dataset/<class_name>/*.jpg`.
3. Run this notebook in VS Code or Jupyter.
4. After training, upload an image for inference when prompted. -->

In [None]:

# Imports
import os
import shutil
from PIL import Image
import torch
import timm
import numpy as np
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split
from torch import nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.amp import autocast, GradScaler
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from tkinter import Tk, filedialog


In [None]:

# 1. Setup Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


In [None]:

# 2. Paths and Constants
data_dir = './Snake_Dataset'
split_dir = './Split_Snake_Dataset'
train_ratio = 0.85
img_size = 384
batch_size = 8
num_epochs = 25


In [None]:

# 3. Transforms
train_transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

inference_transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])


In [None]:

# 4. Split Dataset if Needed
if not os.path.exists(split_dir):
    print('Splitting dataset...')
    os.makedirs(os.path.join(split_dir, 'train'), exist_ok=True)
    os.makedirs(os.path.join(split_dir, 'val'), exist_ok=True)
    full = ImageFolder(data_dir)
    classes = full.classes
    samples = full.samples
    # Map class to image paths
    cls_to_paths = {cls: [] for cls in classes}
    for path, label in samples:
        cls_to_paths[classes[label]].append(path)
    # Copy files
    for cls, paths in cls_to_paths.items():
        train_count = int(len(paths)*train_ratio)
        for i, p in enumerate(paths):
            dest = 'train' if i<train_count else 'val'
            out_dir = os.path.join(split_dir, dest, cls)
            os.makedirs(out_dir, exist_ok=True)
            shutil.copy(p, out_dir)
else:
    print('Dataset already split.')


In [None]:

# 5. Load Datasets & Dataloaders
train_ds = ImageFolder(os.path.join(split_dir,'train'), transform=train_transform)
val_ds = ImageFolder(os.path.join(split_dir,'val'), transform=train_transform)
classes = train_ds.classes
num_classes = len(classes)

print(f'Classes ({num_classes}):', classes)

dl_train = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
dl_val   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)


In [None]:

# 6. Create Model
model = timm.create_model('tf_efficientnetv2_s_in21k', pretrained=True, num_classes=num_classes)
model.to(device)


In [None]:

# 7. Loss, Optimizer, Scheduler, AMP
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2)
scaler = GradScaler()


In [None]:

# 8. Training Loop
best_acc = 0.0
for epoch in range(num_epochs):
    # Training
    model.train()
    total_loss, correct, total = 0,0,0
    for imgs, labels in tqdm(dl_train, desc=f'Epoch {epoch+1}/{num_epochs}'):  
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        with autocast():
            outs = model(imgs)
            loss = criterion(outs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()
        preds = outs.argmax(dim=1)
        correct += (preds==labels).sum().item()
        total += labels.size(0)
    train_acc = correct/total*100

    # Validation
    model.eval()
    correct, total = 0,0
    all_preds, all_labels = [],[]
    with torch.no_grad():
        for imgs, labels in dl_val:
            imgs, labels = imgs.to(device), labels.to(device)
            with autocast(): outs = model(imgs)
            preds = outs.argmax(dim=1)
            all_preds += preds.cpu().tolist()
            all_labels += labels.cpu().tolist()
            correct += (preds==labels).sum().item()
            total += labels.size(0)
    val_acc = correct/total*100
    scheduler.step(epoch + val_acc)

    print(f'Epoch {epoch+1}: Train Loss {total_loss/len(dl_train):.4f}, Train Acc {train_acc:.2f}%, Val Acc {val_acc:.2f}%')
    if val_acc>best_acc:
        best_acc=val_acc
        torch.save(model.state_dict(),'best_efficientv2s.pth')
        print('Saved Best Model')


In [None]:

# 9. Final Evaluation
print('\nFinal Evaluation')
print(classification_report(all_labels, all_preds, target_names=classes))
cm = confusion_matrix(all_labels, all_preds)
sns.heatmap(cm, annot=True, fmt='d', xticklabels=classes, yticklabels=classes)
plt.title('Confusion Matrix')
plt.show()


In [None]:

# 10. Inference: Load Model & Predict
# File dialog
Tk().withdraw()
img_path = filedialog.askopenfilename(title='Select Snake Image')
if img_path:
    img = Image.open(img_path).convert('RGB')
    tensor = inference_transform(img).unsqueeze(0).to(device)
    model.eval()
    with torch.no_grad(), autocast():
        out = model(tensor)
        probs = torch.softmax(out, dim=1)
        top3 = probs.topk(3)
    print('Top-3 Predictions:')
    for prob, idx in zip(top3.values[0], top3.indices[0]):
        print(f"{classes[idx]}: {prob.item()*100:.2f}%")

