# Improved Model Training - Target: 97% Accuracy

This notebook implements advanced techniques to improve model accuracy from 82% to 97%.


In [None]:
import sys
from pathlib import Path

# Add model directory to path
sys.path.append(str(Path('../model')))

from train_improved import train_improved_model
import torch


In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Version: {torch.version.cuda}")


In [None]:
# Training parameters for improved accuracy
base_path = '../data/raw/DiabeticRetinopathyDataset'
num_epochs = 100  # More epochs for better convergence
batch_size = 32
learning_rate = 0.0001  # Lower LR for stable training
patience = 10  # Higher patience

print("Improved Training Configuration:")
print(f"  Epochs: {num_epochs}")
print(f"  Batch Size: {batch_size}")
print(f"  Learning Rate: {learning_rate}")
print(f"  Early Stopping Patience: {patience}")
print(f"\nAdvanced Techniques:")
print(f"  âœ“ Focal Loss for class imbalance")
print(f"  âœ“ Mixup augmentation")
print(f"  âœ“ Weighted random sampling")
print(f"  âœ“ Advanced data augmentation")
print(f"  âœ“ OneCycleLR scheduler")
print(f"  âœ“ Gradient clipping")


In [None]:
# Train the improved model
model, history = train_improved_model(
    base_path=base_path,
    num_epochs=num_epochs,
    batch_size=batch_size,
    learning_rate=learning_rate,
    patience=patience,
    device=device,
    save_dir='../model',
    use_focal_loss=True,
    use_mixup=True,
    use_weighted_sampler=True
)


In [None]:
# Display results
import json
from pathlib import Path

history_path = Path('../model/training_history_improved.json')
if history_path.exists():
    with open(history_path, 'r') as f:
        history_data = json.load(f)
    
    print("\n" + "="*60)
    print("Improved Training Results")
    print("="*60)
    print(f"Best Validation Accuracy: {history_data['best_val_acc']*100:.2f}%")
    print(f"Test Accuracy: {history_data['test_acc']*100:.2f}%")
    print(f"Misdiagnosis Rate: {history_data['misdiagnosis_rate']*100:.2f}%")
    
    # Compare with baseline
    baseline_acc = 0.82
    improvement = (history_data['test_acc'] - baseline_acc) * 100
    print(f"\nImprovement over baseline: +{improvement:.2f}%")
    print(f"Target: 97% | Current: {history_data['test_acc']*100:.2f}%")
    
    if history_data['test_acc'] >= 0.97:
        print("\nðŸŽ‰ Target accuracy achieved!")
    else:
        remaining = (0.97 - history_data['test_acc']) * 100
        print(f"\nðŸ“ˆ Need {remaining:.2f}% more improvement")
    print("="*60)
