<a href="https://colab.research.google.com/github/Vasu050/Brain-Tumour-Segmentation/blob/main/Untitled.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import kagglehub
kagglehub.login()


VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

Kaggle credentials set.
Kaggle credentials successfully validated.


In [None]:

 !pip install -q nibabel
# If you were planning to use MONAI offline install:
# !pip install -q /kaggle/input/monai-offline/monai-1.1.0-202212191849-py3-none-any.whl

# --- Dataset Download (if using Kaggle Hub) ---
# Make sure these paths match where Kaggle puts the data, usually /kaggle/input/
 import kagglehub
 awsaf49_brats20_dataset_training_validation_path = kagglehub.dataset_download('awsaf49/brats20-dataset-training-validation')
 animelover72_monai_offline_path = kagglehub.dataset_download('animelover72/monai-offline') # Still unused in this script
 print('Data source import complete.')


Downloading from https://www.kaggle.com/api/v1/datasets/download/animelover72/monai-offline?dataset_version_number=1...


100%|██████████| 959M/959M [00:13<00:00, 73.6MB/s]

Extracting files...





Data source import complete.


In [None]:
import os
import gc
import random
import numpy as np
import nibabel as nib
from tqdm.notebook import tqdm # Use notebook tqdm for better display in Kaggle
from pathlib import Path
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

In [None]:

# --- Configuration ---
BASE_TRAIN_PATH = Path("/kaggle/input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData")
BASE_VAL_PATH = Path("/kaggle/input/brats20-dataset-training-validation/BraTS2020_ValidationData/MICCAI_BraTS2020_ValidationData") # Adjust if validation data is elsewhere

PROCESSED_TRAIN_DIR = Path("/kaggle/working/processed_train_data")
PROCESSED_VAL_DIR = Path("/kaggle/working/processed_val_data")
MODEL_SAVE_PATH = Path("/kaggle/working/unet3d_brats.pth")

MODALITIES = ['flair', 't1', 't1ce', 't2']
CROP_SIZE = (128, 128, 128) # Increased crop size for better context, adjust based on memory
# CROP_SIZE = (64, 64, 64) # Original size if 128x128x128 causes memory issues
BATCH_SIZE = 1 # Keep batch size 1 if full volumes/large crops are used
CHUNK_SIZE = 20 # Process N patients before clearing memory
TOTAL_EPOCHS = 1 # Number of full passes over the entire dataset
LEARNING_RATE = 1e-4
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Using device: {DEVICE}")
os.makedirs("/kaggle/working", exist_ok=True)
# Create output directories
PROCESSED_TRAIN_DIR.mkdir(exist_ok=True)
PROCESSED_VAL_DIR.mkdir(exist_ok=True)

Using device: cpu


In [None]:


# --- Helper Functions ---
def normalize(volume):
    """Normalize the volume"""
    mask = volume > 0 # Avoid normalizing background
    if np.sum(mask) > 0:
      mean = np.mean(volume[mask])
      std = np.std(volume[mask])
      if std > 1e-5: # Avoid division by zero or near-zero std
          volume = (volume - mean) / std
      # else: volume remains unchanged (or set to zero if preferred)
    # Clip extreme values after normalization if desired
    # volume = np.clip(volume, -5, 5)
    return volume

