In [2]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/brats20-dataset-training-validation/BraTS2020_ValidationData/MICCAI_BraTS2020_ValidationData/name_mapping_validation_data.csv
/kaggle/input/brats20-dataset-training-validation/BraTS2020_ValidationData/MICCAI_BraTS2020_ValidationData/survival_evaluation.csv
/kaggle/input/brats20-dataset-training-validation/BraTS2020_ValidationData/MICCAI_BraTS2020_ValidationData/BraTS20_Validation_084/BraTS20_Validation_084_flair.nii
/kaggle/input/brats20-dataset-training-validation/BraTS2020_ValidationData/MICCAI_BraTS2020_ValidationData/BraTS20_Validation_084/BraTS20_Validation_084_t2.nii
/kaggle/input/brats20-dataset-training-validation/BraTS2020_ValidationData/MICCAI_BraTS2020_ValidationData/BraTS20_Validation_084/BraTS20_Validation_084_t1ce.nii
/kaggle/input/brats20-dataset-training-validation/BraTS2020_ValidationData/MICCAI_BraTS2020_ValidationData/BraTS20_Validation_084/BraTS20_Validation_084_t1.nii
/kaggle/input/brats20-dataset-training-validation/BraTS2020_ValidationData/MICCAI_Br

In [3]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import nibabel as nib
from torchvision import transforms
import matplotlib.pyplot as plt
from tqdm import tqdm

In [4]:
class BraTSDataset(Dataset):
    def __init__(self, root_dir, transform=None, slice_range=(60, 100)):  # Added slice range
        """
        Args:
            root_dir (string): Directory with all the BraTS data
            transform (callable, optional): Optional transform to be applied
            slice_range (tuple): Range of slices to use (for memory efficiency)
        """
        self.root_dir = root_dir
        self.transform = transform
        self.slice_range = slice_range
        self.samples = []
        
        print(f"Initializing dataset from: {root_dir}")
        
        patient_dirs = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
        print(f"Found {len(patient_dirs)} patient directories")
        
        for patient_dir in tqdm(patient_dirs, desc="Loading patients"):
            patient_path = os.path.join(root_dir, patient_dir)
            files = os.listdir(patient_path)
            
            modality_files = {
                'flair': None,
                't1': None,
                't1ce': None,
                't2': None,
                'seg': None
            }
            
            for file in files:
                file_lower = file.lower()
                file_path = os.path.join(patient_path, file)
                
                if not file_lower.endswith('.nii'):
                    continue
                    
                if 'flair.' in file_lower:
                    modality_files['flair'] = file_path
                elif 't1ce.' in file_lower:
                    modality_files['t1ce'] = file_path
                elif 't1.' in file_lower and 't1ce.' not in file_lower:
                    modality_files['t1'] = file_path
                elif 't2.' in file_lower:
                    modality_files['t2'] = file_path
                elif 'seg.' in file_lower:
                    modality_files['seg'] = file_path
            
            if all(modality_files.values()):
                self.samples.append(modality_files)
        
        print(f"\nSuccessfully loaded {len(self.samples)} complete samples")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        sample = self.samples[idx]
        
        try:
            # Load all modalities
            data = {}
            for modality, filepath in sample.items():
                nib_img = nib.load(filepath)
                # Load only the selected slice range
                img_data = nib_img.get_fdata()[..., self.slice_range[0]:self.slice_range[1]]
                # Reduce memory usage by converting to float32
                data[modality] = img_data.astype(np.float32)
            
            # Stack all modalities
            image = np.stack([
                data['flair'],
                data['t1'],
                data['t1ce'],
                data['t2']
            ], axis=0)
            
            # Normalize the image data
            image = (image - image.mean()) / (image.std() + 1e-8)
            
            # Convert to torch tensor
            image = torch.from_numpy(image).float()
            mask = torch.from_numpy(data['seg']).long()
            
            if self.transform:
                image = self.transform(image)
                mask = self.transform(mask)
                
            return {'image': image, 'mask': mask}
            
        except Exception as e:
            print(f"Error loading sample {idx}: {str(e)}")
            raise

