In [2]:
import torch
import torch.nn as nn
import torchvision
import timm

# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
print(f'PyTorch version: {torch.__version__}')

Using device: cuda
PyTorch version: 2.10.0+cu126


## 1.3 Transfer Learning with Pre-trained Weights

### ‚ùì T·∫°i Sao Training From Scratch L√† Sai L·∫ßm?

**V·∫•n ƒë·ªÅ:**
- Dataset nh·ªè (112K images) so v·ªõi ImageNet (14M images)
- M·∫•t ƒëi low-level features (edges, textures) ƒë√£ h·ªçc t·ª´ ImageNet
- Convergence ch·∫≠m, d·ªÖ overfit
- C·∫ßn nhi·ªÅu epochs h∆°n (~100 vs ~30)

**Evidence t·ª´ literature:**
- Rajpurkar et al. (CheXNet): Pre-trained weights ‚Üí +5% AUC
- Irvin et al. (CheXpert): Transfer learning essential cho medical imaging

### üí° Gi·∫£i Ph√°p: Smart Transfer Learning

**Strategy:**
1. **Load ImageNet weights** ‚Üí Low/mid-level features
2. **Replace classifier head** ‚Üí Domain-specific classification
3. **Progressive unfreezing:**
   - Epochs 1-5: Freeze backbone, train head only
   - Epochs 6+: Unfreeze all, fine-tune end-to-end

**Why progressive unfreezing?**
- Prevents catastrophic forgetting of ImageNet features
- Stable training
- Better final performance

### üìà Expected Impact
- **+2-4% AUC** improvement
- **50% faster** convergence
- Better feature representations

In [3]:
class PretrainedResNet(nn.Module):
    """
    ResNet-34 with ImageNet pre-trained weights
    
    Architecture:
    - Backbone: ResNet-34 from torchvision (pre-trained on ImageNet)
    - Head: Custom classifier for 15 chest diseases
    
    Features:
    - Batch Normalization for stable training
    - Dropout for regularization
    - Progressive unfreezing support
    """
    def __init__(self, num_classes=15, pretrained=True, dropout=0.5):
        super(PretrainedResNet, self).__init__()
        
        # Load pre-trained ResNet-34
        if pretrained:
            weights = torchvision.models.ResNet34_Weights.IMAGENET1K_V1
            self.backbone = torchvision.models.resnet34(weights=weights)
            print("‚úÖ Loaded ImageNet pre-trained weights for ResNet-34")
        else:
            self.backbone = torchvision.models.resnet34(weights=None)
            print("‚ö†Ô∏è  Training ResNet-34 from scratch")
        
        # Get feature dimension
        num_features = self.backbone.fc.in_features  # 512 for ResNet-34
        
        # Replace classifier head
        self.backbone.fc = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(num_features, num_classes)
        )
        
        self.num_classes = num_classes
    
    def forward(self, x):
        return self.backbone(x)
    
    def freeze_backbone(self):
        """Freeze t·∫•t c·∫£ layers tr·ª´ classifier head"""
        for name, param in self.backbone.named_parameters():
            if 'fc' not in name:  # Kh√¥ng freeze head
                param.requires_grad = False
        print("üîí Backbone frozen, training head only")
    
    def unfreeze_backbone(self):
        """Unfreeze t·∫•t c·∫£ layers cho fine-tuning"""
        for param in self.backbone.parameters():
            param.requires_grad = True
        print("üîì Backbone unfrozen, training end-to-end")


class PretrainedViT(nn.Module):
    """
    Vision Transformer with ImageNet pre-trained weights
    
    Uses timm library for SOTA ViT implementations
    
    Available models:
    - vit_base_patch16_224: Standard ViT-B/16
    - vit_base_patch32_224: ViT-B/32 (faster)
    - vit_large_patch16_224: ViT-L/16 (best performance)
    """
    def __init__(self, model_name='vit_base_patch16_224', num_classes=15, 
                 pretrained=True, dropout=0.1):
        super(PretrainedViT, self).__init__()
        
        # Create model with timm
        self.model = timm.create_model(
            model_name,
            pretrained=pretrained,
            num_classes=num_classes,
            drop_rate=dropout  # Dropout in ViT blocks
        )
        
        if pretrained:
            print(f"‚úÖ Loaded ImageNet pre-trained weights for {model_name}")
        else:
            print(f"‚ö†Ô∏è  Training {model_name} from scratch")
        
        self.model_name = model_name
        self.num_classes = num_classes
    
    def forward(self, x):
        return self.model(x)
    
    def freeze_backbone(self):
        """Freeze all layers except classifier head"""
        for name, param in self.model.named_parameters():
            if 'head' not in name:  # timm uses 'head' for classifier
                param.requires_grad = False
        print("üîí ViT backbone frozen, training head only")
    
    def unfreeze_backbone(self):
        """Unfreeze all layers for fine-tuning"""
        for param in self.model.parameters():
            param.requires_grad = True
        print("üîì ViT backbone unfrozen, training end-to-end")