In [None]:
def preprocess_and_save(base_path, save_dir, modalities, is_train=True):
    """Loads NIfTI files, preprocesses, and saves as .npz"""
    print(f"Processing data from: {base_path}")
    patient_dirs = sorted([d for d in base_path.iterdir() if d.is_dir()])

    for patient_dir in tqdm(patient_dirs, desc=f"Preprocessing {'Train' if is_train else 'Validation'}"):
        patient_id = patient_dir.name
        try:
            volumes = []
            # Load modalities
            for mod in modalities:
                nii_path = patient_dir / f"{patient_id}_{mod}.nii"
                if not nii_path.exists():
                     # Try .nii.gz extension as well
                     nii_path = patient_dir / f"{patient_id}_{mod}.nii.gz"
                     if not nii_path.exists():
                       raise FileNotFoundError(f"Missing modality {mod} for {patient_id}")

                img_nii = nib.load(str(nii_path))
                img_data = img_nii.get_fdata().astype(np.float32)

                # !! Crucial: Transpose to (Depth, Height, Width) !!
                img_data = np.transpose(img_data, (2, 0, 1))
                img_data = normalize(img_data)
                volumes.append(img_data)

            input_stack = np.stack(volumes) # Shape: (C, D, H, W)

            # Load segmentation mask (handle missing seg for validation if needed)
            seg_path = patient_dir / f"{patient_id}_seg.nii"
            if not seg_path.exists():
                 seg_path = patient_dir / f"{patient_id}_seg.nii.gz"

            if seg_path.exists():
                seg_nii = nib.load(str(seg_path))
                seg_data = seg_nii.get_fdata().astype(np.uint8)
                seg_data = np.transpose(seg_data, (2, 0, 1)) # (D, H, W)
                # !! Convert to Binary Mask (Tumor=1, Background=0) !!
                label = (seg_data > 0).astype(np.uint8)
            elif is_train:
                 # Training data MUST have segmentation
                 raise FileNotFoundError(f"Missing segmentation for training case {patient_id}")
            else:
                 # For validation/test, if seg is missing, create dummy zeros
                 print(f"Segmentation not found for {patient_id}. Creating zero mask.")
                 label = np.zeros(input_stack.shape[1:], dtype=np.uint8) # (D, H, W)


            # Save to .npz file
            output_path = save_dir / f"{patient_id}.npz"
            np.savez_compressed(output_path, input=input_stack, label=label)

            # Free memory
            del volumes, input_stack, label, img_nii, seg_nii, img_data, seg_data
            gc.collect()

        except Exception as e:
            print(f"Skipping {patient_id} due to error: {e}")


In [None]:

# --- Run Preprocessing ---
preprocess_and_save(BASE_TRAIN_PATH, PROCESSED_TRAIN_DIR, MODALITIES, is_train=True)
preprocess_and_save(BASE_VAL_PATH, PROCESSED_VAL_DIR, MODALITIES, is_train=False)
print("Preprocessing finished.")


Processing data from: /kaggle/input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData


Preprocessing Train:   0%|          | 0/369 [00:00<?, ?it/s]

Skipping BraTS20_Training_355 due to error: Missing segmentation for training case BraTS20_Training_355
Processing data from: /kaggle/input/brats20-dataset-training-validation/BraTS2020_ValidationData/MICCAI_BraTS2020_ValidationData


Preprocessing Validation:   0%|          | 0/125 [00:00<?, ?it/s]

Segmentation not found for BraTS20_Validation_001. Creating zero mask.
Skipping BraTS20_Validation_001 due to error: cannot access local variable 'seg_nii' where it is not associated with a value
Segmentation not found for BraTS20_Validation_002. Creating zero mask.
Skipping BraTS20_Validation_002 due to error: cannot access local variable 'seg_nii' where it is not associated with a value
Segmentation not found for BraTS20_Validation_003. Creating zero mask.
Skipping BraTS20_Validation_003 due to error: cannot access local variable 'seg_nii' where it is not associated with a value
Segmentation not found for BraTS20_Validation_004. Creating zero mask.
Skipping BraTS20_Validation_004 due to error: cannot access local variable 'seg_nii' where it is not associated with a value
Segmentation not found for BraTS20_Validation_005. Creating zero mask.
Skipping BraTS20_Validation_005 due to error: cannot access local variable 'seg_nii' where it is not associated with a value
Segmentation not fou

