<a href="https://colab.research.google.com/github/DoanJ7313/Python-Tutorials-Semnani-2024/blob/main/2025/3D_GAN_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
from torchsummary import summary
import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
from scipy.ndimage import map_coordinates
"""
Implementation based on original paper NeurIPS 2016 https://papers.nips.cc/paper/6096-learning-a-probabilistic-latent-space-of-object-shapes-via-3d-generative-adversarial-modeling.pdf
"""


'\nImplementation based on original paper NeurIPS 2016 https://papers.nips.cc/paper/6096-learning-a-probabilistic-latent-space-of-object-shapes-via-3d-generative-adversarial-modeling.pdf\n'

## Discriminator

In [2]:
class Discriminator(torch.nn.Module):
    def __init__(self, in_channels=3, dim=64, out_conv_channels=512):
        super(Discriminator, self).__init__()
        conv1_channels = int(out_conv_channels / 8)
        conv2_channels = int(out_conv_channels / 4)
        conv3_channels = int(out_conv_channels / 2)
        self.out_conv_channels = out_conv_channels
        self.out_dim = int(dim / 16)

        self.conv1 = nn.Sequential(
            nn.Conv3d(
                in_channels=in_channels, out_channels=conv1_channels, kernel_size=4,
                stride=2, padding=1, bias=False
            ),
            nn.BatchNorm3d(conv1_channels),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv3d(
                in_channels=conv1_channels, out_channels=conv2_channels, kernel_size=4,
                stride=2, padding=1, bias=False
            ),
            nn.BatchNorm3d(conv2_channels),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv3 = nn.Sequential(
            nn.Conv3d(
                in_channels=conv2_channels, out_channels=conv3_channels, kernel_size=4,
                stride=2, padding=1, bias=False
            ),
            nn.BatchNorm3d(conv3_channels),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv4 = nn.Sequential(
            nn.Conv3d(
                in_channels=conv3_channels, out_channels=out_conv_channels, kernel_size=4,
                stride=2, padding=1, bias=False
            ),
            nn.BatchNorm3d(out_conv_channels),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.out = nn.Sequential(
            nn.Linear(out_conv_channels * self.out_dim * self.out_dim * self.out_dim, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        # Flatten and apply linear + sigmoid
        x = x.view(-1, self.out_conv_channels * self.out_dim * self.out_dim * self.out_dim)
        x = self.out(x)
        return x


## Generator

In [3]:
class Generator(torch.nn.Module):
    def __init__(self, in_channels=512, out_dim=64, out_channels=1, noise_dim=200, activation="sigmoid"):
        super(Generator, self).__init__()
        self.in_channels = in_channels
        self.out_dim = out_dim
        self.in_dim = int(out_dim / 16)
        conv1_out_channels = int(self.in_channels / 2.0)
        conv2_out_channels = int(conv1_out_channels / 2)
        conv3_out_channels = int(conv2_out_channels / 2)

        self.linear = torch.nn.Linear(noise_dim, in_channels * self.in_dim * self.in_dim * self.in_dim)

        self.conv1 = nn.Sequential(
            nn.ConvTranspose3d(
                in_channels=in_channels, out_channels=conv1_out_channels, kernel_size=(4, 4, 4),
                stride=2, padding=1, bias=False
            ),
            nn.BatchNorm3d(conv1_out_channels),
            nn.ReLU(inplace=True)
        )
        self.conv2 = nn.Sequential(
            nn.ConvTranspose3d(
                in_channels=conv1_out_channels, out_channels=conv2_out_channels, kernel_size=(4, 4, 4),
                stride=2, padding=1, bias=False
            ),
            nn.BatchNorm3d(conv2_out_channels),
            nn.ReLU(inplace=True)
        )
        self.conv3 = nn.Sequential(
            nn.ConvTranspose3d(
                in_channels=conv2_out_channels, out_channels=conv3_out_channels, kernel_size=(4, 4, 4),
                stride=2, padding=1, bias=False
            ),
            nn.BatchNorm3d(conv3_out_channels),
            nn.ReLU(inplace=True)
        )
        self.conv4 = nn.Sequential(
            nn.ConvTranspose3d(
                in_channels=conv3_out_channels, out_channels=out_channels, kernel_size=(4, 4, 4),
                stride=2, padding=1, bias=False
            )
        )
        if activation == "sigmoid":
            self.out = torch.nn.Sigmoid()
        else:
            self.out = torch.nn.Tanh()

    def project(self, x):
        """
        projects and reshapes latent vector to starting volume
        :param x: latent vector
        :return: starting volume
        """
        return x.view(-1, self.in_channels, self.in_dim, self.in_dim, self.in_dim)

    def forward(self, x):
        x = self.linear(x)
        x = self.project(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        return self.out(x)


## Test

In [8]:
def test_gan3d(print_summary=True):
    noise_dim = 200 # latent space vector dim
    in_channels = 512 # convolutional channels
    dim = 64  # cube volume
    model_generator = Generator(in_channels=512, out_dim=dim, out_channels=1, noise_dim=noise_dim)
    noise = torch.rand(1, noise_dim)
    generated_volume = model_generator(noise)
    print("Generator output shape", generated_volume.shape)
    model_discriminator = Discriminator(in_channels=1, dim=dim, out_conv_channels=in_channels)
    out = model_discriminator(generated_volume)
    print("Discriminator output", out.item())
    if print_summary:
      print("\n\nGenerator summary\n\n")
      summary(model_generator, (1, noise_dim))
      print("\n\nDiscriminator summary\n\n")
      summary(model_discriminator, (1,dim,dim,dim))
    return generated_volume

generated_volume = test_gan3d()

Generator output shape torch.Size([1, 1, 64, 64, 64])
Discriminator output 0.4900020658969879


Generator summary




RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)

In [None]:
def get_oblique_slice(volume, angle_deg=45, plane='xy'):
    angle_rad = np.deg2rad(angle_deg)
    size = volume.shape[0]
    coords = np.linspace(-size // 2, size // 2, size)

    # Build coordinate grid for oblique slice
    if plane == 'xy':
        x = coords
        y = coords
        xx, yy = np.meshgrid(x, y)
        # Rotate coordinates
        x_rot = xx * np.cos(angle_rad) - yy * np.sin(angle_rad)
        y_rot = xx * np.sin(angle_rad) + yy * np.cos(angle_rad)
        z = np.zeros_like(x_rot)
    elif plane == 'xz':
        x = coords
        z = coords
        xx, zz = np.meshgrid(x, z)
        x_rot = xx * np.cos(angle_rad) - zz * np.sin(angle_rad)
        z_rot = xx * np.sin(angle_rad) + zz * np.cos(angle_rad)
        y = np.zeros_like(x_rot)
        y_rot = y
    elif plane == 'yz':
        y = coords
        z = coords
        yy, zz = np.meshgrid(y, z)
        y_rot = yy * np.cos(angle_rad) - zz * np.sin(angle_rad)
        z_rot = yy * np.sin(angle_rad) + zz * np.cos(angle_rad)
        x = np.zeros_like(y_rot)
        x_rot = x
    else:
        raise ValueError("Invalid plane: choose from 'xy', 'xz', 'yz'")

    # Shift coordinates back into valid volume index space
    center = np.array(volume.shape) // 2
    sample_coords = np.vstack([
        (x_rot + center[0]).flatten(),
        (y_rot + center[1]).flatten(),
        (z_rot + center[2]).flatten()
    ])

    oblique_slice = map_coordinates(volume, sample_coords, order=1, mode='nearest')
    return oblique_slice.reshape(size, size)

def show_slices(volume):
    volume = volume.squeeze().detach().cpu().numpy()
    mid = volume.shape[0] // 2

    fig, axes = plt.subplots(1, 4, figsize=(16, 4))

    axes[0].imshow(volume[mid, :, :], cmap="gray")
    axes[0].set_title("Axial (XY)")

    axes[1].imshow(volume[:, mid, :], cmap="gray")
    axes[1].set_title("Coronal (XZ)")

    axes[2].imshow(volume[:, :, mid], cmap="gray")
    axes[2].set_title("Sagittal (YZ)")

    oblique = get_oblique_slice(volume, angle_deg=45, plane='xy')
    axes[3].imshow(oblique, cmap="gray")
    axes[3].set_title("Oblique 45° XY")

    plt.tight_layout()
    plt.show()


def plot_voxels(volume, threshold=0.5):
    volume = volume.squeeze().detach().cpu().numpy()
    filled = volume > threshold
    x, y, z = np.where(filled)
    fig = go.Figure(data=go.Scatter3d(
        x=x, y=y, z=z,
        mode='markers',
        marker=dict(
            size=2,
            color=volume[filled],  # Optional: grayscale intensity
            colorscale='Viridis',
            opacity=0.7
        )
    ))
    fig.update_layout(
        scene=dict(aspectmode='data'),
        title="3D Voxel Plot"
    )
    fig.show()

In [None]:
show_slices(generated_volume)

In [None]:
plot_voxels(generated_volume)