In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import nibabel as nib
import numpy as np
import os
import glob
from pathlib import Path
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
from tqdm.notebook import tqdm
import time
import copy
import csv
import pandas as pd
from datetime import datetime

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

In [None]:
class CSVLogger:
    """
    CSV Logger for training metrics - similar to Keras CSVLogger
    """
    def __init__(self, filename='training_log.csv', separator=',', append=False):
        """
        Initialize CSV Logger
        
        Args:
            filename: Name of the CSV file
            separator: Separator character (default: comma)
            append: Whether to append to existing file or overwrite
        """
        self.filename = filename
        self.separator = separator
        self.append = append
        self.keys = None
        self.file_exists = os.path.isfile(filename) and append
        
        # Create directory if it doesn't exist
        os.makedirs(os.path.dirname(filename) if os.path.dirname(filename) else '.', exist_ok=True)
        
        if not self.file_exists:
            # Create new file with headers
            self.file = open(filename, 'w', newline='', encoding='utf-8')
            self.writer = None
        else:
            # Append to existing file
            self.file = open(filename, 'a', newline='', encoding='utf-8')
            self.writer = csv.writer(self.file, delimiter=self.separator)
    
    def log(self, logs):
        """
        Log metrics to CSV file
        
        Args:
            logs: Dictionary of metrics to log
        """
        if self.keys is None:
            self.keys = sorted(logs.keys())
            self.writer = csv.DictWriter(self.file, fieldnames=self.keys, delimiter=self.separator)
            if not self.file_exists:
                self.writer.writeheader()
        
        # Write the row
        self.writer.writerow(logs)
        self.file.flush()
    
    def close(self):
        """Close the CSV file"""
        if hasattr(self, 'file'):
            self.file.close()
    
    def __enter__(self):
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()

In [None]:
class Config:
    # Data paths - UPDATE THESE PATHS
    data_root = r"D:\Kananat\Data\training_dataset"  # Update this path
    train_dir = os.path.join(data_root, "train")
    val_dir = os.path.join(data_root, "val") 
    test_dir = os.path.join(data_root, "test")
    
    # Training parameters
    batch_size = 2  # Small batch size for 3D data due to memory constraints
    num_epochs = 100
    learning_rate = 1e-4
    weight_decay = 1e-5
    
    # Model parameters
    input_size = (224, 224, 224)
    num_classes = 2
    
    # Training settings
    patience = 25  # Early stopping patience
    save_best_model = True
    model_save_path = r"C:\Users\kanan\Desktop\Project_TMJOA\3D_Pipeline\densenet121_tiny_nonrotate_augment\densenet121_tiny_best.pth"
    
    # CSV Logging
    csv_log_path = r"C:\Users\kanan\Desktop\Project_TMJOA\3D_Pipeline\densenet121_tiny_nonrotate_augment\densenet121_tiny_history.csv"
    log_detailed_metrics = True  # Log additional metrics like learning rate, epoch time
    
    # Monitoring
    print_freq = 1
    plot_losses = True

    # Backbone
    backbone = "convnext_tiny"  # Backbone model to use

config = Config()

