In [24]:
import torch
from torch.utils.data import Dataset
import os
import numpy as np

class PETCTNPYDataset(Dataset):
    def __init__(self, npy_dir, transform=None):
        self.npy_dir = npy_dir
        self.image_ids = ['_'.join(f.split('_')[:3]) for f in os.listdir(npy_dir) if f.endswith('_pet.npy')]
        self.transform = transform

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        pet_img = np.load(os.path.join(self.npy_dir, f'{image_id}_pet.npy'))
        ct_img = np.load(os.path.join(self.npy_dir, f'{image_id}_ct.npy'))
        recon_img = np.load(os.path.join(self.npy_dir, f'{image_id}_recon.npy'))
        label = np.load(os.path.join(self.npy_dir, f'{image_id}_label.npy'))

        print(pet_img.shape, ct_img.shape, recon_img.shape)
        
        image = torch.tensor([pet_img, ct_img, recon_img], dtype=torch.float32)
        label = torch.tensor(label, dtype=torch.long)
        
        if self.transform:
            image = self.transform(image)
            label = self.transform(label)
        
        return image, label

In [25]:
import torch.nn as nn
import torch.nn.functional as F

class UNet3D(nn.Module):
    def __init__(self):
        super(UNet3D, self).__init__()
        self.enc1 = self.conv_block(3, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.pool = nn.MaxPool3d(2)
        self.dec3 = self.conv_block(256, 128)
        self.dec2 = self.conv_block(128, 64)
        self.dec1 = nn.Conv3d(64, 1, kernel_size=1)
        
    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool(enc1))
        enc3 = self.enc3(self.pool(enc2))
        dec3 = self.dec3(F.interpolate(enc3, scale_factor=2, mode='trilinear', align_corners=True))
        dec2 = self.dec2(F.interpolate(dec3, scale_factor=2, mode='trilinear', align_corners=True))
        dec1 = self.dec1(F.interpolate(dec2, scale_factor=2, mode='trilinear', align_corners=True))
        return torch.sigmoid(dec1)

In [27]:
import torch.optim as optim
from torch.utils.data import DataLoader

# Paths to your dataset
npy_dir = 'Dataset700_PET_CT_Recon_npy'

# Create dataset and dataloader
dataset = PETCTNPYDataset(npy_dir)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
print('loaded dataset')

# Initialize model, loss function, and optimizer
model = UNet3D()
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
print('loaded model')

# Training loop
num_epochs = 50
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader):.4f}')

loaded dataset
loaded model
cuda


KeyboardInterrupt: 

In [None]:
def evaluate_model(model, dataloader):
    model.eval()
    dice_score = 0.0
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            preds = outputs > 0.5
            dice_score += (2. * (preds * labels).sum()) / (preds.sum() + labels.sum() + 1e-6)
    
    print(f'Dice Score: {dice_score/len(dataloader):.4f}')

# Create a validation dataset and dataloader
val_dataset = PETCTDataset(images_dir.replace('imagesTr', 'imagesTs'), labels_dir.replace('labelsTr', 'labelsTs'))
val_dataloader = DataLoader(val_dataset, batch_size=2, shuffle=False)

# Evaluate the model
evaluate_model(model, val_dataloader)