In [None]:
import tensorflow as tf
import numpy as np

class BiFPN(tf.keras.Model):
    def __init__(self, num_channels):
        super(BiFPN, self).__init__()

        self.topdown = tf.keras.Sequential([
            Conv2D(num_channels, (3, 3), activation='relu'),
            MaxPooling2D((2, 2)),
        ])

        self.bottomup = tf.keras.Sequential([
            UpSampling2D((2, 2)),
            Conv2D(num_channels, (3, 3), activation='relu'),
        ])

        self.concat = tf.keras.layers.Concatenate()

    def call(self, inputs):
        topdown_path = self.topdown(inputs)
        bottomup_path = self.bottomup(inputs)

        output = self.concat([topdown_path, bottomup_path])

        return output

class Generator(tf.keras.Model):
    def __init__(self):
        super(Generator, self).__init__()

        self.encoder = tf.keras.Sequential([
            Conv2D(32, (3, 3), activation='relu', input_shape=(256, 256, 1)),
            MaxPooling2D((2, 2)),
            Conv2D(64, (3, 3), activation='relu'),
            MaxPooling2D((2, 2)),
            Conv2D(128, (3, 3), activation='relu'),
            MaxPooling2D((2, 2)),
        ])

        self.bifpn = BiFPN(128)

        self.decoder = tf.keras.Sequential([
            UpSampling2D((2, 2)),
            Conv2D(128, (3, 3), activation='relu'),
            UpSampling2D((2, 2)),
            Conv2D(64, (3, 3), activation='relu'),
            UpSampling2D((2, 2)),
            Conv2D(32, (3, 3), activation='relu'),
            Conv2D(1, (3, 3), activation='sigmoid')
        ])

    def call(self, inputs):
        x = self.encoder(inputs)
        x = self.bifpn(x)
        x = self.decoder(x)
        return x

class Discriminator(tf.keras.Model):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.encoder = tf.keras.Sequential([
            Conv2D(32, (3, 3), activation='relu', input_shape=(256, 256, 1)),
            MaxPooling2D((2, 2)),
            Conv2D(64, (3, 3), activation='relu'),
            MaxPooling2D((2, 2)),
            Conv2D(128, (3, 3), activation='relu'),
            MaxPooling2D((2, 2)),
        ])

        self.flatten = tf.keras.layers.Flatten()
        self.dense1 = tf.keras.layers.Dense(128, activation='relu')
        self.dense2 = tf.keras.layers.Dense(1, activation='sigmoid')

    def call(self, inputs):
        x = self.encoder(inputs)
        x = self.flatten(x)
        x = self.dense1(x)
        x = self.dense2(x)
        return x

def test_cgan(generator, X_test):
    # Generate synthetic nuclei images using the generator network
    y_pred = generator.predict(X_test)

    # Compute the dice coefficient
    y_pred_binary = np.round(y_pred)
    dice = np.mean(2 * np.sum(y_pred_binary * X_test, axis=(1, 2)) / (np.sum(y_pred_binary, axis=(1, 2)) + np.sum(X_test, axis=(1, 2))))

    # Compute the sensitivity and specificity
    sensitivity = np.mean(np.sum(y_pred_binary * X_test, axis=(1, 2)) / np.sum(X_test, axis=(1, 2)))
    specificity = np.mean(np.sum((1 - y_pred_binary) * (1 - X_test), axis=(1, 2)) / np.sum(1 - X_test, axis=(1, 2)))

    # Compute the F1 score
    f1_score = 2 * sensitivity * precision / (sensitivity + precision)

    return dice, sensitivity, specificity, f1_score



In [None]:
# Define the generator and discriminator models
generator = Generator()
discriminator = Discriminator()

# Compile the models
generator.compile(loss='binary_crossentropy', optimizer='adam')
discriminator.compile(loss='binary_crossentropy', optimizer='adam')

# Train the model
for epoch in range(100):
    # Generate a batch of synthetic images
    synthetic_images = generator.predict(train_dataset.next())

    # Train the discriminator on real and synthetic images
    discriminator.train_on_batch(train_dataset.next(), np.ones((32, 1)))
    discriminator.train_on_batch(synthetic_images, np.zeros((32, 1)))

    # Train the generator
    generator.train_on_batch(train_dataset.next(), np.ones((32, 1)))

    # Save the model weights
    generator.save_weights('generator.h5')
    discriminator.save_weights('discriminator.h5')

# Test the model
generator.load_weights('generator.h5')
discriminator.load_weights('discriminator.h5')

# Generate synthetic images from the test dataset
synthetic_images = generator.predict(test_dataset.next())

# Evaluate the synthetic images using the discriminator
discriminator_scores = discriminator.predict(synthetic_images)

# Compute the accuracy
accuracy = np.mean(discriminator_scores > 0.5)

print('Accuracy:', accuracy)
