In [134]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler, LabelEncoder
from tensorflow import keras
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import Adam

In [135]:
#Get the iris dataset
from sklearn.datasets import load_iris

iris = load_iris()
X = iris.data
y = iris.target

In [136]:
le = LabelEncoder()
y = le.fit_transform(y)

min_max = MinMaxScaler()
X = min_max.fit_transform(X)

In [137]:
def create_generator(input_dim):
    model = Sequential()
    model.add(Dense(10, input_dim=input_dim, activation='relu'))
    model.add(Dense(4, activation='sigmoid'))
    return model

def create_discriminator(input_dim):
    model = Sequential()
    model.add(Dense(10, input_dim=input_dim, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))
    return model


In [138]:
#Create the generator and discriminator
discriminator = create_discriminator(4)
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(), metrics=['accuracy'])

generator = create_generator(4)

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


In [139]:
#create the gan
discriminator.trainable=False
gan_input = Sequential()
gan_input.add(generator)
gan_input.add(discriminator)
gan_input.compile(loss='binary_crossentropy', optimizer=Adam())

In [140]:
# 8. Train the GAN
step_list = []
loss_list_discriminator = []
loss_list_generator = []
def train_gan(epochs, batch_size):

    discriminator = create_discriminator(4)
    discriminator.compile(loss='binary_crossentropy', optimizer=Adam(), metrics=['accuracy'])

    generator = create_generator(4)
    
    for e in range(epochs):
        # Generate random noise
        noise = np.random.rand(batch_size, 4)
        generated_data = generator.predict(noise)
        
        # Get a random batch of real data
        idx = np.random.randint(0, X.shape[0], batch_size)
        real_data = X[idx]
        
        # Combine real and generated data
        combined_data = np.concatenate([real_data, generated_data])
        labels = np.concatenate([np.ones((batch_size, 1)), np.zeros((batch_size, 1))])
        
        # Train the discriminator using fit()
        discriminator.trainable = True
        
        # Ensure that the fit() method is used correctly
        d_loss = discriminator.fit(combined_data, labels, epochs=1, batch_size=batch_size, verbose=0)

        # Train the generator
        noise = np.random.rand(batch_size, 4)
        labels_gan = np.ones((batch_size, 1))  # Labels for generated data
        discriminator.trainable = False
        
        # Ensure that the fit() method for the GAN is also called correctly
        g_loss = gan_input.fit(noise, labels_gan, epochs=1, batch_size=batch_size, verbose=0)

        step_list.append(e)
        loss_list_discriminator.append(d_loss)
        loss_list_generator.append(g_loss)

        if e % 100 == 0:
            print(f'Epoch {e}, Discriminator Loss: {d_loss.history["loss"][0]}, Generator Loss: {g_loss.history["loss"][0]}')


In [141]:
#Generate Samples
def generate_samples(n_samples):
    noise = np.random.rand(n_samples, 4)
    generated_samples = generator.predict(noise)
    return min_max.inverse_transform(generated_samples)

In [142]:
train_gan(epochs=1000, batch_size=32)

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 62ms/step
Epoch 0, Discriminator Loss: 0.7071341276168823, Generator Loss: 0.7168493270874023
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m 

In [146]:
generate_samples(50)

[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 33ms/step


array([[4.311921 , 4.386488 , 4.112524 , 2.4954512],
       [4.3099804, 4.3859286, 4.1608696, 2.4964316],
       [4.3054795, 4.3916326, 4.426365 , 2.498314 ],
       [4.3023033, 4.396476 , 4.362448 , 2.4992611],
       [4.304157 , 4.3936396, 4.364174 , 2.49879  ],
       [4.3200436, 4.377276 , 4.143878 , 2.4929054],
       [4.305508 , 4.3915663, 4.261358 , 2.4981623],
       [4.3223996, 4.370605 , 4.204626 , 2.4933655],
       [4.3088355, 4.3889875, 4.1449056, 2.4966533],
       [4.3041215, 4.393437 , 4.3791456, 2.4987264],
       [4.3044376, 4.3932004, 4.5352883, 2.4987125],
       [4.3060894, 4.3920074, 4.3363185, 2.498043 ],
       [4.3065214, 4.3905168, 4.25206  , 2.4979317],
       [4.3017464, 4.397184 , 4.4157667, 2.4995096],
       [4.3075485, 4.3888583, 4.4638095, 2.497653 ],
       [4.3118014, 4.3833814, 4.2000623, 2.4963124],
       [4.30321  , 4.3957796, 4.358423 , 2.4990394],
       [4.330431 , 4.364387 , 4.154143 , 2.4895425],
       [4.3054175, 4.3921666, 4.3526344, 2.498

In [149]:
print(X)
print(y)

[[0.22222222 0.625      0.06779661 0.04166667]
 [0.16666667 0.41666667 0.06779661 0.04166667]
 [0.11111111 0.5        0.05084746 0.04166667]
 [0.08333333 0.45833333 0.08474576 0.04166667]
 [0.19444444 0.66666667 0.06779661 0.04166667]
 [0.30555556 0.79166667 0.11864407 0.125     ]
 [0.08333333 0.58333333 0.06779661 0.08333333]
 [0.19444444 0.58333333 0.08474576 0.04166667]
 [0.02777778 0.375      0.06779661 0.04166667]
 [0.16666667 0.45833333 0.08474576 0.        ]
 [0.30555556 0.70833333 0.08474576 0.04166667]
 [0.13888889 0.58333333 0.10169492 0.04166667]
 [0.13888889 0.41666667 0.06779661 0.        ]
 [0.         0.41666667 0.01694915 0.        ]
 [0.41666667 0.83333333 0.03389831 0.04166667]
 [0.38888889 1.         0.08474576 0.125     ]
 [0.30555556 0.79166667 0.05084746 0.125     ]
 [0.22222222 0.625      0.06779661 0.08333333]
 [0.38888889 0.75       0.11864407 0.08333333]
 [0.22222222 0.75       0.08474576 0.08333333]
 [0.30555556 0.58333333 0.11864407 0.04166667]
 [0.22222222 