Model fair fine-tuning check

In [3]:
import torch
import torch.nn as nn
from torchvision import models
from torchvision.models import *
import warnings
warnings.filterwarnings('ignore')

def analyze_model_architecture(model, model_name):
    """Analyze model architecture to determine last stage/block for consistent fine-tuning"""
    print(f"\n{'='*60}")
    print(f"🔍 ANALYZING {model_name.upper()}")
    print(f"{'='*60}")
    
    # Count total parameters
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total_params:,}")
    
    # Show main structure
    print("\nMain components:")
    for name, module in model.named_children():
        param_count = sum(p.numel() for p in module.parameters())
        print(f"  {name}: {type(module).__name__} ({param_count:,} params)")
    
    return model

def check_last_stages():
    """Check the last stages of all models for equivalent fine-tuning"""
    
    models_to_check = {
        'ResNet18': models.resnet18(weights=ResNet18_Weights.DEFAULT),
        'ResNet50': models.resnet50(weights=ResNet50_Weights.DEFAULT),
        'EfficientNet-B0': models.efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT),
        'InceptionV3': models.inception_v3(weights=Inception_V3_Weights.DEFAULT),
        'VGG16': models.vgg16(weights=VGG16_Weights.DEFAULT),
        'VGG19': models.vgg19(weights=VGG19_Weights.DEFAULT),
        'MobileNetV2': models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT),
        'MobileNetV3-Large': models.mobilenet_v3_large(weights=MobileNet_V3_Large_Weights.DEFAULT),
        'Xception': None,  # Not in torchvision, will handle separately
    }
    
    recommendations = {}
    
    for name, model in models_to_check.items():
        if model is None:
            print(f"\n⚠️  {name}: Not available in torchvision, needs separate handling")
            continue
            
        analyze_model_architecture(model, name)
        
        # Specific analysis for each architecture
        if 'resnet' in name.lower():
            print(f"\n🎯 FINE-TUNING RECOMMENDATION for {name}:")
            print(f"   Unfreeze: layer4 (entire final stage)")
            layer4_params = sum(p.numel() for p in model.layer4.parameters())
            fc_params = sum(p.numel() for p in model.fc.parameters())
            print(f"   Parameters to train: {layer4_params + fc_params:,}")
            recommendations[name] = {
                'strategy': 'layer4 + fc',
                'params': layer4_params + fc_params,
                'code': 'for name, param in model.named_parameters():\n    if "layer4" in name or "fc" in name:\n        param.requires_grad = True'
            }
            
        elif 'efficientnet' in name.lower():
            print(f"\n🎯 FINE-TUNING RECOMMENDATION for {name}:")
            print(f"   Total feature blocks: {len(model.features)}")
            # For equivalent fine-tuning, unfreeze last 2-3 blocks
            last_blocks = [7, 8]  # Last 2 blocks
            last_block_params = sum(sum(p.numel() for p in model.features[i].parameters()) for i in last_blocks)
            classifier_params = sum(p.numel() for p in model.classifier.parameters())
            print(f"   Unfreeze: features[7], features[8] + classifier")
            print(f"   Parameters to train: {last_block_params + classifier_params:,}")
            recommendations[name] = {
                'strategy': 'features[7:] + classifier',
                'params': last_block_params + classifier_params,
                'code': 'for i in [7, 8]:\n    for param in model.features[i].parameters():\n        param.requires_grad = True\nfor param in model.classifier.parameters():\n    param.requires_grad = True'
            }
            
        elif 'mobilenet' in name.lower():
            print(f"\n🎯 FINE-TUNING RECOMMENDATION for {name}:")
            if hasattr(model, 'features'):
                print(f"   Total feature blocks: {len(model.features)}")
                # For MobileNetV2, unfreeze last 3-4 blocks for equivalence
                if 'v2' in name.lower():
                    last_blocks = list(range(16, 19))  # features[16], [17], [18]
                    last_block_params = sum(sum(p.numel() for p in model.features[i].parameters()) for i in last_blocks)
                    classifier_params = sum(p.numel() for p in model.classifier.parameters())
                    print(f"   Unfreeze: features[16:] + classifier")
                    print(f"   Parameters to train: {last_block_params + classifier_params:,}")
                    recommendations[name] = {
                        'strategy': 'features[16:] + classifier',
                        'params': last_block_params + classifier_params,
                        'code': 'for i in range(16, len(model.features)):\n    for param in model.features[i].parameters():\n        param.requires_grad = True'
                    }
                else:  # MobileNetV3
                    last_block_params = sum(sum(p.numel() for p in block.parameters()) for block in model.features[-3:])
                    classifier_params = sum(p.numel() for p in model.classifier.parameters())
                    print(f"   Unfreeze: last 3 feature blocks + classifier")
                    print(f"   Parameters to train: {last_block_params + classifier_params:,}")
                    recommendations[name] = {
                        'strategy': 'features[-3:] + classifier',
                        'params': last_block_params + classifier_params,
                        'code': 'for block in model.features[-3:]:\n    for param in block.parameters():\n        param.requires_grad = True'
                    }
            
        elif 'vgg' in name.lower():
            print(f"\n🎯 FINE-TUNING RECOMMENDATION for {name}:")
            # VGG: unfreeze last conv block + classifier
            # Find the last few conv layers
            conv_layers = [i for i, layer in enumerate(model.features) if isinstance(layer, nn.Conv2d)]
            last_conv_indices = conv_layers[-2:]  # Last 2 conv layers
            last_conv_params = sum(sum(p.numel() for p in model.features[i].parameters()) for i in last_conv_indices)
            classifier_params = sum(p.numel() for p in model.classifier.parameters())
            print(f"   Unfreeze: last 2 conv layers + classifier")
            print(f"   Parameters to train: {last_conv_params + classifier_params:,}")
            recommendations[name] = {
                'strategy': 'last 2 conv + classifier',
                'params': last_conv_params + classifier_params,
                'code': f'# Unfreeze conv layers {last_conv_indices} and classifier'
            }
            
        elif 'inception' in name.lower():
            print(f"\n🎯 FINE-TUNING RECOMMENDATION for {name}:")
            # InceptionV3: unfreeze Mixed_7a, Mixed_7b, Mixed_7c (final inception blocks)
            final_blocks = ['Mixed_7a', 'Mixed_7b', 'Mixed_7c']
            final_block_params = 0
            for block_name in final_blocks:
                if hasattr(model, block_name):
                    final_block_params += sum(p.numel() for p in getattr(model, block_name).parameters())
            fc_params = sum(p.numel() for p in model.fc.parameters())
            print(f"   Unfreeze: {final_blocks} + fc")
            print(f"   Parameters to train: {final_block_params + fc_params:,}")
            recommendations[name] = {
                'strategy': 'Mixed_7* + fc',
                'params': final_block_params + fc_params,
                'code': 'for name, param in model.named_parameters():\n    if any(block in name for block in ["Mixed_7a", "Mixed_7b", "Mixed_7c"]) or "fc" in name:\n        param.requires_grad = True'
            }
    
    # Summary comparison
    print(f"\n{'='*80}")
    print("📊 FINE-TUNING PARAMETER COMPARISON")
    print(f"{'='*80}")
    
    for name, rec in recommendations.items():
        percentage = (rec['params'] / sum(p.numel() for p in models_to_check[name].parameters())) * 100
        print(f"{name:20}: {rec['params']:>10,} params ({percentage:5.1f}%)")
    
    return recommendations

