In [None]:
# Code to connect to gcs

from google.colab import auth
auth.authenticate_user()
!gcloud init

In [None]:
# Code to connect to google drive and to downlaod dataset from gcs
import os
from google.colab import drive
from google.cloud import storage
import subprocess

# ---------- Configuration ----------
GCS_BUCKET = "similar_dataset"
GCS_DATASET_PATH = "cnn_set"
LOCAL_DATASET_PATH = "/content/cnn_set"
OUTPUT_DIRS = [
    "/content/drive/MyDrive/ImageRetrievalProject/features",
    "/content/drive/MyDrive/ImageRetrievalProject/tsne_plots",
    "/content/drive/MyDrive/ImageRetrievalProject/stats"
]

# ---------- Functions ----------
def mount_drive():
    """Mount Google Drive if not already mounted."""
    try:
        drive.mount('/content/drive', force_remount=True)
        print("Google Drive mounted successfully.")
    except Exception as e:
        print(f"Google Drive mount failed: {e}")

def download_dataset_from_gcs(bucket_name, gcs_path, local_path):
    if os.path.exists(local_path):
        print(f" Dataset already exists at {local_path}. Skipping download.")
        return
    print(f"⬇ Downloading dataset from gs://{bucket_name}/{gcs_path} to {local_path}...")
    try:
        target_parent_dir = os.path.dirname(local_path) if os.path.dirname(local_path) else '/content/'
        command = f"gsutil -m cp -r gs://{bucket_name}/{gcs_path} {target_parent_dir}"
        subprocess.run(command, shell=True, check=True, capture_output=True, text=True)
        print("  Dataset downloaded successfully.")
    except subprocess.CalledProcessError as e:
        print(f" Error downloading dataset: {e.stderr}")
        exit()


def ensure_dirs(paths):
    """Create directories if they do not exist."""
    for path in paths:
        if not os.path.exists(path):
            os.makedirs(path)
            print(f"Created directory: {path}")

# ---------- Execution ----------
if __name__ == '__main__':
    print("🔹 Starting Drive & Dataset Setup...")
    mount_drive()
    download_dataset_from_gcs(GCS_BUCKET, GCS_DATASET_PATH, LOCAL_DATASET_PATH)
    ensure_dirs(OUTPUT_DIRS)
    print("Drive & dataset setup complete. Ready for feature extraction and retrieval.")


In [None]:
# ========================================
#       Feature Extraction (All Models)
# ========================================

import os
import torch
import pickle
from torch import nn
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from tqdm import tqdm

# -------- Configuration --------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 64
IMAGE_SIZE = (224, 224)
NUM_CLASSES = 20
DATASET_PATH = "/content/cnn_set"
FEATURE_SAVE_DIR = "/content/drive/MyDrive/ImageRetrievalProject/features"
os.makedirs(FEATURE_SAVE_DIR, exist_ok=True)

# -------- Model Weights Paths --------
WEIGHTS = {
    "customcnn": "/content/drive/MyDrive/ImageRetrievalProject/metrics/customcnn_35epoch/best_model.pth",
    "resnet50": "/content/drive/MyDrive/metrics/resnet50/best_model.pth",
    "efficientnet_b0": "/content/drive/MyDrive/metrics/efficientnet_b0/best_model.pth",
    "vgg16": "/content/drive/MyDrive/metrics/vgg16/best_model.pth"
}

# -------- Image Transforms --------
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

# -------- Datasets & Dataloaders --------
# Use a custom dataset to get image paths directly
class ImagePathDataset(datasets.ImageFolder):
    def __getitem__(self, index):
        path, target = self.samples[index]
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        return img, target, path

