In [None]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import os
import copy
from collections import Counter
import pandas as pd

# Check GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Define directories
train_dir = 'splitRand_Toshiba/train'
val_dir = 'splitRand_Toshiba/val'
test_dir = 'splitRand_Toshiba/test'

# Define DenseNet model with dropout and fine-tuning
class DenseNetWithDropout(nn.Module):
    def __init__(self, base_model, dropout_rate=0.5, num_classes=3):
        super(DenseNetWithDropout, self).__init__()
        self.features = base_model.features
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(dropout_rate)
        self.classifier = nn.Linear(base_model.classifier.in_features, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = self.classifier(x)
        return x

# Initialize the base DenseNet model
base_model = models.densenet201(weights='DenseNet201_Weights.IMAGENET1K_V1')
model = DenseNetWithDropout(base_model, dropout_rate=0.5, num_classes=3).to(device)

# Freeze initial layers for fine-tuning
for param in model.features.parameters():
    param.requires_grad = False

# Unfreeze deeper layers
for param in model.features[-20:].parameters():  # Adjust the number of layers as needed
    param.requires_grad = True

# Define transformations
base_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

augmentation_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(30),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
    transforms.RandomAffine(degrees=20, shear=10),
    transforms.RandomCrop(224, padding=15),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load datasets
train_dataset = datasets.ImageFolder(train_dir, transform=base_transform)
val_dataset = datasets.ImageFolder(val_dir, transform=base_transform)
test_dataset = datasets.ImageFolder(test_dir, transform=base_transform)

# Balance training dataset
class_counts = Counter(train_dataset.targets)
class_indices = {cls: np.where(np.array(train_dataset.targets) == cls)[0] for cls in class_counts}

minority_class = min(class_counts, key=class_counts.get)
majority_count = max(class_counts.values())

augmented_samples = []
for _ in range(majority_count - class_counts[minority_class]):
    sample_idx = np.random.choice(class_indices[minority_class])
    img_path, label = train_dataset.samples[sample_idx]
    img = Image.open(img_path).convert("RGB")
    augmented_img = augmentation_transform(img)
    augmented_samples.append((augmented_img, label))

balanced_train_data = [(img, label) for img, label in train_dataset] + augmented_samples
balanced_train_dataset = torch.utils.data.TensorDataset(
    torch.stack([item[0] for item in balanced_train_data]),
    torch.tensor([item[1] for item in balanced_train_data])
)

# Data loaders
train_loader = DataLoader(balanced_train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Define loss function, optimizer, and scheduler
criterion = nn.CrossEntropyLoss()
fine_tune_optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-5, weight_decay=1e-4)
scheduler = ReduceLROnPlateau(fine_tune_optimizer, mode='min', factor=0.5, patience=2, verbose=True)

# Training function
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=25, patience=5, save_path="densenet_fine_tuned.pth"):
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
    best_model_weights = copy.deepcopy(model.state_dict())
    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(num_epochs):
        model.train()
        train_loss, train_correct, total_samples = 0, 0, 0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * inputs.size(0)
            train_correct += (outputs.argmax(1) == labels).sum().item()
            total_samples += labels.size(0)

        train_loss /= len(train_loader.dataset)
        train_acc = train_correct / total_samples

        model.eval()
        val_loss, val_correct, total_samples = 0, 0, 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * inputs.size(0)
                val_correct += (outputs.argmax(1) == labels).sum().item()
                total_samples += labels.size(0)

        val_loss /= len(val_loader.dataset)
        val_acc = val_correct / total_samples

        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_acc'].append(train_acc)
        history['val_acc'].append(val_acc)

        print(f"Epoch {epoch+1}/{num_epochs} - "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} - "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

        scheduler.step(val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_weights = copy.deepcopy(model.state_dict())
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered.")
                break

    model.load_state_dict(best_model_weights)
    torch.save(model.state_dict(), save_path)  # Save the fine-tuned model
    print(f"Model saved to {save_path}")
    return history

# Metrics function
def compute_metrics(model, loader, phase, save_csv=False, csv_filename="results.csv"):
    model.eval()
    true_labels, predictions = [], []

    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            preds = outputs.argmax(1)
            true_labels.extend(labels.cpu().numpy())
            predictions.extend(preds.cpu().numpy())

    conf_matrix = confusion_matrix(true_labels, predictions)
    class_report = classification_report(true_labels, predictions, target_names=train_dataset.classes)

    print(f"\n{phase} Classification Report:\n")
    print(class_report)

    plt.figure(figsize=(8, 6))
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=train_dataset.classes, yticklabels=train_dataset.classes)
    plt.title(f'{phase} Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.show()

    if save_csv:
        results = pd.DataFrame({
            'True Label': true_labels,
            'Predicted Label': predictions
        })
        results.to_csv(csv_filename, index=False)
        print(f"Results saved to {csv_filename}")

# Train and evaluate
history = train_model(model, train_loader, val_loader, criterion, fine_tune_optimizer, scheduler, num_epochs=10, patience=3, save_path="densenet_fine_tuned.pth")

# Compute and display metrics
compute_metrics(model, train_loader, "Training")
compute_metrics(model, val_loader, "Validation")
compute_metrics(model, test_loader, "Testing", save_csv=True, csv_filename="densenet_test_results.csv")

# Plot training history
plt.figure(figsize=(12, 5))

# Loss
plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

# Accuracy
plt.subplot(1, 2, 2)
plt.plot(history['train_acc'], label='Train Accuracy')
plt.plot(history['val_acc'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()

plt.tight_layout()
plt.show()