# Run the analysis
recommendations = check_last_stages()

print(f"\n{'='*80}")
print("🎯 FINAL RECOMMENDATIONS FOR FAIR COMPARISON")
print(f"{'='*80}")
print("To ensure fair comparison, aim for ~5-15% of total parameters to be trainable.")
print("Adjust the number of unfrozen layers to achieve similar parameter counts.")


🔍 ANALYZING RESNET18
Total parameters: 11,689,512

Main components:
  conv1: Conv2d (9,408 params)
  bn1: BatchNorm2d (128 params)
  relu: ReLU (0 params)
  maxpool: MaxPool2d (0 params)
  layer1: Sequential (147,968 params)
  layer2: Sequential (525,568 params)
  layer3: Sequential (2,099,712 params)
  layer4: Sequential (8,393,728 params)
  avgpool: AdaptiveAvgPool2d (0 params)
  fc: Linear (513,000 params)

🎯 FINE-TUNING RECOMMENDATION for ResNet18:
   Unfreeze: layer4 (entire final stage)
   Parameters to train: 8,906,728

🔍 ANALYZING RESNET50
Total parameters: 25,557,032

Main components:
  conv1: Conv2d (9,408 params)
  bn1: BatchNorm2d (128 params)
  relu: ReLU (0 params)
  maxpool: MaxPool2d (0 params)
  layer1: Sequential (215,808 params)
  layer2: Sequential (1,219,584 params)
  layer3: Sequential (7,098,368 params)
  layer4: Sequential (14,964,736 params)
  avgpool: AdaptiveAvgPool2d (0 params)
  fc: Linear (2,049,000 params)

🎯 FINE-TUNING RECOMMENDATION for ResNet50:
   U