In [1]:
import os
import json
import pickle
import random
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split, Subset
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.lines import Line2D
import cv2
import numpy as np
from torchvision.models import ResNet18_Weights
from sklearn.metrics import precision_recall_curve, average_precision_score
from torchvision.models import vision_transformer
# ===============================
# Configuration and Parameters
# ===============================

# Paths
DATASET_DIR = 'C:/Users/nicla/DTU/ComputerVision/3rd/Potholes'
ANNOTATED_IMAGES_DIR = os.path.join(DATASET_DIR, 'annotated-images')
TRAINING_DATA_FILE = os.path.join(DATASET_DIR, 'training_data.pkl')

# Training Parameters
NUM_CLASSES = 2  # 1 object class + 1 background
BATCH_SIZE = 32
NUM_EPOCHS = 50  # Maximum number of epochs
LEARNING_RATE = 0.001
VALIDATION_SPLIT = 0.2
RANDOM_SEED = 42
PATIENCE = 5  # Early stopping patience

USE_FULL_DATA = True  # Use the full dataset or a reduced portion

# Device Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

# ===============================
# Helper Functions
# ===============================

def save_precision_recall_curve(labels, probs, model_name):
    """
    Saves the Precision-Recall curve for a model with higher resolution.
    
    Args:
        labels (list): True labels.
        probs (list): Predicted probabilities for the positive class.
        model_name (str): Name of the model (used for saving the plot).
    """
    precision, recall, _ = precision_recall_curve(labels, probs)
    average_precision = average_precision_score(labels, probs)
    
    plt.figure(figsize=(8, 6))  # Increase the figure size if needed
    plt.plot(recall, precision, marker='.', label=f'AP = {average_precision:.2f}')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title(f'Precision-Recall Curve ({model_name})')
    plt.legend()
    plt.grid(True)
    
    # Save the plot with higher resolution (300 dpi)
    plot_save_path = os.path.join(DATASET_DIR, f'precision_recall_{model_name}.png')
    plt.savefig(plot_save_path, dpi=300)  # Set dpi for high resolution
    plt.close()
    print(f'Precision-Recall curve saved to {plot_save_path}')
    
# ===============================
# Task 1: Build the CNN
# ===============================

def build_model(model_type, num_classes):
    if model_type == 'simple_cnn':
        class SimpleCNN(nn.Module):
            def __init__(self):
                super(SimpleCNN, self).__init__()
                self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
                self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
                self.fc1 = nn.Linear(32 * 56 * 56, 128)
                self.fc2 = nn.Linear(128, num_classes)
                self.pool = nn.MaxPool2d(2, 2)
                self.relu = nn.ReLU()
            
            def forward(self, x):
                x = self.pool(self.relu(self.conv1(x)))
                x = self.pool(self.relu(self.conv2(x)))
                x = x.view(-1, 32 * 56 * 56)
                x = self.relu(self.fc1(x))
                x = self.fc2(x)
                return x

        return SimpleCNN()
    
    elif model_type == 'deep_cnn':
        class DeepCNN(nn.Module):
            def __init__(self):
                super(DeepCNN, self).__init__()
                self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)  # Increased channels
                self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
                self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)  # More channels in deeper layers
                self.conv4 = nn.Conv2d(256, 512, kernel_size=3, padding=1)  # Additional convolutional layer
                self.fc1 = nn.Linear(512 * 14 * 14, 1024)  # Adjusted after more downsampling
                self.fc2 = nn.Linear(1024, 512)
                self.fc3 = nn.Linear(512, 128)
                self.fc4 = nn.Linear(128, num_classes)  # Final classification layer
                self.pool = nn.MaxPool2d(2, 2)
                self.relu = nn.ReLU()
                self.dropout = nn.Dropout(0.5)  # Dropout for regularization
                self.batch_norm1 = nn.BatchNorm2d(64)  # Batch Normalization for stability
                self.batch_norm2 = nn.BatchNorm2d(128)
                self.batch_norm3 = nn.BatchNorm2d(256)
                self.batch_norm4 = nn.BatchNorm2d(512)
        
            def forward(self, x):
                x = self.pool(self.relu(self.batch_norm1(self.conv1(x))))  # Conv -> Batch Norm -> ReLU -> Pool
                x = self.pool(self.relu(self.batch_norm2(self.conv2(x))))
                x = self.pool(self.relu(self.batch_norm3(self.conv3(x))))
                x = self.pool(self.relu(self.batch_norm4(self.conv4(x))))
                x = x.view(-1, 512 * 14 * 14)  # Flatten the output for fully connected layers
                x = self.relu(self.fc1(x))
                x = self.dropout(x)  # Apply dropout after dense layers
                x = self.relu(self.fc2(x))
                x = self.relu(self.fc3(x))
                x = self.fc4(x)  # Output layer
                return x

        return DeepCNN()
    
    elif model_type == 'resnet18':
        model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, num_classes)
        return model
    
    elif model_type == 'vgg16':
        model = models.vgg16(pretrained=True)
        num_ftrs = model.classifier[6].in_features
        model.classifier[6] = nn.Linear(num_ftrs, num_classes)
        return model

    elif model_type == 'vit':
        # Use Vision Transformer (ViT) from torchvision
        model = vision_transformer.vit_b_16(pretrained=True)  # Load the ViT model with pre-trained weights
        
        # Replace the classifier head with a custom linear layer
        model.heads = nn.Sequential(
            nn.Linear(model.heads[0].in_features, num_classes)
        )
        
        return model
        
    else: raise ValueError("Unknown model type. Choose from 'simple_cnn', 'deep_cnn', 'resnet18', 'vgg16', 'vit'.")

