In [3]:
# ============================================================================
# Cell 1: Setup
# ============================================================================
import sys
import os
project_root = os.path.abspath('..')
sys.path.insert(0, project_root)

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms

from src.utils.config_loader import ConfigLoader
from src.models import LeNet5
from src.data.dataset import ChestXrayDataset
from src.training.trainer import Trainer

print("‚úÖ Imports successful!")

‚úÖ Imports successful!


In [None]:
# ============================================================================
# Cell 2: Load configs
# ============================================================================
loader = ConfigLoader(project_root)

model_config = loader.load_model_config('lenet')
data_config = loader.load_data_config()

print("=" * 70)
print("üìã CONFIGURATION")
print("=" * 70)
print(f"\nModel: {model_config['model']['name']}")
print(f"  Type: {model_config['model']['type']}")
print(f"  Description: {model_config['model']['description']}")

print(f"\nTraining:")
print(f"  Epochs: {model_config['training']['num_epochs']}")
print(f"  Batch Size: {model_config['training']['batch_size']}")
print(f"  Learning Rate: {model_config['training']['learning_rate']}")
print(f"  Weighted Loss: {model_config['training']['loss'].get('weighted', False)}")

print(f"\nData:")
print(f"  Root: {data_config['data']['root_dir']}")
print(f"  Classes: {data_config['data']['classes']}")
print(f"  Workers: {model_config['data']['num_workers']}")
print(f"  Prefetch: {model_config['data']['prefetch_factor']}")
print("=" * 70)

In [None]:
# ============================================================================
# Cell 3: Prepare data transforms
# ============================================================================
norm_type = model_config['data']['normalization']
mean = data_config['data']['normalization'][norm_type]['mean']
std = data_config['data']['normalization'][norm_type]['std']

print(f"Normalization: {norm_type}")
print(f"  Mean: {mean}")
print(f"  Std: {std}")