In [None]:
class Medical3DDataset(Dataset):
    def __init__(self, data_dir, transform=None, target_size=(224, 224, 224)):
        """
        Dataset for 3D medical images in .nii.gz format
        
        Args:
            data_dir: Directory containing class folders
            transform: Data augmentation function (if provided)
            target_size: Target size for resizing (224, 224, 224)
        """
        self.data_dir = data_dir
        self.transform = transform
        self.target_size = target_size
        self.samples = []
        self.class_to_idx = {}
        
        # Get class folders
        class_folders = [d for d in os.listdir(data_dir) 
                        if os.path.isdir(os.path.join(data_dir, d))]
        class_folders.sort()
        
        # Create class to index mapping
        for idx, class_name in enumerate(class_folders):
            self.class_to_idx[class_name] = idx
            
        # Collect all samples
        for class_name in class_folders:
            class_dir = os.path.join(data_dir, class_name)
            class_idx = self.class_to_idx[class_name]
            
            # Find all .nii.gz files
            nii_files = glob.glob(os.path.join(class_dir, "*.nii.gz"))
            
            for file_path in nii_files:
                self.samples.append((file_path, class_idx))
                
        print(f"Found {len(self.samples)} samples in {len(class_folders)} classes")
        print(f"Classes: {self.class_to_idx}")
        
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        file_path, label = self.samples[idx]
        
        # Load .nii.gz file
        nii_img = nib.load(file_path)
        image = nii_img.get_fdata()
        
        # Convert to float32
        image = image.astype(np.float32)
        
        # Resize from (255, 255, 255) to (224, 224, 224)
        image = self.resize_3d(image, self.target_size)
        
        # Normalize to [0, 1]
        if image.max() > image.min():
            image = (image - image.min()) / (image.max() - image.min())

        # Apply data augmentation if provided
        if self.transform:
            image = self.transform(image)
            image = image.astype(np.float32)

        if image.max() > image.min():
            image = (image - image.min()) / (image.max() - image.min())
        
        # Add channel dimension and convert to tensor
        image = torch.from_numpy(image).unsqueeze(0)  # Shape: (1, 224, 224, 224)
            
        return image, torch.tensor(label, dtype=torch.long)
    
    def resize_3d(self, image, target_size):
        """
        Resize 3D image from (255, 255, 255) to target_size using interpolation
        """
        from scipy.ndimage import zoom
        
        current_size = image.shape
        zoom_factors = [target_size[i] / current_size[i] for i in range(3)]
        
        # Use order=1 for linear interpolation (good balance of speed and quality)
        resized_image = zoom(image, zoom_factors, order=1)
        
        return resized_image

In [None]:
import volumentations

def get_augmentation(patch_size):
    return volumentations.Compose([
        # volumentations.Rotate((-10, 10), (0, 0), (0, 0), p=0.5),
        # volumentations.Rotate((0, 0), (-10, 10), (0, 0), p=0.5),
        # volumentations.Rotate((0, 0), (0, 0), (-10, 10), p=0.5),
        volumentations.RandomCropFromBorders(crop_value=0.1, p=0.5),
        volumentations.ElasticTransform((0, 0.25), interpolation=2, p=0.1),
        volumentations.Resize(patch_size, interpolation=1, resize_type=0, always_apply=True, p=1.0),
        # volumentations.Flip(0, p=0.5),
        # volumentations.Flip(1, p=0.5),
        # volumentations.Flip(2, p=0.5),
        # volumentations.RandomRotate90((1, 2), p=0.5),
        volumentations.GaussianNoise(var_limit=(0, 5), p=0.2),
        # volumentations.RandomGamma(gamma_limit=(80, 120), p=0.2),
    ], p=1.0)

def augment_volume(volume):

    augment = get_augmentation(volume.shape)
    augmented_volume = augment(image=volume)["image"]
    
    return augmented_volume

In [None]:
def create_data_loaders(config, train_transform=None, val_transform=None):
    """Create train, validation, and test data loaders"""
    
    # Create datasets
    train_dataset = Medical3DDataset(config.train_dir, transform=train_transform)
    val_dataset = Medical3DDataset(config.val_dir, transform=val_transform)
    test_dataset = Medical3DDataset(config.test_dir, transform=val_transform)
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=config.batch_size, 
        shuffle=True, 
        num_workers=0,
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=config.batch_size, 
        shuffle=False, 
        num_workers=0,
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    test_loader = DataLoader(
        test_dataset, 
        batch_size=config.batch_size, 
        shuffle=False, 
        num_workers=0,
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    return train_loader, val_loader, test_loader, train_dataset.class_to_idx

# Create data loaders
# Note: Pass your data augmentation function as train_transform if you have one
# Example: train_loader, val_loader, test_loader, class_to_idx = create_data_loaders(config, train_transform=your_augmentation_function)
train_loader, val_loader, test_loader, class_to_idx = create_data_loaders(config, train_transform=augment_volume)
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device, epoch, config):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    running_corrects = 0
    total_samples = 0
    
    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{config.num_epochs}')
    
    for batch_idx, (inputs, labels) in enumerate(progress_bar):
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Statistics
        _, preds = torch.max(outputs, 1)
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)
        total_samples += inputs.size(0)
        
        # Update progress bar
        if batch_idx % config.print_freq == 0:
            current_loss = running_loss / total_samples
            current_acc = running_corrects.double() / total_samples
            progress_bar.set_postfix({
                'Loss': f'{current_loss:.4f}',
                'Acc': f'{current_acc:.4f}'
            })
    
    epoch_loss = running_loss / total_samples
    epoch_acc = running_corrects.double() / total_samples
    
    return epoch_loss, epoch_acc.item()

