In [1]:
import torch
from torchvision.transforms import v2
from torch.utils.data import DataLoader

import project.utils
%reload project
from project.data.luna_dataset import Luna16Dataset
from project.models.vnet import VNet

In [2]:
model = VNet(num_classes=1)

In [3]:
print(torch.version.cuda)

12.4


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [5]:
model.to(device)

VNet(
  (input_block): VNet_input_block(
    (conv1): Conv3d(1, 16, kernel_size=(5, 5, 5), stride=(1, 1, 1), padding=(2, 2, 2))
    (bn1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu1): PReLU(num_parameters=16)
  )
  (down_block1): VNet_down_block(
    (down_conv): Conv3d(16, 32, kernel_size=(2, 2, 2), stride=(2, 2, 2))
    (bn1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu1): PReLU(num_parameters=32)
    (convs): Sequential(
      (0): Conv_in_stage(
        (conv1): Conv3d(32, 32, kernel_size=(5, 5, 5), stride=(1, 1, 1), padding=(2, 2, 2))
        (bn1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): PReLU(num_parameters=1)
      )
      (1): Conv_in_stage(
        (conv1): Conv3d(32, 32, kernel_size=(5, 5, 5), stride=(1, 1, 1), padding=(2, 2, 2))
        (bn1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   

In [6]:
from torch import nn
from torchvision.transforms import v2
from torchvision import tv_tensors
from project.config import PROJECT_ROOT

class PadDepthTransform(nn.Module):
    """Pad the depth dimension of the input tensor to be divisible by 16."""
    def forward(self, img, mask) -> tuple[torch.Tensor, torch.Tensor]:
        # Check if the number of depth dimensions is odd
        if img.shape[1] % 16 != 0:
            # Create a zero-filled padding with the same height and width
            padding = torch.zeros(1, 16 - img.shape[1] % 16, *img.shape[2:], device=img.device, dtype=img.dtype)
            # Concatenate the padding to the tensor
            img = torch.cat([img, padding], dim=1)
            mask = torch.cat([mask, padding], dim=1)
        return tv_tensors.Image(img), tv_tensors.Mask(mask)
    
class Normalize3D(nn.Module):
    def __init__(self, mean, std):
        super().__init__()
        self.mean = torch.tensor(mean)
        self.std = torch.tensor(std)
        
    def forward(self, img, mask) -> tuple[torch.Tensor, torch.Tensor]:
        img = (img - self.mean) / self.std
        return img, mask
    
class Resize3d(nn.Module):
    def __init__(self, size: tuple[int, int, int]):
        super().__init__()
        self.size = size
        
    def forward(self, img: torch.Tensor, mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        img = nn.functional.interpolate(img.unsqueeze(0).float(), size=self.size, mode="trilinear", align_corners=False).squeeze(0).long()
        mask = nn.functional.interpolate(mask.unsqueeze(0).float(), size=self.size, mode="nearest").squeeze(0).long()
        return img, mask

transforms = v2.Compose([
    Resize3d(size=(224, 224, 224)),
    PadDepthTransform(),
    v2.ToDtype({tv_tensors.Image: torch.float32, tv_tensors.Mask: torch.int64, "others": None}, scale=True),
    Normalize3D(mean=-0.023, std=0.026),
])

luna16 = Luna16Dataset(root=PROJECT_ROOT / "data/luna16", transforms=transforms, train=True)
luna16_base = Luna16Dataset(root=PROJECT_ROOT / "data/luna16", transforms=None, train=True)

In [7]:
resize = Resize3d((256, 256, 256))
resize(*luna16_base[0])[0].shape

torch.Size([1, 256, 256, 256])

In [8]:
transforms(*luna16_base[0])

(tensor([[[[0.8846, 0.8846, 0.8846,  ..., 0.8846, 0.8846, 0.8846],
           [0.8846, 0.8846, 0.8846,  ..., 0.8846, 0.8846, 0.8846],
           [0.8846, 0.8846, 0.8846,  ..., 0.8846, 0.8846, 0.8846],
           ...,
           [0.8846, 0.8846, 0.8846,  ..., 0.8846, 0.8846, 0.8846],
           [0.8846, 0.8846, 0.8846,  ..., 0.8846, 0.8846, 0.8846],
           [0.8846, 0.8846, 0.8846,  ..., 0.8846, 0.8846, 0.8846]],
 
          [[0.8846, 0.8846, 0.8846,  ..., 0.8846, 0.8846, 0.8846],
           [0.8846, 0.8846, 0.8846,  ..., 0.8846, 0.8846, 0.8846],
           [0.8846, 0.8846, 0.8846,  ..., 0.8846, 0.8846, 0.8846],
           ...,
           [0.8846, 0.8846, 0.8846,  ..., 0.8846, 0.8846, 0.8846],
           [0.8846, 0.8846, 0.8846,  ..., 0.8846, 0.8846, 0.8846],
           [0.8846, 0.8846, 0.8846,  ..., 0.8846, 0.8846, 0.8846]],
 
          [[0.8846, 0.8846, 0.8846,  ..., 0.8846, 0.8846, 0.8846],
           [0.8846, 0.8846, 0.8846,  ..., 0.8846, 0.8846, 0.8846],
           [0.8846, 0.88

In [9]:
len(luna16.masks)

712

In [10]:
luna16[0][0].shape, luna16[0][1].shape

(torch.Size([1, 224, 224, 224]), torch.Size([1, 224, 224, 224]))

In [11]:
luna_train_loader = DataLoader(luna16, batch_size=1, shuffle=False)  # Batch size has to be 1 because each image has different depth

In [13]:
for i, (data, target) in enumerate(luna_train_loader):
    # print(data.shape, data.dtype)
    # print(target.shape, target.dtype)
    print(target.unique())
    if i > 5:
        break


tensor([0])
tensor([  0, 255])
tensor([  0, 255])
tensor([  0, 255])
tensor([0])
tensor([  0, 255])
tensor([0])


In [14]:
# means = []
# stds = []
# for i, (data, target) in enumerate(luna_train_loader):
#     mean = data.mean()
#     std = data.std()
#     means.append(mean)
#     stds.append(std)
    
# print(f"Mean: {torch.stack(means).mean()}")
# print(f"Std: {torch.stack(stds).mean()}")

# output
# 

In [15]:
from project.models.vnet import VNet
from project.models.unet3d import UNet3D
# import segmentation_models_3D as sm

# model = VNet(num_classes=1)
model = UNet3D(n_channels=1, n_classes=1)
# model = sm.FPN(
#     'densenet121',
#     classes=1,
#     activation='sigmoid'
# )

In [16]:
# Number of parameters
sum(p.numel() for p in model.parameters() if p.requires_grad)

12946785

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Train loop

import wandb
import torch.nn as nn
import torch.optim as optim

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

luna_test_loader = DataLoader(Luna16Dataset(root=PROJECT_ROOT / "data" / "luna16", transforms=transforms, train=False), batch_size=1, shuffle=False)

# run = None
run = wandb.init(
    project="luna16",
    group=None,
    job_type="train",
    config={
        "model": "UNet3D",
        "optimizer": optimizer.__class__.__name__,
        "scheduler": scheduler.__class__.__name__,
        "lr": 0.001,
        "batch_size": 1,
        "epochs": 10,
    }
)

model.train().to(device)
for epoch in range(10):
    for i, (data, target) in enumerate(luna_train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        target = target.float() # BCEWithLogitsLoss expects float targets
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        print(f"Epoch: {epoch}, Batch: {i}, Loss: {loss.item()}")
        if run:
            run.log({"epoch": epoch, "batch": i, "train/loss": loss.item()})
    scheduler.step()
    
    losses = []
    for i, (data, target) in enumerate(luna_test_loader):
        data, target = data.to(device), target.to(device)
        output = model(data)
        target = target.float()
        loss = criterion(output, target)
        losses.append(loss.item())
        
    test_loss = sum(losses) / len(losses)  # Average loss over each image
    print(f"Epoch: {epoch}, Test loss: {test_loss}")
    if run:
        run.log({"epoch": epoch, "test/loss": test_loss})

if run:  
    run.finish()


wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: papetoast (papetoast-org1). Use `wandb login --relogin` to force relogin


Epoch: 0, Batch: 0, Loss: 0.8855414986610413
Epoch: 0, Batch: 1, Loss: 0.798898458480835
Epoch: 0, Batch: 2, Loss: 0.7426631450653076
Epoch: 0, Batch: 3, Loss: 0.6973922848701477
Epoch: 0, Batch: 4, Loss: 0.674547553062439
Epoch: 0, Batch: 5, Loss: 0.6578269600868225
Epoch: 0, Batch: 6, Loss: 0.6408712267875671
Epoch: 0, Batch: 7, Loss: 0.6274880170822144
Epoch: 0, Batch: 8, Loss: 0.6131066083908081
Epoch: 0, Batch: 9, Loss: 0.5927512049674988
Epoch: 0, Batch: 10, Loss: 0.5830245614051819
Epoch: 0, Batch: 11, Loss: 0.5767033100128174
Epoch: 0, Batch: 12, Loss: 0.5765429735183716
Epoch: 0, Batch: 13, Loss: 0.5502112507820129
Epoch: 0, Batch: 14, Loss: 0.5404123067855835
Epoch: 0, Batch: 15, Loss: 0.5301836133003235
Epoch: 0, Batch: 16, Loss: 0.5238696336746216
Epoch: 0, Batch: 17, Loss: 0.5150948166847229
Epoch: 0, Batch: 18, Loss: 0.5092343091964722
Epoch: 0, Batch: 19, Loss: 0.5143136978149414
Epoch: 0, Batch: 20, Loss: 0.5264184474945068
Epoch: 0, Batch: 21, Loss: 0.49556928873062134