# Image Classification using PyTorch

This notebook demonstrates a binary image classification task using PyTorch and transfer learning with ResNet-18. We'll classify cats into Egyptian and Persian breeds.

## Overview
1. **Setup Environment**: Import necessary libraries and set up configurations
2. **Data Loading**: Load and preprocess the dataset
3. **Model Creation**: Set up a pre-trained ResNet-18 model
4. **Training**: Train the model with progress tracking
5. **Evaluation**: Test the model and visualize results

## Requirements
- PyTorch
- torchvision
- PIL
- matplotlib
- tqdm
- numpy

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm
import os
import random
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

In [None]:
# Check for GPU availability
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Data transformations with better augmentation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

In [None]:
def load_data(data_dir='dataset', batch_size=32):
    """Load and prepare data loaders"""
    try:
        image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
                         for x in ['train', 'val']}
        
        dataloaders = {x: DataLoader(image_datasets[x], batch_size=batch_size,
                                    shuffle=True, num_workers=4)
                      for x in ['train', 'val']}
        
        dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
        class_names = image_datasets['train'].classes
        
        print(f"Classes found: {class_names}")
        print(f"Dataset sizes: {dataset_sizes}")
        
        return dataloaders, dataset_sizes, class_names
    
    except Exception as e:
        print(f"Error loading data: {str(e)}")
        raise

# Load the data
dataloaders, dataset_sizes, class_names = load_data()

In [None]:
def setup_model(num_classes=2):
    """Set up the model with transfer learning"""
    model = models.resnet18(pretrained=True)
    
    # Freeze all layers except the final ones
    for param in model.parameters():
        param.requires_grad = False
        
    # Modify the final layer for our number of classes
    num_ftrs = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Linear(num_ftrs, 512),
        nn.ReLU(),
        nn.Dropout(0.2),
        nn.Linear(512, num_classes)
    )
    
    return model.to(device)

# Initialize model, loss function, and optimizer
model = setup_model()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)

In [None]:
def train_model(model, criterion, optimizer, num_epochs=10):
    """Train the model with progress tracking and visualization"""
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
    best_acc = 0.0
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0
            
            # Progress bar
            pbar = tqdm(dataloaders[phase], desc=f'{phase} phase')
            
            for inputs, labels in pbar:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                
                # Update progress bar
                pbar.set_postfix({'loss': f'{loss.item():.4f}'})

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            
            # Store history
            history[f'{phase}_loss'].append(epoch_loss)
            history[f'{phase}_acc'].append(epoch_acc)

            print(f'{phase.capitalize()} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                torch.save(model.state_dict(), 'best_model.pth')

        print('-' * 60)
    
    return model, history

# Train the model
model, history = train_model(model, criterion, optimizer)

In [None]:
def plot_training_history(history):
    """Plot training and validation metrics"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Plot loss
    ax1.plot(history['train_loss'], label='Training Loss')
    ax1.plot(history['val_loss'], label='Validation Loss')
    ax1.set_title('Model Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    
    # Plot accuracy
    ax2.plot(history['train_acc'], label='Training Accuracy')
    ax2.plot(history['val_acc'], label='Validation Accuracy')
    ax2.set_title('Model Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.legend()
    
    plt.tight_layout()
    plt.show()

# Plot training history
plot_training_history(history)

In [None]:
def predict_image(image_path, model, class_names):
    """Predict class for a single image"""
    try:
        # Load and preprocess the image
        image = Image.open(image_path)
        input_tensor = data_transforms['val'](image)
        input_batch = input_tensor.unsqueeze(0).to(device)
        
        # Make prediction
        model.eval()
        with torch.no_grad():
            output = model(input_batch)
            probabilities = torch.nn.functional.softmax(output[0], dim=0)
            _, predicted_idx = torch.max(output, 1)
        
        # Get prediction and confidence
        predicted_class = class_names[predicted_idx.item()]
        confidence = probabilities[predicted_idx.item()].item()
        
        # Display results
        plt.figure(figsize=(8, 6))
        plt.imshow(image)
        plt.axis('off')
        plt.title(f'Predicted: {predicted_class}\nConfidence: {confidence:.2%}')
        plt.show()
        
        return predicted_class, confidence
        
    except Exception as e:
        print(f"Error predicting image: {str(e)}")
        raise

# Example usage:
# predict_image('path_to_test_image.jpg', model, class_names)