
<br>
Medical Image Classification Model Training Script<br>
Training Script for Medical Image Classification<br>


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
import seaborn as sns
import os
from tqdm import tqdm
# CHANGED: Importing from the new densenet_model file
from densenet_model import create_model
import numpy as np
from sklearn.metrics import (accuracy_score, classification_report, confusion_matrix,
                             roc_auc_score, roc_curve, f1_score, precision_score, 
                             recall_score, average_precision_score)
import warnings
warnings.filterwarnings('ignore')

In [None]:
def get_data_loaders(data_dir, batch_size=32, img_size=224, use_weighted_sampler=False, num_workers=4):
    """
    Load data and prepare for training/validation, with support for handling imbalanced data.
    
    Args:
        data_dir: Path to data directory (must contain 'train' and 'val' folders).
        batch_size: Batch size.
        img_size: Image size after transformation.
        use_weighted_sampler: Use WeightedRandomSampler to balance data.
    """
    
    # Training image transformations (with enhanced data augmentation)
    train_transform = transforms.Compose([
        transforms.Resize((int(img_size * 1.1), int(img_size * 1.1))),  # Resize slightly larger
        transforms.RandomCrop(img_size),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),  # Good for medical images
        transforms.RandomRotation(degrees=15),
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.RandomGrayscale(p=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                             std=[0.229, 0.224, 0.225]),
        transforms.RandomErasing(p=0.1)  # Random Erasing technique
    ])
    
    # Validation image transformations (no augmentation)
    val_transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                             std=[0.229, 0.224, 0.225])
    ])
    
    # Load training data
    train_dataset = datasets.ImageFolder(
        root=os.path.join(data_dir, 'train'),
        transform=train_transform
    )
    
    # Load validation data
    val_dataset = datasets.ImageFolder(
        root=os.path.join(data_dir, 'val'),
        transform=val_transform
    )
    
    # Calculate weights to handle imbalanced data
    train_sampler = None
    class_weights = None
    
    if use_weighted_sampler:
        # Calculate sample count for each class
        class_counts = {}
        for idx, (path, class_idx) in enumerate(train_dataset.samples):
            class_counts[class_idx] = class_counts.get(class_idx, 0) + 1
        
        # Calculate weights (inverse of frequency)
        total_samples = sum(class_counts.values())
        num_classes = len(class_counts)
        class_weights = [total_samples / (num_classes * count) for count in class_counts.values()]
        
        # Create weights for each sample
        sample_weights = [class_weights[class_idx] for _, class_idx in train_dataset.samples]
        train_sampler = WeightedRandomSampler(
            weights=sample_weights,
            num_samples=len(sample_weights),
            replacement=True
        )
        print(f"âœ“ WeightedRandomSampler enabled")
        print(f"  Weights: {dict(zip(train_dataset.classes, class_weights))}")
    
    # Create DataLoaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=(train_sampler is None),  # Do not shuffle if sampler is used
        sampler=train_sampler,
        num_workers=num_workers,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    # Calculate weights for Weighted Loss if needed
    if not class_weights:
        class_counts = {}
        for _, class_idx in train_dataset.samples:
            class_counts[class_idx] = class_counts.get(class_idx, 0) + 1
        total_samples = sum(class_counts.values())
        num_classes = len(class_counts)
        class_weights = [total_samples / (num_classes * count) for count in class_counts.values()]
    
    return train_loader, val_loader, train_dataset.classes, class_weights

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    """Train the model for one epoch"""
    model.train()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    progress_bar = tqdm(train_loader, desc='Training')
    for images, labels in progress_bar:
        images = images.to(device)
        labels = labels.to(device)
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        
        # Update progress bar
        progress_bar.set_postfix({'loss': loss.item()})
    
    epoch_loss = running_loss / len(train_loader)
    epoch_acc = accuracy_score(all_labels, all_preds)
    
    return epoch_loss, epoch_acc