In [None]:
# --- Dataset Class ---
class BrainTumorDataset3D(Dataset):
    def __init__(self, npz_files, crop_size, is_train=True):
        self.npz_files = npz_files
        self.crop_size = crop_size
        self.is_train = is_train
        self.cache = {} # Simple caching mechanism (optional)

    def __len__(self):
        return len(self.npz_files)

    def _crop_3d(self, img, label):
        # Assumes img shape: (C, D, H, W) and label shape: (D, H, W)
        _, D, H, W = img.shape
        cd, ch, cw = self.crop_size

        # Ensure crop size is not larger than image dimensions
        cd = min(cd, D)
        ch = min(ch, H)
        cw = min(cw, W)

        if self.is_train:
            # Random crop
            d = random.randint(0, D - cd) if D > cd else 0
            h = random.randint(0, H - ch) if H > ch else 0
            w = random.randint(0, W - cw) if W > cw else 0
        else:
            # Center crop
            d = (D - cd) // 2
            h = (H - ch) // 2
            w = (W - cw) // 2

        img_cropped = img[:, d:d+cd, h:h+ch, w:w+cw]
        label_cropped = label[d:d+cd, h:h+ch, w:w+cw]

        return img_cropped, label_cropped

    def __getitem__(self, idx):
        filepath = self.npz_files[idx]

        # Simple caching (load data only once per epoch if memory allows)
        if filepath in self.cache:
             data = self.cache[filepath]
        else:
            data = np.load(filepath)
            if len(self.cache) < 50: # Limit cache size
                 self.cache[filepath] = data

       # data = np.load(filepath) # Load data directly if cache is not used or full

        image, label = data['input'], data['label'] # shapes: (C,D,H,W), (D,H,W)

        # Ensure data types are correct before cropping/conversion
        image = image.astype(np.float32)
        label = label.astype(np.uint8) # Dice loss usually expects Long, but we need Float for BCE/Dice intermediate steps

        image_cropped, label_cropped = self._crop_3d(image, label)

        # Add channel dimension to label: (D,H,W) -> (1,D,H,W)
        label_cropped = np.expand_dims(label_cropped, axis=0)

        # Convert to tensors
        image_tensor = torch.tensor(image_cropped, dtype=torch.float32)
        # Label needs to be float for Dice/BCE loss calculation, but ensure it's 0 or 1
        label_tensor = torch.tensor(label_cropped, dtype=torch.float32)

        return image_tensor, label_tensor

In [None]:
# --- Define 3D U-Net Model ---
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm3d(out_channels)
        self.relu2 = nn.ReLU()

    def forward(self, x):
        x = self.relu1(self.bn1(self.conv1(x)))
        x = self.relu2(self.bn2(self.conv2(x)))
        return x

class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = ConvBlock(in_channels, out_channels)
        self.pool = nn.MaxPool3d(kernel_size=2, stride=2)

    def forward(self, x):
        s = self.conv(x) # Save skip connection
        p = self.pool(s)
        return s, p

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, skip_channels, out_channels):
        super().__init__()
        # Use ConvTranspose3d for upsampling
        self.upconv = nn.ConvTranspose3d(in_channels, skip_channels, kernel_size=2, stride=2)
        self.conv = ConvBlock(skip_channels + skip_channels, out_channels) # Concatenated channels

    def forward(self, x, skip):
        x = self.upconv(x)
        # Ensure spatial dimensions match for concatenation if needed (rarely for stride 2)
        # Example: if x.shape != skip.shape: adjust padding or crop 'skip'
        x = torch.cat([x, skip], dim=1) # Concatenate along channel dimension
        x = self.conv(x)
        return x

