In [4]:

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt
import numpy as np

# ==============================
# Configurations
# ==============================
train_dir = 'rose-3/train'
validation_dir = 'rose-3/valid'
test_dir = 'rose-3/test'

batch_size = 32
img_height, img_width = 224, 224   # ResNet input size
epochs = 20
model_name = "RoseResNet18"

# Folder to save everything
save_dir = "SavedModels"
os.makedirs(save_dir, exist_ok=True)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Training on: {device}")

# ==============================
# Data Preparation with Stronger Augmentations
# ==============================
train_transform = transforms.Compose([
    transforms.Resize((img_height, img_width)),
    transforms.RandomRotation(30),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.RandomAffine(15),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], 
                         [0.229, 0.224, 0.225])  # ImageNet normalization
])

val_transform = transforms.Compose([
    transforms.Resize((img_height, img_width)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], 
                         [0.229, 0.224, 0.225])
])

train_dataset = datasets.ImageFolder(train_dir, transform=train_transform)
val_dataset = datasets.ImageFolder(validation_dir, transform=val_transform)
test_dataset = datasets.ImageFolder(test_dir, transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

num_classes = len(train_dataset.classes)

# ==============================
# Transfer Learning Model (ResNet18)
# ==============================
cnn_model = models.resnet18(pretrained=True)

# Freeze earlier layers
for param in cnn_model.parameters():
    param.requires_grad = False

# Replace final fully connected layer
cnn_model.fc = nn.Sequential(
    nn.Linear(cnn_model.fc.in_features, 256),
    nn.ReLU(),
    nn.Dropout(0.4),
    nn.Linear(256, num_classes)
)

cnn_model = cnn_model.to(device)

# ==============================
# Training Setup
# ==============================
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(cnn_model.fc.parameters(), lr=0.0005)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

train_losses, val_losses = [], []
train_accs, val_accs = [], []

# ==============================
# Training Loop
# ==============================
best_val_acc = 0.0
for epoch in range(epochs):
    cnn_model.train()
    running_loss, correct, total = 0.0, 0, 0
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = cnn_model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    train_loss = running_loss / len(train_loader)
    train_acc = correct / total
    train_losses.append(train_loss)
    train_accs.append(train_acc)

    # Validation
    cnn_model.eval()
    val_loss, val_correct, val_total = 0.0, 0, 0
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = cnn_model(imgs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)

    val_loss /= len(val_loader)
    val_acc = val_correct / val_total
    val_losses.append(val_loss)
    val_accs.append(val_acc)

    # Save best model in all formats
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        save_name = f"{model_name}_{device.type}"

        torch.save(cnn_model.state_dict(), os.path.join(save_dir, f"{save_name}.pth"))

        scripted_model = torch.jit.script(cnn_model)
        scripted_model.save(os.path.join(save_dir, f"{save_name}.pt"))

        dummy_input = torch.randn(1, 3, img_height, img_width, device=device)
        torch.onnx.export(
            cnn_model, dummy_input, os.path.join(save_dir, f"{save_name}.onnx"),
            input_names=["input"], output_names=["output"], 
            dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}
        )

        print(f"✅ Best model saved to {save_dir}/{save_name}.*")

    scheduler.step()

    print(f"Epoch [{epoch+1}/{epochs}] "
          f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
          f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

# ==============================
# Save Metrics and Plots
# ==============================
np.save(os.path.join(save_dir, "train_losses.npy"), np.array(train_losses))
np.save(os.path.join(save_dir, "val_losses.npy"), np.array(val_losses))
np.save(os.path.join(save_dir, "train_accs.npy"), np.array(train_accs))
np.save(os.path.join(save_dir, "val_accs.npy"), np.array(val_accs))

plt.plot(train_losses, label="Train Loss")
plt.plot(val_losses, label="Validation Loss")
plt.legend(); plt.xlabel("Epochs"); plt.ylabel("Loss")
plt.savefig(os.path.join(save_dir, "loss_curve.png")); plt.close()

plt.plot(train_accs, label="Train Acc")
plt.plot(val_accs, label="Validation Acc")
plt.legend(); plt.xlabel("Epochs"); plt.ylabel("Accuracy")
plt.savefig(os.path.join(save_dir, "accuracy_curve.png")); plt.close()

print(f"🎯 Final Best Validation Accuracy: {best_val_acc:.4f}")
print(f"📂 All files saved inside: {save_dir}/")


Training on: cuda
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /home/spidey/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|████████████████████████████████████████████████████████████████████| 44.7M/44.7M [00:01<00:00, 25.3MB/s]
  torch.onnx.export(


✅ Best model saved to SavedModels/RoseResNet18_cuda.*
Epoch [1/20] Train Loss: 0.8243, Train Acc: 0.7380 | Val Loss: 0.3748, Val Acc: 0.9076
✅ Best model saved to SavedModels/RoseResNet18_cuda.*
Epoch [2/20] Train Loss: 0.3494, Train Acc: 0.8881 | Val Loss: 0.2810, Val Acc: 0.9139
✅ Best model saved to SavedModels/RoseResNet18_cuda.*
Epoch [3/20] Train Loss: 0.2572, Train Acc: 0.9225 | Val Loss: 0.2555, Val Acc: 0.9162
Epoch [4/20] Train Loss: 0.2226, Train Acc: 0.9268 | Val Loss: 0.2874, Val Acc: 0.9059
Epoch [5/20] Train Loss: 0.2155, Train Acc: 0.9325 | Val Loss: 0.3148, Val Acc: 0.9007
Epoch [6/20] Train Loss: 0.1778, Train Acc: 0.9401 | Val Loss: 0.2952, Val Acc: 0.8949
Epoch [7/20] Train Loss: 0.1817, Train Acc: 0.9387 | Val Loss: 0.2834, Val Acc: 0.9013
✅ Best model saved to SavedModels/RoseResNet18_cuda.*
Epoch [8/20] Train Loss: 0.1528, Train Acc: 0.9498 | Val Loss: 0.2286, Val Acc: 0.9173
✅ Best model saved to SavedModels/RoseResNet18_cuda.*
Epoch [9/20] Train Loss: 0.1408, T