In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from torchsummary import summary
from thop import profile
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# Enhanced structure
class HardSwish(nn.Module):
    def forward(self, x):
        return x * F.relu6(x + 3) / 6

class MixConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_sizes, stride=1, padding=1):
        super(MixConv, self).__init__()
        self.groups = len(kernel_sizes)
        self.convs = nn.ModuleList([
            nn.Conv2d(in_channels, out_channels // self.groups, kernel_size=k, stride=stride, padding=padding, groups=in_channels)
            for k in kernel_sizes
        ])

    def forward(self, x):
        splits = torch.split(x, x.size(1) // self.groups, 1)
        outputs = [conv(split) for conv, split in zip(self.convs, splits)]
        return torch.cat(outputs, 1)

class SandGlassBlock(nn.Module):
    def __init__(self, in_channels, out_channels, expansion_factor, stride):
        super(SandGlassBlock, self).__init__()
        self.stride = stride
        mid_channels = in_channels * expansion_factor

        self.expand = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            HardSwish()
        )

        self.dwconv = nn.Sequential(
            MixConv(mid_channels, mid_channels, kernel_sizes=[3, 5, 7], stride=stride, padding=1),
            nn.BatchNorm2d(mid_channels),
            HardSwish()
        )

        self.project = nn.Sequential(
            nn.Conv2d(mid_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        if self.stride == 1 and x.size(1) == self.project[1].num_features:
            return x + self.project(self.dwconv(self.expand(x)))
        else:
            return self.project(self.dwconv(self.expand(x)))

# Define the model architecture (same as when it was saved)
class CustomEfficientNet(nn.Module):
    def replace_mbconv(self):
        def replace_block(block, expansion_factor):
            in_channels = block[0][0].in_channels
            out_channels = block[-1].out_channels
            stride = block[0][0].stride
            return SandGlassBlock(in_channels, out_channels, expansion_factor=expansion_factor, stride=stride)

        for name, module in self.base_model.features.named_children():
            if isinstance(module, models.efficientnet.MBConv):
                if module.block[1].expand_ratio == 1:  # MBConv1
                    new_module = replace_block(module.block, expansion_factor=1)
                elif module.block[1].expand_ratio == 6:  # MBConv6
                    new_module = replace_block(module.block, expansion_factor=6)

                setattr(self.base_model.features, name, new_module)
                
    def __init__(self, num_classes=2):
        super(CustomEfficientNet, self).__init__()
        self.base_model = models.efficientnet_b0(pretrained=False)
        
        # Replace MBConv1 and MBConv6 blocks
        self.replace_mbconv()

        # Modify the classifier
        num_ftrs = self.base_model.classifier[1].in_features
        self.base_model.classifier = nn.Sequential(
            nn.Linear(num_ftrs, 256),
            HardSwish(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        return self.base_model(x)

# Instantiate the model
model = CustomEfficientNet(num_classes=2)

# Move model to GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Load the saved model state
model_path = 'Saved Model/EnhancedENetB0_v0.pth'
model.load_state_dict(torch.load(model_path))

# Set the model to evaluation mode
model.eval()

# Define transformations for training and testing datasets
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Set dataset paths
train_dataset = datasets.ImageFolder(root='dataset/train', transform=transform)
test_dataset = datasets.ImageFolder(root='dataset/test', transform=transform)
demo_dataset = datasets.ImageFolder(root='dataset_demo', transform=transform)

# Adjust batch size
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
demo_loader = DataLoader(demo_dataset, batch_size=batch_size, shuffle=False)

# Function to calculate accuracy and get predictions
def calculate_accuracy_and_predictions(loader, model):
    model.eval()
    correct = 0
    total = 0
    all_labels = []
    all_preds = []
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(predicted.cpu().numpy())
    accuracy = 100 * correct / total
    return accuracy, all_labels, all_preds

# Calculate training and testing accuracy and predictions
train_accuracy, train_labels, train_preds = calculate_accuracy_and_predictions(train_loader, model)
test_accuracy, test_labels, test_preds = calculate_accuracy_and_predictions(test_loader, model)
demo_accuracy, demo_labels, demo_preds = calculate_accuracy_and_predictions(demo_loader, model)

# Calculate FLOPs and number of parameters
input_size = (1, 3, 224, 224)
flops, params = profile(model, inputs=(torch.randn(input_size).to(device),))

print(f'Training Accuracy: {train_accuracy:.2f}%')
print(f'Testing Accuracy: {test_accuracy:.2f}%')
print(f'Demo Accuracy: {demo_accuracy:.2f}%')
print(f'FLOPs: {flops}')
print(f'Number of parameters: {params}')

# Print model summary
summary(model, input_size[1:])

# Plot confusion matrix
def plot_confusion_matrix(labels, preds, title):
    cm = confusion_matrix(labels, preds)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['Adults', 'Children'])
    disp.plot(cmap=plt.cm.Blues)
    plt.title(title)
    plt.show()

# Plot training confusion matrix
plot_confusion_matrix(train_labels, train_preds, 'Training Confusion Matrix')

# Plot testing confusion matrix
plot_confusion_matrix(test_labels, test_preds, 'Testing Confusion Matrix')

# Plot demo confusion matrix
plot_confusion_matrix(demo_labels, demo_preds, 'Demo Confusion Matrix')