# Deep Learning Project - Image Classification

In [None]:
import torch
import os
from model import ResNet34
from train import Trainer
from dataset_loader import get_data_loaders

# Create necessary directories if they don't exist
os.makedirs('../logs', exist_ok=True)
os.makedirs('../checkpoints', exist_ok=True)

## Loading Data

In [None]:
# Load and check data
train_loader, val_loader, test_loader = get_data_loaders(batch_size=32)

print(f"Number of training batches: {len(train_loader)}")
print(f"Number of validation batches: {len(val_loader)}")
print(f"Number of test batches: {len(test_loader)}")

# Get class names
class_names = train_loader.dataset.classes
print(f"Classes: {class_names}")

## Initialize and Train Model

In [None]:
# Initialize model
model = ResNet34(num_classes=2)

# Initialize trainer
trainer = Trainer(
    model=model,
    epochs=50, 
    train_loader = train_loader,
    val_loader = val_loader,
    patience=5  # Early stopping patience
)

# Train the model
trainer.train()

## 4. Evaluate on Test Set

In [None]:
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

def evaluate_model(test_loader):

    _, all_preds, all_labels = trainer.evaluate(test_loader)
    # Print classification report
    print("Classification Report:")
    print(classification_report(all_labels, all_preds, target_names=class_names))
    
    # Plot confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.show()

# Load best model
checkpoint = torch.load('../checkpoints/best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded model from epoch {checkpoint['epoch']} with validation accuracy: {checkpoint['best_accuracy']:.4f}")

# Evaluate on test set
evaluate_model( test_loader)