In [2]:
import logging
import os
import json
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import nibabel as nib

In [3]:
# --- Logging ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# --- Dataset ---
class HeartDataset(Dataset):
    def __init__(self, data_dir, cases, transform=None, patch_size=(96, 96, 96)):
        self.data_dir = data_dir
        self.cases = cases
        self.transform = transform
        self.patch_size = patch_size

    def __len__(self):
        return len(self.cases)

    def __getitem__(self, idx):
        case = self.cases[idx]
        try:
            img_path = os.path.join(self.data_dir, 'imagesTr', f"{case}.nii.gz")
            label_path = os.path.join(self.data_dir, 'labelsTr', f"{case}.nii.gz")
            
            img = nib.load(img_path).get_fdata().astype(np.float32)
            label = nib.load(label_path).get_fdata().astype(np.int64)
            
            img = np.clip(img, -1000, 1000)
            img = (img + 1000) / 2000.0
            img = np.expand_dims(img, axis=0)
            
            img = torch.from_numpy(img.copy())
            label = torch.from_numpy(label.copy())

            if self.patch_size:
                img, label = self.atrium_aware_crop(img, label, self.patch_size)
            
            return img, label

        except Exception as e:
            logger.error(f"Error loading case {case}: {e}")
            dummy_img = torch.zeros((1,) + self.patch_size)
            dummy_label = torch.zeros(self.patch_size, dtype=torch.long)
            return dummy_img, dummy_label

    def atrium_aware_crop(self, img, label, patch_size):
        d, h, w = img.shape[1:]
        pd, ph, pw = patch_size

        # Try to find a crop containing left atrium (label == 1)
        for _ in range(10):  # Increased attempts for better atrium sampling
            start_d = np.random.randint(0, max(1, d - pd + 1))
            start_h = np.random.randint(0, max(1, h - ph + 1))
            start_w = np.random.randint(0, max(1, w - pw + 1))

            label_crop = label[start_d:start_d+pd, start_h:start_h+ph, start_w:start_w+pw]
            atrium_ratio = (label_crop == 1).float().mean().item()
            
            # Accept if we have some atrium voxels (lowered threshold for heart segmentation)
            if atrium_ratio > 0.0005:
                break
        else:
            # If no atrium found after 10 attempts, use center crop
            start_d = max(0, (d - pd) // 2)
            start_h = max(0, (h - ph) // 2)
            start_w = max(0, (w - pw) // 2)

        img_crop = img[:, start_d:start_d+pd, start_h:start_h+ph, start_w:start_w+pw]
        label_crop = label[start_d:start_d+pd, start_h:start_h+ph, start_w:start_w+pw]
        return img_crop, label_crop


In [4]:
# --- UNet3D (unchanged) ---
class UNet3D(nn.Module):
    def __init__(self, in_channels=1, out_channels=2, features=[32, 64, 128, 256]):
        super(UNet3D, self).__init__()
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()
        self.pool = nn.MaxPool3d(kernel_size=2, stride=2)

        for feature in features:
            self.encoder.append(self.conv_block(in_channels, feature))
            in_channels = feature

        self.bottleneck = self.conv_block(features[-1], features[-1] * 2)

        for feature in reversed(features):
            self.decoder.append(nn.ConvTranspose3d(feature * 2, feature, kernel_size=2, stride=2))
            self.decoder.append(self.conv_block(feature * 2, feature))

        self.final_conv = nn.Conv3d(features[0], out_channels, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        skip_connections = []
        for down in self.encoder:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)
        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]
        for idx in range(0, len(self.decoder), 2):
            x = self.decoder[idx](x)
            skip = skip_connections[idx // 2]
            if x.shape != skip.shape:
                x = F.interpolate(x, size=skip.shape[2:], mode='trilinear', align_corners=False)
            x = self.decoder[idx + 1](torch.cat((skip, x), dim=1))
        return self.final_conv(x)

# --- Loss Functions ---
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth

    def forward(self, pred, target):
        pred = F.softmax(pred, dim=1)
        pred_flat = pred[:, 1].contiguous().view(-1)
        target_flat = (target == 1).float().view(-1)
        intersection = (pred_flat * target_flat).sum()
        dice = (2. * intersection + self.smooth) / (pred_flat.sum() + target_flat.sum() + self.smooth)
        return 1 - dice

class CombinedLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.ce = nn.CrossEntropyLoss()
        self.dice = DiceLoss()

    def forward(self, pred, target):
        return self.ce(pred, target) + self.dice(pred, target)

# --- Dice Metrics ---
def dice_coefficient(pred, target, num_classes=2):
    pred = F.softmax(pred, dim=1)
    dice_scores = []
    for cls in range(num_classes):
        pred_cls = (pred[:, cls] > 0.5).float()
        target_cls = (target == cls).float()
        intersection = (pred_cls * target_cls).sum()
        union = pred_cls.sum() + target_cls.sum()
        dice = (2. * intersection / union) if union > 0 else 1.0
        dice_scores.append(dice.item() if hasattr(dice, 'item') else float(dice))
    return dice_scores

In [5]:
# --- Trainer ---
def train_epoch(model, loader, optimizer, criterion, device, scaler):
    model.train()
    total_loss, all_dices = 0, []
    for data, target in tqdm(loader, desc="Training"):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        with torch.amp.autocast(device_type=device):
            output = model(data)
            loss = criterion(output, target)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()
        dice = dice_coefficient(output, target)
        all_dices.append(dice)
        left_atrium_ratio = (target == 1).sum().item() / target.numel()
        logger.info(f"Left atrium voxels: {left_atrium_ratio*100:.4f}%")
    return total_loss / len(loader), np.mean(all_dices, axis=0)

def validate_epoch(model, loader, criterion, device):
    model.eval()
    total_loss, all_dices = 0, []
    with torch.no_grad():
        for data, target in tqdm(loader, desc="Validation"):
            data, target = data.to(device), target.to(device)
            with torch.amp.autocast(device_type=device):
                output = model(data)
                loss = criterion(output, target)
            total_loss += loss.item()
            dice = dice_coefficient(output, target)
            all_dices.append(dice)
            left_atrium_ratio = (target == 1).sum().item() / target.numel()
            logger.info(f"[VAL] Left atrium voxels: {left_atrium_ratio*100:.4f}%")
    return total_loss / len(loader), np.mean(all_dices, axis=0)

In [6]:
def save_checkpoint(model, optimizer, scheduler, scaler, epoch, loss, dice, save_dir, is_best=False, is_epoch=False):
    """Save checkpoint with different naming conventions"""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'scaler_state_dict': scaler.state_dict(),
        'loss': loss,
        'dice': dice
    }
    
    if is_best:
        torch.save(checkpoint, os.path.join(save_dir, 'best_model.pth'))
        logger.info("Best model saved.")
    
    if is_epoch:
        torch.save(checkpoint, os.path.join(save_dir, f'model_epoch_{epoch}_{dice}.pth'))
        logger.info(f"Epoch {epoch} model saved.")
    
    # Always save as latest checkpoint for resuming
    torch.save(checkpoint, os.path.join(save_dir, 'latest_checkpoint.pth'))

def load_checkpoint(checkpoint_path, model, optimizer, scheduler, scaler):
    """Load checkpoint and return starting epoch"""
    if os.path.exists(checkpoint_path):
        logger.info(f"Loading checkpoint from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location='cpu',weights_only=False)
        
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        scaler.load_state_dict(checkpoint['scaler_state_dict'])
        
        start_epoch = checkpoint['epoch'] + 1
        best_dice = checkpoint.get('dice', [0, 0])[1]  # Get left atrium dice
        
        logger.info(f"Resumed from epoch {checkpoint['epoch']}, best dice: {best_dice:.4f}")
        return start_epoch, best_dice
    else:
        logger.info("No checkpoint found, starting from scratch")
        return 0, 0

In [6]:
# --- Main ---
def main():
    config = {
        'data_dir': r"C:\Users\dell\Desktop\HEART\Task02_Heart",
        'batch_size': 1,
        'learning_rate': 1e-4,
        'num_epochs': 10,
        'patch_size': (96, 96, 96),
        'val_split': 0.2,
        'save_dir': './checkpoints',
        'device': 'cuda' if torch.cuda.is_available() else 'cpu',
        'resume_from': './checkpoints/latest_checkpoint.pth',  # Auto-resume from latest
        'save_every_epoch': True  # Save model at every epoch
    }

    os.makedirs(config['save_dir'], exist_ok=True)

    with open(os.path.join(config['data_dir'], 'dataset.json'), 'r') as f:
        dataset_info = json.load(f)
    train_cases = [os.path.basename(entry['image']).replace('.nii.gz', '') for entry in dataset_info['training']]
    train_cases, val_cases = train_test_split(train_cases, test_size=config['val_split'], random_state=42)

    train_dataset = HeartDataset(config['data_dir'], train_cases, patch_size=config['patch_size'])
    val_dataset = HeartDataset(config['data_dir'], val_cases, patch_size=config['patch_size'])

    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False)

    model = UNet3D().to(config['device'])
    criterion = CombinedLoss()
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10)
    scaler = GradScaler()

    # Load checkpoint if resuming
    start_epoch, best_dice = load_checkpoint(config['resume_from'], model, optimizer, scheduler, scaler)

    logger.info(f"Starting training from epoch {start_epoch}")
    logger.info(f"Target: Left atrium segmentation")

    for epoch in range(start_epoch, config['num_epochs']):
        logger.info(f"Epoch {epoch+1}/{config['num_epochs']}")
        train_loss, train_dice = train_epoch(model, train_loader, optimizer, criterion, config['device'], scaler)
        val_loss, val_dice = validate_epoch(model, val_loader, criterion, config['device'])
        scheduler.step(val_loss)
        
        logger.info(f"Train Loss: {train_loss:.4f}, Dice (BG/Left Atrium): {train_dice}")
        logger.info(f"Val Loss: {val_loss:.4f}, Dice (BG/Left Atrium): {val_dice}")

        # Save every epoch if requested
        if config['save_every_epoch']:
            save_checkpoint(model, optimizer, scheduler, scaler, epoch, val_loss, val_dice, 
                          config['save_dir'], is_epoch=True)

        # Save best model based on left atrium dice (index 1)
        if val_dice[1] > best_dice:
            best_dice = val_dice[1]
            save_checkpoint(model, optimizer, scheduler, scaler, epoch, val_loss, val_dice, 
                          config['save_dir'], is_best=True)
        
        # Always save latest checkpoint for resuming
        save_checkpoint(model, optimizer, scheduler, scaler, epoch, val_loss, val_dice, 
                       config['save_dir'])

    logger.info(f"Training completed. Best left atrium dice: {best_dice:.4f}")

if __name__ == '__main__':
    import multiprocessing
    multiprocessing.set_start_method('spawn', force=True)
    main()

  scaler = GradScaler()
INFO:__main__:No checkpoint found, starting from scratch
INFO:__main__:Starting training from epoch 0
INFO:__main__:Target: Left atrium segmentation
INFO:__main__:Epoch 1/10
Training:   0%|          | 0/16 [00:00<?, ?it/s]INFO:__main__:Left atrium voxels: 1.7393%
Training:   6%|▋         | 1/16 [00:01<00:15,  1.04s/it]INFO:__main__:Left atrium voxels: 2.7523%
Training:  12%|█▎        | 2/16 [00:01<00:11,  1.22it/s]INFO:__main__:Left atrium voxels: 0.6887%
Training:  19%|█▉        | 3/16 [00:02<00:09,  1.43it/s]INFO:__main__:Left atrium voxels: 0.6947%
Training:  25%|██▌       | 4/16 [00:02<00:07,  1.50it/s]INFO:__main__:Left atrium voxels: 0.4105%
Training:  31%|███▏      | 5/16 [00:03<00:07,  1.55it/s]INFO:__main__:Left atrium voxels: 0.1212%
Training:  38%|███▊      | 6/16 [00:04<00:06,  1.63it/s]INFO:__main__:Left atrium voxels: 5.8179%
Training:  44%|████▍     | 7/16 [00:04<00:05,  1.73it/s]INFO:__main__:Left atrium voxels: 0.4183%
Training:  50%|█████     |

In [9]:
# --- Main ---
def main():
    config = {
        'data_dir': r"C:\Users\dell\Desktop\HEART\Task02_Heart",
        'batch_size': 1,
        'learning_rate': 1e-4,
        'num_epochs': 25,
        'patch_size': (96, 96, 96),
        'val_split': 0.2,
        'save_dir': './checkpoints',
        'device': 'cuda' if torch.cuda.is_available() else 'cpu',
        'resume_from': './checkpoints/latest_checkpoint.pth',  # Auto-resume from latest
        'save_every_epoch': True  # Save model at every epoch
    }

    os.makedirs(config['save_dir'], exist_ok=True)

    with open(os.path.join(config['data_dir'], 'dataset.json'), 'r') as f:
        dataset_info = json.load(f)
    train_cases = [os.path.basename(entry['image']).replace('.nii.gz', '') for entry in dataset_info['training']]
    train_cases, val_cases = train_test_split(train_cases, test_size=config['val_split'], random_state=42)

    train_dataset = HeartDataset(config['data_dir'], train_cases, patch_size=config['patch_size'])
    val_dataset = HeartDataset(config['data_dir'], val_cases, patch_size=config['patch_size'])

    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False)

    model = UNet3D().to(config['device'])
    criterion = CombinedLoss()
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10)
    scaler = GradScaler()

    # Load checkpoint if resuming
    start_epoch, best_dice = load_checkpoint(config['resume_from'], model, optimizer, scheduler, scaler)

    logger.info(f"Starting training from epoch {start_epoch}")
    logger.info(f"Target: Left atrium segmentation")

    for epoch in range(start_epoch, config['num_epochs']):
        logger.info(f"Epoch {epoch+1}/{config['num_epochs']}")
        train_loss, train_dice = train_epoch(model, train_loader, optimizer, criterion, config['device'], scaler)
        val_loss, val_dice = validate_epoch(model, val_loader, criterion, config['device'])
        scheduler.step(val_loss)
        
        logger.info(f"Train Loss: {train_loss:.4f}, Dice (BG/Left Atrium): {train_dice}")
        logger.info(f"Val Loss: {val_loss:.4f}, Dice (BG/Left Atrium): {val_dice}")

        # Save every epoch if requested
        if config['save_every_epoch']:
            save_checkpoint(model, optimizer, scheduler, scaler, epoch, val_loss, val_dice, 
                          config['save_dir'], is_epoch=True)

        # Save best model based on left atrium dice (index 1)
        if val_dice[1] > best_dice:
            best_dice = val_dice[1]
            save_checkpoint(model, optimizer, scheduler, scaler, epoch, val_loss, val_dice, 
                          config['save_dir'], is_best=True)
        
        # Always save latest checkpoint for resuming
        save_checkpoint(model, optimizer, scheduler, scaler, epoch, val_loss, val_dice, 
                       config['save_dir'])

    logger.info(f"Training completed. Best left atrium dice: {best_dice:.4f}")

