# Car Brand Detection System
This notebook covers the end-to-end process of:
1. Data Inspection
2. Data Splitting (Train/Val/Test) with **Data Augmentation**
3. Model Training (ResNet18) with **Fine-tuning**
4. Evaluation & Inference

In [None]:
import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms, models
import torchvision
from PIL import Image
from collections import Counter
import matplotlib.pyplot as plt
import numpy as np
import torch_directml # Import for AMD GPU support

# Configuration
DATA_DIR = "archive" 
MODEL_SAVE_PATH = "car_brand_model.pth"
BEST_MODEL_SAVE_PATH = "car_brand_model_best.pth"
MAPPING_SAVE_PATH = "class_mapping.json"
BATCH_SIZE = 32
EPOCHS = 10
LEARNING_RATE = 0.0001 # Lower learning rate for fine-tuning
IMG_SIZE = 224

# Set device to AMD GPU via DirectML
try:
    device = torch_directml.device()
    print(f"Using device: {device} (AMD GPU via DirectML)")
except Exception as e:
    print(f"Error setting DirectML device: {e}. Falling back to CPU.")
    device = torch.device("cpu")

In [None]:
def analyze_dataset(directory):
    data = []
    if not os.path.exists(directory):
        print(f"Directory not found: {directory}")
        return

    for filename in os.listdir(directory):
        if filename.lower().endswith((".jpg", ".jpeg", ".png")):
            parts = filename.split('_')
            if len(parts) > 0:
                make = parts[0]
                data.append(make)
    
    print(f"Total images: {len(data)}")
    print(f"Unique makes: {len(set(data))}")
    
    counts = Counter(data)
    print("\nTop 20 Makes distribution:")
    
    # Visualization
    makes, values = zip(*counts.most_common(20))
    plt.figure(figsize=(12, 6))
    plt.bar(makes, values)
    plt.xticks(rotation=45)
    plt.title("Top 20 Car Brands Distribution")
    plt.show()

analyze_dataset(DATA_DIR)

In [None]:
class CarDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.classes = set()
        
        valid_extensions = {'.jpg', '.jpeg', '.png'}
        if os.path.exists(root_dir):
            for filename in os.listdir(root_dir):
                if os.path.splitext(filename)[1].lower() in valid_extensions:
                    parts = filename.split('_')
                    if len(parts) > 0:
                        make = parts[0]
                        self.image_paths.append(os.path.join(root_dir, filename))
                        self.labels.append(make)
                        self.classes.add(make)
        
        self.classes = sorted(list(self.classes))
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
        
        # print(f"Found {len(self.image_paths)} images from {len(self.classes)} classes.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label_name = self.labels[idx]
        label_idx = self.class_to_idx[label_name]
        
        try:
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image, label_idx
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            return torch.zeros((3, IMG_SIZE, IMG_SIZE)), label_idx

# Transforms with Augmentation for Training
train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create base dataset to get classes and length
base_dataset = CarDataset(DATA_DIR, transform=None)

# Save class mapping
with open(MAPPING_SAVE_PATH, 'w') as f:
    json.dump(base_dataset.classes, f)
print(f"Saved class mapping to {MAPPING_SAVE_PATH}")

# Manual Split to apply different transforms
total_size = len(base_dataset)
train_size = int(0.7 * total_size)
val_size = int(0.15 * total_size)
test_size = total_size - train_size - val_size

# Generate random indices
indices = torch.randperm(total_size).tolist()
train_indices = indices[:train_size]
val_indices = indices[train_size:train_size+val_size]
test_indices = indices[train_size+val_size:]

# Create Subsets with specific transforms
# Note: We instantiate CarDataset again for each split to attach the correct transform
train_dataset = Subset(CarDataset(DATA_DIR, transform=train_transform), train_indices)
val_dataset = Subset(CarDataset(DATA_DIR, transform=val_transform), val_indices)
test_dataset = Subset(CarDataset(DATA_DIR, transform=val_transform), test_indices)

print(f"Data Split Summary:")
print(f"  Training Set:   {len(train_dataset)} images (with Augmentation)")
print(f"  Validation Set: {len(val_dataset)} images")
print(f"  Test Set:       {len(test_dataset)} images")

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

In [None]:
# Model Setup & Fine-tuning Logic
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(base_dataset.classes))

# Load previous best model if available
# FIX: Load to CPU first to avoid DirectML map_location issues
if os.path.exists(BEST_MODEL_SAVE_PATH):
    print(f"Loading existing model weights from {BEST_MODEL_SAVE_PATH}...")
    try:
        # Load to CPU first
        state_dict = torch.load(BEST_MODEL_SAVE_PATH, map_location="cpu")
        model.load_state_dict(state_dict)
        print("Weights loaded successfully! Continuing training...")
    except Exception as e:
        print(f"Error loading weights: {e}. Starting from scratch.")
else:
    print("No previous model found. Training from scratch.")

# Now move model to device
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

In [None]:
# Training Loop
best_acc = 0.0

print("Evaluating initial performance...")
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in val_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
initial_acc = 100 * correct / total
print(f"Initial Validation Accuracy: {initial_acc:.2f}%")
best_acc = initial_acc

train_losses = []
val_accuracies = []

for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")
    model.train()
    running_loss = 0.0
    
    for i, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 0:
            print(f"Batch {i}/{len(train_loader)}, Loss: {loss.item():.4f}")
    
    scheduler.step()
    epoch_loss = running_loss / len(train_loader)
    train_losses.append(epoch_loss)

    # Validation
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    epoch_acc = 100 * correct / total
    val_accuracies.append(epoch_acc)
    print(f"Validation Accuracy: {epoch_acc:.2f}%")
    
    if epoch_acc > best_acc:
        best_acc = epoch_acc
        torch.save(model.state_dict(), BEST_MODEL_SAVE_PATH)
        print(f"New best model saved with accuracy: {best_acc:.2f}%")

# Save Final Model
torch.save(model.state_dict(), MODEL_SAVE_PATH)
print(f"Final model saved to {MODEL_SAVE_PATH}")

# Plotting
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.title('Training Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(val_accuracies, label='Validation Accuracy')
plt.title('Validation Accuracy')
plt.legend()
plt.show()

### Final Evaluation on Test Set

In [None]:
# Load best model
# FIX: Load to CPU first
model.load_state_dict(torch.load(BEST_MODEL_SAVE_PATH, map_location="cpu"))
model = model.to(device)
model.eval()

correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Final Test Accuracy: {100 * correct / total:.2f}%")

In [None]:
# Inference Visualization
def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.cpu().numpy().transpose((1, 2, 0)) # Move to CPU for numpy
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated

# Get a batch of test data
inputs, classes_idx = next(iter(test_loader))

# Make a grid from batch
out = torchvision.utils.make_grid(inputs[:4])

outputs = model(inputs[:4].to(device))
_, preds = torch.max(outputs, 1)

class_names = base_dataset.classes
title = [f"Pred: {class_names[x]}" for x in preds]

plt.figure(figsize=(15, 5))
imshow(out, title=title)