In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch
import torch.nn as nn

import torch
import torch.nn as nn
import torch.nn.functional as F

class UNet3D(nn.Module):
    def __init__(self, num_layers, num_input_channels, num_output_channels,
                 batchnorm=True, final_activation=True, dropout=0.0):
        super(UNet3D, self).__init__()
        self.num_layers = num_layers
        self.batchnorm = batchnorm
        self.final_activation = final_activation
        self.dropout_rate = dropout
        self.growth_rate = 2  # controls how features grow per layer

        self.encoder_layers = nn.ModuleList()
        self.decoder_layers = nn.ModuleList()
        self.pool = nn.MaxPool3d(kernel_size=2, stride=2)
        self.upsample = nn.ModuleList()

        in_channels = num_input_channels
        features = 16

        # Build encoder
        for _ in range(num_layers):
            self.encoder_layers.append(self.conv_block(in_channels, features))
            in_channels = features
            features *= self.growth_rate

        # Bottleneck
        self.bottleneck = self.conv_block(in_channels, features)

        # Build decoder
        for _ in range(num_layers):
            self.upsample.append(nn.ConvTranspose3d(features, features // self.growth_rate, kernel_size=2, stride=2))
            self.decoder_layers.append(self.conv_block(features, features // self.growth_rate))
            features //= self.growth_rate

        # Final output layer
        self.output_layer = nn.Conv3d(features, num_output_channels, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        layers = [
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        ]
        if self.batchnorm:
            layers.insert(1, nn.BatchNorm3d(out_channels))
        if self.dropout_rate > 0:
            layers.append(nn.Dropout3d(p=self.dropout_rate))

        layers += [
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        ]
        if self.batchnorm:
            layers.insert(-1, nn.BatchNorm3d(out_channels))
        if self.dropout_rate > 0:
            layers.append(nn.Dropout3d(p=self.dropout_rate))

        return nn.Sequential(*layers)

    def forward(self, x):
        enc_feats = []

        # Encoder
        for encoder in self.encoder_layers:
            x = encoder(x)
            enc_feats.append(x)
            x = self.pool(x)

        # Bottleneck
        x = self.bottleneck(x)

        # Decoder
        for i, decoder in enumerate(self.decoder_layers):
            x = self.upsample[i](x)
            enc_feat = self.center_crop(enc_feats[-(i+1)], x.shape[2:])
            x = torch.cat([x, enc_feat], dim=1)
            x = decoder(x)

        x = self.output_layer(x)
        if self.final_activation:
            x = torch.sigmoid(x)
        return x

    def center_crop(self, tensor, target_shape):
        _, _, d, h, w = tensor.shape
        td, th, tw = target_shape
        d1 = (d - td) // 2
        h1 = (h - th) // 2
        w1 = (w - tw) // 2
        return tensor[:, :, d1:d1+td, h1:h1+th, w1:w1+tw]

    def print_model(self):
        print(self)




In [None]:
model = UNet3D(
    num_layers=3,
    num_input_channels=1,
    num_output_channels=1,
    batchnorm=True,
    final_activation=True,
    dropout=0.1
)

model.print_model()


UNet3D(
  (encoder_layers): ModuleList(
    (0): Sequential(
      (0): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Dropout3d(p=0.1, inplace=False)
      (4): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (5): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU(inplace=True)
      (7): Dropout3d(p=0.1, inplace=False)
    )
    (1): Sequential(
      (0): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Dropout3d(p=0.1, inplace=False)
      (4): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (5): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_runnin

In [None]:
def train(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0

    for image, mask in loader:
        image = image.to(device)
        mask = mask.to(device)

        optimizer.zero_grad()
        output = model(image)

        loss = criterion(output, mask)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(loader)


In [None]:
model = UNet3D(...)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.BCEWithLogitsLoss()
train_loader = DataLoader(VolumeDataset(...), batch_size=2, shuffle=True)

for epoch in range(10):
    loss = train(model, train_loader, optimizer, criterion, device='cuda')
    print(f"Epoch {epoch}: loss = {loss:.4f}")


In [None]:
import os
import torch
from torch.utils.data import Dataset
import tifffile as tiff

class VolumeTIFFDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transform=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = tiff.imread(self.image_paths[idx]).astype('float32')  # shape: (D, H, W)
        mask = tiff.imread(self.mask_paths[idx]).astype('float32')    # shape: (D, H, W)

        image = (image - image.min()) / (image.max() - image.min() + 1e-8)

        image = torch.tensor(image).unsqueeze(0)  # [1, D, H, W]
        mask = torch.tensor(mask).unsqueeze(0)

        if self.transform:
            image, mask = self.transform(image, mask)

        return image, mask


In [None]:
from torch.utils.data import DataLoader
from dataset import VolumeTIFFDataset
import glob

image_paths = sorted(glob.glob("data/images/*.tiff"))
mask_paths = sorted(glob.glob("data/masks/*.tiff"))

dataset = VolumeTIFFDataset(image_paths, mask_paths)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

img, msk = next(iter(dataloader))
print("Image shape:", img.shape)
print("Mask shape:", msk.shape)