In [5]:
class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(EncoderBlock, self).__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm3d(out_channels)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool3d(kernel_size=2, stride=2)
        
    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        before_pool = x
        x = self.pool(x)
        return x, before_pool

In [6]:
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DecoderBlock, self).__init__()
        self.upconv = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm3d(out_channels)
        self.relu = nn.ReLU()
        
    def forward(self, x, skip_connection):
        x = self.upconv(x)
        x = torch.cat([x, skip_connection], dim=1)
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        return x

In [7]:
class VAEGAN3D(nn.Module):
    def __init__(self, in_channels=4, latent_dim=256):
        super(VAEGAN3D, self).__init__()
        
        # Encoder
        self.enc1 = EncoderBlock(in_channels, 64)
        self.enc2 = EncoderBlock(64, 128)
        self.enc3 = EncoderBlock(128, 256)
        
        # VAE components
        self.fc_mu = nn.Linear(256 * 8 * 8 * 8, latent_dim)
        self.fc_var = nn.Linear(256 * 8 * 8 * 8, latent_dim)
        
        # Decoder
        self.fc_decoder = nn.Linear(latent_dim, 256 * 8 * 8 * 8)
        self.dec3 = DecoderBlock(512, 128)
        self.dec2 = DecoderBlock(256, 64)
        self.dec1 = DecoderBlock(128, 32)
        
        self.final_conv = nn.Conv3d(32, in_channels, kernel_size=1)
        
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std
        
    def forward(self, x):
        # Encoding
        x1, skip1 = self.enc1(x)
        x2, skip2 = self.enc2(x1)
        x3, skip3 = self.enc3(x2)
        
        # Flatten
        batch_size = x.size(0)
        x_flat = x3.view(batch_size, -1)
        
        # VAE
        mu = self.fc_mu(x_flat)
        log_var = self.fc_var(x_flat)
        z = self.reparameterize(mu, log_var)
        
        # Decoding
        x = self.fc_decoder(z)
        x = x.view(batch_size, 256, 8, 8, 8)
        
        x = self.dec3(x, skip3)
        x = self.dec2(x, skip2)
        x = self.dec1(x, skip1)
        
        x = self.final_conv(x)
        
        return x, mu, log_var


In [8]:
def train_model(model, train_loader, optimizer, device, num_epochs=100):
    model.train()
    criterion = nn.MSELoss()
    kl_weight = 0.01  # Weight for KL divergence loss
    
    for epoch in range(num_epochs):
        running_loss = 0.0
        
        for batch in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
            images = batch['image'].to(device)
            
            # Forward pass
            reconstructed, mu, log_var = model(images)
            
            # Reconstruction loss
            recon_loss = criterion(reconstructed, images)
            
            # KL divergence
            kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
            
            # Total loss
            loss = recon_loss + kl_weight * kl_loss
            
            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
        epoch_loss = running_loss / len(train_loader)
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}')


