In [1]:
import torch
from torchvision.transforms import v2
from torch.utils.data import DataLoader

from project.data.luna_dataset import Luna16Dataset
from project.vnet import VNet

In [2]:
model = VNet(num_classes=1)


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
model.to(device)

VNet(
  (input_block): VNet_input_block(
    (conv1): Conv3d(1, 16, kernel_size=(5, 5, 5), stride=(1, 1, 1), padding=(2, 2, 2))
    (bn1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu1): PReLU(num_parameters=16)
  )
  (down_block1): VNet_down_block(
    (down_conv): Conv3d(16, 32, kernel_size=(2, 2, 2), stride=(2, 2, 2))
    (bn1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu1): PReLU(num_parameters=32)
    (convs): Sequential(
      (0): Conv_in_stage(
        (conv1): Conv3d(32, 32, kernel_size=(5, 5, 5), stride=(1, 1, 1), padding=(2, 2, 2))
        (bn1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): PReLU(num_parameters=1)
      )
      (1): Conv_in_stage(
        (conv1): Conv3d(32, 32, kernel_size=(5, 5, 5), stride=(1, 1, 1), padding=(2, 2, 2))
        (bn1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   

In [14]:
from torch import nn
from project.config import PROJECT_ROOT

class PadDepthTransform(nn.Module):
    def forward(self, img, label):
        # Check if the number of depth dimensions is odd
        if img.shape[1] % 2 != 0:
            # Create a zero-filled padding with the same height and width
            padding = torch.zeros(1, *img.shape[2:], device=img.device, dtype=img.dtype)
            # Concatenate the padding to the tensor
            tensor = torch.cat([img, padding], dim=1)
        return tensor

transforms = v2.Compose([
    v2.Resize((80, 80)),
    PadDepthTransform(),
    v2.ToDtype(torch.float32, scale=True),
])

luna16 = Luna16Dataset(root=PROJECT_ROOT / "data/luna16", transforms=transforms, train=True)
luna16_base = Luna16Dataset(root=PROJECT_ROOT / "data/luna16", transforms=None, train=True)

In [15]:
transforms(luna16_base[0])

AttributeError: 'tuple' object has no attribute 'shape'

In [7]:
len(luna16.masks)

712

In [8]:
luna16[0][0].shape, luna16[0][1].shape

(torch.Size([1, 121, 80, 80]), torch.Size([1, 121, 80, 80]))

In [9]:
luna_train_loader = DataLoader(luna16, batch_size=1, shuffle=True)  # Batch size has to be 1 because each image has different depth

In [10]:
for i, (data, target) in enumerate(luna_train_loader):
    print(data.shape, data.dtype)
    print(target.shape, target.dtype)
    break

torch.Size([1, 1, 280, 80, 80]) torch.float32
torch.Size([1, 1, 280, 80, 80]) torch.uint8


In [11]:
from project.unet3d import UNet3D

model = UNet3D(n_channels=1, n_classes=1, width_multiplier=0.5).to(device)

In [12]:
# Train loop

import torch.nn as nn
import torch.optim as optim

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001)

for epoch in range(10):
    model.train()
    for i, (data, target) in enumerate(luna_train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        print(f"Epoch: {epoch}, Batch: {i}, Loss: {loss.item()}")


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 273 but got size 272 for tensor number 1 in the list.

In [13]:
data.shape

torch.Size([1, 1, 273, 80, 80])