class PretrainedSwinTransformer(nn.Module):
    """
    Swin Transformer - Hierarchical Vision Transformer
    
    Advantages over standard ViT:
    1. Hierarchical feature maps (like CNN)
    2. Shifted windows for efficient computation
    3. Better for dense prediction tasks
    4. More suitable for medical imaging
    
    Paper: "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows"
    """
    def __init__(self, model_name='swin_base_patch4_window7_224', 
                 num_classes=15, pretrained=True, dropout=0.1):
        super(PretrainedSwinTransformer, self).__init__()
        
        # Create Swin Transformer
        self.model = timm.create_model(
            model_name,
            pretrained=pretrained,
            num_classes=num_classes,
            drop_rate=dropout
        )
        
        if pretrained:
            print(f"‚úÖ Loaded ImageNet pre-trained weights for {model_name}")
        else:
            print(f"‚ö†Ô∏è  Training {model_name} from scratch")
        
        self.model_name = model_name
        self.num_classes = num_classes
    
    def forward(self, x):
        return self.model(x)
    
    def freeze_backbone(self):
        for name, param in self.model.named_parameters():
            if 'head' not in name:
                param.requires_grad = False
        print("üîí Swin backbone frozen, training head only")
    
    def unfreeze_backbone(self):
        for param in self.model.parameters():
            param.requires_grad = True
        print("üîì Swin backbone unfrozen, training end-to-end")


# Model factory
def create_model(model_type='resnet34', num_classes=15, pretrained=True):
    """
    Factory function to create models
    
    Args:
        model_type: 'resnet34', 'vit_base', 'vit_large', 'swin_base'
        num_classes: Number of output classes
        pretrained: Use ImageNet pre-trained weights
    """
    if model_type == 'resnet34':
        model = PretrainedResNet(num_classes, pretrained)
    elif model_type == 'vit_base':
        model = PretrainedViT('vit_base_patch16_224', num_classes, pretrained)
    elif model_type == 'vit_large':
        model = PretrainedViT('vit_large_patch16_224', num_classes, pretrained)
    elif model_type == 'swin_base':
        model = PretrainedSwinTransformer('swin_base_patch4_window7_224', 
                                         num_classes, pretrained)
    else:
        raise ValueError(f"Unknown model type: {model_type}")
    
    return model


print("‚úÖ Pre-trained models implemented:")
print("   1. PretrainedResNet (ResNet-34)")
print("   2. PretrainedViT (ViT-Base/16, ViT-Large/16)")
print("   3. PretrainedSwinTransformer (Swin-Base)")
print("\nüéØ Features:")
print("   - ImageNet pre-trained weights")
print("   - Progressive unfreezing support")
print("   - Dropout regularization")
print("   - Easy model creation via factory function")

‚úÖ Pre-trained models implemented:
   1. PretrainedResNet (ResNet-34)
   2. PretrainedViT (ViT-Base/16, ViT-Large/16)
   3. PretrainedSwinTransformer (Swin-Base)

üéØ Features:
   - ImageNet pre-trained weights
   - Progressive unfreezing support
   - Dropout regularization
   - Easy model creation via factory function


### üß™ Test Model Creation

Verify models can be created and loaded correctly

In [4]:
print("üß™ Testing model creation...\n")

# Test ResNet-34
print("1Ô∏è‚É£ Creating ResNet-34...")
resnet = create_model('resnet34', num_classes=15, pretrained=True)
print(f"   Parameters: {sum(p.numel() for p in resnet.parameters()):,}")
print(f"   Trainable: {sum(p.numel() for p in resnet.parameters() if p.requires_grad):,}\n")

# Test ViT
print("2Ô∏è‚É£ Creating ViT-Base/16...")
try:
    vit = create_model('vit_base', num_classes=15, pretrained=True)
    print(f"   Parameters: {sum(p.numel() for p in vit.parameters()):,}")
    print(f"   Trainable: {sum(p.numel() for p in vit.parameters() if p.requires_grad):,}\n")
except Exception as e:
    print(f"   ‚ö†Ô∏è Error loading ViT: {e}\n")

# Test Swin
print("3Ô∏è‚É£ Creating Swin Transformer...")
try:
    swin = create_model('swin_base', num_classes=15, pretrained=True)
    print(f"   Parameters: {sum(p.numel() for p in swin.parameters()):,}")
    print(f"   Trainable: {sum(p.numel() for p in swin.parameters() if p.requires_grad):,}\n")
except Exception as e:
    print(f"   ‚ö†Ô∏è Error loading Swin: {e}\n")

# Test freeze/unfreeze
print("4Ô∏è‚É£ Testing freeze/unfreeze...")
resnet.freeze_backbone()
frozen_params = sum(p.numel() for p in resnet.parameters() if p.requires_grad)
print(f"   Frozen trainable params: {frozen_params:,}")

resnet.unfreeze_backbone()
unfrozen_params = sum(p.numel() for p in resnet.parameters() if p.requires_grad)
print(f"   Unfrozen trainable params: {unfrozen_params:,}")

print("\n‚úÖ All models created successfully!")

üß™ Testing model creation...

1Ô∏è‚É£ Creating ResNet-34...
‚úÖ Loaded ImageNet pre-trained weights for ResNet-34
   Parameters: 21,292,367
   Trainable: 21,292,367

2Ô∏è‚É£ Creating ViT-Base/16...




‚úÖ Loaded ImageNet pre-trained weights for vit_base_patch16_224
   Parameters: 85,810,191
   Trainable: 85,810,191

3Ô∏è‚É£ Creating Swin Transformer...
‚úÖ Loaded ImageNet pre-trained weights for swin_base_patch4_window7_224
   Parameters: 86,758,599
   Trainable: 86,758,599

4Ô∏è‚É£ Testing freeze/unfreeze...
üîí Backbone frozen, training head only
   Frozen trainable params: 7,695
üîì Backbone unfrozen, training end-to-end
   Unfrozen trainable params: 21,292,367

‚úÖ All models created successfully!