In [None]:
def validate(model, val_loader, criterion, device):
    """Evaluate model on validation data and calculate all metrics"""
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        progress_bar = tqdm(val_loader, desc='Validation')
        for images, labels in progress_bar:
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            
            # Calculate probabilities (for AUC-ROC)
            probs = torch.softmax(outputs, dim=1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
            
            progress_bar.set_postfix({'loss': loss.item()})
    
    epoch_loss = running_loss / len(val_loader)
    epoch_acc = accuracy_score(all_labels, all_preds)
    
    # Calculate additional metrics
    all_probs = np.array(all_probs)
    
    # AUC-ROC (for binary or multi-class classification)
    if len(np.unique(all_labels)) == 2:
        # Binary classification
        if all_probs.shape[1] == 2:
             auc_roc = roc_auc_score(all_labels, all_probs[:, 1])
        else:
             auc_roc = 0.0 # Edge case fallback
    else:
        # Multi-class classification
        try:
            auc_roc = roc_auc_score(all_labels, all_probs, multi_class='ovr', average='weighted')
        except:
            auc_roc = 0.0
    
    f1 = f1_score(all_labels, all_preds, average='weighted')
    precision = precision_score(all_labels, all_preds, average='weighted')
    recall = recall_score(all_labels, all_preds, average='weighted')
    
    return epoch_loss, epoch_acc, all_preds, all_labels, all_probs, {
        'auc_roc': auc_roc,
        'f1_score': f1,
        'precision': precision,
        'recall': recall
    }

In [None]:
def plot_training_history(history, save_path='training_history.png'):
    """Plot training history with all metrics"""
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    epochs = range(1, len(history['train_loss']) + 1)
    
    # Loss Plot
    axes[0, 0].plot(epochs, history['train_loss'], label='Training Loss', marker='o')
    axes[0, 0].plot(epochs, history['val_loss'], label='Validation Loss', marker='s')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Training and Validation Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True)
    
    # Accuracy Plot
    axes[0, 1].plot(epochs, history['train_acc'], label='Training Accuracy', marker='o')
    axes[0, 1].plot(epochs, history['val_acc'], label='Validation Accuracy', marker='s')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].set_title('Training and Validation Accuracy')
    axes[0, 1].legend()
    axes[0, 1].grid(True)
    
    # AUC-ROC Plot (if available)
    if 'val_auc_roc' in history:
        axes[1, 0].plot(epochs, history['val_auc_roc'], label='Validation AUC-ROC', 
                        marker='s', color='green')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('AUC-ROC')
        axes[1, 0].set_title('Validation AUC-ROC')
        axes[1, 0].legend()
        axes[1, 0].grid(True)
    
    # F1-Score Plot (if available)
    if 'val_f1_score' in history:
        axes[1, 1].plot(epochs, history['val_f1_score'], label='Validation F1-Score', 
                        marker='s', color='purple')
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('F1-Score')
        axes[1, 1].set_title('Validation F1-Score')
        axes[1, 1].legend()
        axes[1, 1].grid(True)
    else:
        axes[1, 1].axis('off')
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"Training history plot saved: {save_path}")

In [None]:
def plot_confusion_matrix(y_true, y_pred, class_names, save_path='confusion_matrix.png'):
    """Plot confusion matrix"""
    cm = confusion_matrix(y_true, y_pred)
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names,
                cbar_kws={'label': 'Number of Samples'})
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"Confusion matrix saved: {save_path}")

