In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):
    """Residual block with batch normalization"""
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ImprovedCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(ImprovedCNN, self).__init__()
        
        # Initial convolution
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        
        # Residual blocks
        self.layer1 = self._make_layer(64, 64, 2, stride=1)
        self.layer2 = self._make_layer(64, 128, 2, stride=2)
        self.layer3 = self._make_layer(128, 256, 2, stride=2)
        self.layer4 = self._make_layer(256, 512, 2, stride=2)
        
        # Global average pooling and classifier
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)
        
        # Initialize weights
        self._initialize_weights()
    
    def _make_layer(self, in_channels, out_channels, num_blocks, stride):
        layers = []
        layers.append(ResidualBlock(in_channels, out_channels, stride))
        for _ in range(1, num_blocks):
            layers.append(ResidualBlock(out_channels, out_channels, 1))
        return nn.Sequential(*layers)
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x




In [2]:
# Training configuration for better results
def get_training_config():
    """Recommended training configuration"""
    config = {
        'learning_rate': 0.1,
        'momentum': 0.9,
        'weight_decay': 5e-4,
        'batch_size': 128,
        'epochs': 200,
        'lr_scheduler': 'cosine',  # or 'multistep'
        'warmup_epochs': 5,
        'label_smoothing': 0.1,
    }
    return config


# Data augmentation for training
def get_transforms():
    """Recommended data augmentation"""
    from torchvision import transforms
    
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        transforms.RandomErasing(p=0.5, scale=(0.02, 0.33)),
    ])
    
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    return train_transform, test_transform


# Example training loop with improvements
def train_epoch(model, train_loader, optimizer, criterion, device, epoch, warmup_epochs=5):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        
        # Warmup learning rate
        if epoch < warmup_epochs:
            warmup_lr = 0.1 * (epoch * len(train_loader) + batch_idx) / (warmup_epochs * len(train_loader))
            for param_group in optimizer.param_groups:
                param_group['lr'] = warmup_lr
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    
    accuracy = 100. * correct / total
    avg_loss = running_loss / len(train_loader)
    return avg_loss, accuracy


# Model summary functions
def count_parameters(model):
    """Count total and trainable parameters"""
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params, trainable_params


def print_model_summary(model, input_size=(3, 32, 32), batch_size=1, device='cpu'):
    """
    Print model summary similar to TensorFlow's model.summary()
    
    Args:
        model: PyTorch model
        input_size: Input tensor size (C, H, W)
        batch_size: Batch size for forward pass
        device: Device to run on
    """
    def register_hook(module):
        def hook(module, input, output):
            class_name = str(module.__class__).split(".")[-1].split("'")[0]
            module_idx = len(summary)
            
            m_key = f"{class_name}-{module_idx+1}"
            summary[m_key] = {}
            summary[m_key]["input_shape"] = list(input[0].size())
            summary[m_key]["output_shape"] = list(output.size()) if isinstance(output, torch.Tensor) else [list(o.size()) for o in output]
            
            params = 0
            if hasattr(module, "weight") and hasattr(module.weight, "size"):
                params += torch.prod(torch.LongTensor(list(module.weight.size()))).item()
                summary[m_key]["trainable"] = module.weight.requires_grad
            if hasattr(module, "bias") and hasattr(module.bias, "size"):
                params += torch.prod(torch.LongTensor(list(module.bias.size()))).item()
            summary[m_key]["nb_params"] = params
        
        if not isinstance(module, nn.Sequential) and not isinstance(module, nn.ModuleList) and not (module == model):
            hooks.append(module.register_forward_hook(hook))
    
    # Create summary dict
    summary = {}
    hooks = []
    
    # Register hooks
    model.apply(register_hook)
    
    # Make a forward pass
    model.eval()
    x = torch.randn(batch_size, *input_size).to(device)
    model(x)
    
    # Remove hooks
    for h in hooks:
        h.remove()
    
    # Print summary
    print("=" * 100)
    print(f"{'Layer (type)':<40} {'Output Shape':<25} {'Param #':<15}")
    print("=" * 100)
    
    total_params = 0
    trainable_params = 0
    
    for layer in summary:
        line = f"{layer:<40} {str(summary[layer]['output_shape']):<25} {summary[layer]['nb_params']:>15,}"
        print(line)
        total_params += summary[layer]["nb_params"]
        if "trainable" in summary[layer]:
            if summary[layer]["trainable"]:
                trainable_params += summary[layer]["nb_params"]
    
    print("=" * 100)
    print(f"Total params: {total_params:,}")
    print(f"Trainable params: {trainable_params:,}")
    print(f"Non-trainable params: {total_params - trainable_params:,}")
    print("=" * 100)


def print_simple_summary(model):
    """Simple parameter count summary"""
    print("\n" + "="*60)
    print(f"{'Model Architecture Summary':^60}")
    print("="*60)
    
    total_params = 0
    trainable_params = 0
    
    print(f"\n{'Layer':<30} {'Parameters':>15} {'Trainable':>12}")
    print("-"*60)
    
    for name, param in model.named_parameters():
        num_params = param.numel()
        total_params += num_params
        if param.requires_grad:
            trainable_params += num_params
        
        trainable_status = "Yes" if param.requires_grad else "No"
        print(f"{name:<30} {num_params:>15,} {trainable_status:>12}")
    
    print("-"*60)
    print(f"{'Total Parameters':<30} {total_params:>15,}")
    print(f"{'Trainable Parameters':<30} {trainable_params:>15,}")
    print(f"{'Non-trainable Parameters':<30} {(total_params - trainable_params):>15,}")
    print("="*60)
    
    # Calculate model size
    param_size = total_params * 4 / (1024**2)  # Assuming float32
    print(f"\nEstimated Model Size: {param_size:.2f} MB (float32)")
    print("="*60 + "\n")


# Example usage
if __name__ == "__main__":
    # Create model
    model = ImprovedCNN(num_classes=10)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)
    
    print("\n### DETAILED MODEL SUMMARY ###")
    print_model_summary(model, input_size=(3, 32, 32), device=device)
    
    print("\n\n### PARAMETER SUMMARY BY LAYER ###")
    print_simple_summary(model)
    
    # Quick parameter count
    total, trainable = count_parameters(model)
    print(f"\nQuick Stats:")
    print(f"  Total parameters: {total:,}")
    print(f"  Trainable parameters: {trainable:,}")
    
    # # Alternative: Use torchsummary (if installed)
    # try:
    #     from torchsummary import summary
    #     print("\n\n### TORCHSUMMARY OUTPUT ###")
    #     summary(model, (3, 32, 32))
    # except ImportError:
    #     print("\n[Optional] Install torchsummary for more details: pip install torchsummary")


### DETAILED MODEL SUMMARY ###
Layer (type)                             Output Shape              Param #        
Conv2d-1                                 [1, 64, 32, 32]                     1,728
BatchNorm2d-2                            [1, 64, 32, 32]                       128
Conv2d-3                                 [1, 64, 32, 32]                    36,864
BatchNorm2d-4                            [1, 64, 32, 32]                       128
Conv2d-5                                 [1, 64, 32, 32]                    36,864
BatchNorm2d-6                            [1, 64, 32, 32]                       128
ResidualBlock-7                          [1, 64, 32, 32]                         0
Conv2d-8                                 [1, 64, 32, 32]                    36,864
BatchNorm2d-9                            [1, 64, 32, 32]                       128
Conv2d-10                                [1, 64, 32, 32]                    36,864
BatchNorm2d-11                           [1, 64, 32, 32