class UNet3D(nn.Module):
    def __init__(self, in_channels=4, out_channels=1, features=[32, 64, 128, 256]):
        super().__init__()

        # Encoder path
        self.enc1 = EncoderBlock(in_channels, features[0])
        self.enc2 = EncoderBlock(features[0], features[1])
        self.enc3 = EncoderBlock(features[1], features[2])

        # Bottleneck
        self.bottleneck = ConvBlock(features[2], features[3])

        # Decoder path
        self.dec3 = DecoderBlock(features[3], features[2], features[1])
        self.dec2 = DecoderBlock(features[1], features[1], features[0])
        self.dec1 = DecoderBlock(features[0], features[0], features[0]) # Pass features[0] twice

        # Final convolution
        self.final_conv = nn.Conv3d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        s1, p1 = self.enc1(x)
        s2, p2 = self.enc2(p1)
        s3, p3 = self.enc3(p2)

        b = self.bottleneck(p3)

        d3 = self.dec3(b, s3)
        d2 = self.dec2(d3, s2)
        d1 = self.dec1(d2, s1)

        output = self.final_conv(d1)
        # No sigmoid here if using BCEWithLogitsLoss or calculating Dice from logits
        return output


In [None]:

# --- Dice Loss ---
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-5):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, logits, targets):
        probs = torch.sigmoid(logits) # Convert logits to probabilities

        # Flatten label and prediction tensors
        probs = probs.view(-1)
        targets = targets.view(-1)

        intersection = (probs * targets).sum()
        dice = (2. * intersection + self.smooth) / (probs.sum() + targets.sum() + self.smooth)

        return 1 - dice # Return loss (1 - Dice score)


In [None]:
# --- Training Setup ---
model = UNet3D(in_channels=4, out_channels=1).to(DEVICE)
# Try loading checkpoint if it exists
if MODEL_SAVE_PATH.exists():
    print(f"Loading model checkpoint from {MODEL_SAVE_PATH}")
    model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))

criterion = DiceLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
# Optional: Add a learning rate scheduler
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

all_train_npz_files = sorted(list(PROCESSED_TRAIN_DIR.glob("*.npz")))
print(f"Found {len(all_train_npz_files)} training files.")


Found 368 training files.


In [None]:
# --- Training Loop ---
for epoch in range(TOTAL_EPOCHS):
    print(f"\n--- Epoch {epoch+1}/{TOTAL_EPOCHS} ---")
    model.train()
    random.shuffle(all_train_npz_files) # Shuffle data files each epoch
    epoch_loss = 0.0
    processed_files = 0

    # Chunked loading and training
    for chunk_start in range(0, len(all_train_npz_files), CHUNK_SIZE):
        chunk_files = all_train_npz_files[chunk_start : chunk_start + CHUNK_SIZE]
        if not chunk_files: continue

        print(f"  Loading chunk: {chunk_start // CHUNK_SIZE + 1} / { (len(all_train_npz_files) + CHUNK_SIZE - 1) // CHUNK_SIZE }")
        dataset = BrainTumorDataset3D(chunk_files, CROP_SIZE, is_train=True)
        # num_workers > 0 can speed up loading but might cause issues on Kaggle notebooks
        train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

        chunk_loss = 0.0
        for i, (inputs, targets) in enumerate(tqdm(train_loader, desc=f"  Chunk {chunk_start//CHUNK_SIZE + 1} Train", leave=False)):
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

            optimizer.zero_grad()
            outputs = model(inputs) # Logits
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            chunk_loss += loss.item()
            processed_files += len(inputs) # Add batch size (usually 1)

        avg_chunk_loss = chunk_loss / len(train_loader) if len(train_loader) > 0 else 0
        epoch_loss += chunk_loss # Accumulate loss from all batches in the chunk
        print(f"  Chunk {chunk_start//CHUNK_SIZE + 1} Avg Loss: {avg_chunk_loss:.4f}")

        # Clear memory after processing a chunk
        del dataset, train_loader, inputs, targets, outputs, chunk_files
        gc.collect()
        if DEVICE == torch.device('cuda'):
            torch.cuda.empty_cache()

    avg_epoch_loss = epoch_loss / processed_files if processed_files > 0 else 0
    print(f"--- Epoch {epoch+1} Average Loss: {avg_epoch_loss:.4f} ---")

    # Optional: Learning rate scheduler step
    # scheduler.step()

    # Save model checkpoint after each epoch
    print(f"Saving model checkpoint to {MODEL_SAVE_PATH}")
    torch.save(model.state_dict(), MODEL_SAVE_PATH)