In [None]:
def validate_epoch(model, val_loader, criterion, device):
    """Validate for one epoch"""
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    total_samples = 0
    
    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc='Validating'):
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            _, preds = torch.max(outputs, 1)
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
            total_samples += inputs.size(0)
    
    epoch_loss = running_loss / total_samples
    epoch_acc = running_corrects.double() / total_samples
    
    return epoch_loss, epoch_acc.item()

In [None]:
def train_model(model, train_loader, val_loader, config):
    """Main training function with CSV logging"""
    
    # Move model to device
    model = model.to(device)
    
    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
    
    # Learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5)
    
    # Initialize CSV Logger
    csv_logger = CSVLogger(config.csv_log_path, append=False)
    
    # Training history
    train_losses = []
    train_accs = []
    val_losses = []
    val_accs = []
    
    # Early stopping
    best_val_loss = float('inf')
    best_model_wts = copy.deepcopy(model.state_dict())
    patience_counter = 0
    
    print("Starting training...")
    print(f"Metrics will be logged to: {config.csv_log_path}")
    print("-" * 50)
    
    try:
        for epoch in range(config.num_epochs):
            epoch_start_time = time.time()
            
            # Train
            train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device, epoch, config)
            
            # Validate
            val_loss, val_acc = validate_epoch(model, val_loader, criterion, device)
            
            # Get current learning rate
            current_lr = optimizer.param_groups[0]['lr']
            
            # Update scheduler
            scheduler.step(val_loss)
            
            # Save history
            train_losses.append(train_loss)
            train_accs.append(train_acc)
            val_losses.append(val_loss)
            val_accs.append(val_acc)
            
            epoch_time = time.time() - epoch_start_time
            
            # Prepare metrics for CSV logging
            metrics = {
                'epoch': epoch + 1,
                'train_loss': train_loss,
                'train_accuracy': train_acc,
                'val_loss': val_loss,
                'val_accuracy': val_acc,
                'learning_rate': current_lr,
                'epoch_time': epoch_time,
                'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            }
            
            # Add additional metrics if enabled
            if config.log_detailed_metrics:
                metrics.update({
                    'best_val_loss': best_val_loss if val_loss >= best_val_loss else val_loss,
                    'patience_counter': patience_counter,
                    'is_best_epoch': val_loss < best_val_loss
                })
            
            # Log to CSV
            csv_logger.log(metrics)
            
            print(f'Epoch {epoch+1}/{config.num_epochs}:')
            print(f'  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
            print(f'  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')
            print(f'  Learning Rate: {current_lr:.2e}')
            print(f'  Time: {epoch_time:.2f}s')
            print('-' * 50)
            
            # Early stopping and best model saving
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model_wts = copy.deepcopy(model.state_dict())
                patience_counter = 0
                
                if config.save_best_model:
                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'best_val_loss': best_val_loss,
                        'train_losses': train_losses,
                        'val_losses': val_losses,
                        'train_accs': train_accs,
                        'val_accs': val_accs,
                        'class_to_idx': class_to_idx
                    }, config.model_save_path)
                    print(f'Best model saved to {config.model_save_path}')
            else:
                patience_counter += 1
                
            if patience_counter >= config.patience:
                print(f'Early stopping triggered after {epoch+1} epochs')
                break
                
    finally:
        # Always close the CSV logger
        csv_logger.close()
    
    # Load best model weights
    model.load_state_dict(best_model_wts)
    
    return model, {
        'train_losses': train_losses,
        'train_accs': train_accs,
        'val_losses': val_losses,
        'val_accs': val_accs
    }

In [None]:
import timm_3d
from torchinfo import summary

model = timm_3d.create_model(
    'densenet121',
    pretrained=False,
    num_classes=2
)

model.features.conv0 = nn.Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(2, 2, 2), padding=(3, 3, 3), bias=False)

# print(model)

summary(model, input_size=(4, 1, 224, 224, 224)) 


In [None]:
trained_model, history = train_model(model, train_loader, val_loader, config)