In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim.lr_scheduler
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
import pydicom
from pathlib import Path
from typing import Dict, List, Tuple
from collections import defaultdict
from sklearn.metrics import confusion_matrix, log_loss
import timm
import warnings
warnings.filterwarnings('ignore')

# Import from our modules
from data_visualization import LumbarSpineVisualizer
from pattern_analysis import PatternAnalyzer
from preprocessing_pipeline import (
    preprocess_image,
    save_processed_data,
    process_fold_data,
    create_stratified_folds
)
from classification_model import (
    LumbarSpineDataset,
    AttentionBlock,
    LumbarClassifier,
    train_epoch as train_classification,
    validate as validate_classification
)
from regression_model import (
    LumbarSpineRegDataset,
    LumbarRegressor,
    WeightedL1Loss,
    train_epoch as train_regression,
    validate as validate_regression
)
from evaluation_metrics import (
    compute_competition_metric,
    evaluate_model,
    evaluate_regression_model
)
from prediction_pipeline import (
    OptimizedPredictionPipeline,
    predict_preprocessed_samples
)
from advanced_analysis import AdvancedAnalysis

class LumbarSpineAnalysis:
    """Main class for lumbar spine analysis"""
    def __init__(self, base_path: str):
        self.base_path = Path(base_path)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {device}")

        # Initialize components
        self.initialize_components()

    def initialize_components(self):
        """Initialize all necessary components"""
        # Load data
        self.train_df = pd.read_csv(self.base_path / 'train.csv')
        self.coords_df = pd.read_csv(self.base_path / 'train_label_coordinates.csv')
        self.series_df = pd.read_csv(self.base_path / 'train_series_descriptions.csv')

        # Create visualizer
        self.visualizer = LumbarSpineVisualizer()

        # Initialize models
        self.classification_model = None
        self.regression_model = None
        self.prediction_pipeline = None

    def preprocess_data(self):
        """Preprocess and save data"""
        print("Creating stratified folds...")
        folds = create_stratified_folds(self.train_df)

        print("Processing training data...")
        train_samples = process_fold_data(
            study_ids=folds[0]['train'],
            base_path=self.base_path,
            coords_df=self.coords_df,
            series_df=self.series_df,
            train_df=self.train_df,
            augment=True
        )

        print("Processing validation data...")
        val_samples = process_fold_data(
            study_ids=folds[0]['val'],
            base_path=self.base_path,
            coords_df=self.coords_df,
            series_df=self.series_df,
            train_df=self.train_df,
            augment=False
        )

        # Save processed data
        save_processed_data(train_samples, 'train_processed.npy')
        save_processed_data(val_samples, 'val_processed.npy')

        return train_samples, val_samples

    def train_classification_model(self, train_samples, val_samples):
        """Train classification model"""
        print("\nTraining Classification Model...")

        # Create datasets
        train_dataset = LumbarSpineDataset(train_samples, augment=True)
        val_dataset = LumbarSpineDataset(val_samples, augment=False)

        # Create dataloaders
        train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
        val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

        # Initialize model
        self.classification_model = LumbarClassifier().to(self.device)

        # Training parameters
        criterion = nn.CrossEntropyLoss(
            weight=torch.tensor([1.0, 4.75, 12.29]).to(self.device)
        )
        optimizer = torch.optim.AdamW(
            self.classification_model.parameters(),
            lr=1e-4,
            weight_decay=0.01
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=30,
            eta_min=1e-6
        )

        # Training loop
        best_val_acc = 0
        for epoch in range(30):
            train_loss, train_acc = train_classification(
                self.classification_model, train_loader, criterion, optimizer, self.device
            )
            val_loss, val_acc = validate_classification(
                self.classification_model, val_loader, criterion, self.device
            )
            scheduler.step()

            print(f"Epoch {epoch+1}/30:")
            print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
            print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save(self.classification_model.state_dict(), 'best_model.pth')

    def train_regression_model(self, train_samples, val_samples):
        """Train regression model"""
        print("\nTraining Regression Model...")

        # Create datasets
        train_dataset = LumbarSpineRegDataset(train_samples, augment=True)
        val_dataset = LumbarSpineRegDataset(val_samples, augment=False)

        # Create dataloaders
        train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
        val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

        # Initialize model
        self.regression_model = LumbarRegressor().to(self.device)

        # Training parameters
        criterion = WeightedL1Loss()
        optimizer = torch.optim.AdamW(
            self.regression_model.parameters(),
            lr=1e-4,
            weight_decay=0.01
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=30,
            eta_min=1e-6
        )

        # Training loop
        best_val_loss = float('inf')
        for epoch in range(30):
            train_loss = train_regression(
                self.regression_model, train_loader, criterion, optimizer, self.device
            )
            val_loss, val_mae = validate_regression(
                self.regression_model, val_loader, criterion, self.device
            )
            scheduler.step()

            print(f"Epoch {epoch+1}/30:")
            print(f"Train Loss: {train_loss:.4f}")
            print(f"Val Loss: {val_loss:.4f}, Val MAE: {val_mae:.4f}")

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(self.regression_model.state_dict(), 'best_regression_model.pth')

    def analyze_results(self, val_samples):
        """Analyze model results"""
        # Load best models
        self.classification_model.load_state_dict(torch.load('best_model.pth'))
        self.regression_model.load_state_dict(torch.load('best_regression_model.pth'))

        # Create prediction pipeline
        self.prediction_pipeline = OptimizedPredictionPipeline(
            self.classification_model,
            self.device
        )

        # Create analyzer
        analyzer = AdvancedAnalysis(
            self.prediction_pipeline,
            'val_processed.npy'
        )

        # Run analyses
        print("\nRunning Statistical Analysis...")
        stats_results = analyzer.analyze_prediction_patterns(num_samples=500)

        print("\nAnalyzing Challenging Cases...")
        analyzer.analyze_challenging_cases(num_cases=5)

    def predict_new_cases(self, image_paths: List[str]):
        """Predict on new cases"""
        for image_path in image_paths:
            # Load and preprocess image
            dcm = pydicom.dcmread(image_path)
            image = dcm.pixel_array
            processed_image = preprocess_image(image)

            # Make prediction
            prediction = self.prediction_pipeline.predict(
                torch.from_numpy(processed_image).unsqueeze(0),
                condition='Spinal Canal Stenosis',  # Example condition
                level='L4_L5'  # Example level
            )

            # Visualize results
            self.visualizer.plot_prediction(
                image,
                prediction,
                title=f"Prediction for {Path(image_path).name}"
            )

def main():
    # Initialize
    base_path = '/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification'
    analyzer = LumbarSpineAnalysis(base_path)

    # Preprocess data
    train_samples, val_samples = analyzer.preprocess_data()

    # Train models
    analyzer.train_classification_model(train_samples, val_samples)
    analyzer.train_regression_model(train_samples, val_samples)

    # Analyze results
    analyzer.analyze_results(val_samples)

    # Optional: Predict on new images
    # new_images = ['path/to/image1.dcm', 'path/to/image2.dcm']
    # analyzer.predict_new_cases(new_images)

if __name__ == "__main__":
    main()