# üöÄ Comprehensive Model Improvements for Chest X-ray Classification

## üìã T·ªïng Quan (Overview)

Notebook n√†y tri·ªÉn khai **to√†n b·ªô improvement plan** ƒë∆∞·ª£c thi·∫øt k·∫ø ƒë·ªÉ c·∫£i thi·ªán hi·ªáu su·∫•t c·ªßa c√°c m√¥ h√¨nh ph√¢n lo·∫°i b·ªánh t·ª´ ·∫£nh X-quang ng·ª±c. Ch√∫ng ta s·∫Ω √°p d·ª•ng c√°c k·ªπ thu·∫≠t state-of-the-art d·ª±a tr√™n ph√¢n t√≠ch s√¢u c√°c h·∫°n ch·∫ø c·ªßa m√¥ h√¨nh g·ªëc.

### üéØ M·ª•c Ti√™u C·∫£i Thi·ªán

| Model | Original AUC | Target AUC | Expected Gain |
|-------|--------------|------------|---------------|
| ResNet-34 | 0.86 | 0.88-0.89 | +2-3% |
| ViT-Base | 0.86 | 0.88-0.90 | +2-4% |
| Swin Transformer | - | 0.89-0.91 | New |
| Ensemble | - | 0.90-0.92 | Best |

### üìä C√°c V·∫•n ƒê·ªÅ C·∫ßn Gi·∫£i Quy·∫øt

1. **Training from Scratch**: M√¥ h√¨nh g·ªëc kh√¥ng s·ª≠ d·ª•ng pre-trained weights ‚Üí M·∫•t ƒëi ki·∫øn th·ª©c h·ªçc ƒë∆∞·ª£c t·ª´ ImageNet
2. **Class Imbalance**: Dataset c·ª±c k·ª≥ m·∫•t c√¢n b·∫±ng (No Finding: 53.84% vs Hernia: 0.20%) ‚Üí Bias v·ªÅ c√°c class ph·ªï bi·∫øn
3. **Weak Augmentation**: Ch·ªâ d√πng flip & rotate ‚Üí Generalization k√©m
4. **Fixed Architecture**: Kh√¥ng t·ªëi ∆∞u h√≥a architecture ‚Üí B·ªè l·ª° modern techniques
5. **Label Noise**: NIH dataset c√≥ ~10% label errors t·ª´ NLP extraction ‚Üí H·ªçc sai patterns

## üîß Setup & Dependencies

### Gi·∫£i Th√≠ch
Ch√∫ng ta c·∫ßn c√°c th∆∞ vi·ªán sau:
- **timm**: SOTA vision models v·ªõi pre-trained weights
- **albumentations**: Advanced augmentation cho medical imaging
- **sklearn**: Metrics v√† utilities
- **torch**: Deep learning framework

In [1]:
# Standard libraries
import os
import sys
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm.auto import tqdm
from typing import Dict, List, Tuple, Optional

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torchvision
from torchvision import transforms

# Computer Vision
from PIL import Image
import cv2

# Advanced libraries
try:
    import timm  # PyTorch Image Models
    import albumentations as A
    from albumentations.pytorch import ToTensorV2
    print("‚úÖ All advanced libraries loaded successfully")
