In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import time
import math # For potential rounding/floor operations if needed, though int() conversion often suffices

# --- Configuration ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Model Hyperparameters (Adjust as needed)
WIDTH_MULTIPLIER = 1.0 # Alpha (typical values: 1.0, 0.75, 0.5, 0.25)
NUM_CLASSES = 10        # For CIFAR-10
LEARNING_RATE = 0.001
BATCH_SIZE = 128
NUM_EPOCHS = 20 # Keep low for demonstration; increase for real training

# --- Building Blocks ---

class DepthwiseSeparableConv(nn.Module):
    """
    MobileNet V1 Depthwise Separable Convolution block.
    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        stride (int): Stride for the depthwise convolution. Default: 1.
    """
    def __init__(self, in_channels, out_channels, stride=1):
        super(DepthwiseSeparableConv, self).__init__()

        # --- Depthwise Convolution ---
        # Applies spatial filtering independently for each input channel.
        # groups=in_channels makes it depthwise.
        self.depthwise = nn.Conv2d(
            in_channels,
            in_channels, # Output channels = Input channels for depthwise
            kernel_size=3,
            stride=stride,
            padding=1,     # Preserves spatial dimensions for stride 1
            groups=in_channels,
            bias=False     # BatchNorm has bias, so conv bias is redundant
        )
        self.bn_depthwise = nn.BatchNorm2d(in_channels)
        self.relu_depthwise = nn.ReLU(inplace=True)

        # --- Pointwise Convolution ---
        # 1x1 convolution to combine channel information.
        self.pointwise = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=False     # BatchNorm has bias
        )
        self.bn_pointwise = nn.BatchNorm2d(out_channels)
        self.relu_pointwise = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.bn_depthwise(x)
        x = self.relu_depthwise(x)

        x = self.pointwise(x)
        x = self.bn_pointwise(x)
        x = self.relu_pointwise(x)
        return x

# --- MobileNet V1 Architecture ---

class MobileNetV1(nn.Module):
    """
    MobileNet V1 implementation with Width Multiplier (alpha).

    Args:
        alpha (float): Width multiplier, controls the number of channels.
                       alpha=1.0 is the baseline MobileNet.
        num_classes (int): Number of output classes for classification.
    """
    def __init__(self, alpha=1.0, num_classes=1000):
        super(MobileNetV1, self).__init__()
        self.alpha = alpha
        self.num_classes = num_classes

        # Helper function to apply width multiplier
        def apply_alpha(channels):
            return max(1, int(channels * self.alpha)) # Ensure at least 1 channel

        # --- Stem ---
        # Initial standard convolution layer
        # Input: 224x224x3 (assumed ImageNet size) -> Output: 112x112x(32*alpha)
        initial_channels = apply_alpha(32)
        self.stem = nn.Sequential(
            nn.Conv2d(3, initial_channels, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(initial_channels),
            nn.ReLU(inplace=True)
        )
        current_channels = initial_channels

        # --- Body: Stacked Depthwise Separable Convolutions ---
        # Layer specifications based on Table 1, modified by alpha
        # Format: (output_baseline_channels, stride)
        layer_configs = [
            (64, 1),
            (128, 2),
            (128, 1),
            (256, 2),
            (256, 1),
            (512, 2),
            # Repeat 5 times
            (512, 1), (512, 1), (512, 1), (512, 1), (512, 1),
            (1024, 2),
            # Note: Stride for the last 1024 block is debated.
            # Paper Table 1 implies s2, but output size suggests s1.
            # Common implementations often use s1 here to get 7x7 feature map before pooling.
            # We use s1 here. If s2 is used, the AdaptiveAvgPool2d handles it.
            (1024, 1)
        ]

        body_layers = []
        for out_baseline_channels, stride in layer_configs:
            out_channels_alpha = apply_alpha(out_baseline_channels)
            body_layers.append(
                DepthwiseSeparableConv(current_channels, out_channels_alpha, stride=stride)
            )
            current_channels = out_channels_alpha # Update channels for next layer input

        self.body = nn.Sequential(*body_layers)

        # --- Classifier Head ---
        # Global Average Pooling and Fully Connected Layer
        self.gap = nn.AdaptiveAvgPool2d((1, 1)) # Output: B x C x 1 x 1
        self.flatten = nn.Flatten()             # Output: B x C
        self.fc = nn.Linear(current_channels, self.num_classes) # No activation/BN needed before loss

        # --- Weight Initialization (Optional but Recommended) ---
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.stem(x)
        x = self.body(x)
        x = self.gap(x)
        x = self.flatten(x)
        x = self.fc(x) # Logits output
        return x

# --- Data Loading and Preprocessing ---

# Normalization for CIFAR-10
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Note: MobileNet expects 224x224 input typically.
# For CIFAR-10 (32x32), we either need to resize or accept that the
# effective receptive field and downsampling might behave differently.
# For simplicity here, we use CIFAR-10 directly.
# For real use, add transforms.Resize(224) if using ImageNet-based model directly.

print("Loading CIFAR-10 dataset...")
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=BATCH_SIZE*2, shuffle=False, num_workers=2)
print("Dataset loaded.")

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# --- Model, Loss, Optimizer ---

