In [1]:
import torch
from torchvision.transforms import v2
from torch.utils.data import DataLoader

import project.utils
# reload project
from project.data.luna_dataset import Luna16Dataset
from project.models.vnet import VNet

In [2]:
model = VNet(num_classes=1)

In [3]:
print(torch.version.cuda)

12.4


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [5]:
model.to(device)

VNet(
  (input_block): VNet_input_block(
    (conv1): Conv3d(1, 16, kernel_size=(5, 5, 5), stride=(1, 1, 1), padding=(2, 2, 2))
    (bn1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu1): PReLU(num_parameters=16)
  )
  (down_block1): VNet_down_block(
    (down_conv): Conv3d(16, 32, kernel_size=(2, 2, 2), stride=(2, 2, 2))
    (bn1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu1): PReLU(num_parameters=32)
    (convs): Sequential(
      (0): Conv_in_stage(
        (conv1): Conv3d(32, 32, kernel_size=(5, 5, 5), stride=(1, 1, 1), padding=(2, 2, 2))
        (bn1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): PReLU(num_parameters=1)
      )
      (1): Conv_in_stage(
        (conv1): Conv3d(32, 32, kernel_size=(5, 5, 5), stride=(1, 1, 1), padding=(2, 2, 2))
        (bn1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   

In [6]:
from torch import nn
from torchvision.transforms import v2
from torchvision import tv_tensors
from project.config import PROJECT_ROOT

class PadDepthTransform(nn.Module):
    """Pad the depth dimension of the input tensor to be divisible by 16."""
    def forward(self, img, mask) -> tuple[torch.Tensor, torch.Tensor]:
        # Check if the number of depth dimensions is odd
        if img.shape[1] % 16 != 0:
            # Create a zero-filled padding with the same height and width
            padding = torch.zeros(1, 16 - img.shape[1] % 16, *img.shape[2:], device=img.device, dtype=img.dtype)
            # Concatenate the padding to the tensor
            img = torch.cat([img, padding], dim=1)
            mask = torch.cat([mask, padding], dim=1)
        return tv_tensors.Image(img), tv_tensors.Mask(mask)
    
class Normalize3D(nn.Module):
    def __init__(self, mean, std):
        super().__init__()
        self.mean = torch.tensor(mean)
        self.std = torch.tensor(std)
        
    def forward(self, img, mask) -> tuple[torch.Tensor, torch.Tensor]:
        img = (img - self.mean) / self.std
        return img, mask
    
class Resize3d(nn.Module):
    def __init__(self, size: tuple[int, int, int]):
        super().__init__()
        self.size = size
        
    def forward(self, img: torch.Tensor, mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        img = nn.functional.interpolate(img.unsqueeze(0).float(), size=self.size, mode="trilinear", align_corners=False).squeeze(0).long()
        mask = nn.functional.interpolate(mask.unsqueeze(0).float(), size=self.size, mode="nearest").squeeze(0).long()
        return img, mask

transforms = v2.Compose([
    Resize3d(size=(224, 224, 224)),
    v2.ToDtype({tv_tensors.Image: torch.float32, tv_tensors.Mask: torch.int64, "others": None}),
    Normalize3D(mean=-790.1, std=889.6),
])

luna16 = Luna16Dataset(root=PROJECT_ROOT / "data/luna16", transforms=transforms, train=True)
luna16_base = Luna16Dataset(root=PROJECT_ROOT / "data/luna16", transforms=None, train=True)

In [7]:
luna16[0][0].min(), luna16[0][0].max(), luna16[0][0].float().std()

(tensor(-2.5111), tensor(2.7530), tensor(1.2607))

In [8]:
luna16_base[0][0].float().std()

tensor(1126.5992)

In [9]:
(luna16[12][1].float() / 255).unique()

tensor([0., 1.])

In [10]:
resize = Resize3d((256, 256, 256))
resize(*luna16_base[0])[0].shape

torch.Size([1, 256, 256, 256])

In [11]:
transforms(*luna16_base[0])

(tensor([[[[-2.5111, -2.5111, -2.5111,  ..., -2.5111, -2.5111, -2.5111],
           [-2.5111, -2.5111, -2.5111,  ..., -2.5111, -2.5111, -2.5111],
           [-2.5111, -2.5111, -2.5111,  ..., -2.5111, -2.5111, -2.5111],
           ...,
           [-2.5111, -2.5111, -2.5111,  ..., -2.5111, -2.5111, -2.5111],
           [-2.5111, -2.5111, -2.5111,  ..., -2.5111, -2.5111, -2.5111],
           [-2.5111, -2.5111, -2.5111,  ..., -2.5111, -2.5111, -2.5111]],
 
          [[-2.5111, -2.5111, -2.5111,  ..., -2.5111, -2.5111, -2.5111],
           [-2.5111, -2.5111, -2.5111,  ..., -2.5111, -2.5111, -2.5111],
           [-2.5111, -2.5111, -2.5111,  ..., -2.5111, -2.5111, -2.5111],
           ...,
           [-2.5111, -2.5111, -2.5111,  ..., -2.5111, -2.5111, -2.5111],
           [-2.5111, -2.5111, -2.5111,  ..., -2.5111, -2.5111, -2.5111],
           [-2.5111, -2.5111, -2.5111,  ..., -2.5111, -2.5111, -2.5111]],
 
          [[-2.5111, -2.5111, -2.5111,  ..., -2.5111, -2.5111, -2.5111],
           [-

In [12]:
len(luna16.masks)

712

In [13]:
luna16[0][0].shape, luna16[0][1].shape

(torch.Size([1, 224, 224, 224]), torch.Size([1, 224, 224, 224]))

In [14]:
luna_train_loader = DataLoader(luna16, batch_size=1, shuffle=False)  # Batch size has to be 1 because each image has different depth

In [15]:
for i, (data, target) in enumerate(luna_train_loader):
    # print(data.shape, data.dtype)
    # print(target.shape, target.dtype)
    print(target.unique())
    if i > 5:
        break


tensor([0])
tensor([  0, 255])
tensor([  0, 255])
tensor([  0, 255])
tensor([0])
tensor([  0, 255])
tensor([0])


In [16]:
# means = []
# stds = []
# base_loader = DataLoader(luna16_base, batch_size=1, shuffle=False)
# for i, (data, target) in enumerate(base_loader):
#     data = data.float()
#     mean = data.mean()
#     std = data.std()
#     means.append(mean)
#     stds.append(std)
    
# print(f"Mean: {torch.stack(means).mean()}")
# print(f"Std: {torch.stack(stds).mean()}")


In [17]:
%reload project
from project.models.vnet import VNet
from project.models.unet3d import UNet3D
# import segmentation_models_3D as sm

model = VNet(num_classes=2)
# model = UNet3D(n_channels=1, n_classes=1)
# model = sm.FPN(
#     'densenet121',
#     classes=1,
#     activation='sigmoid'
# )

In [18]:
# Number of parameters
sum(p.numel() for p in model.parameters() if p.requires_grad)

22558182

In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [20]:
def mask_to_onehot(mask: torch.Tensor, num_classes: int) -> torch.Tensor:
    mask = mask.unsqueeze(0)
    mask_onehot = torch.zeros((num_classes, *mask.shape[1:]), device=mask.device)
    mask_onehot.scatter_(0, mask, 1)
    return mask_onehot

mask_to_onehot(luna16[0][1], 1).shape

torch.Size([1, 1, 224, 224, 224])

In [None]:
# Train loop

import wandb
import torch.nn as nn
import torch.optim as optim

# criterion = nn.BCEWithLogitsLoss()
criterion = nn.BCELoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

luna_train_loader = DataLoader(luna16, batch_size=1, shuffle=True)
luna_test_loader = DataLoader(Luna16Dataset(root=PROJECT_ROOT / "data" / "luna16", transforms=transforms, train=False), batch_size=1, shuffle=False)

# run = None
run = wandb.init(
    project="luna16",
    group=None,
    job_type="train",
    config={
        "model": "VNet",
        "loss_function": criterion.__class__.__name__,
        "optimizer": optimizer.__class__.__name__,
        "scheduler": scheduler.__class__.__name__,
        "lr": 0.001,
        "batch_size": 1,
        "epochs": 10,
    }
)

model.train().to(device)
for epoch in range(10):
    for i, (data, target) in enumerate(luna_train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        target = target / 255 # BCELoss expects float targets
        print(output[:, 0:1, :, :, :].shape)
        loss = criterion(output[:, 0:1, :, :, :], target)
        loss.backward()
        optimizer.step()
        print(f"Epoch: {epoch}, Batch: {i}, Loss: {loss.item()}")
        if run:
            run.log({"epoch": epoch, "batch": i, "train/loss": loss.item()})
    scheduler.step()
    
    losses = []
    for i, (data, target) in enumerate(luna_test_loader):
        data, target = data.to(device), target.to(device)
        output = model(data)
        target = target / 255  # BCELoss expects float targets
        loss = criterion(output, target.squeeze(0))
        losses.append(loss.item())
        
    test_loss = sum(losses) / len(losses)  # Average loss over each image
    print(f"Epoch: {epoch}, Test loss: {test_loss}")
    if run:
        run.log({"epoch": epoch, "test/loss": test_loss})

if run:  
    run.finish()


wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: papetoast (papetoast-org1). Use `wandb login --relogin` to force relogin


torch.Size([1, 1, 224, 224, 224])
Epoch: 0, Batch: 0, Loss: 0.6507939696311951
torch.Size([1, 1, 224, 224, 224])


In [19]:
torch.save(model.state_dict(), "unet3d_0.071.pth")

In [22]:
api = wandb.Api()
run = api.run("papetoast-org1/luna16/n8x2ixoy")

In [23]:
run.config["model"] = "UNet3D"
run.update()

In [24]:
model.__class__

project.models.unet3d.UNet3D

In [25]:
from matplotlib import pyplot as plt
from ipywidgets import interact


def visualize_ct_slices(ct_array):
    """
    Visualize slices of a 3D CT scan interactively.
    
    Parameters:
    - ct_array (numpy array): 3D array representing the CT scan.
    """
    if ct_array is None:
        print("No CT array to visualize.")
        return

    def show_slice(slice_idx):
        plt.figure(figsize=(6, 6))
        plt.imshow(ct_array[slice_idx], cmap='gray')
        plt.title(f"Slice {slice_idx + 1}/{ct_array.shape[0]}")
        plt.axis('off')
        plt.show()
    
    interact(show_slice, slice_idx=(0, ct_array.shape[0] - 1))
    

In [68]:
model.eval()
luna_test = Luna16Dataset(root=PROJECT_ROOT / "data" / "luna16", transforms=None, train=False)
inputs, target = luna_test[0]
# output = model(inputs.unsqueeze(0).to(device)).squeeze(0).squeeze(0).detach().cpu().numpy()

In [69]:
inputs

Image([[[[-1020, -1008,  -982,  ...,  -950,  -977, -1006],
         [-1013, -1016,  -997,  ...,  -965,  -978,  -980],
         [-1002, -1013, -1017,  ...,  -966,  -969,  -973],
         ...,
         [ -965,  -953,  -897,  ...,  -415,  -579,  -736],
         [-1000,  -917,  -860,  ...,  -913,  -968,  -994],
         [ -936,  -840,  -877,  ..., -1012, -1016, -1024]],

        [[ -992, -1001,  -973,  ...,  -975,  -982,  -996],
         [ -985,  -979,  -972,  ...,  -950,  -993, -1015],
         [-1019,  -998,  -969,  ...,  -989, -1004, -1015],
         ...,
         [ -895,  -894,  -875,  ...,  -504,  -621,  -783],
         [ -908,  -865,  -824,  ...,  -967, -1017,  -987],
         [ -876,  -814,  -844,  ..., -1024, -1024, -1024]],

        [[ -972,  -978,  -988,  ..., -1008, -1005, -1013],
         [ -955,  -964,  -970,  ...,  -989,  -979,  -978],
         [ -962,  -963,  -954,  ...,  -980,  -989,  -997],
         ...,
         [ -981,  -842,  -805,  ...,  -443,  -673,  -944],
         [

In [39]:
output.shape

(224, 224, 224)

In [61]:
import numpy as np


inputs = (inputs.squeeze(0).squeeze(0).numpy() * 0.026 + 0.023) * 255

In [66]:
inputs.astype(np.uint8).std()

0.0

In [None]:
target = target.squeeze(0).squeeze(0).numpy()

In [None]:
visualize_ct_slices(inputs.astype(np.uint8))

interactive(children=(IntSlider(value=111, description='slice_idx', max=223), Output()), _dom_classes=('widget…