In [None]:
def plot_roc_curve(y_true, y_probs, class_names, save_path='roc_curve.png'):
    """Plot ROC curve"""
    n_classes = len(class_names)
    
    plt.figure(figsize=(10, 8))
    
    if n_classes == 2:
        # Binary classification
        if y_probs.shape[1] == 2:
            fpr, tpr, _ = roc_curve(y_true, y_probs[:, 1])
            auc_score = roc_auc_score(y_true, y_probs[:, 1])
            plt.plot(fpr, tpr, label=f'{class_names[1]} (AUC = {auc_score:.3f})', linewidth=2)
    else:
        # Multi-class classification
        for i in range(n_classes):
            y_true_binary = (y_true == i).astype(int)
            # Check if class exists in batch to avoid errors
            if np.sum(y_true_binary) > 0:
                fpr, tpr, _ = roc_curve(y_true_binary, y_probs[:, i])
                try:
                    auc_score = roc_auc_score(y_true_binary, y_probs[:, i])
                    plt.plot(fpr, tpr, label=f'{class_names[i]} (AUC = {auc_score:.3f})', linewidth=2)
                except:
                    pass
    
    plt.plot([0, 1], [0, 1], 'k--', label='Random Classifier', linewidth=1)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate (1 - Specificity)')
    plt.ylabel('True Positive Rate (Sensitivity)')
    plt.title('ROC Curve')
    plt.legend(loc='lower right')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"ROC curve saved: {save_path}")