print("\n--- Training Finished ---")



--- Epoch 1/1 ---
  Loading chunk: 1 / 19


  Chunk 1 Train:   0%|          | 0/20 [00:00<?, ?it/s]

  Chunk 1 Avg Loss: 0.8966
  Loading chunk: 2 / 19


  Chunk 2 Train:   0%|          | 0/20 [00:00<?, ?it/s]

  Chunk 2 Avg Loss: 0.8844
  Loading chunk: 3 / 19


  Chunk 3 Train:   0%|          | 0/20 [00:00<?, ?it/s]

  Chunk 3 Avg Loss: 0.8477
  Loading chunk: 4 / 19


  Chunk 4 Train:   0%|          | 0/20 [00:00<?, ?it/s]

  Chunk 4 Avg Loss: 0.8549
  Loading chunk: 5 / 19


  Chunk 5 Train:   0%|          | 0/20 [00:00<?, ?it/s]

  Chunk 5 Avg Loss: 0.8347
  Loading chunk: 6 / 19


  Chunk 6 Train:   0%|          | 0/20 [00:00<?, ?it/s]

  Chunk 6 Avg Loss: 0.8907
  Loading chunk: 7 / 19


  Chunk 7 Train:   0%|          | 0/20 [00:00<?, ?it/s]

  Chunk 7 Avg Loss: 0.8900
  Loading chunk: 8 / 19


  Chunk 8 Train:   0%|          | 0/20 [00:00<?, ?it/s]

  Chunk 8 Avg Loss: 0.8728
  Loading chunk: 9 / 19


  Chunk 9 Train:   0%|          | 0/20 [00:00<?, ?it/s]

  Chunk 9 Avg Loss: 0.8873
  Loading chunk: 10 / 19


  Chunk 10 Train:   0%|          | 0/20 [00:00<?, ?it/s]

  Chunk 10 Avg Loss: 0.8521
  Loading chunk: 11 / 19


  Chunk 11 Train:   0%|          | 0/20 [00:00<?, ?it/s]

  Chunk 11 Avg Loss: 0.9047
  Loading chunk: 12 / 19


  Chunk 12 Train:   0%|          | 0/20 [00:00<?, ?it/s]

  Chunk 12 Avg Loss: 0.8867
  Loading chunk: 13 / 19


  Chunk 13 Train:   0%|          | 0/20 [00:00<?, ?it/s]

  Chunk 13 Avg Loss: 0.8521
  Loading chunk: 14 / 19


  Chunk 14 Train:   0%|          | 0/20 [00:00<?, ?it/s]

  Chunk 14 Avg Loss: 0.8906
  Loading chunk: 15 / 19


  Chunk 15 Train:   0%|          | 0/20 [00:00<?, ?it/s]

  Chunk 15 Avg Loss: 0.8770
  Loading chunk: 16 / 19


  Chunk 16 Train:   0%|          | 0/20 [00:00<?, ?it/s]

  Chunk 16 Avg Loss: 0.8344
  Loading chunk: 17 / 19


  Chunk 17 Train:   0%|          | 0/20 [00:00<?, ?it/s]

  Chunk 17 Avg Loss: 0.8398
  Loading chunk: 18 / 19


  Chunk 18 Train:   0%|          | 0/20 [00:00<?, ?it/s]

  Chunk 18 Avg Loss: 0.8084
  Loading chunk: 19 / 19


  Chunk 19 Train:   0%|          | 0/8 [00:00<?, ?it/s]

  Chunk 19 Avg Loss: 0.8710
--- Epoch 1 Average Loss: 0.8670 ---
Saving model checkpoint to /kaggle/working/unet3d_brats.pth

--- Training Finished ---


In [None]:




# --- Validation ---
print("\n--- Starting Validation ---")
all_val_npz_files = sorted(list(PROCESSED_VAL_DIR.glob("*.npz")))
if not all_val_npz_files:
    print("No validation files found. Skipping validation.")
else:
    print(f"Found {len(all_val_npz_files)} validation files.")
    val_dataset = BrainTumorDataset3D(all_val_npz_files, CROP_SIZE, is_train=False) # Use center crop
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

    model.eval() # Set model to evaluation mode
    total_val_loss = 0.0
    val_dice_scores = [] # Optional: store dice scores per image

    with torch.no_grad(): # Disable gradient calculations
        for i, (images, labels) in enumerate(tqdm(val_loader, desc="Validation")):
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images) # Get logits

            # Calculate validation loss (optional, but good practice)
            loss = criterion(outputs, labels)
            total_val_loss += loss.item()

            # Apply sigmoid to get probabilities for stats/metrics
            probs = torch.sigmoid(outputs)

            # Optional: Calculate Dice score for this validation sample
            # Flatten for Dice calculation
            probs_flat = probs.view(-1)
            labels_flat = labels.view(-1)
            intersection = (probs_flat * labels_flat).sum()
            dice_score = (2. * intersection + criterion.smooth) / (probs_flat.sum() + labels_flat.sum() + criterion.smooth)
            val_dice_scores.append(dice_score.item())


            print(f"\nValidation Batch {i+1}/{len(val_loader)}")
            print(f"  Image shape: {images.shape}")
            print(f"  Label shape: {labels.shape}")
            print(f"  Output Logits shape: {outputs.shape}")
            print(f"  Output Probs shape: {probs.shape}")
            # Stats on probabilities (range 0-1) are more intuitive
            print(f"  Output Probs stats - min: {probs.min().item():.4f}, max: {probs.max().item():.4f}, mean: {probs.mean().item():.4f}")
            print(f"  Batch Dice Loss: {loss.item():.4f}")
            print(f"  Batch Dice Score: {dice_score.item():.4f}")


            # Visualize a slice (optional, pick one sample if batch_size > 1)
            if i < 3: # Show for first 3 batches
                slice_idx = images.shape[2] // 2 # Middle slice (Depth dimension)
                plt.figure(figsize=(18, 5))
                plt.subplot(1, 4, 1)
                plt.imshow(images[0, 0, slice_idx, :, :].cpu().numpy(), cmap="bone") # Flair modality
                plt.title(f"Input Flair (Slice {slice_idx})")
                plt.axis("off")

                plt.subplot(1, 4, 2)
                plt.imshow(labels[0, 0, slice_idx, :, :].cpu().numpy(), cmap="gray")
                plt.title("Ground Truth Mask")
                plt.axis("off")

                plt.subplot(1, 4, 3)
                plt.imshow(probs[0, 0, slice_idx, :, :].cpu().numpy(), cmap="gray")
                plt.title("Predicted Mask (Prob)")
                plt.axis("off")

                plt.subplot(1, 4, 4)
                plt.imshow((probs[0, 0, slice_idx, :, :].cpu().numpy() > 0.5).astype(float), cmap="gray") # Thresholded mask
                plt.title("Predicted Mask (Thresholded > 0.5)")
                plt.axis("off")

                plt.suptitle(f"Validation Sample {i+1}")
                plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to prevent title overlap
                plt.show()


    avg_val_loss = total_val_loss / len(val_loader) if len(val_loader) > 0 else 0
    avg_val_dice = np.mean(val_dice_scores) if val_dice_scores else 0
    print(f"\n--- Validation Finished ---")
    print(f"Average Validation Loss: {avg_val_loss:.4f}")
    print(f"Average Validation Dice Score: {avg_val_dice:.4f}")

    # Clear memory
    del val_dataset, val_loader, images, labels, outputs, probs
    gc.collect()
    if DEVICE == torch.device('cuda'):
        torch.cuda.empty_cache()