if __name__ == '__main__':
    import multiprocessing
    multiprocessing.set_start_method('spawn', force=True)
    main()

  scaler = GradScaler()
INFO:__main__:Loading checkpoint from ./checkpoints/latest_checkpoint.pth
INFO:__main__:Resumed from epoch 9, best dice: 0.4128
INFO:__main__:Starting training from epoch 10
INFO:__main__:Target: Left atrium segmentation
INFO:__main__:Epoch 11/25
Training:   0%|          | 0/16 [00:00<?, ?it/s]INFO:__main__:Left atrium voxels: 3.6875%
Training:   6%|▋         | 1/16 [00:00<00:10,  1.49it/s]INFO:__main__:Left atrium voxels: 6.0720%
Training:  12%|█▎        | 2/16 [00:01<00:10,  1.39it/s]INFO:__main__:Left atrium voxels: 1.4808%
Training:  19%|█▉        | 3/16 [00:02<00:09,  1.40it/s]INFO:__main__:Left atrium voxels: 0.1512%
Training:  25%|██▌       | 4/16 [00:02<00:07,  1.51it/s]INFO:__main__:Left atrium voxels: 5.3312%
Training:  31%|███▏      | 5/16 [00:03<00:07,  1.49it/s]INFO:__main__:Left atrium voxels: 4.2653%
Training:  38%|███▊      | 6/16 [00:04<00:06,  1.50it/s]INFO:__main__:Left atrium voxels: 2.8549%
Training:  44%|████▍     | 7/16 [00:04<00:05,  1.55

In [7]:
# --- Main ---
def main():
    config = {
        'data_dir': r"C:\Users\dell\Desktop\HEART\Task02_Heart",
        'batch_size': 1,
        'learning_rate': 1e-4,
        'num_epochs': 35,
        'patch_size': (96, 96, 96),
        'val_split': 0.2,
        'save_dir': './checkpoints',
        'device': 'cuda' if torch.cuda.is_available() else 'cpu',
        'resume_from': './checkpoints/latest_checkpoint.pth',  # Auto-resume from latest
        'save_every_epoch': True  # Save model at every epoch
    }

    os.makedirs(config['save_dir'], exist_ok=True)

    with open(os.path.join(config['data_dir'], 'dataset.json'), 'r') as f:
        dataset_info = json.load(f)
    train_cases = [os.path.basename(entry['image']).replace('.nii.gz', '') for entry in dataset_info['training']]
    train_cases, val_cases = train_test_split(train_cases, test_size=config['val_split'], random_state=42)

    train_dataset = HeartDataset(config['data_dir'], train_cases, patch_size=config['patch_size'])
    val_dataset = HeartDataset(config['data_dir'], val_cases, patch_size=config['patch_size'])

    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False)

    model = UNet3D().to(config['device'])
    criterion = CombinedLoss()
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10)
    scaler = GradScaler()

    # Load checkpoint if resuming
    start_epoch, best_dice = load_checkpoint(config['resume_from'], model, optimizer, scheduler, scaler)

    logger.info(f"Starting training from epoch {start_epoch}")
    logger.info(f"Target: Left atrium segmentation")

    for epoch in range(start_epoch, config['num_epochs']):
        logger.info(f"Epoch {epoch+1}/{config['num_epochs']}")
        train_loss, train_dice = train_epoch(model, train_loader, optimizer, criterion, config['device'], scaler)
        val_loss, val_dice = validate_epoch(model, val_loader, criterion, config['device'])
        scheduler.step(val_loss)
        
        logger.info(f"Train Loss: {train_loss:.4f}, Dice (BG/Left Atrium): {train_dice}")
        logger.info(f"Val Loss: {val_loss:.4f}, Dice (BG/Left Atrium): {val_dice}")

        # Save every epoch if requested
        if config['save_every_epoch']:
            save_checkpoint(model, optimizer, scheduler, scaler, epoch, val_loss, val_dice, 
                          config['save_dir'], is_epoch=True)

        # Save best model based on left atrium dice (index 1)
        if val_dice[1] > best_dice:
            best_dice = val_dice[1]
            save_checkpoint(model, optimizer, scheduler, scaler, epoch, val_loss, val_dice, 
                          config['save_dir'], is_best=True)
        
        # Always save latest checkpoint for resuming
        save_checkpoint(model, optimizer, scheduler, scaler, epoch, val_loss, val_dice, 
                       config['save_dir'])

    logger.info(f"Training completed. Best left atrium dice: {best_dice:.4f}")

