In [1]:
pip install torch torchvision torchaudio matplotlib pillow nibabel numpy torchsummary

Note: you may need to restart the kernel to use updated packages.


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms

from torch.utils.data import Dataset, DataLoader
import torchsummary
import nibabel as nib
import os
import numpy as np
import matplotlib.pyplot as plt

print(torch.__version__)

2.6.0+cu118


In [3]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    print('CUDA is available. Using GPU.')
else:
    device = torch.device('cpu')
    print('CUDA is not available. Using CPU.')
total_memory = torch.cuda.get_device_properties(device).total_memory
print(f'Total GPU memory: {total_memory / (1024**3):.2f} GB')

CUDA is available. Using GPU.
Total GPU memory: 24.00 GB


In [4]:
scan_path = 'E:/FLARE23/FLARE23_1501-2000'
label_path = 'E:/FLARE23/MICCAI-FLARE23/labelsTr2200'

import os
import numpy as np
import nibabel as nib
from torch.utils.data import Dataset
from torchvision import transforms

class AbdominalDataset(Dataset):
    def __init__(self, scan_dir, label_dir, transforms, multiple, sample_size):
        self.scan_dir = scan_dir
        self.label_dir = label_dir
        self.transforms = transforms
        self.multiple = multiple
        self.sample_size = sample_size
        self.prepare_files()

    def prepare_files(self):
        self.scan_idx = []
        self.label_idx = []

        # Process scan files
        for scan_name in os.listdir(self.scan_dir):
            try:
                # Extract the integer part of the filename
                scan_id = int(scan_name.split('_')[1])
                self.scan_idx.append(scan_id)
            except ValueError:
                # Skip files that don't have an integer as the first part of the filename
                print(f"Skipping invalid scan file: {scan_name}")
                continue

        # Process label files
        for label_name in os.listdir(self.label_dir):
            try:
                # Extract the integer part of the filename
                label_id = int(label_name.split('_')[0])
                self.label_idx.append(label_id)
            except ValueError:
                # Skip files that don't have an integer as the first part of the filename
                print(f"Skipping invalid label file: {label_name}")
                continue

        # Ensure that the scan and label indices are sorted
        self.scan_idx.sort()
        self.label_idx.sort()

    def __len__(self):
        return len(self.scan_idx)

    def addZeroPadding(self, scan, label):
        target_depth = (scan.shape[2] // self.multiple + 1) * self.multiple
        pad_depth = target_depth - scan.shape[2]
        pad_front = pad_depth // 2
        pad_back = pad_depth - pad_front

        scan_padded = np.pad(scan, ((0, 0), (0, 0), (pad_front, pad_back)), 'constant', constant_values=(0, 0))
        label_padded = np.pad(label, ((0, 0), (0, 0), (pad_front, pad_back)), 'constant', constant_values=(0, 0))

        return scan_padded, label_padded

    def randSpatialSample(self, scan, label):
        start_x = np.random.randint(0, scan.shape[0] - self.sample_size[0] + 1)
        start_y = np.random.randint(0, scan.shape[1] - self.sample_size[1] + 1)
        start_z = np.random.randint(0, scan.shape[2] - self.sample_size[2] + 1)

        scan_sampled = scan[start_x:start_x + self.sample_size[0], start_y:start_y + self.sample_size[1], start_z:start_z + self.sample_size[2]]
        label_sampled = label[start_x:start_x + self.sample_size[0], start_y:start_y + self.sample_size[1], start_z:start_z + self.sample_size[2]]

        return scan_sampled, label_sampled

    def __getitem__(self, idx):
        scan_loc = os.path.join(self.scan_dir, str(self.scan_idx[idx]) + '.nii.gz')
        label_loc = os.path.join(self.label_dir, str(self.label_idx[idx]) + '.nii.gz')

        scan_nii, label_nii = nib.load(scan_loc), nib.load(label_loc)
        scan, label = scan_nii.get_fdata(), label_nii.get_fdata()

        if scan.shape[2] % self.multiple != 0:
            scan, label = self.addZeroPadding(scan, label)

        rs_scan, rs_label = self.randSpatialSample(scan, label)

        if self.transforms:
            scan, label = self.transforms(rs_scan), self.transforms(rs_label)

        return scan, label

# Define the transforms
abdominal_transforms = transforms.Compose([
    transforms.ToTensor()
])

# Create the dataset
abdominal_dataset = AbdominalDataset(scan_dir=scan_path, label_dir=label_path, transforms=abdominal_transforms, multiple=16, sample_size=(512, 512, 16))

print(f"There are {len(abdominal_dataset)} scans and annotations in the dataset.")

Skipping invalid label file: 20230507-fix
Skipping invalid label file: labelsTr2200
Skipping invalid label file: labelsTr2200.zip
There are 500 scans and annotations in the dataset.


In [5]:
class MaskedAutoencoder3D(nn.Module):
    def __init__(self, in_channels, out_channels, hidden_dim):
        super(MaskedAutoencoder3D, self).__init__()

        # Encoder part (Conv3D Layers with BatchNorm3D)
        self.encoder = nn.Sequential(
            nn.Conv3d(in_channels, hidden_dim, kernel_size=4, stride=2, padding=1),  # Output: (hidden_dim, 8, 256, 256)
            nn.BatchNorm3d(hidden_dim),
            nn.ReLU(),

            nn.Conv3d(hidden_dim, hidden_dim * 2, kernel_size=4, stride=2, padding=1),  # Output: (hidden_dim*2, 4, 128, 128)
            nn.BatchNorm3d(hidden_dim * 2),
            nn.ReLU(),

            nn.Conv3d(hidden_dim * 2, hidden_dim * 4, kernel_size=4, stride=2, padding=1),  # Output: (hidden_dim*4, 2, 64, 64)
            nn.BatchNorm3d(hidden_dim * 4),
            nn.ReLU(),

            # Removed some more convolutions to keep dimensions higher
            nn.Conv3d(hidden_dim * 4, hidden_dim * 8, kernel_size=4, stride=2, padding=1),  # Output: (hidden_dim*8, 1, 32, 32)
            nn.BatchNorm3d(hidden_dim * 8),
            nn.ReLU()
        )

        # Decoder part (ConvTranspose3D Layers with BatchNorm3D)
        self.decoder = nn.Sequential(
            nn.ConvTranspose3d(hidden_dim * 8, hidden_dim * 4, kernel_size=4, stride=2, padding=1),  # Output: (hidden_dim*4, 2, 64, 64)
            nn.BatchNorm3d(hidden_dim * 4),
            nn.ReLU(),

            nn.ConvTranspose3d(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=2, padding=1),  # Output: (hidden_dim*2, 4, 128, 128)
            nn.BatchNorm3d(hidden_dim * 2),
            nn.ReLU(),

            nn.ConvTranspose3d(hidden_dim * 2, hidden_dim, kernel_size=4, stride=2, padding=1),  # Output: (hidden_dim, 8, 256, 256)
            nn.BatchNorm3d(hidden_dim),
            nn.ReLU(),

            nn.ConvTranspose3d(hidden_dim, out_channels, kernel_size=4, stride=2, padding=1),  # Output: (out_channels, 16, 512, 512)
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# creating the model
in_channels = 1
out_channels = 1
hidden_dim = 32

model = MaskedAutoencoder3D(in_channels, out_channels, hidden_dim)

torchsummary.summary(model.cuda(), (in_channels, 16, 512, 512))  # input shape: (in_channels, depth, height, width)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1      [-1, 32, 8, 256, 256]           2,080
       BatchNorm3d-2      [-1, 32, 8, 256, 256]              64
              ReLU-3      [-1, 32, 8, 256, 256]               0
            Conv3d-4      [-1, 64, 4, 128, 128]         131,136
       BatchNorm3d-5      [-1, 64, 4, 128, 128]             128
              ReLU-6      [-1, 64, 4, 128, 128]               0
            Conv3d-7       [-1, 128, 2, 64, 64]         524,416
       BatchNorm3d-8       [-1, 128, 2, 64, 64]             256
              ReLU-9       [-1, 128, 2, 64, 64]               0
           Conv3d-10       [-1, 256, 1, 32, 32]       2,097,408
      BatchNorm3d-11       [-1, 256, 1, 32, 32]             512
             ReLU-12       [-1, 256, 1, 32, 32]               0
  ConvTranspose3d-13       [-1, 128, 2, 64, 64]       2,097,280
      BatchNorm3d-14       [-1, 128, 2,

In [None]:
batch_size = 64

# Create data loaders.
train_dataloader = DataLoader(abdominal_dataset, batch_size=batch_size)
#test_dataloader = DataLoader(abdominal_dataset, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

IndexError: list index out of range

In [10]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [None]:
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [12]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

In [13]:
epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------


IndexError: list index out of range