In [1]:
#!/usr/bin/env python
# coding: utf-8

import os
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm





In [3]:
# =========================
# Configuration and Settings
# =========================

class Config:
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Data paths
    data_paths = {
        '1_12': r'Bat/Final Testing Dataset',
        '1_4': r'Bat/1-4',
        '5_8': r'Bat/5-8',
        '9_12': r'Bat/9-12',
        'Top': r'Data\Top_level'
    }

    # Training parameters
    training_params = {
        'batch_size': 32,
        'num_workers': 4,
        'epochs': 50
    }

    # Learning rates for different models
    learning_rates = {
        '1_12': 0.001,
        '1_4': 0.0001,
        '5_8': 0.0001,
        '9_12': 0.0001,
        'Top': 0.0001
    }

    # Number of classes for each model
    num_classes = {
        '1_12': 12,
        '1_4': 4,
        '5_8': 4,
        '9_12': 4,
        'Top': 3
    }

    # Model save paths
    model_save_paths = {
        '1_12': '1_12_bats.pth',
        '1_4': '1_4_model.pth',
        '5_8': '5_8_model.pth',
        '9_12': '9_12_model.pth',
        'Top': 'top_model.pth'
    }

# Initialize configuration
cfg = Config()



In [None]:
# =========================
# Data Handling Functions
# =========================

def get_data_transforms():
    """Define and return data transformations."""
    return transforms.Compose([
        transforms.ToTensor(),
        transforms.RandomRotation(45)
    ])

def load_dataset(root_dir, transform):
    """
    Load dataset using ImageFolder.

    Args:
        root_dir (str): Path to the dataset directory.
        transform (torchvision.transforms.Compose): Transformations to apply.

    Returns:
        torch.utils.data.Dataset: Loaded dataset.
    """
    return datasets.ImageFolder(root=root_dir, transform=transform)

def split_dataset(dataset, val_ratio=0.2, test_ratio=0.2):
    """
    Split dataset into training, validation, and test sets.

    Args:
        dataset (torch.utils.data.Dataset): The dataset to split.
        val_ratio (float): Fraction of data for validation.
        test_ratio (float): Fraction of data for testing.

    Returns:
        tuple: train_dataset, val_dataset, test_dataset
    """
    total_size = len(dataset)
    val_size = int(val_ratio * total_size)
    test_size = int(test_ratio * total_size)
    train_size = total_size - val_size - test_size
    return random_split(dataset, [train_size, val_size, test_size])

def get_dataloaders(train_dataset, val_dataset, test_dataset, batch_size, num_workers):
    """
    Create DataLoaders for training, validation, and testing.

    Args:
        train_dataset (torch.utils.data.Dataset): Training dataset.
        val_dataset (torch.utils.data.Dataset): Validation dataset.
        test_dataset (torch.utils.data.Dataset): Test dataset.
        batch_size (int): Batch size.
        num_workers (int): Number of subprocesses for data loading.

    Returns:
        tuple: train_loader, val_loader, test_loader
    """
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    return train_loader, val_loader, test_loader



In [None]:
# =========================
# Training and Evaluation Functions
# =========================

def train_model(model, train_loader, val_loader, loss_fn, optimizer, device, epochs, model_name):
    """
    Train the given model and validate after each epoch.

    Args:
        model (nn.Module): The model to train.
        train_loader (DataLoader): DataLoader for training data.
        val_loader (DataLoader): DataLoader for validation data.
        loss_fn (nn.Module): Loss function.
        optimizer (torch.optim.Optimizer): Optimizer.
        device (torch.device): Device to train on.
        epochs (int): Number of epochs.
        model_name (str): Name identifier for the model.

    Returns:
        nn.Module: Trained model.
    """
    model.to(device)
    for epoch in range(epochs):
        model.train()
        print(f"Epoch: {epoch + 1}/{epochs} - Model: {model_name}")
        pbar = tqdm(train_loader, desc=f"Training Epoch {epoch+1}")
        running_loss = 0.0
        correct = 0
        total = 0
        for inputs, labels in pbar:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            pbar.set_postfix({'Loss': f"{running_loss / (total):.4f}",
                              'Accuracy': f"{correct / total:.4f}"})

        # Validation after each epoch
        val_acc, val_loss = evaluate_model(model, val_loader, loss_fn, device)
        print(f"Validation - Accuracy: {val_acc:.4f}, Loss: {val_loss:.4f}\n")

    return model

def evaluate_model(model, data_loader, loss_fn, device):
    """
    Evaluate the model on the given dataset.

    Args:
        model (nn.Module): The model to evaluate.
        data_loader (DataLoader): DataLoader for the dataset.
        loss_fn (nn.Module): Loss function.
        device (torch.device): Device to evaluate on.

    Returns:
        tuple: Accuracy, Average Loss
    """
    model.eval()
    correct = 0
    total = 0
    total_loss = 0.0
    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    avg_loss = total_loss / len(data_loader)
    return accuracy, avg_loss



