# Begin training model

### Importing necessary libraries
- Taking as example/guide 3D-Organ repository -> https://github.com/lmcanavals/3D-ORGAN/tree/master

In [None]:
!python -m pip install tensorflow keras keras-complex

### Generator and Discriminator

In [None]:
from keras.models import Model
from keras.layers import Input, Conv2D, Conv2DTranspose, Flatten, Dense, Reshape, LeakyReLU, Concatenate
from keras.optimizers import Adam
import os
import numpy as np


In [42]:
from keras.models import Model
from keras.layers import Input, Dense, Reshape, Flatten, Activation, Conv3DTranspose, Conv3D
from keras.layers import Embedding, Lambda, Concatenate, Add, BatchNormalization
from keras.layers import GlobalAvgPool3D, Multiply, LeakyReLU

import keras.backend as K

def dense_layer(inp, f, act='relu', bn=True):
    initializer = act if act is not None else ''
    initializer = 'he_uniform' if initializer.find('relu') != -1 else 'glorot_uniform'
    out = Dense(f, use_bias=False, kernel_initializer=initializer)(inp)
    if bn: out = BatchNormalization()(out)
    
    if act == 'lrelu':
        out = LeakyReLU(alpha=0.2)(out)
    elif act is not None:
        out = Activation(act)(out)
    
    return out

def conv_layer(inp, f, k=4, s=2, p='same', act='relu', bn=True, transpose=False,
               se=False, se_ratio=16):
    initializer = act if act is not None else ''
    initializer = 'he_uniform' if initializer.find('relu') != -1 else 'glorot_uniform'
    fun = Conv3DTranspose if transpose else Conv3D
    out = fun(f, k, strides=s, padding=p, use_bias=False, kernel_initializer=initializer)(inp)
    if bn: out = BatchNormalization()(out)
    
    if act == 'lrelu':
        out = LeakyReLU(alpha=0.2)(out)
    elif act is not None:
        out = Activation(act)(out)

    # squeeze and excite
    if se:
        out_se = GlobalAvgPool3D()(out)
        r = f // se_ratio if (f // se_ratio) > 0 else 1
        out_se = Reshape((1, 1, f))(out_se)
        out_se = Dense(r, use_bias=False, kernel_initializer='he_uniform',
                       activation='relu')(out_se)
        out_se = Dense(f, use_bias=False, activation='sigmoid')(out_se)
        out = Multiply()([out, out_se])
    
    return out

def generator():
    input_layer = Input(shape=(700, 500, 3))  # Adjust dimensions based on your 2D images

    x = dense_layer(input_layer, 128)
    x = Reshape((700, 500, 128))(x)

    x = conv_layer(x, 64, transpose=True)
    x = conv_layer(x, 128, transpose=True)
    x = conv_layer(x, 256, transpose=True)

    output_layer = Conv3DTranspose(1, kernel_size=(3, 3, 3), strides=(2, 2, 2), activation='sigmoid', padding='same')(x)

    model = Model(inputs=input_layer, outputs=output_layer)
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

    return model

def discriminator():
    input_2d = Input(shape=(700, 500, 3))
    input_3d = Input(shape=(32, 32, 32, 3))

    x = conv_layer(input_2d, 64)
    x = conv_layer(x, 128)
    x = conv_layer(x, 256)

    x = Flatten()(x)
    x = dense_layer(x, 1, act='sigmoid', bn=False)

    model = Model(inputs=[input_2d, input_3d], outputs=x)
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

    return model


In [44]:
# Function to load 2D image data from .npy files
def load_2d_data(folder_path):
    data = []
    for file_name in os.listdir(folder_path):
        if file_name.endswith(".npy"):
            file_path = os.path.join(folder_path, file_name)
            image = np.load(file_path)
            data.append(image)
    return np.array(data)

# Load 2D data
input_2d_data = load_2d_data('/home/dele/Documents/Machine_learning/TF/Train2D_Data/')

# Load 3D voxel data
input_3d_data = np.load('/home/dele/Documents/Machine_learning/TF/voxel_data.npy',allow_pickle=True)
# Combined model
input_2d = Input(shape=(700, 500, 3))
input_3d = Input(shape=(32, 32, 32, 3))

generated_3d = generator(input_2d)
validity = discriminator([input_2d, generated_3d])

# Compile the combined model
combined_model = Model([input_2d, input_3d], validity)
combined_model.compile(loss=['binary_crossentropy'], optimizer='adam', metrics=['accuracy'])

# Labels (assuming you have your labels ready)
labels = np.ones((len(input_2d_data), 1))

# Train the combined model
combined_model.fit([input_2d_data, input_3d_data], labels, epochs=5, batch_size=128)

TypeError: generator() takes 0 positional arguments but 1 was given