train_dataset = ImagePathDataset(os.path.join(DATASET_PATH, "train"), transform=transform)
val_dataset = ImagePathDataset(os.path.join(DATASET_PATH, "val"), transform=transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

# ========================================
#       Custom CNN Definition
# ========================================
class CustomCNN(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super(CustomCNN, self).__init__()
        self.conv_block1 = nn.Sequential(
            nn.Conv2d(3, 16, 3, 1, 1), nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(2,2)
        )
        self.conv_block2 = nn.Sequential(
            nn.Conv2d(16, 64, 3,1,1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2,2)
        )
        self.conv_block3 = nn.Sequential(
            nn.Conv2d(64, 256, 3,1,1), nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d(2,2)
        )
        self.conv_block4 = nn.Sequential(
            nn.Conv2d(256, 1024, 3,1,1), nn.BatchNorm2d(1024), nn.ReLU(), nn.MaxPool2d(2,2)
        )
        dummy = torch.zeros(1,3,224,224)
        out = self.conv_block1(dummy)
        out = self.conv_block2(out)
        out = self.conv_block3(out)
        out = self.conv_block4(out)
        self.flatten_dim = out.view(1,-1).shape[1]
        self.fc1 = nn.Linear(self.flatten_dim, 1024)
        self.fc2 = nn.Linear(1024,512)
        self.fc3 = nn.Linear(512,128)
        self.fc_out = nn.Linear(128, num_classes)
        self.dropout = nn.Dropout(0.15)

    def forward(self, x):
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        x = self.conv_block3(x)
        x = self.conv_block4(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = F.relu(self.fc3(x))
        return self.fc_out(x)

# ========================================
#       Feature Extraction Function
# ========================================
def extract_features(model, loader, model_name):
    """
    Dynamically extracts features, labels, and paths.
    """
    model.eval()
    model = model.to(DEVICE)
    features, labels, paths = [], [], []

    with torch.no_grad():
        for inputs, targets, img_paths in tqdm(loader, desc=f"Extracting {model_name} features"):
            inputs = inputs.to(DEVICE)

            if model_name == "customcnn":
                # Manual forward pass for CustomCNN to get the feature vector
                x = model.conv_block1(inputs)
                x = model.conv_block2(x)
                x = model.conv_block3(x)
                x = model.conv_block4(x)
                x = x.view(x.size(0), -1)
                x = F.relu(model.fc1(x))
                x = F.relu(model.fc2(x))
                x = model.dropout(x)
                out = F.relu(model.fc3(x))
            elif model_name == "vgg16":
                # Manual forward pass for VGG16 to handle the avgpool layer correctly
                x = model.features(inputs)
                x = model.avgpool(x)
                x = torch.flatten(x, 1)
                x = model.classifier[0](x)
                x = F.relu(x)
                x = model.classifier[1](x)
                x = F.relu(x)
                x = model.classifier[2](x)
                out = F.relu(x) # This is the feature vector (4096 dim)
            else:
                out = model(inputs)

            out_flat = out.view(out.size(0), -1)
            features.extend(out_flat.cpu().numpy())
            labels.extend(targets.cpu().numpy())
            paths.extend(img_paths)

    return features, labels, paths

# ========================================
#       Model Loading & Feature Extraction
# ========================================
def process_model(model_name, model, loader_train, loader_val, weight_path=None):
    # Load weights for the full model
    if weight_path is not None and os.path.exists(weight_path):
        model.load_state_dict(torch.load(weight_path, map_location=DEVICE))
        print(f"{model_name} weights loaded from {weight_path}")

    # Create the feature extraction model by removing the final classification layer
    if model_name == "customcnn" or model_name == "vgg16":
        feature_model = model # Use the full model and handle feature extraction in the loop
    elif model_name == "resnet50":
        feature_model = nn.Sequential(*list(model.children())[:-1])
    elif model_name == "efficientnet_b0":
        feature_model = nn.Sequential(model.features, model.avgpool, model.classifier[0])
    else:
        raise ValueError("Model type not supported for feature extraction.")

    # Extract train + val features
    print(f"Extracting features for {model_name} - TRAIN")
    train_features, train_labels, train_paths = extract_features(feature_model, loader_train, model_name)
    print(f"Extracting features for {model_name} - VAL")
    val_features, val_labels, val_paths = extract_features(feature_model, loader_val, model_name)

    # Merge train + val
    all_features = train_features + val_features
    all_labels = train_labels + val_labels
    all_paths = train_paths + val_paths

    # Save to .pkl
    save_path = os.path.join(FEATURE_SAVE_DIR, f"{model_name}_features.pkl")
    with open(save_path, 'wb') as f:
        pickle.dump({'features': all_features, 'labels': all_labels, 'paths': all_paths}, f)
    print(f"Saved merged features for {model_name} at {save_path}")

# ========================================
#       Execute All Models
# ========================================
def main():
    # 1️ Custom CNN
     #custom_cnn = CustomCNN()
     #process_model("customcnn", custom_cnn, train_loader, val_loader,
               #    weight_path=WEIGHTS["customcnn"])

    # 2️ ResNet50
     #resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
     #resnet.fc = nn.Linear(resnet.fc.in_features, NUM_CLASSES)
    # process_model("resnet50", resnet, train_loader, val_loader,
                #   weight_path=WEIGHTS["resnet50"])

    # 3️ EfficientNet-B0
   #  efficient = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
    # efficient.classifier[1] = nn.Linear(efficient.classifier[1].in_features, NUM_CLASSES)
    # process_model("efficientnet_b0", efficient, train_loader, val_loader,
                 #  weight_path=WEIGHTS["efficientnet_b0"])

    # 4️ VGG16
    vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
    vgg.classifier[6] = nn.Linear(4096, NUM_CLASSES)
    process_model("vgg16", vgg, train_loader, val_loader,
                  weight_path=WEIGHTS["vgg16"])

if __name__ == "__main__":
    main()

In [None]:
# ===========================
# FULL CUSTOM CNN TRAINING SCRIPT V1
# ===========================
import os
import json
import time
import copy
import subprocess
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report

# ---- 1. GOOGLE DRIVE MOUNT ----
try:
    from google.colab import drive
    drive.mount('/content/drive')
    print(" Google Drive mounted successfully.")
except ImportError:
    print(" Not running in Colab, skipping drive mount.")

# ---- 2. CONFIG ----
CNN_NAME = "customcnn"
NUM_EPOCHS = 35
BASE_METRICS_DIR = "/content/drive/MyDrive/ImageRetrievalProject/metrics"
RUN_METRICS_DIR = os.path.join(BASE_METRICS_DIR, f"{CNN_NAME}_{NUM_EPOCHS}epoch")
os.makedirs(RUN_METRICS_DIR, exist_ok=True)

CONFIG = {
    "batch_size": 64,
    "num_epochs": NUM_EPOCHS,
    "learning_rate": 1e-5,
    "weight_decay": 1e-4,
    "min_lr": 1e-7,
    "image_size": (224, 224),
    "best_model_path": os.path.join(RUN_METRICS_DIR, "mod_arch_best_model.pth"),
    "metrics_json": os.path.join(RUN_METRICS_DIR, "training_metrics.json")
}

from google.colab import auth
auth.authenticate_user()

!gcloud init


# ---- 3. DOWNLOAD DATASET FROM GCS ----
GCS_BUCKET_NAME = 'similar_dataset'
GCS_DATASET_PATH = 'cnn_set'
LOCAL_DATA_ROOT = '/content/cnn_set'

def download_dataset_from_gcs(bucket_name, gcs_path, local_path):
    if os.path.exists(local_path):
        print(f" Dataset already exists at {local_path}. Skipping download.")
        return
    print(f"⬇ Downloading dataset from gs://{bucket_name}/{gcs_path} to {local_path}...")
    try:
        target_parent_dir = os.path.dirname(local_path) if os.path.dirname(local_path) else '/content/'
        command = f"gsutil -m cp -r gs://{bucket_name}/{gcs_path} {target_parent_dir}"
        subprocess.run(command, shell=True, check=True, capture_output=True, text=True)
        print("  Dataset downloaded successfully.")
    except subprocess.CalledProcessError as e:
        print(f" Error downloading dataset: {e.stderr}")
        exit()

download_dataset_from_gcs(GCS_BUCKET_NAME, GCS_DATASET_PATH, LOCAL_DATA_ROOT)







# ---- 4. CNN ARCHITECTURE ----
class CNN(nn.Module):
    def __init__(self, num_classes):
        super(CNN, self).__init__()
        self.conv_block1 = nn.Sequential(
            nn.Conv2d(3, 16, 3, 1, 1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.conv_block2 = nn.Sequential(
            nn.Conv2d(16, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.conv_block3 = nn.Sequential(
            nn.Conv2d(64, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.conv_block4 = nn.Sequential(
            nn.Conv2d(256, 1024, 3, 1, 1),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        dummy_input = torch.zeros(1, 3, 224, 224)
        with torch.no_grad():
            out = self.conv_block4(self.conv_block3(self.conv_block2(self.conv_block1(dummy_input))))
            self.flatten_dim = out.view(1, -1).size(1)

        self.fc1 = nn.Linear(self.flatten_dim, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 128)
        self.fc_out = nn.Linear(128, num_classes)
        self.dropout = nn.Dropout(0.15)

    def forward(self, x):
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        x = self.conv_block3(x)
        x = self.conv_block4(x)
        x = x.view(-1, self.flatten_dim)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = F.relu(self.fc3(x))
        return self.fc_out(x)

# ---- 5. DEVICE ----
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f" Using device: {DEVICE}")

# ---- 6. DATA ----
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(CONFIG["image_size"], scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize(CONFIG["image_size"]),
    transforms.CenterCrop(CONFIG["image_size"]),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_data = datasets.ImageFolder(os.path.join(LOCAL_DATA_ROOT, "train"), transform=train_transform)
val_data = datasets.ImageFolder(os.path.join(LOCAL_DATA_ROOT, "val"), transform=val_transform)

train_loader = DataLoader(train_data, batch_size=CONFIG["batch_size"], shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_data, batch_size=CONFIG["batch_size"], shuffle=False, num_workers=2, pin_memory=True)



# ---- 7. MODEL/LOSS/OPTIMIZER ----
model = CNN(num_classes=len(train_data.classes)).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=CONFIG["learning_rate"], weight_decay=CONFIG["weight_decay"])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, min_lr=CONFIG["min_lr"])

# ---- 8. TRAINING LOOP ----
train_losses, val_losses, train_accs, val_accs, epoch_times = [], [], [], [], []
best_val_loss = float('inf')
epochs_no_improve = 0

print(" Starting training...\n")
total_start_time = time.time()

for epoch in range(CONFIG["num_epochs"]):
    epoch_start = time.time()
    model.train()
    running_loss, correct, total = 0.0, 0, 0

    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

        if (batch_idx + 1) % max(1, len(train_loader)//10) == 0:
            print(f" Epoch [{epoch+1}/{CONFIG['num_epochs']}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}")

    avg_train_loss = running_loss / total
    train_acc = 100 * correct / total

    # Validation
    model.eval()
    val_loss, val_correct, val_total = 0.0, 0, 0

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)
    avg_val_loss = val_loss / val_total
    val_acc = 100 * val_correct / val_total

    train_losses.append(avg_train_loss)
    val_losses.append(avg_val_loss)
    train_accs.append(train_acc)
    val_accs.append(val_acc)
    scheduler.step(avg_val_loss)
    epoch_time = time.time() - epoch_start
    epoch_times.append(epoch_time)
    improved = avg_val_loss < best_val_loss
    if improved:
        best_val_loss = avg_val_loss
        epochs_no_improve = 0
        torch.save(model.state_dict(), CONFIG["best_model_path"])
    else:
        epochs_no_improve += 1
    print(f"\n Epoch [{epoch+1}/{CONFIG['num_epochs']}] Summary:")
    print(f"   Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"   Val Loss: {avg_val_loss:.4f} | Val Acc: {val_acc:.2f}% {' IMPROVED' if improved else ' NO IMPROVEMENT'}")
    print(f"   Epoch Time: {epoch_time:.2f}s | ETA: ~{int(np.mean(epoch_times)*(CONFIG['num_epochs']-epoch-1))}s")
    print(f"  Patience: {epochs_no_improve}/3\n")
    if epochs_no_improve >= 3:
        print(" Early stopping triggered. Loading best model...")
        model.load_state_dict(torch.load(CONFIG["best_model_path"]))
        break





# ---- 9. SAVE METRICS ----
metrics = {
    "train_loss": train_losses,
    "val_loss": val_losses,
    "train_acc": train_accs,
    "val_acc": val_accs,
    "epoch_times": epoch_times
}
with open(CONFIG["metrics_json"], "w") as f:
    json.dump(metrics, f, indent=4)

print(f" Metrics saved to {CONFIG['metrics_json']}")

# ---- 10. PLOT CURVES ----
plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.legend()
plt.title('Loss Curve')

plt.subplot(1,2,2)
plt.plot(train_accs, label='Train Acc')
plt.plot(val_accs, label='Val Acc')
plt.legend()
plt.title('Accuracy Curve')

plt.savefig(os.path.join(RUN_METRICS_DIR, "training_curves.png"))
plt.show()

# ---- 11. CONFUSION MATRIX & REPORT ----
all_preds, all_labels = [], []
model.eval()
with torch.no_grad():
    for inputs, labels in val_loader:
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

cm = confusion_matrix(all_labels, all_preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=train_data.classes)
disp.plot(cmap=plt.cm.Blues, xticks_rotation=45)
plt.title("Confusion Matrix")
plt.savefig(os.path.join(RUN_METRICS_DIR, "confusion_matrix.png"))
plt.show()

report = classification_report(all_labels, all_preds, target_names=train_data.classes)
with open(os.path.join(RUN_METRICS_DIR, "classification_report.txt"), "w") as f:
    f.write(report)

print("\nTraining complete! All metrics & plots saved.")


In [None]:
# ===========================
# FINAL CUSTOM CNN TRAINING SCRIPT V2
# ===========================
import os
import json
import time
import copy
import subprocess
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report

# ---- 1. GOOGLE DRIVE MOUNT ----
try:
    from google.colab import drive
    drive.mount('/content/drive')
    print(" Google Drive mounted successfully.")
except ImportError:
    print(" Not running in Colab, skipping drive mount.")

# ---- 2. CONFIG ----
CNN_NAME = "customcnn_v2"
NUM_EPOCHS = 50
BASE_METRICS_DIR = "/content/drive/MyDrive/ImageRetrievalProject/metrics"
RUN_METRICS_DIR = os.path.join(BASE_METRICS_DIR, f"{CNN_NAME}_{NUM_EPOCHS}epoch")
os.makedirs(RUN_METRICS_DIR, exist_ok=True)

CONFIG = {
    "batch_size": 64,
    "num_epochs": NUM_EPOCHS,
    "learning_rate": 1.5e-4,
    "weight_decay": 5e-5,
    "min_lr": 1e-7,
    "image_size": (224, 224),
    "best_model_path": os.path.join(RUN_METRICS_DIR, "best_model.pth"),
    "metrics_json": os.path.join(RUN_METRICS_DIR, "training_metrics.json")
}

# ---- 3. DOWNLOAD DATASET FROM GCS ----
GCS_BUCKET_NAME = 'similar_dataset'
GCS_DATASET_PATH = 'cnn_set'
LOCAL_DATA_ROOT = '/content/cnn_set'

def download_dataset_from_gcs(bucket_name, gcs_path, local_path):
    if os.path.exists(local_path):
        print(f" Dataset already exists at {local_path}. Skipping download.")
        return
    print(f"⬇ Downloading dataset from gs://{bucket_name}/{gcs_path} to {local_path}...")
    try:
        target_parent_dir = os.path.dirname(local_path) if os.path.dirname(local_path) else '/content/'
        command = f"gsutil -m cp -r gs://{bucket_name}/{gcs_path} {target_parent_dir}"
        subprocess.run(command, shell=True, check=True, capture_output=True, text=True)
        print("  Dataset downloaded successfully.")
    except subprocess.CalledProcessError as e:
        print(f" Error downloading dataset: {e.stderr}")
        exit()

download_dataset_from_gcs(GCS_BUCKET_NAME, GCS_DATASET_PATH, LOCAL_DATA_ROOT)


# ---- 4. CNN ARCHITECTURE (STABLE VERSION) ----
class CNN(nn.Module):
    def __init__(self, num_classes):
        super(CNN, self).__init__()
        self.conv_block1 = nn.Sequential(
            nn.Conv2d(3, 16, 3, 1, 1), nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(2, 2)
        )
        self.conv_block2 = nn.Sequential(
            nn.Conv2d(16, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2, 2)
        )
        self.conv_block3 = nn.Sequential(
            nn.Conv2d(64, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d(2, 2)
        )
        self.conv_block4 = nn.Sequential(
            nn.Conv2d(256, 1024, 3, 1, 1), nn.BatchNorm2d(1024), nn.ReLU(), nn.MaxPool2d(2, 2)
        )
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc_out = nn.Linear(128, num_classes)
        self.dropout = nn.Dropout(0.12)

    def forward(self, x):
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        x = self.conv_block3(x)
        x = self.conv_block4(x)
        x = self.pool(x)
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        return self.fc_out(x)


# ---- 5. DEVICE ----
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f" Using device: {DEVICE}")

# ---- 6. DATA ----
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(CONFIG["image_size"], scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(CONFIG["image_size"]),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_data = datasets.ImageFolder(os.path.join(LOCAL_DATA_ROOT, "train"), transform=train_transform)
val_data = datasets.ImageFolder(os.path.join(LOCAL_DATA_ROOT, "val"), transform=val_transform)

train_loader = DataLoader(train_data, batch_size=CONFIG["batch_size"], shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_data, batch_size=CONFIG["batch_size"], shuffle=False, num_workers=2, pin_memory=True)

# ---- 7. MODEL/LOSS/OPTIMIZER ----
model = CNN(num_classes=len(train_data.classes)).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=CONFIG["learning_rate"], weight_decay=CONFIG["weight_decay"])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, min_lr=CONFIG["min_lr"])

# ---- 8. TRAINING LOOP ----
train_losses, val_losses, train_accs, val_accs, epoch_times = [], [], [], [], []
best_val_loss = float('inf')
epochs_no_improve = 0
patience = 5

print(" Starting training...\n")
total_start_time = time.time()

for epoch in range(CONFIG["num_epochs"]):
    epoch_start = time.time()
    model.train()
    running_loss, correct, total = 0.0, 0, 0

    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

        if (batch_idx + 1) % max(1, len(train_loader)//5) == 0:
            print(f" Epoch [{epoch+1}/{CONFIG['num_epochs']}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}")

    avg_train_loss = running_loss / total
    train_acc = 100 * correct / total

    # Validation
    model.eval()
    val_loss, val_correct, val_total = 0.0, 0, 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)

    avg_val_loss = val_loss / val_total
    val_acc = 100 * val_correct / val_total

    train_losses.append(avg_train_loss)
    val_losses.append(avg_val_loss)
    train_accs.append(train_acc)
    val_accs.append(val_acc)

    scheduler.step(avg_val_loss)
    epoch_time = time.time() - epoch_start
    epoch_times.append(epoch_time)

    improved = avg_val_loss < best_val_loss
    if improved:
        best_val_loss = avg_val_loss
        epochs_no_improve = 0
        torch.save(model.state_dict(), CONFIG["best_model_path"])
    else:
        epochs_no_improve += 1

    print(f"\n Epoch [{epoch+1}/{CONFIG['num_epochs']}] Summary:")
    print(f"   Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"   Val Loss: {avg_val_loss:.4f} | Val Acc: {val_acc:.2f}% {'*IMPROVED*' if improved else ''}")
    print(f"   Patience: {epochs_no_improve}/{patience}")

    if epochs_no_improve >= patience:
        print(f"\nEarly stopping triggered after {patience} epochs with no improvement. Loading best model.")
        model.load_state_dict(torch.load(CONFIG["best_model_path"]))
        break

# ---- 9. SAVE METRICS ----
metrics = {
    "train_loss": train_losses,
    "val_loss": val_losses,
    "train_acc": train_accs,
    "val_acc": val_accs,
    "epoch_times": epoch_times,
    "best_val_loss": best_val_loss,
}
with open(CONFIG["metrics_json"], "w") as f:
    json.dump(metrics, f, indent=4)

print(f"\nMetrics saved to {CONFIG['metrics_json']}")

# ---- 10. PLOT CURVES ----
plt.style.use('seaborn-v0_8-darkgrid')
plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.legend()
plt.title('Loss Curve')
plt.xlabel('Epoch')
plt.ylabel('Loss')

plt.subplot(1,2,2)
plt.plot(train_accs, label='Train Acc')
plt.plot(val_accs, label='Val Acc')
plt.legend()
plt.title('Accuracy Curve')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')

plt.tight_layout()
plt.savefig(os.path.join(RUN_METRICS_DIR, "training_curves.png"))
plt.show()

# ---- 11. CONFUSION MATRIX & REPORT ----
all_preds, all_labels = [], []
model.eval()
with torch.no_grad():
    for inputs, labels in val_loader:
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

cm = confusion_matrix(all_labels, all_preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=train_data.classes)
fig, ax = plt.subplots(figsize=(10, 10))
disp.plot(cmap=plt.cm.Blues, ax=ax, xticks_rotation=45)
plt.title("Confusion Matrix")
plt.tight_layout()
plt.savefig(os.path.join(RUN_METRICS_DIR, "confusion_matrix.png"))
plt.show()

report = classification_report(all_labels, all_preds, target_names=train_data.classes, zero_division=0)
with open(os.path.join(RUN_METRICS_DIR, "classification_report.txt"), "w") as f:
    f.write(report)
print("\nClassification Report:")
print(report)

print("\nTraining complete! All metrics & plots saved.")

In [None]:
# ========================================
#   Feature Extraction (CustomCNN v2 Only)
# ========================================

import os
import torch
import pickle
from torch import nn
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np

# -------- Configuration --------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 64
IMAGE_SIZE = (224, 224)
NUM_CLASSES = 20
DATASET_PATH = "/content/cnn_set"
FEATURE_SAVE_DIR = "/content/drive/MyDrive/ImageRetrievalProject/features"
os.makedirs(FEATURE_SAVE_DIR, exist_ok=True)
print(f"Using device: {DEVICE}")

# -------- Model Weights Paths (MODIFIED) --------
# Only contains the customcnn model and specifies a unique output name
WEIGHTS = {
    "customcnn": {
        "path": "/content/drive/MyDrive/ImageRetrievalProject/metrics/customcnn_v2_50epoch/best_model.pth",
        "output_filename": "customcnn_v2_features.pkl"
    }
}

# -------- Image Transforms --------
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# -------- Custom Dataset to include image paths --------
class ImagePathDataset(datasets.ImageFolder):
    def __getitem__(self, index):
        path, target = self.samples[index]
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        return img, self.targets[index], path

# -------- Datasets & Dataloaders --------
full_dataset = ImagePathDataset(DATASET_PATH, transform=transform)
full_loader = DataLoader(full_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
print(f"Found {len(full_dataset)} images in {DATASET_PATH}")


# ========================================
#   1. STABLE Custom CNN Definition
# ========================================
class CustomCNN(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super(CustomCNN, self).__init__()
        self.conv_block1 = nn.Sequential(
            nn.Conv2d(3, 16, 3, 1, 1), nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(2, 2)
        )
        self.conv_block2 = nn.Sequential(
            nn.Conv2d(16, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2, 2)
        )
        self.conv_block3 = nn.Sequential(
            nn.Conv2d(64, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d(2, 2)
        )
        self.conv_block4 = nn.Sequential(
            nn.Conv2d(256, 1024, 3, 1, 1), nn.BatchNorm2d(1024), nn.ReLU(), nn.MaxPool2d(2, 2)
        )
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc_out = nn.Linear(128, num_classes)
        self.dropout = nn.Dropout(0.15)

    def forward(self, x):
        x = self.conv_block1(x); x = self.conv_block2(x); x = self.conv_block3(x); x = self.conv_block4(x)
        x = self.pool(x); x = self.flatten(x)
        x = F.relu(self.fc1(x)); x = self.dropout(x); x = F.relu(self.fc2(x))
        return self.fc_out(x)


# ========================================
#   2. Unified Feature Extractor Class
# ========================================
class FeatureExtractor(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.model_name = model_name

        full_model = self._load_full_model(model_name)

        # MODIFIED: Get weight path from nested dictionary
        weight_path = WEIGHTS.get(model_name, {}).get("path")
        if weight_path and os.path.exists(weight_path):
            try:
                full_model.load_state_dict(torch.load(weight_path, map_location=DEVICE))
                print(f"Successfully loaded weights for {model_name} from {weight_path}")
            except Exception as e:
                print(f"Error loading weights for {model_name}: {e}.")
        else:
            print(f"Warning: Weights not found for {model_name}.")

        self.feature_model = self._create_feature_model(full_model).to(DEVICE).eval()

    def _load_full_model(self, name):
        if name == "customcnn":
            return CustomCNN(num_classes=NUM_CLASSES)
        # Other model definitions are not needed since we only run customcnn
        else:
            raise ValueError(f"Unknown model name: {name}")

    def _create_feature_model(self, full_model):
        if self.model_name == "customcnn":
            return nn.Sequential(
                full_model.conv_block1, full_model.conv_block2, full_model.conv_block3, full_model.conv_block4,
                full_model.pool, full_model.flatten,
                full_model.fc1, nn.ReLU(), full_model.dropout, full_model.fc2
            )

    def forward(self, x):
        with torch.no_grad():
            features = self.feature_model(x.to(DEVICE))
            return features.cpu().numpy()


# ========================================
#   3. Main Execution Loop
# ========================================
def main():

    for model_name in WEIGHTS.keys():
        print(f"\n--- Processing model: {model_name.upper()} ---")

        extractor = FeatureExtractor(model_name)

        all_features, all_labels, all_paths = [], [], []
        for inputs, labels, paths in tqdm(full_loader, desc=f"Extracting {model_name} features"):
            batch_features = extractor(inputs)
            all_features.append(batch_features)
            all_labels.extend(labels.numpy())
            all_paths.extend(paths)

        all_features = np.vstack(all_features)


        output_filename = WEIGHTS[model_name].get("output_filename", f"{model_name}_features.pkl")
        save_path = os.path.join(FEATURE_SAVE_DIR, output_filename)

        with open(save_path, 'wb') as f:
            pickle.dump({
                'features': all_features,
                'labels': all_labels,
                'paths': all_paths
            }, f)
        print(f" Saved {len(all_features)} features for {model_name} to {save_path}")

if __name__ == "__main__":
    main()

In [None]:
# ==================================
# TRAINING SCRIPT PRETRAINED MODELS
# ==================================
import os
import time
import json
import subprocess
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report

# -----------------------
# User Config
# -----------------------
MODEL_NAME = "resnet50"       # choose model to be train: "resnet50", "efficientnet_b0", "vgg16"
NUM_EPOCHS = 30
BATCH_SIZE = 64
LEARNING_RATE = 1e-5
WEIGHT_DECAY = 1e-4
MIN_LR = 1e-7
PATIENCE = 4
IMAGE_SIZE = (224, 224)

# GCS dataset info (will download if LOCAL_DATA_ROOT missing)
GCS_BUCKET_NAME = "similar_dataset"
GCS_DATASET_PATH = "cnn_set"
LOCAL_DATA_ROOT = "/content/cnn_set"

DRIVE_CANDIDATE = Path("/content/drive/MyDrive")

# -----------------------
# 1) Mount Google Drive (safe) and set DRIVE_ROOT
# -----------------------
try:
    # only works in Colab
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    print(" Google Drive mount requested.")
except Exception as e:
    print("google.colab.drive not available or mount popup skipped:", e)

if DRIVE_CANDIDATE.exists():
    DRIVE_ROOT = DRIVE_CANDIDATE
    print("Google Drive available:", DRIVE_ROOT)
else:
    DRIVE_ROOT = Path("/content")
    print(" Google Drive not found; using local path", DRIVE_ROOT)

# -----------------------
# 2) Download dataset from GCS if missing
# -----------------------
def download_dataset_from_gcs(bucket_name, gcs_path, local_path):
    if os.path.exists(local_path):
        print(f"Dataset already exists at {local_path}, skipping GCS download.")
        return
    print(f" Downloading dataset from gs://{bucket_name}/{gcs_path} to /content/ ...")
    cmd = f"gsutil -m cp -r gs://{bucket_name}/{gcs_path} /content/"
    try:
        subprocess.run(cmd, shell=True, check=True, capture_output=True, text=True)
        # The GCS copy will create /content/{gcs_path} or /content/cnn_set
        if os.path.exists(local_path):
            print("Dataset downloaded to", local_path)
        else:
            print("Download finished but expected path not found:", local_path)
    except subprocess.CalledProcessError as e:
        print("Error downloading dataset from GCS. stderr:")
        print(e.stderr)
        print("Make sure you have 'gsutil' available and the bucket is public/your account has access.")
        raise


try:
    download_dataset_from_gcs(GCS_BUCKET_NAME, GCS_DATASET_PATH, LOCAL_DATA_ROOT)
except Exception:
    print("Proceeding — if dataset is not present locally, script will fail when loading data.")

# -----------------------
# 3) Verify dataset layout
# -----------------------
train_dir = os.path.join(LOCAL_DATA_ROOT, "train")
val_dir = os.path.join(LOCAL_DATA_ROOT, "val")
if not (os.path.isdir(train_dir) and os.path.isdir(val_dir)):
    raise FileNotFoundError(f"Expected dataset at {train_dir} and {val_dir}. Please ensure dataset exists or GCS download succeeded.")

# -----------------------
# 4) Transforms & Dataloaders
# -----------------------
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.8,1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(0.2,0.2,0.2,0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])
val_transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])

train_dataset = datasets.ImageFolder(train_dir, transform=train_transform)
val_dataset = datasets.ImageFolder(val_dir, transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

num_classes = len(train_dataset.classes)
print(f"Found {num_classes} classes. Example classes: {train_dataset.classes[:8]}")

# -----------------------
# 5) Model factory (replace head)
# -----------------------
def get_model(name, num_classes, device):
    name = name.lower()
    if name == "resnet50":
        model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif name in ("efficientnet_b0", "efficientnetb0", "efficientnet-b0"):
        # handle different torchvision versions: classifier may be Sequential or Linear
        try:
            model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)
        except AttributeError:
            # older torchvision may not have weights enum; try without enum
            model = models.efficientnet_b0(pretrained=True)
        # replace classifier last linear
        if hasattr(model, "classifier"):
            # common pattern: Sequential(Dropout, Linear)
            cl = model.classifier
            if isinstance(cl, nn.Sequential) and isinstance(cl[-1], nn.Linear):
                in_f = cl[-1].in_features
                cl[-1] = nn.Linear(in_f, num_classes)
                model.classifier = cl
            elif isinstance(cl, nn.Linear):
                model.classifier = nn.Linear(cl.in_features, num_classes)
            else:
                raise RuntimeError("Unexpected EfficientNet classifier structure.")
        else:
            raise RuntimeError("EfficientNet model missing 'classifier' attribute.")
    elif name in ("vgg16", "vgg_16"):
        model = models.vgg16(weights=models.VGG16_Weights.DEFAULT)
        # classifier is Sequential; last element usually Linear
        if isinstance(model.classifier, nn.Sequential) and isinstance(model.classifier[-1], nn.Linear):
            in_f = model.classifier[-1].in_features
            model.classifier[-1] = nn.Linear(in_f, num_classes)
        else:
            raise RuntimeError("Unexpected VGG classifier structure.")
    else:
        raise ValueError("Unknown model name: " + name)
    return model.to(device)

# -----------------------
# 6) Device and instantiate model
# -----------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)
model = get_model(MODEL_NAME, num_classes, DEVICE)
print("Model created:", MODEL_NAME)

# -----------------------
# 7) Output directories & model summary
# -----------------------
RUN_METRICS_DIR = Path(DRIVE_ROOT) / "metrics" / f"{MODEL_NAME}_{NUM_EPOCHS}epoch"
RUN_METRICS_DIR.mkdir(parents=True, exist_ok=True)
BEST_MODEL_PATH = RUN_METRICS_DIR / "best_model.pth"
FINAL_MODEL_PATH = RUN_METRICS_DIR / f"{MODEL_NAME}_final.pth"
METRICS_JSON = RUN_METRICS_DIR / "metrics.json"
MODEL_SUMMARY_PATH = RUN_METRICS_DIR / "model_summary.txt"

with open(MODEL_SUMMARY_PATH, "w") as f:
    f.write(str(model))
print("Saving outputs to:", RUN_METRICS_DIR)

# -----------------------
# 8) Optimizer / Loss / Scheduler
# -----------------------
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=PATIENCE, min_lr=MIN_LR)

# -----------------------
# 9) Training loop
# -----------------------
train_losses, val_losses = [], []
train_accs, val_accs = [], []
epoch_times = []
best_val_loss = float("inf")
best_epoch = None
epochs_no_improve = 0

print_interval = max(1, len(train_loader) // 10)  # ~10 prints per epoch

print("Starting training...")

total_start = time.time()
for epoch in range(NUM_EPOCHS):
    epoch_start = time.time()
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(DEVICE)
        labels = labels.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

        if (batch_idx + 1) % print_interval == 0 or (batch_idx + 1) == len(train_loader):
            print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}")

    avg_train_loss = running_loss / total if total > 0 else 0.0
    train_acc = 100.0 * correct / total if total > 0 else 0.0

    # Validation
    model.eval()
    running_val_loss = 0.0
    val_correct = 0
    val_total = 0
    all_val_preds, all_val_labels = [], []

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs = inputs.to(DEVICE)
            labels = labels.to(DEVICE)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_val_loss += loss.item() * inputs.size(0)
            preds = outputs.argmax(dim=1)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)
            all_val_preds.extend(preds.cpu().numpy())
            all_val_labels.extend(labels.cpu().numpy())

    avg_val_loss = running_val_loss / val_total if val_total > 0 else float('inf')
    val_acc = 100.0 * val_correct / val_total if val_total > 0 else 0.0

    train_losses.append(float(avg_train_loss))
    val_losses.append(float(avg_val_loss))
    train_accs.append(float(train_acc))
    val_accs.append(float(val_acc))

    # Scheduler step and LR logging
    prev_lr = optimizer.param_groups[0]['lr']
    scheduler.step(avg_val_loss)
    curr_lr = optimizer.param_groups[0]['lr']
    if curr_lr < prev_lr:
        print(f"Learning rate reduced from {prev_lr:.8f} to {curr_lr:.8f}")

    epoch_time = time.time() - epoch_start
    epoch_times.append(epoch_time)
    avg_epoch_time = float(np.mean(epoch_times))
    remaining_epochs = NUM_EPOCHS - (epoch + 1)
    eta_seconds = int(remaining_epochs * avg_epoch_time)

    improved = avg_val_loss < best_val_loss
    if improved:
        best_val_loss = float(avg_val_loss)
        best_epoch = epoch + 1
        epochs_no_improve = 0
        torch.save(model.state_dict(), BEST_MODEL_PATH)
    else:
        epochs_no_improve += 1

    improvement_str = "  New best model saved!" if improved else ""
    print(f"\nEpoch [{epoch+1}/{NUM_EPOCHS}] Summary:")
    print(f"  Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"  Val   Loss: {avg_val_loss:.4f} | Val   Acc: {val_acc:.2f}%{improvement_str}")
    print(f"  Epoch Time: {epoch_time:.2f}s | ETA: ~{eta_seconds//60}m {eta_seconds%60}s")
    print(f"  Patience Counter: {epochs_no_improve}/{PATIENCE}\n")

    # early stopping prompt
    if epochs_no_improve >= PATIENCE:
        choice = input(f" Validation loss hasn't improved for {PATIENCE} epochs. Type 'stop' to end training or press Enter to continue: ")
        if choice.lower() == 'stop':
            print("Early stopping chosen. Restoring best model and stopping training.")
            if BEST_MODEL_PATH.exists():
                model.load_state_dict(torch.load(BEST_MODEL_PATH, map_location=DEVICE))
            break
        else:
            print("Continuing training (you pressed Enter). Patience counter reset.")
            epochs_no_improve = 0

# end training
total_time = time.time() - total_start
print(f"Training finished in {total_time/60:.2f} minutes. Best epoch: {best_epoch}, best_val_loss: {best_val_loss:.6f}")

# ensure best weights loaded
if BEST_MODEL_PATH.exists():
    try:
        model.load_state_dict(torch.load(BEST_MODEL_PATH, map_location=DEVICE))
        print(f"Loaded best model from epoch {best_epoch}.")
    except Exception as e:
        print("Warning: could not load best checkpoint:", e)

# save final model
torch.save(model.state_dict(), FINAL_MODEL_PATH)
print("Saved final model to:", FINAL_MODEL_PATH)

# -----------------------
# 10) Final evaluation & save metrics/plots
# -----------------------
model.eval()
all_preds, all_labels = [], []
with torch.no_grad():
    for inputs, labels in val_loader:
        inputs = inputs.to(DEVICE); labels = labels.to(DEVICE)
        outputs = model(inputs)
        preds = outputs.argmax(dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# confusion matrix + classification report
if len(all_labels) > 0:
    cm = confusion_matrix(all_labels, all_preds)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=train_dataset.classes)
    disp.plot(cmap=plt.cm.Blues, xticks_rotation=45)
    cm_path = RUN_METRICS_DIR / "confusion_matrix.png"
    plt.title("Confusion Matrix")
    plt.savefig(cm_path, bbox_inches='tight')
    plt.show()
    print("Saved confusion matrix to:", cm_path)

    report = classification_report(all_labels, all_preds, target_names=train_dataset.classes, digits=4)
    report_path = RUN_METRICS_DIR / "classification_report.txt"
    with open(report_path, "w") as f:
        f.write(report)
    print("Saved classification report to:", report_path)
    print("\nClassification report:\n", report)
else:
    print("No validation labels found; skipping confusion matrix and report.")

# save loss/accuracy curves
epochs_range = range(1, len(train_losses) + 1)
plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.plot(epochs_range, train_losses, label='Train Loss')
plt.plot(epochs_range, val_losses, label='Val Loss')
plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.title("Loss Curve"); plt.legend()
plt.subplot(1,2,2)
plt.plot(epochs_range, train_accs, label='Train Acc')
plt.plot(epochs_range, val_accs, label='Val Acc')
plt.xlabel("Epoch"); plt.ylabel("Accuracy (%)"); plt.title("Accuracy Curve"); plt.legend()
lossacc_path = RUN_METRICS_DIR / "training_curves.png"
plt.savefig(lossacc_path, bbox_inches='tight')
plt.show()
print("Saved training curves to:", lossacc_path)

# save epoch times and metrics json
epoch_times_path = RUN_METRICS_DIR / "epoch_times.npy"
np.save(epoch_times_path, np.array(epoch_times))
print("Saved epoch times to:", epoch_times_path)

metrics = {
    "train_losses": train_losses,
    "val_losses": val_losses,
    "train_accuracies": train_accs,
    "val_accuracies": val_accs,
    "epoch_times": epoch_times,
    "best_val_loss": float(best_val_loss),
    "best_epoch": best_epoch,
    "model_name": MODEL_NAME,
    "num_epochs_run": len(train_losses),
    "class_labels": train_dataset.classes
}
with open(METRICS_JSON, "w") as f:
    json.dump(metrics, f, indent=2)
print("Saved metrics JSON to:", METRICS_JSON)

print("All done. Files saved in:", RUN_METRICS_DIR)