print(f"Initializing MobileNetV1 with alpha={WIDTH_MULTIPLIER}...")
model = MobileNetV1(alpha=WIDTH_MULTIPLIER, num_classes=NUM_CLASSES).to(DEVICE)

# --- Sanity Check: Print Model Summary and Parameter Count ---
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model initialized. Parameter count: {count_parameters(model):,}")
# You can optionally print the model structure: print(model)

criterion = nn.CrossEntropyLoss()

# Optimizer - Paper uses RMSprop, Adam is also common
# optimizer = optim.RMSprop(model.parameters(), lr=LEARNING_RATE, alpha=0.9, eps=1e-08, weight_decay=0.00004, momentum=0.9)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE) # Adam often works well out-of-the-box

# Learning rate scheduler (optional, but often helpful)
# Example: Step decay
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)


# --- Training Loop ---

def train_one_epoch(epoch):
    model.train() # Set model to training mode
    running_loss = 0.0
    correct = 0
    total = 0
    start_time = time.time()

    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Statistics
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        if (i + 1) % 100 == 0: # Print every 100 mini-batches
             print(f'[Epoch {epoch + 1}/{NUM_EPOCHS}, Batch {i + 1}/{len(trainloader)}] '
                   f'Loss: {running_loss / 100:.3f} | '
                   f'Acc: {100 * correct / total:.2f}%')
             running_loss = 0.0 # Reset loss average

    epoch_time = time.time() - start_time
    epoch_acc = 100 * correct / total
    print(f'Epoch {epoch + 1} Training finished. Accuracy: {epoch_acc:.2f}%, Time: {epoch_time:.2f}s')
    return epoch_acc

# --- Validation Loop ---

def validate():
    model.eval() # Set model to evaluation mode
    correct = 0
    total = 0
    total_loss = 0.0
    start_time = time.time()

    with torch.no_grad(): # No need to track gradients during validation
        for data in testloader:
            images, labels = data
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * images.size(0) # Accumulate weighted loss

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_time = time.time() - start_time
    val_acc = 100 * correct / total
    avg_loss = total_loss / total
    print(f'Validation Accuracy: {val_acc:.2f}%, Avg Loss: {avg_loss:.4f}, Time: {val_time:.2f}s')
    print("-" * 30)
    return val_acc


# --- Main Training Execution ---

print("Starting Training...")
best_val_acc = 0.0
for epoch in range(NUM_EPOCHS):
    train_one_epoch(epoch)
    val_acc = validate()

    # Optional: Update learning rate scheduler
    # if scheduler:
    #     scheduler.step()

    # Optional: Save the model checkpoint if validation accuracy improves
    if val_acc > best_val_acc:
        print(f"Validation accuracy improved ({best_val_acc:.2f}% -> {val_acc:.2f}%). Saving model...")
        torch.save(model.state_dict(), f'mobilenetv1_alpha{WIDTH_MULTIPLIER}_cifar10_best.pth')
        best_val_acc = val_acc

print('Finished Training')
print(f"Best Validation Accuracy: {best_val_acc:.2f}%")

# --- Optional: Load best model and evaluate again ---
# print("Loading best model for final evaluation...")
# model.load_state_dict(torch.load(f'mobilenetv1_alpha{WIDTH_MULTIPLIER}_cifar10_best.pth'))
# validate()