In [None]:
! pip install SimpleITK




In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [27]:
class ResidualBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv1 = nn.Conv3d(in_c, out_c, 3, padding=1)
        self.norm1 = nn.InstanceNorm3d(out_c)
        self.conv2 = nn.Conv3d(out_c, out_c, 3, padding=1)
        self.norm2 = nn.InstanceNorm3d(out_c)
        self.shortcut = nn.Sequential(
            nn.Conv3d(in_c, out_c, 1),
            nn.InstanceNorm3d(out_c)
        ) if in_c != out_c else nn.Identity()

    def forward(self, x):
        residual = self.shortcut(x)
        x = F.relu(self.norm1(self.conv1(x)))
        x = self.norm2(self.conv2(x))
        return F.relu(x + residual)


class VNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super(VNet, self).__init__()

        # Encoder
        self.enc1 = ResidualBlock(in_channels, 64)
        self.pool1 = nn.MaxPool3d(2)
        self.enc2 = ResidualBlock(64, 128)
        self.pool2 = nn.MaxPool3d(2)
        self.enc3 = ResidualBlock(128, 256)
        self.pool3 = nn.MaxPool3d(2)

        # Bottleneck
        self.bottleneck = nn.Sequential(
            ResidualBlock(256, 512),
            nn.Dropout3d(0.5)
        )

        # Decoder
        self.up3 = nn.ConvTranspose3d(512, 256, 2, stride=2)
        self.dec3 = ResidualBlock(512, 256)
        self.up2 = nn.ConvTranspose3d(256, 128, 2, stride=2)
        self.dec2 = ResidualBlock(256, 128)
        self.up1 = nn.ConvTranspose3d(128, 64, 2, stride=2)
        self.dec1 = ResidualBlock(128, 64)

        # Output
        self.final = nn.Conv3d(64, out_channels, 1)

        # Attention Modules
        self.attention3 = nn.Sequential(
            nn.Conv3d(256, 1, kernel_size=1),
            nn.Sigmoid()
        )
        self.attention2 = nn.Sequential(
            nn.Conv3d(128, 1, kernel_size=1),
            nn.Sigmoid()
        )
        self.attention1 = nn.Sequential(
            nn.Conv3d(64, 1, kernel_size=1),
            nn.Sigmoid()
        )       

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))

        # Bottleneck
        b = self.bottleneck(self.pool3(e3))

        # Decoder with Attention
        d3 = self.up3(b)
        if d3.size()[2:] != e3.size()[2:]:
            d3 = F.interpolate(d3, size=e3.size()[2:], mode='trilinear', align_corners=False)
        
        att3 = self.attention3(e3)
        e3 = e3 * att3
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)

        d2 = self.up2(d3)
        if d2.size()[2:] != e2.size()[2:]:
            d2 = F.interpolate(d2, size=e2.size()[2:], mode='trilinear', align_corners=False)
        
        att2 = self.attention2(e2)
        e2 = e2 * att2
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)

        d1 = self.up1(d2)
        if d1.size()[2:] != e1.size()[2:]:
            d1 = F.interpolate(d1, size=e1.size()[2:], mode='trilinear', align_corners=False)
        
        att1 = self.attention1(e1)
        e1 = e1 * att1
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)
        return self.final(d1)

        # Final output

In [29]:
import torch
import os
import SimpleITK as sitk
import numpy as np

# Load the trained model
model_path = "best_model1.pth"  # Updated model path
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Ensure VNet is imported or defined
# Adjust the import if necessary

# Initialize model
model = VNet(in_channels=1, out_channels=1).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()  # Set model to evaluation mode

# Paths
input_folder = "D:\segm\ct_patch_dataset_for_training"  # Folder containing .mha files
output_folder = "output"  # Folder to save segmented .mha files
os.makedirs(output_folder, exist_ok=True)

# Function to preprocess MHA files into tensors
def preprocess_mha(mha_path):
    image = sitk.ReadImage(mha_path)
    array = sitk.GetArrayFromImage(image).astype(np.float32)

    # Normalize to [0, 1] (Avoid division by zero)
    epsilon = 1e-8
    array = (array - array.min()) / (array.max() - array.min() + epsilon)

    # Convert to tensor and add batch + channel dimension
    tensor = torch.tensor(array).unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, D, H, W)
    return tensor.to(device), image

# Function to post-process and save output as MHA
def save_mha(output_tensor, reference_image, save_path):
    output_array = output_tensor.squeeze(0).squeeze(0).cpu().numpy()  # Ensure correct shape

    # Convert back to SimpleITK image
    output_image = sitk.GetImageFromArray(output_array)
    output_image.CopyInformation(reference_image)  # Maintain original metadata

    # Save as .mha
    sitk.WriteImage(output_image, save_path)
    print(f"Saved: {save_path}")

# Iterate through all MHA files in the input folder
for filename in os.listdir(input_folder):
    if filename.endswith(".mhd"):
        input_path = os.path.join(input_folder, filename)
        output_path = os.path.join(output_folder, filename)  # Save with the same name

        # Preprocess the input
        input_tensor, ref_image = preprocess_mha(input_path)

        # Perform inference
        with torch.no_grad():
            output = model(input_tensor)
            output = torch.sigmoid(output)  # Apply sigmoid for binary segmentation
            output = (output > 0.5).float()  # Threshold to get binary mask

        # Save output as MHA
        save_mha(output, ref_image, output_path)



Saved: output\1.3.6.1.4.1.14519.5.2.1.6279.6001.108197895896446896160048741492_0.mhd
Saved: output\1.3.6.1.4.1.14519.5.2.1.6279.6001.109002525524522225658609808059_0.mhd
Saved: output\1.3.6.1.4.1.14519.5.2.1.6279.6001.109002525524522225658609808059_1.mhd
Saved: output\1.3.6.1.4.1.14519.5.2.1.6279.6001.111172165674661221381920536987_0.mhd
Saved: output\1.3.6.1.4.1.14519.5.2.1.6279.6001.124154461048929153767743874565_0.mhd
Saved: output\1.3.6.1.4.1.14519.5.2.1.6279.6001.126264578931778258890371755354_0.mhd
Saved: output\1.3.6.1.4.1.14519.5.2.1.6279.6001.128023902651233986592378348912_0.mhd
Saved: output\1.3.6.1.4.1.14519.5.2.1.6279.6001.129055977637338639741695800950_0.mhd
Saved: output\1.3.6.1.4.1.14519.5.2.1.6279.6001.130438550890816550994739120843_0.mhd
Saved: output\1.3.6.1.4.1.14519.5.2.1.6279.6001.134996872583497382954024478441_0.mhd
Saved: output\1.3.6.1.4.1.14519.5.2.1.6279.6001.134996872583497382954024478441_1.mhd
Saved: output\1.3.6.1.4.1.14519.5.2.1.6279.6001.13499687258349738