In [4]:
import tensorflow as tf
import matplotlib.pyplot as plt
import os
from keras.layers import Conv3D,Conv2D,Dense,Flatten,Dropout,LeakyReLU,Conv2DTranspose,BatchNormalization,Reshape,Embedding,Concatenate,Input,ReLU
from keras.losses import SparseCategoricalCrossentropy,BinaryCrossentropy
from keras.utils import plot_model
from keras.optimizers import Adam,SGD,RMSprop
from pathlib import Path
import keras
import numpy as np

tf.config.list_physical_devices(), tf.__version__

([PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
  PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')],
 '2.16.2')

In [None]:
preprocess = lambda x,y: ((x-0.5)*2,y)
rescale = lambda x,y: (x/255.0,y)
display_model = lambda x,y: plot_model(x,show_shapes=True,show_layer_names=True,to_file=f'{y}.png',show_layer_activations=True) # helper function to plot model

In [6]:
no_of_class = 10

In [None]:
latent_dim = np.array([7,7,512]) # the latent space we will derive images from
noise_shape = 512 # shape of noise vector/array/tensor

# base upsampling/convo Transpose layers
class upsample_class():
    def __init__(self) -> None:
        self.count = 0
    
    def __call__(self,filters:int,use_norm:bool,use_drop:bool,input:keras.Layer) -> keras.Layer:
        self.count += 1
        
        x = Conv2DTranspose(filters,(5,5),strides=2,padding='same', name=f'conv_{self.count}',use_bias=False)(input)
        if use_norm: x = BatchNormalization(name=f'norm_{self.count}')(x)
        if use_drop: x = ReLU(name=f'relu_{self.count}')(x)
        return x
    

# make generator (Input -> noise,label : Output -> img)
def make_generator(no_of_class:int,noise_shape:tuple[int],latent_dim:tuple[int]) -> keras.Model:
    
    label_input = Input(shape=(1,),name='label_input')
    l = Embedding(no_of_class,128,name='label_embedding')(label_input)
    l = Dense(np.prod(latent_dim)//2,name='label_dense')(l)
    l = Flatten(name='label_flatten')(l)
    
    noise_input = Input(shape=(noise_shape,),name='noise_input')
    n = Dense(np.prod(latent_dim)//2,name='noise_dense')(noise_input)
    n = Flatten(name='noise_flatten')(n)
    
    merge = Concatenate(name='concatenate')([n,l])
    x = Reshape(latent_dim,name='Merged_label')(merge)
    x = BatchNormalization(name='Merged_norm')(x)
    
    upsample_block = upsample_class()
    
    x = upsample_block(filters=128,use_norm=True,use_drop=False,input=x)
    x = upsample_block(filters=64,use_norm=True,use_drop=False,input=x)
    
    
    x = Conv2DTranspose(1,(3,3),strides=1,padding='same', name='final',activation='tanh')(x) # tanh to normalize values between -1 and 1
    x = BatchNormalization(name='final_norm')(x)
    
    generator = tf.keras.Model([noise_input,label_input],x,name='AC_Generator')
    
    return generator

make_generator(no_of_class,noise_shape,latent_dim).summary()

In [12]:
class discriminator_class_3d():
    def __init__(self) -> None:
        self.count = 0

    def __call__(self, filters: int, input: tf.keras.layers.Layer) -> tf.keras.layers.Layer:
        self.count += 1
        x = Conv3D(filters=filters, kernel_size=(5, 5, 4), strides=2, padding='same', name=f'Convo_{self.count}')(input)
        x = LeakyReLU(name=f'Leaky_{self.count}')(x)
        x = Dropout(0.1, name=f'Dropout_{self.count}')(x)
        return x

# Make the discriminator (Input-> img, label : Output -> class, real/fake)
def make_discriminator_3d(img: tuple[int], no_of_class: int) -> keras.Model:
    label_input = Input(shape=(1,), name='label_input')
    
    l = Embedding(no_of_class, 128, name='label_embedding')(label_input)
    l = Dense(np.prod(img)*1, name='label_dense')(l)
    l = Reshape((img[0], img[1], img[2], 1), name='label_reshape')(l) 

    img_input = Input(shape=(img[0], img[1], img[2], 1), name='image_input') 
    
    merge = Concatenate(name='concatenate')([img_input, l])  # Concatenate along channel axis

    discriminator_block = discriminator_class_3d()

    x = discriminator_block(64, merge)
    x = discriminator_block(256, x)
    x = discriminator_block(512, x)

    x = Flatten(name='flatten')(x)

    label_based = Dense(no_of_class, activation='sigmoid', name='label_predict')(x)  # Softmax to condense to probability
    real_based = Dense(1, name='real_predict')(x)  # WGAN, no activation

    discriminator = Model([img_input, label_input], [label_based, real_based], name='AC_Discriminator_3D')

    return discriminator

IMG_SIZE = (64, 64, 64)  # Example 3D input shape
no_of_class = 10  # Number of classes
make_discriminator_3d(IMG_SIZE, no_of_class).summary()