if __name__ == '__main__':
    import multiprocessing
    multiprocessing.set_start_method('spawn', force=True)
    main()

  scaler = GradScaler()
INFO:__main__:Loading checkpoint from ./checkpoints/latest_checkpoint.pth
INFO:__main__:Resumed from epoch 24, best dice: 0.0700
INFO:__main__:Starting training from epoch 25
INFO:__main__:Target: Left atrium segmentation
INFO:__main__:Epoch 26/35
Training:   0%|          | 0/16 [00:00<?, ?it/s]INFO:__main__:Left atrium voxels: 0.4136%
Training:   6%|▋         | 1/16 [00:01<00:16,  1.09s/it]INFO:__main__:Left atrium voxels: 0.0982%
Training:  12%|█▎        | 2/16 [00:01<00:11,  1.22it/s]INFO:__main__:Left atrium voxels: 0.8104%
Training:  19%|█▉        | 3/16 [00:02<00:09,  1.32it/s]INFO:__main__:Left atrium voxels: 2.6825%
Training:  25%|██▌       | 4/16 [00:03<00:08,  1.35it/s]INFO:__main__:Left atrium voxels: 5.1713%
Training:  31%|███▏      | 5/16 [00:03<00:07,  1.45it/s]INFO:__main__:Left atrium voxels: 1.6198%
Training:  38%|███▊      | 6/16 [00:04<00:06,  1.49it/s]INFO:__main__:Left atrium voxels: 4.9302%
Training:  44%|████▍     | 7/16 [00:05<00:06,  1.4

In [8]:
# --- Main ---
def main():
    config = {
        'data_dir': r"C:\Users\dell\Desktop\HEART\Task02_Heart",
        'batch_size': 1,
        'learning_rate': 1e-4,
        'num_epochs': 50,
        'patch_size': (96, 96, 96),
        'val_split': 0.25,
        'save_dir': './checkpoints',
        'device': 'cuda' if torch.cuda.is_available() else 'cpu',
        'resume_from': './checkpoints/latest_checkpoint.pth',  # Auto-resume from latest
        'save_every_epoch': True  # Save model at every epoch
    }

    os.makedirs(config['save_dir'], exist_ok=True)

    with open(os.path.join(config['data_dir'], 'dataset.json'), 'r') as f:
        dataset_info = json.load(f)
    train_cases = [os.path.basename(entry['image']).replace('.nii.gz', '') for entry in dataset_info['training']]
    train_cases, val_cases = train_test_split(train_cases, test_size=config['val_split'], random_state=42)

    train_dataset = HeartDataset(config['data_dir'], train_cases, patch_size=config['patch_size'])
    val_dataset = HeartDataset(config['data_dir'], val_cases, patch_size=config['patch_size'])

    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False)

    model = UNet3D().to(config['device'])
    criterion = CombinedLoss()
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10)
    scaler = GradScaler()

    # Load checkpoint if resuming
    start_epoch, best_dice = load_checkpoint(config['resume_from'], model, optimizer, scheduler, scaler)

    logger.info(f"Starting training from epoch {start_epoch}")
    logger.info(f"Target: Left atrium segmentation")

    for epoch in range(start_epoch, config['num_epochs']):
        logger.info(f"Epoch {epoch+1}/{config['num_epochs']}")
        train_loss, train_dice = train_epoch(model, train_loader, optimizer, criterion, config['device'], scaler)
        val_loss, val_dice = validate_epoch(model, val_loader, criterion, config['device'])
        scheduler.step(val_loss)
        
        logger.info(f"Train Loss: {train_loss:.4f}, Dice (BG/Left Atrium): {train_dice}")
        logger.info(f"Val Loss: {val_loss:.4f}, Dice (BG/Left Atrium): {val_dice}")

        # Save every epoch if requested
        if config['save_every_epoch']:
            save_checkpoint(model, optimizer, scheduler, scaler, epoch, val_loss, val_dice, 
                          config['save_dir'], is_epoch=True)

        # Save best model based on left atrium dice (index 1)
        if val_dice[1] > best_dice:
            best_dice = val_dice[1]
            save_checkpoint(model, optimizer, scheduler, scaler, epoch, val_loss, val_dice, 
                          config['save_dir'], is_best=True)
        
        # Always save latest checkpoint for resuming
        save_checkpoint(model, optimizer, scheduler, scaler, epoch, val_loss, val_dice, 
                       config['save_dir'])

    logger.info(f"Training completed. Best left atrium dice: {best_dice:.4f}")