In [None]:
def main():
    # ==================== Settings ====================
    DATA_DIR = 'data'  # Path to data folder
    
    # CHANGED: Reduced batch size slightly for DenseNet safety (optional, can be 64 if GPU is good)
    BATCH_SIZE = 32 
    NUM_EPOCHS = 30
    LEARNING_RATE = 0.001
    IMG_SIZE = 224
    NUM_CLASSES = None  # Automatically detected from data
    SAVE_DIR = 'checkpoints'
    
    # Transfer Learning Settings
    USE_TRANSFER_LEARNING = True 
    # CHANGED: Set model name to densenet121
    MODEL_NAME = 'densenet121'  # 'resnet18', 'resnet50', 'vgg16', 'densenet121', 'inception_v3'
    FREEZE_BACKBONE = True  # Freeze backbone initially
    AUTO_UNFREEZE = True
    UNFREEZE_EPOCH = 5  # Unfreeze backbone after this epoch
    PRETRAINED = True 
    
    # Imbalanced Data Settings
    USE_WEIGHTED_SAMPLER = True 
    USE_WEIGHTED_LOSS = True 
    
    # Learning Rate Scheduler Settings
    USE_SCHEDULER = True
    SCHEDULER_TYPE = 'cosine'  # 'plateau', 'cosine', 'step'
    
    # ==================== Initialize Directories ====================
    os.makedirs(SAVE_DIR, exist_ok=True)
    
    # GPU Check
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("="*70)
    print(f"Device: {device}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
    print("="*70)
    
    if AUTO_UNFREEZE and UNFREEZE_EPOCH > NUM_EPOCHS:
        UNFREEZE_EPOCH = NUM_EPOCHS
    
    # ==================== Load Data ====================
    print("\nLoading data...")
    train_loader, val_loader, class_names, class_weights = get_data_loaders(
        DATA_DIR,
        batch_size=BATCH_SIZE,
        img_size=IMG_SIZE,
        use_weighted_sampler=USE_WEIGHTED_SAMPLER,
        num_workers=4
    )
    print(f"\nâœ“ Data loaded successfully!")
    print(f"  Classes: {class_names}")
    print(f"  Training images: {len(train_loader.dataset)}")
    print(f"  Validation images: {len(val_loader.dataset)}")
    
    detected_num_classes = len(class_names)
    if NUM_CLASSES is None:
        NUM_CLASSES = detected_num_classes
    elif NUM_CLASSES != detected_num_classes:
        print(f"\nâڑ ï¸ڈ  NUM_CLASSES ({NUM_CLASSES}) matches detected classes "
              f"({detected_num_classes}). Using data-detected value.")
        NUM_CLASSES = detected_num_classes
    print(f"  Total classes: {NUM_CLASSES}")
    
    # ==================== Create Model ====================
    print(f"\nCreating model...")
    if USE_TRANSFER_LEARNING:
        print(f"  Model Type: Transfer Learning ({MODEL_NAME})")
        model = create_model(
            num_classes=NUM_CLASSES,
            input_channels=3,
            pretrained=PRETRAINED,
            model_type='transfer',
            model_name=MODEL_NAME,
            freeze_backbone=FREEZE_BACKBONE
        )
    else:
        print(f"  Model Type: Custom CNN (From Scratch)")
        model = create_model(
            num_classes=NUM_CLASSES,
            input_channels=3,
            pretrained=False,
            model_type='custom'
        )
    
    model = model.to(device)
    backbone_unfrozen = not FREEZE_BACKBONE
    if FREEZE_BACKBONE:
        print("  Backbone layers initially frozen.")
    
    # Print model info
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"  Total parameters: {total_params:,}")
    print(f"  Trainable parameters: {trainable_params:,}")
    
    # ==================== Loss Function ====================
    if USE_WEIGHTED_LOSS and class_weights:
        weights = torch.FloatTensor(class_weights).to(device)
        criterion = nn.CrossEntropyLoss(weight=weights)
        print(f"\nâœ“ Weighted Loss enabled")
        print(f"  Weights: {dict(zip(class_names, class_weights))}")
    else:
        criterion = nn.CrossEntropyLoss()
    
    # ==================== Optimizer ====================
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
    
    # ==================== Learning Rate Scheduler ====================
    scheduler = None
    if USE_SCHEDULER:
        if SCHEDULER_TYPE == 'plateau':
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, mode='min', factor=0.5, patience=5
            )
        elif SCHEDULER_TYPE == 'cosine':
            scheduler = optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=NUM_EPOCHS, eta_min=1e-6
            )
        elif SCHEDULER_TYPE == 'step':
            scheduler = optim.lr_scheduler.StepLR(
                optimizer, step_size=15, gamma=0.5
            )
        print(f"\nâœ“ Learning Rate Scheduler enabled ({SCHEDULER_TYPE})")
    
    # ==================== Training History ====================
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': [],
        'val_auc_roc': [],
        'val_f1_score': [],
        'val_precision': [],
        'val_recall': []
    }
    
    best_val_acc = 0.0
    best_val_auc = 0.0
    
    # ==================== Training Loop ====================
    print("\n" + "="*70)
    print("Starting Training...")
    print("="*70)
    
    for epoch in range(NUM_EPOCHS):
        print(f"\n{'='*70}")
        print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
        print(f"{'='*70}")
        
        if FREEZE_BACKBONE and AUTO_UNFREEZE and not backbone_unfrozen and (epoch + 1) >= UNFREEZE_EPOCH:
            print(f"--> Epoch {epoch+1}: Unfreezing backbone layers.")
            for param in model.parameters():
                param.requires_grad = True
            backbone_unfrozen = True
        
        # Train
        train_loss, train_acc = train_epoch(
            model, train_loader, criterion, optimizer, device
        )
        
        # Validate
        val_loss, val_acc, val_preds, val_labels, val_probs, metrics = validate(
            model, val_loader, criterion, device
        )
        
        # Update LR
        if scheduler:
            if SCHEDULER_TYPE == 'plateau':
                scheduler.step(val_loss)
            else:
                scheduler.step()
        
        # Save History
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        history['val_auc_roc'].append(metrics['auc_roc'])
        history['val_f1_score'].append(metrics['f1_score'])
        history['val_precision'].append(metrics['precision'])
        history['val_recall'].append(metrics['recall'])
        
        # Print Results
        current_lr = optimizer.param_groups[0]['lr']
        print(f"\nTrain Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
        print(f"Val AUC-ROC: {metrics['auc_roc']:.4f}, Val F1-Score: {metrics['f1_score']:.4f}")
        print(f"Val Precision: {metrics['precision']:.4f}, Val Recall: {metrics['recall']:.4f}")
        print(f"Learning Rate: {current_lr:.6f}")
        
        # Save Best Model
        save_model = False
        if metrics['auc_roc'] > best_val_auc:
            best_val_auc = metrics['auc_roc']
            save_model = True
            print(f"âœ“ AUC-ROC improved! (AUC: {best_val_auc:.4f})")
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            save_model = True
            print(f"âœ“ Accuracy improved! (Acc: {best_val_acc:.4f})")
        
        if save_model:
            # CHANGED: Saving as densenet
            save_path = os.path.join(SAVE_DIR, 'best_densenet121_model.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
                'val_auc_roc': metrics['auc_roc'],
                'class_names': class_names,
                'model_type': 'transfer' if USE_TRANSFER_LEARNING else 'custom',
                'model_name': MODEL_NAME if USE_TRANSFER_LEARNING else 'custom'
            }, save_path)
            print(f"  Model saved to {save_path}")
    
    # ==================== Final Evaluation ====================
    print("\n" + "="*70)
    print("Final Evaluation")
    print("="*70)
    
    # Classification Report
    print("\nClassification Report:")
    print(classification_report(val_labels, val_preds, target_names=class_names))
    
    # Plotting
    print("\nPlotting graphs...")
    plot_training_history(history, save_path=os.path.join(SAVE_DIR, 'training_history.png'))
    plot_confusion_matrix(val_labels, val_preds, class_names, 
                          save_path=os.path.join(SAVE_DIR, 'confusion_matrix.png'))
    plot_roc_curve(val_labels, val_probs, class_names, 
                   save_path=os.path.join(SAVE_DIR, 'roc_curve.png'))
    
    # Summary
    print("\n" + "="*70)
    print("Final Result Summary:")
    print("="*70)
    print(f"Best Accuracy: {best_val_acc:.4f}")
    print(f"Best AUC-ROC: {best_val_auc:.4f}")
    print(f"Final F1-Score: {metrics['f1_score']:.4f}")
    print(f"Final Precision: {metrics['precision']:.4f}")
    print(f"Final Recall: {metrics['recall']:.4f}")
    print("="*70)
    # CHANGED: Updated filename in print statement
    print(f"\nâœ“ Training completed successfully!")
    print(f"  Best model saved: {os.path.join(SAVE_DIR, 'best_densenet121_model.pth')}")