except ImportError as e:
    print(f"‚ùå Missing libraries: {e}")
    print("üîß Please install manually using:")
    print("   pip install timm albumentations")
    print("   Or activate your virtual environment and install there")
    print("\n‚ö†Ô∏è  Skipping advanced features for now...")
    
    # Create dummy classes to avoid import errors
    class DummyAlbumentations:
        class Compose:
            def __init__(self, transforms):
                self.transforms = transforms
            def __call__(self, **kwargs):
                return kwargs
        
        class Resize:
            def __init__(self, *args, **kwargs):
                pass
        class RandomCrop:
            def __init__(self, *args, **kwargs):
                pass
        class HorizontalFlip:
            def __init__(self, *args, **kwargs):
                pass
        class ShiftScaleRotate:
            def __init__(self, *args, **kwargs):
                pass
        class OneOf:
            def __init__(self, transforms, *args, **kwargs):
                self.transforms = transforms
        class GaussNoise:
            def __init__(self, *args, **kwargs):
                pass
        class GaussianBlur:
            def __init__(self, *args, **kwargs):
                pass
        class MotionBlur:
            def __init__(self, *args, **kwargs):
                pass
        class RandomBrightnessContrast:
            def __init__(self, *args, **kwargs):
                pass
        class CLAHE:
            def __init__(self, *args, **kwargs):
                pass
        class GridDistortion:
            def __init__(self, *args, **kwargs):
                pass
        class Normalize:
            def __init__(self, *args, **kwargs):
                pass
    
    A = DummyAlbumentations()
    
    class DummyToTensorV2:
        def __init__(self):
            pass
        def __call__(self, image):
            return image
    
    ToTensorV2 = DummyToTensorV2
    
    # Dummy timm
    class DummyTimm:
        def create_model(self, *args, **kwargs):
            raise NotImplementedError("timm not available - install with: pip install timm")
    
    timm = DummyTimm()
    
    print("‚úÖ Using dummy implementations - limited functionality")

# Scikit-learn
from sklearn.metrics import roc_auc_score, roc_curve, auc, confusion_matrix
from sklearn.model_selection import StratifiedKFold

# Suppress warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
def set_seed(seed=42):
    """ƒê·∫∑t random seed cho reproducibility"""
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üñ•Ô∏è  Device: {device}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   CUDA Version: {torch.version.cuda}")

‚úÖ All advanced libraries loaded successfully
üñ•Ô∏è  Device: cuda
   GPU: NVIDIA GeForce RTX 3060 Laptop GPU
   CUDA Version: 12.6


## üìÅ Configuration & Paths

### Gi·∫£i Th√≠ch
Centralized configuration gi√∫p d·ªÖ d√†ng ƒëi·ªÅu ch·ªânh hyperparameters v√† paths

In [2]:
# Project paths
PROJECT_ROOT = Path("D:/MSE/10.Deep Learning/Group_Final/ViT-Chest-Xray")
DATA_DIR = PROJECT_ROOT / "Project" / "data"
CSV_PATH = PROJECT_ROOT / "Project" / "input" / "Data_Entry_2017_v2020.csv"
IMAGE_DIR = PROJECT_ROOT / "Project" / "input" / "images"
SAVE_DIR = PROJECT_ROOT / "Project" / "improve" / "results"
SAVE_DIR.mkdir(exist_ok=True, parents=True)

# Training configuration
CONFIG = {
    # Data
    'img_size': 224,
    'num_classes': 15,
    'batch_size': 32,
    'num_workers': 4,
    
    # Training
    'epochs': 30,
    'learning_rate': 1e-4,
    'weight_decay': 1e-5,
    'warmup_epochs': 3,
    
    # Augmentation
    'use_advanced_aug': True,
    'aug_probability': 0.5,
    
    # Class imbalance
    'use_weighted_loss': True,
    'use_focal_loss': True,
    'focal_alpha': 0.25,
    'focal_gamma': 2.0,
    
    # Transfer learning
    'use_pretrained': True,
    'freeze_backbone_epochs': 5,  # Freeze backbone for first N epochs
    
    # Advanced techniques
    'use_label_smoothing': True,
    'label_smoothing': 0.1,
    'use_mixup': True,
    'mixup_alpha': 0.2,
    
    # Model saving
    'save_best_only': True,
    'early_stopping_patience': 10,
}

# Disease classes
DISEASE_CLASSES = [
    'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass',
    'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation', 'Edema',
    'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia', 'No Finding'
]

print("‚úÖ Configuration loaded successfully")
print(f"üìä Number of classes: {CONFIG['num_classes']}")
print(f"üñºÔ∏è  Image size: {CONFIG['img_size']}x{CONFIG['img_size']}")
print(f"üì¶ Batch size: {CONFIG['batch_size']}")

‚úÖ Configuration loaded successfully
üìä Number of classes: 15
üñºÔ∏è  Image size: 224x224
üì¶ Batch size: 32