if __name__ == '__main__':
    import multiprocessing
    multiprocessing.set_start_method('spawn', force=True)
    main()

  scaler = GradScaler()
INFO:__main__:Loading checkpoint from ./checkpoints/latest_checkpoint.pth
INFO:__main__:Resumed from epoch 34, best dice: 0.7788
INFO:__main__:Starting training from epoch 35
INFO:__main__:Target: Left atrium segmentation
INFO:__main__:Epoch 36/50
Training:   0%|          | 0/15 [00:00<?, ?it/s]INFO:__main__:Left atrium voxels: 6.5795%
Training:   7%|▋         | 1/15 [00:01<00:15,  1.13s/it]INFO:__main__:Left atrium voxels: 0.5337%
Training:  13%|█▎        | 2/15 [00:01<00:11,  1.15it/s]INFO:__main__:Left atrium voxels: 2.1766%
Training:  20%|██        | 3/15 [00:02<00:09,  1.27it/s]INFO:__main__:Left atrium voxels: 4.2661%
Training:  27%|██▋       | 4/15 [00:03<00:08,  1.34it/s]INFO:__main__:Left atrium voxels: 0.9393%
Training:  33%|███▎      | 5/15 [00:03<00:06,  1.45it/s]INFO:__main__:Left atrium voxels: 1.8065%
Training:  40%|████      | 6/15 [00:04<00:06,  1.47it/s]INFO:__main__:Left atrium voxels: 0.4492%
Training:  47%|████▋     | 7/15 [00:04<00:05,  1.5

In [9]:
# --- Main ---
def main():
    config = {
        'data_dir': r"C:\Users\dell\Desktop\HEART\Task02_Heart",
        'batch_size': 1,
        'learning_rate': 1e-4,
        'num_epochs': 65,
        'patch_size': (96, 96, 96),
        'val_split': 0.3,
        'save_dir': './checkpoints',
        'device': 'cuda' if torch.cuda.is_available() else 'cpu',
        'resume_from': './checkpoints/best_model.pth',  # Auto-resume from latest
        'save_every_epoch': True  # Save model at every epoch
    }

    os.makedirs(config['save_dir'], exist_ok=True)

    with open(os.path.join(config['data_dir'], 'dataset.json'), 'r') as f:
        dataset_info = json.load(f)
    train_cases = [os.path.basename(entry['image']).replace('.nii.gz', '') for entry in dataset_info['training']]
    train_cases, val_cases = train_test_split(train_cases, test_size=config['val_split'], random_state=42)

    train_dataset = HeartDataset(config['data_dir'], train_cases, patch_size=config['patch_size'])
    val_dataset = HeartDataset(config['data_dir'], val_cases, patch_size=config['patch_size'])

    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False)

    model = UNet3D().to(config['device'])
    criterion = CombinedLoss()
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10)
    scaler = GradScaler()

    # Load checkpoint if resuming
    start_epoch, best_dice = load_checkpoint(config['resume_from'], model, optimizer, scheduler, scaler)

    logger.info(f"Starting training from epoch {start_epoch}")
    logger.info(f"Target: Left atrium segmentation")

    for epoch in range(start_epoch, config['num_epochs']):
        logger.info(f"Epoch {epoch+1}/{config['num_epochs']}")
        train_loss, train_dice = train_epoch(model, train_loader, optimizer, criterion, config['device'], scaler)
        val_loss, val_dice = validate_epoch(model, val_loader, criterion, config['device'])
        scheduler.step(val_loss)
        
        logger.info(f"Train Loss: {train_loss:.4f}, Dice (BG/Left Atrium): {train_dice}")
        logger.info(f"Val Loss: {val_loss:.4f}, Dice (BG/Left Atrium): {val_dice}")

        # Save every epoch if requested
        if config['save_every_epoch']:
            save_checkpoint(model, optimizer, scheduler, scaler, epoch, val_loss, val_dice, 
                          config['save_dir'], is_epoch=True)

        # Save best model based on left atrium dice (index 1)
        if val_dice[1] > best_dice:
            best_dice = val_dice[1]
            save_checkpoint(model, optimizer, scheduler, scaler, epoch, val_loss, val_dice, 
                          config['save_dir'], is_best=True)
        
        # Always save latest checkpoint for resuming
        save_checkpoint(model, optimizer, scheduler, scaler, epoch, val_loss, val_dice, 
                       config['save_dir'])

    logger.info(f"Training completed. Best left atrium dice: {best_dice:.4f}")

