In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from torchsummary import summary
from thop import profile

# Define transformations for training and testing datasets
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Set dataset paths
train_dataset = datasets.ImageFolder(root='dataset/train', transform=transform)
test_dataset = datasets.ImageFolder(root='dataset/test', transform=transform)

# Adjust batch size
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Enhanced structure
class HardSwish(nn.Module):
    def forward(self, x):
        return x * F.relu6(x + 3) / 6

class MixConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_sizes, stride=1, padding=1):
        super(MixConv, self).__init__()
        self.groups = len(kernel_sizes)
        self.convs = nn.ModuleList([
            nn.Conv2d(in_channels, out_channels // self.groups, kernel_size=k, stride=stride, padding=padding, groups=in_channels)
            for k in kernel_sizes
        ])

    def forward(self, x):
        splits = torch.split(x, x.size(1) // self.groups, 1)
        outputs = [conv(split) for conv, split in zip(self.convs, splits)]
        return torch.cat(outputs, 1)

class SandGlassBlock(nn.Module):
    def __init__(self, in_channels, out_channels, expansion_factor, stride):
        super(SandGlassBlock, self).__init__()
        self.stride = stride
        mid_channels = in_channels * expansion_factor

        self.expand = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            HardSwish()
        )

        self.dwconv = nn.Sequential(
            MixConv(mid_channels, mid_channels, kernel_sizes=[3, 5, 7], stride=stride, padding=1),
            nn.BatchNorm2d(mid_channels),
            HardSwish()
        )

        self.project = nn.Sequential(
            nn.Conv2d(mid_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        if self.stride == 1 and x.size(1) == self.project[1].num_features:
            return x + self.project(self.dwconv(self.expand(x)))
        else:
            return self.project(self.dwconv(self.expand(x)))

# Define the model (EfficientNet-B0 with custom layers)
class CustomEfficientNet(nn.Module):
    def replace_mbconv(self):
        def replace_block(block, expansion_factor):
            in_channels = block[0][0].in_channels
            out_channels = block[-1].out_channels
            stride = block[0][0].stride
            return SandGlassBlock(in_channels, out_channels, expansion_factor=expansion_factor, stride=stride)

        for name, module in self.base_model.features.named_children():
            if isinstance(module, models.efficientnet.MBConv):
                if module.block[1].expand_ratio == 1:  # MBConv1
                    new_module = replace_block(module.block, expansion_factor=1)
                elif module.block[1].expand_ratio == 6:  # MBConv6
                    new_module = replace_block(module.block, expansion_factor=6)

                setattr(self.base_model.features, name, new_module)
                
    def __init__(self, num_classes=2):
        super(CustomEfficientNet, self).__init__()
        self.base_model = models.efficientnet_b0(pretrained=True)
        
        # Replace MBConv1 and MBConv6 blocks
        self.replace_mbconv()

        # Modify the classifier
        num_ftrs = self.base_model.classifier[1].in_features
        self.base_model.classifier = nn.Sequential(
            nn.Linear(num_ftrs, 256),
            HardSwish(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        return self.base_model(x)

model = CustomEfficientNet(num_classes=2)

# Move model to GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training function
def train_model(model, train_loader, criterion, optimizer, num_epochs=25):
    for epoch in range(num_epochs):
        print(f'Epoch {epoch + 1}/{num_epochs}')
        
        model.train()
        
        running_loss = 0.0
        running_corrects = 0
        
        for inputs, labels in train_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            
            with torch.set_grad_enabled(True):
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)
                
                loss.backward()
                optimizer.step()
            
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
        
        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = running_corrects.double() / len(train_loader.dataset)
        
        print(f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
        print('-' * 15)
    
    return model

# Train the model
model = train_model(model, train_loader, criterion, optimizer, num_epochs=25)

# Ensure the directory exists
os.makedirs('Saved Model', exist_ok=True)

# Save the trained model
torch.save(model.state_dict(), 'Saved Model/EnhancedENetB0_v0.pth')

# Load the model for inference or further evaluation
model.load_state_dict(torch.load('Saved Model/EnhancedENetB0_v0.pth'))
model.eval()

# Function to calculate accuracy
def calculate_accuracy(loader, model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

# Calculate training and testing accuracy
train_accuracy = calculate_accuracy(train_loader, model)
test_accuracy = calculate_accuracy(test_loader, model)

print(f'Training Accuracy: {train_accuracy:.2f}%')
print(f'Testing Accuracy: {test_accuracy:.2f}%')

# Calculate FLOPs and number of parameters
input_size = (1, 3, 224, 224)
flops, params = profile(model, inputs=(torch.randn(input_size).to(device),))

print(f'FLOPs: {flops}')
print(f'Number of parameters: {params}')

# Print model summary
summary(model, input_size[1:])