In [None]:
# =========================
# Combined Model Class
# =========================

class CombinedModel(nn.Module):
    """
    Combined model that uses a top-level model to decide which specialized model to use for each input.
    """
    def __init__(self, top_model, specialized_models, device):
        """
        Initialize the CombinedModel.

        Args:
            top_model (nn.Module): The top-level model.
            specialized_models (dict): Dictionary of specialized models.
            device (torch.device): Device to run the models on.
        """
        super(CombinedModel, self).__init__()
        self.top_model = top_model
        self.specialized_models = nn.ModuleDict(specialized_models)
        self.device = device

    def forward(self, x):
        """
        Forward pass through the combined model.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Combined output tensor.
        """
        # Get decisions from the top-level model
        decision_logits = self.top_model(x)
        decisions = torch.argmax(decision_logits, dim=1)  # Shape: (batch_size,)

        batch_size = x.size(0)
        total_classes = sum([cfg.num_classes[key] for key in self.specialized_models.keys()])
        combined_outputs = torch.zeros(batch_size, total_classes).to(self.device)

        # Process inputs in batches based on decisions
        for decision, model_key in enumerate(self.specialized_models.keys()):
            indices = (decisions == decision).nonzero(as_tuple=True)[0]
            if len(indices) == 0:
                continue  # No samples for this decision
            subset = x[indices]
            outputs = self.specialized_models[model_key](subset)
            # Determine the class index offset
            class_offset = sum([cfg.num_classes[key] for key in list(self.specialized_models.keys())[:decision]])
            combined_outputs[indices, class_offset:class_offset + cfg.num_classes[model_key]] = outputs

        return combined_outputs


In [None]:
# =========================
# Main Execution Flow
# =========================

def main():
    # Define data transformations
    transform = get_data_transforms()

    # Dictionary to hold trained specialized models
    trained_models = {}

    # Train Specialized Models
    for model_key in ['1_12', '1_4', '5_8', '9_12', 'Top']:
        root_dir = cfg.data_paths[model_key]
        dataset = load_dataset(root_dir, transform)
        train_ds, val_ds, test_ds = split_dataset(dataset)
        train_loader, val_loader, test_loader = get_dataloaders(
            train_ds, val_ds, test_ds,
            cfg.training_params['batch_size'],
            cfg.training_params['num_workers']
        )

        # Initialize model
        model = CNN(cfg.num_classes[model_key])

        # Define loss and optimizer
        loss_fn = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=cfg.learning_rates[model_key])

        # Train the model
        trained_model = train_model(
            model, train_loader, val_loader,
            loss_fn, optimizer,
            cfg.device,
            cfg.training_params['epochs'],
            model_key
        )

        # Save the trained model's state_dict
        torch.save(trained_model.state_dict(), cfg.model_save_paths[model_key])
        print(f"Model {model_key} saved to {cfg.model_save_paths[model_key]}\n")

        # Store the trained model
        trained_models[model_key] = trained_model

    # Initialize Combined Model
    specialized_model_keys = ['1_4', '5_8', '9_12']
    specialized_models = {key: trained_models[key] for key in specialized_model_keys}
    combined_model = CombinedModel(
        top_model=trained_models['Top'],
        specialized_models=specialized_models,
        device=cfg.device
    ).to(cfg.device)

    # Load state_dicts for specialized models if not already loaded
    for key in specialized_model_keys:
        combined_model.specialized_models[key].to(cfg.device)

    # Define loss function for combined model
    combined_loss_fn = nn.CrossEntropyLoss()



In [None]:
    # =========================
    # Evaluation of Combined Model
    # =========================

    # Load the test dataset for combined evaluation
    test_root_dir = cfg.data_paths['1_12']
    test_dataset = load_dataset(test_root_dir, transform)
    _, _, combined_test_ds = split_dataset(test_dataset)
    _, _, combined_test_loader = get_dataloaders(
        *split_dataset(test_dataset),
        cfg.training_params['batch_size'],
        cfg.training_params['num_workers']
    )

    # Evaluate the combined model
    combined_model.eval()
    correct = 0
    total = 0
    total_loss = 0.0
    with torch.no_grad():
        for inputs, labels in tqdm(combined_test_loader, desc="Evaluating Combined Model"):
            inputs, labels = inputs.to(cfg.device), labels.to(cfg.device)
            outputs = combined_model(inputs)
            loss = combined_loss_fn(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    test_accuracy = correct / total
    test_avg_loss = total_loss / len(combined_test_loader)
    print(f"Combined Model - Test Accuracy: {test_accuracy:.4f}, Test Loss: {test_avg_loss:.4f}")

if __name__ == "__main__":
    main()