if __name__ == '__main__':
    import multiprocessing
    multiprocessing.set_start_method('spawn', force=True)
    main()

  scaler = GradScaler()
INFO:__main__:Loading checkpoint from ./checkpoints/best_model.pth
INFO:__main__:Resumed from epoch 42, best dice: 0.8587
INFO:__main__:Starting training from epoch 43
INFO:__main__:Target: Left atrium segmentation
INFO:__main__:Epoch 44/65
Training:   0%|          | 0/14 [00:00<?, ?it/s]INFO:__main__:Left atrium voxels: 0.1324%
Training:   7%|▋         | 1/14 [00:01<00:13,  1.07s/it]INFO:__main__:Left atrium voxels: 0.0904%
Training:  14%|█▍        | 2/14 [00:01<00:09,  1.24it/s]INFO:__main__:Left atrium voxels: 4.2683%
Training:  21%|██▏       | 3/14 [00:02<00:08,  1.31it/s]INFO:__main__:Left atrium voxels: 0.5290%
Training:  29%|██▊       | 4/14 [00:03<00:07,  1.40it/s]INFO:__main__:Left atrium voxels: 4.3338%
Training:  36%|███▌      | 5/14 [00:03<00:06,  1.48it/s]INFO:__main__:Left atrium voxels: 5.3249%
Training:  43%|████▎     | 6/14 [00:04<00:05,  1.58it/s]INFO:__main__:Left atrium voxels: 2.8746%
Training:  50%|█████     | 7/14 [00:04<00:04,  1.61it/s]I