# ===============================
# Task 2: Create the DataLoader
# ===============================

class ProposalDataset(Dataset):
    """
    Custom Dataset for Object Proposals.
    """
    def __init__(self, proposals, image_dir, transform=None):
        """
        Args:
            proposals (list): List of proposal dictionaries.
            image_dir (str): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.proposals = proposals
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.proposals)

    def __getitem__(self, idx):
        proposal = self.proposals[idx]
        image_filename = proposal['image_filename']
        bbox = proposal['bbox']
        label = proposal['label']

        # Load image
        image_path = os.path.join(self.image_dir, image_filename)
        image = Image.open(image_path).convert('RGB')

        # Crop the proposal region
        cropped_image = image.crop((bbox['xmin'], bbox['ymin'], bbox['xmax'], bbox['ymax']))

        # Apply transforms
        if self.transform:
            cropped_image = self.transform(cropped_image)

        return cropped_image, label

def load_data(training_data_file):
    """
    Loads the training data from a pickle file.
    """
    with open(training_data_file, 'rb') as f:
        combined_data = pickle.load(f)
    proposals = combined_data['proposals']
    ground_truths = combined_data['ground_truths']
    return proposals, ground_truths

# ===============================
# Task 3: Fine-tune the Network
# ===============================

def train_model_with_early_stopping(model, criterion, optimizer, train_loader, val_loader, model_name, patience=PATIENCE):
    best_loss = float('inf')
    patience_counter = 0
    best_model_wts = None
    
    for epoch in range(NUM_EPOCHS):
        print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
        
        # Training phase
        model.train()
        running_loss = 0.0
        for inputs, labels in tqdm(train_loader, desc="Training"):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * inputs.size(0)
        
        train_loss = running_loss / len(train_loader.dataset)
        print(f"Training Loss: {train_loss:.4f}")

        # Validation phase
        model.eval()
        val_loss = 0.0
        all_labels = []
        all_probs = []

        with torch.no_grad():
            for inputs, labels in tqdm(val_loader, desc="Validation"):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * inputs.size(0)
                
                # For precision-recall calculation
                probs = torch.softmax(outputs, dim=1)[:, 1].cpu().numpy()
                all_probs.extend(probs)
                all_labels.extend(labels.cpu().numpy())
        
        val_loss /= len(val_loader.dataset)
        print(f"Validation Loss: {val_loss:.4f}")

        # Early stopping check
        if val_loss < best_loss:
            best_loss = val_loss
            best_model_wts = model.state_dict()
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping triggered after {patience} epochs without improvement.")
                break

    # Load the best weights
    model.load_state_dict(best_model_wts)
    
    # Save the precision-recall curve and model weights
    save_precision_recall_curve(all_labels, all_probs, model_name)
    model_save_path = os.path.join(DATASET_DIR, f'{model_name}_best_weights.pth')
    torch.save(model.state_dict(), model_save_path)
    print(f'Model saved to {model_save_path}')
    
    return model
# ===============================
# Task 4: Evaluate the Model
# ===============================

def evaluate_model(model, dataloader):
    """
    Evaluates the model's classification accuracy.
    """
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Evaluating"):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total if total > 0 else 0
    print(f'Validation Accuracy: {accuracy * 100:.2f}%')
    return accuracy

# ===============================
# Visualization Function
# ===============================

def visualize_samples(model, subset, used_dataset, full_dataset, ground_truths, num_samples=5):
    """
    Visualizes a few samples from the dataset with ground truth and predictions.

    Args:
        model: Trained PyTorch model.
        subset: Subset of the dataset (e.g., validation set).
        used_dataset: The dataset used for training/validation (either full or reduced).
        full_dataset: The original ProposalDataset.
        ground_truths: Dictionary of ground truth boxes.
        num_samples: Number of samples to visualize.
    """
    model.eval()
    samples = random.sample(range(len(subset)), num_samples)

    for idx in samples:
        # Map subset index to original dataset index
        if isinstance(used_dataset, Subset):
            reduced_idx = used_dataset.indices[idx]
        else:
            reduced_idx = idx  # When using full dataset
        proposal = full_dataset.proposals[reduced_idx]

        image_filename = proposal['image_filename']
        bbox = proposal['bbox']
        label = proposal['label']

        # Load image
        image_path = os.path.join(full_dataset.image_dir, image_filename)
        image = Image.open(image_path).convert('RGB')
        cropped_image = image.crop((bbox['xmin'], bbox['ymin'], bbox['xmax'], bbox['ymax']))

        # Apply transforms
        transform = full_dataset.transform
        transformed_image = transform(cropped_image).unsqueeze(0).to(device)

        # Forward pass
        with torch.no_grad():
            output = model(transformed_image)
            _, predicted = torch.max(output, 1)
            predicted = predicted.item()

        # Convert image for plotting
        image_np = transformed_image.squeeze(0).cpu().numpy()
        image_np = np.transpose(image_np, (1, 2, 0))
        image_np = image_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        image_np = np.clip(image_np, 0, 1)

        # Plot the image
        fig, ax = plt.subplots(1, figsize=(6, 6))
        ax.imshow(image_np)
        ax.axis('off')
        title = f"True Label: {'Object' if label == 1 else 'Background'} | Predicted: {'Object' if predicted == 1 else 'Background'}"
        ax.set_title(title)

        # Draw ground truth boxes (if available)
        gt_boxes = ground_truths.get(image_filename, [])
        for gt in gt_boxes:
            rect = patches.Rectangle((gt['xmin'], gt['ymin']),
                                     gt['xmax'] - gt['xmin'],
                                     gt['ymax'] - gt['ymin'],
                                     linewidth=2, edgecolor='green', facecolor='none')
            ax.add_patch(rect)

        plt.show()

# ===============================
# Main Execution
# ===============================

def main():
    torch.manual_seed(RANDOM_SEED)
    random.seed(RANDOM_SEED)
    np.random.seed(RANDOM_SEED)

    # Load data
    proposals, ground_truths = load_data(TRAINING_DATA_FILE)
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    full_dataset = ProposalDataset(proposals, ANNOTATED_IMAGES_DIR, transform=transform)
    
    if USE_FULL_DATA:
        used_dataset = full_dataset
    else:
        total_samples = len(full_dataset)
        reduced_sample_size = max(1, int(total_samples * 0.05))
        sampled_indices = random.sample(range(total_samples), reduced_sample_size)
        used_dataset = Subset(full_dataset, sampled_indices)

    dataset_size = len(used_dataset)
    val_size = int(dataset_size * VALIDATION_SPLIT)
    train_size = dataset_size - val_size
    train_dataset, val_dataset = random_split(used_dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    # Define the models to train
    model_types = ['simple_cnn', 'deep_cnn', 'resnet18', 'vgg16', 'vit']
    #model_types = ['vit']

    for model_type in model_types:
        print(f"\nTraining {model_type}...\n")
        model = build_model(model_type, NUM_CLASSES).to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)

        # Train the model with early stopping
        train_model_with_early_stopping(model, criterion, optimizer, train_loader, val_loader, model_name=model_type)

if __name__ == "__main__":
    main()

Using device: cuda

Training simple_cnn...

Epoch 1/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [01:14<00:00, 14.54it/s]


Training Loss: 0.4065


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:15<00:00, 16.90it/s]


Validation Loss: 0.3477
Epoch 2/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [01:16<00:00, 14.17it/s]


Training Loss: 0.2935


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:15<00:00, 17.08it/s]


Validation Loss: 0.2717
Epoch 3/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [01:18<00:00, 13.82it/s]


Training Loss: 0.2553


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:15<00:00, 17.08it/s]


Validation Loss: 0.2460
Epoch 4/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [01:19<00:00, 13.60it/s]


Training Loss: 0.2241


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:15<00:00, 17.06it/s]


Validation Loss: 0.2407
Epoch 5/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [01:18<00:00, 13.75it/s]


Training Loss: 0.1921


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:16<00:00, 16.58it/s]


Validation Loss: 0.2282
Epoch 6/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [01:15<00:00, 14.36it/s]


Training Loss: 0.1603


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:16<00:00, 16.83it/s]


Validation Loss: 0.2360
Epoch 7/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [01:16<00:00, 14.20it/s]


Training Loss: 0.1272


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:16<00:00, 16.44it/s]


Validation Loss: 0.2381
Epoch 8/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [01:13<00:00, 14.73it/s]


Training Loss: 0.1006


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:16<00:00, 16.54it/s]


Validation Loss: 0.2524
Epoch 9/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [01:15<00:00, 14.23it/s]


Training Loss: 0.0720


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:17<00:00, 15.84it/s]


Validation Loss: 0.3089
Epoch 10/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [01:15<00:00, 14.33it/s]


Training Loss: 0.0547


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:16<00:00, 16.60it/s]


Validation Loss: 0.2782
Early stopping triggered after 5 epochs without improvement.
Precision-Recall curve saved to C:/Users/nicla/DTU/ComputerVision/3rd/Potholes\precision_recall_simple_cnn.png
Model saved to C:/Users/nicla/DTU/ComputerVision/3rd/Potholes\simple_cnn_best_weights.pth

Training deep_cnn...

Epoch 1/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [01:59<00:00,  9.07it/s]


Training Loss: 0.3707


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:19<00:00, 13.98it/s]


Validation Loss: 0.3063
Epoch 2/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [01:57<00:00,  9.23it/s]


Training Loss: 0.2797


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:18<00:00, 14.29it/s]


Validation Loss: 0.2458
Epoch 3/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [01:54<00:00,  9.41it/s]


Training Loss: 0.2484


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:19<00:00, 13.89it/s]


Validation Loss: 0.2188
Epoch 4/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [01:54<00:00,  9.44it/s]


Training Loss: 0.2258


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:19<00:00, 14.16it/s]


Validation Loss: 0.2146
Epoch 5/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [01:55<00:00,  9.37it/s]


Training Loss: 0.2069


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:19<00:00, 14.00it/s]


Validation Loss: 0.1908
Epoch 6/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [01:58<00:00,  9.11it/s]


Training Loss: 0.1930


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:19<00:00, 14.00it/s]


Validation Loss: 0.2209
Epoch 7/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [01:57<00:00,  9.18it/s]


Training Loss: 0.1793


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:19<00:00, 14.15it/s]


Validation Loss: 0.1754
Epoch 8/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [01:56<00:00,  9.26it/s]


Training Loss: 0.1670


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:19<00:00, 13.77it/s]


Validation Loss: 0.1707
Epoch 9/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [01:55<00:00,  9.39it/s]


Training Loss: 0.1513


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:19<00:00, 13.82it/s]


Validation Loss: 0.1667
Epoch 10/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [01:56<00:00,  9.27it/s]


Training Loss: 0.1441


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:19<00:00, 13.87it/s]


Validation Loss: 0.1588
Epoch 11/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [01:56<00:00,  9.24it/s]


Training Loss: 0.1298


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:19<00:00, 14.11it/s]


Validation Loss: 0.1593
Epoch 12/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [01:55<00:00,  9.34it/s]


Training Loss: 0.1203


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:19<00:00, 13.92it/s]


Validation Loss: 0.1759
Epoch 13/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [01:53<00:00,  9.51it/s]


Training Loss: 0.1087


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:19<00:00, 13.83it/s]


Validation Loss: 0.2124
Epoch 14/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [01:56<00:00,  9.25it/s]


Training Loss: 0.1006


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:19<00:00, 14.17it/s]


Validation Loss: 0.1705
Epoch 15/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [01:50<00:00,  9.78it/s]


Training Loss: 0.0957


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:18<00:00, 14.33it/s]


Validation Loss: 0.1738
Early stopping triggered after 5 epochs without improvement.
Precision-Recall curve saved to C:/Users/nicla/DTU/ComputerVision/3rd/Potholes\precision_recall_deep_cnn.png
Model saved to C:/Users/nicla/DTU/ComputerVision/3rd/Potholes\deep_cnn_best_weights.pth

Training resnet18...

Epoch 1/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [01:20<00:00, 13.45it/s]


Training Loss: 0.2198


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:17<00:00, 15.51it/s]


Validation Loss: 0.1421
Epoch 2/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [01:19<00:00, 13.55it/s]


Training Loss: 0.1181


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:18<00:00, 14.95it/s]


Validation Loss: 0.1001
Epoch 3/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [01:19<00:00, 13.62it/s]


Training Loss: 0.0685


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:18<00:00, 14.90it/s]


Validation Loss: 0.0942
Epoch 4/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [01:19<00:00, 13.60it/s]


Training Loss: 0.0390


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:18<00:00, 14.93it/s]


Validation Loss: 0.0801
Epoch 5/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [01:19<00:00, 13.59it/s]


Training Loss: 0.0231


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:18<00:00, 14.85it/s]


Validation Loss: 0.0952
Epoch 6/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [01:19<00:00, 13.61it/s]


Training Loss: 0.0215


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:18<00:00, 14.91it/s]


Validation Loss: 0.0840
Epoch 7/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [01:19<00:00, 13.62it/s]


Training Loss: 0.0129


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:18<00:00, 14.60it/s]


Validation Loss: 0.0829
Epoch 8/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [01:19<00:00, 13.52it/s]


Training Loss: 0.0105


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:18<00:00, 14.85it/s]


Validation Loss: 0.0862
Epoch 9/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [01:19<00:00, 13.50it/s]


Training Loss: 0.0076


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:18<00:00, 14.49it/s]


Validation Loss: 0.0807
Early stopping triggered after 5 epochs without improvement.
Precision-Recall curve saved to C:/Users/nicla/DTU/ComputerVision/3rd/Potholes\precision_recall_resnet18.png
Model saved to C:/Users/nicla/DTU/ComputerVision/3rd/Potholes\resnet18_best_weights.pth

Training vgg16...





Epoch 1/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [03:13<00:00,  5.57it/s]


Training Loss: 0.2544


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:24<00:00, 11.23it/s]


Validation Loss: 0.1668
Epoch 2/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [03:13<00:00,  5.59it/s]


Training Loss: 0.1436


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:24<00:00, 11.00it/s]


Validation Loss: 0.1285
Epoch 3/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [03:17<00:00,  5.48it/s]


Training Loss: 0.0984


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:24<00:00, 11.07it/s]


Validation Loss: 0.1215
Epoch 4/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [03:15<00:00,  5.52it/s]


Training Loss: 0.0672


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:24<00:00, 10.93it/s]


Validation Loss: 0.0963
Epoch 5/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [03:15<00:00,  5.51it/s]


Training Loss: 0.0434


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:25<00:00, 10.62it/s]


Validation Loss: 0.1151
Epoch 6/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [03:16<00:00,  5.49it/s]


Training Loss: 0.0293


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:24<00:00, 11.11it/s]


Validation Loss: 0.1026
Epoch 7/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [03:13<00:00,  5.59it/s]


Training Loss: 0.0234


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:24<00:00, 11.17it/s]


Validation Loss: 0.0912
Epoch 8/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [03:15<00:00,  5.54it/s]


Training Loss: 0.0166


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:24<00:00, 11.05it/s]


Validation Loss: 0.1009
Epoch 9/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [03:11<00:00,  5.63it/s]


Training Loss: 0.0127


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:24<00:00, 10.81it/s]


Validation Loss: 0.1177
Epoch 10/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [03:11<00:00,  5.64it/s]


Training Loss: 0.0136


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:24<00:00, 11.19it/s]


Validation Loss: 0.1025
Epoch 11/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [03:08<00:00,  5.73it/s]


Training Loss: 0.0104


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:24<00:00, 11.21it/s]


Validation Loss: 0.1019
Epoch 12/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [03:11<00:00,  5.63it/s]


Training Loss: 0.0059


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:24<00:00, 11.06it/s]


Validation Loss: 0.1148
Early stopping triggered after 5 epochs without improvement.
Precision-Recall curve saved to C:/Users/nicla/DTU/ComputerVision/3rd/Potholes\precision_recall_vgg16.png
Model saved to C:/Users/nicla/DTU/ComputerVision/3rd/Potholes\vgg16_best_weights.pth

Training vit...





Epoch 1/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [04:12<00:00,  4.28it/s]


Training Loss: 0.1942


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:27<00:00,  9.68it/s]


Validation Loss: 0.1290
Epoch 2/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [04:04<00:00,  4.41it/s]


Training Loss: 0.1097


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:27<00:00,  9.84it/s]


Validation Loss: 0.1087
Epoch 3/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [04:07<00:00,  4.36it/s]


Training Loss: 0.0742


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:27<00:00,  9.87it/s]


Validation Loss: 0.0951
Epoch 4/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [04:14<00:00,  4.24it/s]


Training Loss: 0.0528


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:28<00:00,  9.51it/s]


Validation Loss: 0.1019
Epoch 5/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [04:06<00:00,  4.37it/s]


Training Loss: 0.0392


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:27<00:00,  9.96it/s]


Validation Loss: 0.1048
Epoch 6/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [04:00<00:00,  4.50it/s]


Training Loss: 0.0270


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:26<00:00, 10.05it/s]


Validation Loss: 0.1117
Epoch 7/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [04:09<00:00,  4.33it/s]


Training Loss: 0.0184


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:27<00:00,  9.97it/s]


Validation Loss: 0.1040
Epoch 8/50


Training: 100%|████████████████████████████████████████████████████████████████████| 1080/1080 [04:04<00:00,  4.41it/s]


Training Loss: 0.0217


Validation: 100%|████████████████████████████████████████████████████████████████████| 270/270 [00:27<00:00,  9.88it/s]


Validation Loss: 0.1256
Early stopping triggered after 5 epochs without improvement.
Precision-Recall curve saved to C:/Users/nicla/DTU/ComputerVision/3rd/Potholes\precision_recall_vit.png
Model saved to C:/Users/nicla/DTU/ComputerVision/3rd/Potholes\vit_best_weights.pth


In [4]:
import os
import json
import pickle
import torch
import torch.nn as nn
from torchvision import transforms, models
from sklearn.metrics import precision_recall_curve, average_precision_score
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
from torchvision.ops import nms
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

# ===============================
# Paths and Configuration
# ===============================

DATASET_DIR = 'C:/Users/nicla/DTU/ComputerVision/3rd/Potholes'
ANNOTATED_IMAGES_DIR = os.path.join(DATASET_DIR, 'annotated-images')
TRAINING_DATA_FILE = os.path.join(DATASET_DIR, 'training_data.pkl')
PROPOSALS_FILE = os.path.join(DATASET_DIR, 'selective_search_proposals_fast.json')
SPLITS_FILE = os.path.join(DATASET_DIR, 'splits.json')
MODELS_DIR = DATASET_DIR
BATCH_SIZE = 32
NUM_CLASSES = 2
IOU_THRESHOLD_NMS = 0.3
IOU_THRESHOLD_EVAL = 0.5
CONFIDENCE_THRESHOLD = 0.5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {device}")

# ===============================
# Helper Functions
# ===============================

def load_model(model_path):
    """
    Load a model based on its filename.
    """
    if 'resnet18' in model_path:
        model = models.resnet18()
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, NUM_CLASSES)
    elif 'vgg16' in model_path:
        model = models.vgg16()
        num_ftrs = model.classifier[6].in_features
        model.classifier[6] = nn.Linear(num_ftrs, NUM_CLASSES)
    elif 'vit' in model_path:
        model = models.vision_transformer.vit_b_16(pretrained=True)
        model.heads = nn.Sequential(nn.Linear(model.heads[0].in_features, NUM_CLASSES))
    elif 'deep_cnn' in model_path:
        class DeepCNN(nn.Module):
            def __init__(self):
                super(DeepCNN, self).__init__()
                self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)  # Increased channels
                self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
                self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)  # More channels in deeper layers
                self.conv4 = nn.Conv2d(256, 512, kernel_size=3, padding=1)  # Additional convolutional layer
                self.fc1 = nn.Linear(512 * 14 * 14, 1024)  # Adjusted after more downsampling
                self.fc2 = nn.Linear(1024, 512)
                self.fc3 = nn.Linear(512, 128)
                self.fc4 = nn.Linear(128, NUM_CLASSES)  # Final classification layer
                self.pool = nn.MaxPool2d(2, 2)
                self.relu = nn.ReLU()
                self.dropout = nn.Dropout(0.5)  # Dropout for regularization
                self.batch_norm1 = nn.BatchNorm2d(64)  # Batch Normalization for stability
                self.batch_norm2 = nn.BatchNorm2d(128)
                self.batch_norm3 = nn.BatchNorm2d(256)
                self.batch_norm4 = nn.BatchNorm2d(512)

            def forward(self, x):
                x = self.pool(self.relu(self.batch_norm1(self.conv1(x))))  # Conv -> Batch Norm -> ReLU -> Pool
                x = self.pool(self.relu(self.batch_norm2(self.conv2(x))))
                x = self.pool(self.relu(self.batch_norm3(self.conv3(x))))
                x = self.pool(self.relu(self.batch_norm4(self.conv4(x))))
                x = x.view(-1, 512 * 14 * 14)  # Flatten the output for fully connected layers
                x = self.relu(self.fc1(x))
                x = self.dropout(x)  # Apply dropout after dense layers
                x = self.relu(self.fc2(x))
                x = self.relu(self.fc3(x))
                x = self.fc4(x)  # Output layer
                return x

        model = DeepCNN()
    elif 'simple_cnn' in model_path:
        class SimpleCNN(nn.Module):
            def __init__(self):
                super(SimpleCNN, self).__init__()
                self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
                self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
                self.fc1 = nn.Linear(32 * 56 * 56, 128)
                self.fc2 = nn.Linear(128, NUM_CLASSES)
                self.pool = nn.MaxPool2d(2, 2)
                self.relu = nn.ReLU()
            
            def forward(self, x):
                x = self.pool(self.relu(self.conv1(x)))
                x = self.pool(self.relu(self.conv2(x)))
                x = x.view(-1, 32 * 56 * 56)
                x = self.relu(self.fc1(x))
                x = self.fc2(x)
                return x

        model = SimpleCNN()
    else:
        raise ValueError(f"Unknown model type in {model_path}")

    # Load the model weights
    model.load_state_dict(torch.load(model_path, map_location=device))
    model = model.to(device)
    model.eval()
    return model

def parse_annotation(xml_file):
    """
    Parse Pascal VOC XML annotations.
    """
    import xml.etree.ElementTree as ET
    tree = ET.parse(xml_file)
    root = tree.getroot()
    boxes = []
    for obj in root.findall('object'):
        bndbox = obj.find('bndbox')
        bbox = {
            'xmin': int(float(bndbox.find('xmin').text)),
            'ymin': int(float(bndbox.find('ymin').text)),
            'xmax': int(float(bndbox.find('xmax').text)),
            'ymax': int(float(bndbox.find('ymax').text))
        }
        boxes.append(bbox)
    return boxes

def compute_iou(box1, box2):
    """
    Compute IoU between two bounding boxes.
    """
    x_left = max(box1['xmin'], box2['xmin'])
    y_top = max(box1['ymin'], box2['ymin'])
    x_right = min(box1['xmax'], box2['xmax'])
    y_bottom = min(box1['ymax'], box2['ymax'])
    if x_right < x_left or y_bottom < y_top:
        return 0.0
    intersection_area = (x_right - x_left + 1) * (y_bottom - y_top + 1)
    box1_area = (box1['xmax'] - box1['xmin'] + 1) * (box1['ymax'] - box1['ymin'] + 1)
    box2_area = (box2['xmax'] - box2['xmin'] + 1) * (box2['ymax'] - box2['ymin'] + 1)
    return intersection_area / float(box1_area + box2_area - intersection_area)

def apply_nms(detections, iou_threshold):
    """
    Apply Non-Maximum Suppression (NMS) to detections.
    """
    if len(detections) == 0:
        return []
    boxes = torch.tensor([d['bbox'] for d in detections], dtype=torch.float32)
    scores = torch.tensor([d['score'] for d in detections], dtype=torch.float32)
    keep_indices = nms(boxes, scores, iou_threshold)
    return [detections[i] for i in keep_indices]

def evaluate_model(model, dataloader, model_name):
    """
    Evaluate the model and generate Precision-Recall curve.
    """
    all_labels = []
    all_probs = []

    model.eval()
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc=f"Evaluating {model_name}"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            probs = nn.functional.softmax(outputs, dim=1)[:, 1]
            all_probs.extend(probs.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    precision, recall, _ = precision_recall_curve(all_labels, all_probs)
    ap = average_precision_score(all_labels, all_probs)

    # Save Precision-Recall curve
    plt.figure(figsize=(8, 6))
    plt.plot(recall, precision, label=f"AP = {ap:.2f}")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title(f"Precision-Recall Curve - {model_name}")
    plt.legend()
    plt.grid(True)
    plot_save_path = os.path.join(DATASET_DIR, f"precision_recall_{model_name}_test.png")
    plt.savefig(plot_save_path, dpi=300)
    plt.close()
    print(f"Saved Precision-Recall curve for {model_name} to {plot_save_path}")

# ===============================
# Main Execution
# ===============================

def main():
    # Load the dataset
    with open(TRAINING_DATA_FILE, 'rb') as f:
        combined_data = pickle.load(f)
    proposals = combined_data['proposals']

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    class ProposalDataset(torch.utils.data.Dataset):
        def __init__(self, proposals, image_dir, transform=None):
            self.proposals = proposals
            self.image_dir = image_dir
            self.transform = transform

        def __len__(self):
            return len(self.proposals)

        def __getitem__(self, idx):
            proposal = self.proposals[idx]
            image_filename = proposal['image_filename']
            bbox = proposal['bbox']
            label = proposal['label']
            image_path = os.path.join(self.image_dir, image_filename)
            image = Image.open(image_path).convert("RGB")
            cropped_image = image.crop((bbox['xmin'], bbox['ymin'], bbox['xmax'], bbox['ymax']))
            if self.transform:
                cropped_image = self.transform(cropped_image)
            return cropped_image, label

    dataset = ProposalDataset(proposals, ANNOTATED_IMAGES_DIR, transform=transform)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)

    # Evaluate all saved models
    saved_models = [f for f in os.listdir(MODELS_DIR) if f.endswith("_best_weights.pth")]
    for model_file in saved_models:
        model_path = os.path.join(MODELS_DIR, model_file)
        model_name = model_file.replace("_best_weights.pth", "")
        print(f"Evaluating model: {model_name}")
        model = load_model(model_path)
        evaluate_model(model, dataloader, model_name)

if __name__ == "__main__":
    main()


Using device: cuda
Evaluating model: deep_cnn


  model.load_state_dict(torch.load(model_path, map_location=device))
Evaluating deep_cnn: 100%|█████████████████████████████████████████████████████████| 1350/1350 [01:30<00:00, 14.85it/s]
  model.load_state_dict(torch.load(model_path, map_location=device))


Saved Precision-Recall curve for deep_cnn to C:/Users/nicla/DTU/ComputerVision/3rd/Potholes\precision_recall_deep_cnn_test.png
Evaluating model: resnet18


Evaluating resnet18: 100%|█████████████████████████████████████████████████████████| 1350/1350 [01:32<00:00, 14.67it/s]
  model.load_state_dict(torch.load(model_path, map_location=device))


Saved Precision-Recall curve for resnet18 to C:/Users/nicla/DTU/ComputerVision/3rd/Potholes\precision_recall_resnet18_test.png
Evaluating model: simple_cnn


Evaluating simple_cnn: 100%|███████████████████████████████████████████████████████| 1350/1350 [01:15<00:00, 17.84it/s]


Saved Precision-Recall curve for simple_cnn to C:/Users/nicla/DTU/ComputerVision/3rd/Potholes\precision_recall_simple_cnn_test.png
Evaluating model: vgg16


  model.load_state_dict(torch.load(model_path, map_location=device))
Evaluating vgg16: 100%|████████████████████████████████████████████████████████████| 1350/1350 [01:58<00:00, 11.42it/s]


Saved Precision-Recall curve for vgg16 to C:/Users/nicla/DTU/ComputerVision/3rd/Potholes\precision_recall_vgg16_test.png
Evaluating model: vit


  model.load_state_dict(torch.load(model_path, map_location=device))
Evaluating vit: 100%|██████████████████████████████████████████████████████████████| 1350/1350 [02:12<00:00, 10.21it/s]


Saved Precision-Recall curve for vit to C:/Users/nicla/DTU/ComputerVision/3rd/Potholes\precision_recall_vit_test.png