In [9]:
class MemoryEfficientVAEGAN3D(nn.Module):
    def __init__(self, in_channels=4, latent_dim=128):
        super(MemoryEfficientVAEGAN3D, self).__init__()
        
        self.enc1 = EncoderBlock(in_channels, 32)
        self.enc2 = EncoderBlock(32, 64)
        self.enc3 = EncoderBlock(64, 128)
        
        # Calculate flattened dimension size
        self.flatten_size = 128 * 5 * 5 * 5  # Adjusted based on input size
        
        # Update linear layer dimensions
        self.fc_mu = nn.Linear(self.flatten_size, latent_dim)
        self.fc_var = nn.Linear(self.flatten_size, latent_dim)
        self.fc_decoder = nn.Linear(latent_dim, self.flatten_size)
        
        self.dec3 = DecoderBlock(256, 64)
        self.dec2 = DecoderBlock(128, 32)
        self.dec1 = DecoderBlock(64, 16)
        
        self.final_conv = nn.Conv3d(16, in_channels, kernel_size=1)

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        # Print shape at each step for debugging
        batch_size = x.size(0)
        
        x1, skip1 = self.enc1(x)
        x2, skip2 = self.enc2(x1)
        x3, skip3 = self.enc3(x2)
        
        # Flatten with correct dimensions
        x_flat = x3.view(batch_size, -1)
        
        # Print shape for verification
        print(f"Flattened shape: {x_flat.shape}")
        
        mu = self.fc_mu(x_flat)
        log_var = self.fc_var(x_flat)
        z = self.reparameterize(mu, log_var)
        
        # Reshape back to 3D
        x = self.fc_decoder(z)
        x = x.view(batch_size, 128, 5, 5, 5)  # Adjusted dimensions
        
        x = self.dec3(x, skip3)
        x = self.dec2(x, skip2)
        x = self.dec1(x, skip1)
        
        x = self.final_conv(x)
        
        return x, mu, log_var

In [10]:
def main():
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Enable memory efficient options
    torch.cuda.empty_cache()
    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True
    
    # Dataset and DataLoader with reduced batch size
    dataset_path = '/kaggle/input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData'
    
    try:
        # Create dataset with slice range
        dataset = BraTSDataset(
            root_dir=dataset_path,
            slice_range=(60, 100)  # Only use 40 slices instead of full volume
        )
        print(f"Dataset size: {len(dataset)}")
        
        # Reduced batch size and num_workers
        train_loader = DataLoader(
            dataset,
            batch_size=1,  # Reduced from 2
            shuffle=True,
            num_workers=2,  # Reduced from 4
            pin_memory=True if torch.cuda.is_available() else False
        )
        
        # Test the data loader
        print("\nTesting data loader with one batch...")
        sample_batch = next(iter(train_loader))
        print(f"Sample batch shapes - Image: {sample_batch['image'].shape}, Mask: {sample_batch['mask'].shape}")
        
        # Initialize memory-efficient model
        model = MemoryEfficientVAEGAN3D().to(device)
        
        # Use gradient checkpointing if available
        if hasattr(model, 'enc1'):
            for module in [model.enc1, model.enc2, model.enc3]:
                if hasattr(module, 'checkpoint'):
                    module.checkpoint = True
        
        optimizer = optim.Adam(model.parameters(), lr=0.0001)  # Reduced learning rate
        
        print("Starting training...")
        train_model(model, train_loader, optimizer, device)
        
    except Exception as e:
        print(f"Error during initialization: {str(e)}")
        raise

if __name__ == '__main__':
    main()

Using device: cuda
Initializing dataset from: /kaggle/input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData
Found 369 patient directories


Loading patients: 100%|██████████| 369/369 [00:00<00:00, 2091.87it/s]


Successfully loaded 368 complete samples
Dataset size: 368

Testing data loader with one batch...





Sample batch shapes - Image: torch.Size([1, 4, 240, 240, 40]), Mask: torch.Size([1, 240, 240, 40])
Starting training...


Epoch 1/100:   0%|          | 0/368 [00:09<?, ?it/s]

Flattened shape: torch.Size([1, 576000])
Error during initialization: mat1 and mat2 shapes cannot be multiplied (1x576000 and 16000x128)





RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x576000 and 16000x128)

In [36]:
import os
dataset_path = '/kaggle/input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData'
example_patient = 'BraTS20_Training_083'
print("Files in patient folder:", os.listdir(os.path.join(dataset_path, example_patient)))

Files in patient folder: ['BraTS20_Training_083_flair.nii', 'BraTS20_Training_083_t1.nii', 'BraTS20_Training_083_seg.nii', 'BraTS20_Training_083_t2.nii', 'BraTS20_Training_083_t1ce.nii']