In [None]:
if __name__ == '__main__':
    main()

![WhatsApp Image 2025-12-30 at 1.39.29 PM (1).jpeg](<attachment:WhatsApp Image 2025-12-30 at 1.39.29 PM (1).jpeg>)

### **Eğitim Sonuçlarının Analizi**

Eğitim sürecine ait grafikleri incelediğimizde şu sonuçları görüyoruz:

* **Loss (Kayıp) Grafiği:** Hem Training hem de Validation loss değerleri düzenli bir şekilde düşüyor. En önemlisi, iki çizgi birbirine yakın seyrediyor; bu da modelin ezberlemediğini (overfitting yapmadığını) ve veriyi gerçekten öğrendiğini gösteriyor.
* **Accuracy (Doğruluk) Grafiği:** Modelimizin başarısı her epoch'ta artarak devam etmiş ve sonuçta **%95** civarında yüksek bir doğruluk oranına ulaşmıştır. Grafikteki dalgalanmaların azalması, eğitimin istikrarlı bir şekilde tamamlandığını işaret ediyor.

![alt text](<WhatsApp Image 2025-12-30 at 1.39.29 PM-1.jpeg>)

### **Confusion Matrix (Hata Matrisi) Değerlendirmesi**

Modelin hangi sınıflarda ne kadar başarılı olduğunu daha detaylı görmek için Confusion Matrix'i inceledik:

* **Genel Başarı:** Matrisin köşegeni (diagonal) üzerindeki koyu renkli kutular, doğru tahmin sayılarımızın yüksek olduğunu gösteriyor.
* **Sınıf Bazlı Analiz:**
    * **Tüberküloz** ve **Normal** sınıflarındaki tüm görüntüleri hatasız (%100) bilmişiz.
    * Hata payımız oldukça düşük; sadece **1 COVID-19** vakasını Tüberküloz ile, ve **1 Pnömoni** vakasını Normal ile karıştırmış.
* **Sonuç:** Toplamda sadece 2 yanlış tahminimiz var. Bu da geliştirdiğimiz DenseNet modelinin medikal sınıflandırma için gayet güvenilir çalıştığını kanıtlıyor.