# Training transforms (with augmentation)
train_transform = transforms.Compose([
    transforms.Resize((data_config['data']['image_size'], 
                      data_config['data']['image_size'])),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(data_config['data']['augmentation']['rotation_degrees']),
    transforms.ColorJitter(
        brightness=data_config['data']['augmentation']['color_jitter']['brightness'],
        contrast=data_config['data']['augmentation']['color_jitter']['contrast']
    ),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

# Validation transforms (no augmentation)
val_transform = transforms.Compose([
    transforms.Resize((data_config['data']['image_size'], 
                      data_config['data']['image_size'])),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

print("‚úÖ Transforms created")

In [None]:
# =============================================================================
# Cell 3.5: Calculate Class Weights
# =============================================================================
import torch
import numpy as np

# L·∫•y class distribution
dist = train_dataset.get_class_distribution()
classes = train_dataset.classes

print("=" * 70)
print("üìä CLASS DISTRIBUTION & WEIGHTS")
print("=" * 70)

# Hi·ªÉn th·ªã distribution
total_samples = sum(dist.values())
print("\nOriginal Distribution:")
for cls in classes:
    count = dist[cls]
    percentage = (count / total_samples) * 100
    print(f"  {cls:20s}: {count:5d} ({percentage:5.2f}%)")

# T√≠nh class weights (inverse frequency)
class_weights = []
print("\nClass Weights (Inverse Frequency):")
for cls in classes:
    count = dist[cls]
    weight = total_samples / (len(classes) * count)
    class_weights.append(weight)
    print(f"  {cls:20s}: {weight:.4f}")

class_weights = torch.FloatTensor(class_weights)

print(f"\nWeights tensor: {class_weights}")
print(f"Sum of weights: {class_weights.sum():.4f}")
print("=" * 70)

In [None]:
# ============================================================================
# Cell 4: Create datasets
# ============================================================================
train_dataset = ChestXrayDataset(
    root_dir=os.path.join(data_config['data']['root_dir'], 'train'),
    transform=train_transform
)

val_dataset = ChestXrayDataset(
    root_dir=os.path.join(data_config['data']['root_dir'], 'val'),
    transform=val_transform
)

print(f"\n‚úÖ Train: {len(train_dataset)} samples")
print(f"‚úÖ Val: {len(val_dataset)} samples")

# Show distribution
print("\nüìä Train Distribution:")
for cls, count in train_dataset.get_class_distribution().items():
    print(f"  {cls:20s}: {count:5d}")

In [None]:
# ============================================================================
# Cell 5: Create dataloaders (OPTIMIZED)
# ============================================================================
train_loader = DataLoader(
    train_dataset,
    batch_size=model_config['training']['batch_size'],
    shuffle=True,
    num_workers=model_config['data']['num_workers'],
    pin_memory=model_config['data']['pin_memory'],
    prefetch_factor=model_config['data']['prefetch_factor'],
    persistent_workers=model_config['data']['persistent_workers']
)

val_loader = DataLoader(
    val_dataset,
    batch_size=model_config['training']['batch_size'],
    shuffle=False,
    num_workers=model_config['data']['num_workers'],
    pin_memory=model_config['data']['pin_memory'],
    prefetch_factor=model_config['data']['prefetch_factor'],
    persistent_workers=model_config['data']['persistent_workers']
)

print(f"‚úÖ Train loader: {len(train_loader)} batches")
print(f"‚úÖ Val loader: {len(val_loader)} batches")

In [None]:
# ============================================================================
# Cell 6: Create model
# ============================================================================
device = model_config['device'] if torch.cuda.is_available() else 'cpu'

model = LeNet5(num_classes=model_config['model']['num_classes'])
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(
    model.parameters(),
    lr=model_config['training']['learning_rate'],
    weight_decay=model_config['training']['weight_decay'],
    betas=tuple(model_config['training']['optimizer']['betas'])
)

print("=" * 70)
print("ü§ñ MODEL SETUP")
print("=" * 70)
print(f"Model: {model_config['model']['name']}")
print(f"Device: {device}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"\nLoss: CrossEntropyLoss")
print(f"Optimizer: {model_config['training']['optimizer']['type']}")
print(f"LR: {model_config['training']['learning_rate']}")
print("=" * 70)

In [None]:
# ============================================================================
# Cell 7: Train
# ============================================================================
import time

trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    checkpoint_dir=model_config['checkpoint']['save_dir']
)

print("\nüöÄ Starting training...")
start_time = time.time()

history = trainer.train(
    num_epochs=model_config['training']['num_epochs'],
    save_best=model_config['checkpoint']['save_best']
)

elapsed_time = time.time() - start_time
print(f"\n‚è±Ô∏è  Total training time: {elapsed_time/60:.1f} minutes")

# Rename checkpoint
import shutil
old_path = os.path.join(model_config['checkpoint']['save_dir'], 'best_model.pth')
new_path = os.path.join(model_config['checkpoint']['save_dir'], 
                        f"{model_config['model']['type']}_best.pth")
if os.path.exists(old_path):
    shutil.move(old_path, new_path)
    print(f"‚úÖ Checkpoint: {new_path}")

In [None]:
# ============================================================================
# Cell 8: Visualize
# ============================================================================
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Loss
axes[0].plot(history['train_loss'], label='Train', marker='o', linewidth=2)
axes[0].plot(history['val_loss'], label='Val', marker='s', linewidth=2)
axes[0].set_title(f"{model_config['model']['name']} - Loss", fontsize=14, fontweight='bold')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy
axes[1].plot(history['train_acc'], label='Train', marker='o', linewidth=2)
axes[1].plot(history['val_acc'], label='Val', marker='s', linewidth=2)
axes[1].set_title(f"{model_config['model']['name']} - Accuracy", fontsize=14, fontweight='bold')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
save_path = f"../results/figures/{model_config['model']['type']}_training.png"
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.show()

print(f"‚úÖ Plot saved: {save_path}")
print(f"\nüìä Best Val Accuracy: {max(history['val_acc']):.2f}%")