# Debug notebook
This notebook is meant to be used to look into the classes and test if they are working as expected

In [1]:
from pytorch_lightning import Trainer
from data_loading import MRIDataModule
from pathlib import Path
from model import *

# Path to dataset
data_dir = Path("data/WMH")

# Initialize the DataModule
batch_size = 4
data_module = MRIDataModule(dataset="wmh", batch_size=batch_size)

# Instantiate the LightningModule
model = LitUNet(
    n_dims=3,
    input_keys=["flair", "t1"],
    label_key="WMH",
    in_channels=2,
    out_channels=1,
    base_channels=8,
    depth=3,
    use_transpose=False,
    use_normalization=True,
    learning_rate=1e-3,
)

# Train with PyTorch Lightning
trainer = Trainer(max_epochs=1, accelerator="auto")

  from .autonotebook import tqdm as notebook_tqdm
Attribute 'final_activation' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['final_activation'])`.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [2]:
data_module.setup()
for batch in data_module.train_dataloader():
    print(batch['flair'].shape)  # Should output: [batch_size, 1, 128, 128, 128]
    # vis_slice(batch, 0)

Loading dataset: 100%|██████████| 48/48 [00:15<00:00,  3.05it/s]
Loading dataset: 100%|██████████| 6/6 [00:01<00:00,  3.47it/s]
Loading dataset: 100%|██████████| 7/7 [00:02<00:00,  3.19it/s]


TypeError: list indices must be integers or slices, not str

# Building function that creates a 2d dataset

In [None]:
def vis_slice(batch, batch_number, key):
    image = batch[key]
    
    # Select a slice to visualize (e.g., the middle slice)
    slice_index = image.shape[2] // 2

    # Create subplots
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # FLAIR Image
    axes[0].imshow(image[batch_number, 0, slice_index, :, :], cmap='gray')
    axes[0].set_title('first dim')
    axes[0].axis('off')

        # FLAIR Image
    axes[1].imshow(image[batch_number, 0, :, slice_index, :], cmap='gray')
    axes[1].set_title('second dim')
    axes[1].axis('off')

        # FLAIR Image
    axes[2].imshow(image[batch_number, 0, :, :, slice_index], cmap='gray')
    axes[2].set_title('third dim')
    axes[2].axis('off')



    plt.show()

In [None]:
for batch_number in range(batch_size):
    vis_slice(batch, batch_number, "flair")

In [None]:
vis_slice(batch, 0, "seg")