In [1]:
import numpy as np 
import matplotlib.pyplot as plt 
from tqdm import tqdm
import cv2
import os
import seaborn as sns
import tensorflow as tf
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

from keras.models import Sequential, Model
from keras.layers import Dense, Flatten, Conv2D, Reshape, Input, Conv2DTranspose
from keras.layers import Activation, LeakyReLU, BatchNormalization, Dropout, Resizing
from keras.losses import BinaryCrossentropy
from tensorflow.keras.applications import VGG16

import warnings
warnings.filterwarnings('ignore')

try:
    from tensorflow.keras.optimizers.legacy import Adam
except ImportError:
    from keras.optimizers import Adam

caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io_plugins.so: undefined symbol: _ZN3tsl6StatusC1EN10tensorflow5error4CodeESt17basic_string_viewIcSt11char_traitsIcEENS_14SourceLocationE']
caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io.so: undefined symbol: _ZTVN10tensorflow13GcsFileSystemE']


In [2]:
NOISE_DIM = 150 
BATCH_SIZE = 16
STEPS_PER_EPOCH = 1000
EPOCHS = 10
SEED = 40
WIDTH, HEIGHT, CHANNELS = 128, 128, 3


OPTIMIZER = Adam(0.0002, 0.5)


In [3]:
#MAIN_DIR = "../input/brain-mri-images-for-brain-tumor-detection/yes"
MAIN_DIR = "../input/pc-data-dataset-gen"

In [4]:
# Loading and Preprocessing the Images
def load_images(folder):
    
    imgs = []
    target = 1
    labels = []
    for i in os.listdir(folder):
        img_dir = os.path.join(folder,i)
        try:
            img = cv2.imread(img_dir)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
            img = cv2.resize(img, (128,128))
            imgs.append(img)
            labels.append(target)
        except:
            continue
        
    imgs = np.array(imgs)
    labels = np.array(labels)
    
    return imgs, labels


In [5]:
import zipfile

In [6]:
data, labels = load_images(MAIN_DIR)
data.shape, labels.shape

((0,), (0,))

In [7]:
from keras.preprocessing.image import ImageDataGenerator

In [8]:
with zipfile.ZipFile("../input/pc-data-dataset-gen/test.zip","r") as z:
    z.extractall(".")
with zipfile.ZipFile("../input/pc-data-dataset-gen/train.zip","r") as z:
    z.extractall(".")


In [9]:
image_gen = ImageDataGenerator(
                                width_shift_range=0.1,
                                height_shift_range=0.1,
                                rescale=1/255,
                                shear_range=0.2,
                                zoom_range=0.2,
                                horizontal_flip=True,
                                fill_mode="nearest"
                                )


In [10]:
batch_size = 32
test_image_gen = image_gen.flow_from_directory("./test",
                                                target_size=(128, 128),
                                                batch_size=batch_size,
                                                class_mode="categorical")
train_image_gen = image_gen.flow_from_directory("./train",
                                                target_size=(128, 128),
                                                batch_size=batch_size,
                                                class_mode="categorical")


Found 1051 images belonging to 10 classes.
Found 8412 images belonging to 10 classes.


In [None]:
import matplotlib.pyplot as plt
import numpy as np

class_labels = list(test_image_gen.class_indices.keys())

plt.figure(figsize=(20, 8))
for i in range(10):
    axs = plt.subplot(2, 5, i + 1)
    batch = test_image_gen.next()
    image = batch[0][0]  # Get the first image in the batch
    label = np.argmax(batch[1][0])  # Get the class label index
    plt.imshow(image, cmap="gray")
    plt.title(class_labels[label])
    plt.axis('off')
    axs.set_xticklabels([])
    axs.set_yticklabels([])
    plt.subplots_adjust(wspace=None, hspace=None)
plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import numpy as np

class_labels = list(train_image_gen.class_indices.keys())

plt.figure(figsize=(20, 8))
for i in range(10):
    axs = plt.subplot(2, 5, i + 1)
    batch = test_image_gen.next()
    image = batch[0][0]  # Get the first image in the batch
    label = np.argmax(batch[1][0])  # Get the class label index
    plt.imshow(image, cmap="gray")
    plt.title(class_labels[label])
    plt.axis('off')
    axs.set_xticklabels([])
    axs.set_yticklabels([])
    plt.subplots_adjust(wspace=None, hspace=None)
plt.tight_layout()
plt.show()


In [None]:
image_shape = test_image_gen.image_shape
print("Image shape:", image_shape)
image_shape = train_image_gen.image_shape
print("Image shape:", image_shape)

In [None]:
X_test = test_image_gen
X_train = train_image_gen


In [None]:
## Generate 20 random numbers to index images from data

In [None]:
np.random.seed(SEED)
idxs = np.random.randint(0, 155, 20)
print(idxs)

In [None]:
print(os.listdir('./train'))  # Should list class subfolders inside 'train'
print(os.listdir('./test'))   # Should list class subfolders inside 'test'


In [None]:
print(f"Number of classes in train set: {train_image_gen.num_classes}")
print(f"Number of classes in test set: {test_image_gen.num_classes}")


In [None]:
# Get a single batch from the generator and check its shape
batch = train_image_gen.next()
print("Batch shape:", batch[0].shape)  # Batch shape should be (batch_size, 128, 128, 3)


In [None]:
batch = test_image_gen.next()
print("Batch shape:", batch[0].shape)  # Batch shape should be (batch_size, 128, 128, 3)


In [None]:
# Example to fetch a batch and use it directly
batch = train_image_gen.next()
X_train = batch[0]  # Images
y_train = batch[1]  # Labels

# Now you can use X_train and y_train for training or further processing


In [None]:
# Example to fetch a batch and use it directly
batch = test_image_gen.next()
X_test = batch[0]  # Images
y_test = batch[1]  # Labels


In [None]:
X_train_all = []
y_train_all = []

# Let's collect the images for one full pass through the data (one epoch)
for _ in range(len(train_image_gen)):
    batch = train_image_gen.next()
    X_train_all.append(batch[0])
    y_train_all.append(batch[1])

# Convert lists to numpy arrays
X_train_all = np.concatenate(X_train_all)
y_train_all = np.concatenate(y_train_all)

print("Shape of X_train_all:", X_train_all.shape)
print("Shape of y_train_all:", y_train_all.shape)


In [None]:
X_test_all = []
y_test_all = []

# Let's collect the images for one full pass through the data (one epoch)
for _ in range(len(test_image_gen)):
    batch = train_image_gen.next()
    X_test_all.append(batch[0])
    y_test_all.append(batch[1])

# Convert lists to numpy arrays
X_test_all = np.concatenate(X_test_all)
y_test_all = np.concatenate(y_test_all)

print("Shape of X_test_all:", X_test_all.shape)
print("Shape of y_test_all:", y_test_all.shape)


In [None]:
NOISE_DIM = 100  
BATCH_SIZE = 32

STEPS_PER_EPOCH = 2000
EPOCHS = 10
SEED = 40
WIDTH, HEIGHT, CHANNELS = 128, 128, 1

OPTIMIZER = Adam(0.0002, 0.5)


In [None]:
# Normalize and reshape all the images
X_train = (X_train_all.astype(np.float32) - 127.5) / 127.5
X_train = X_train.reshape(-1, WIDTH, HEIGHT, 3)  # Keep the 3 color channels for RGB

print("Shape of X_train after reshaping:", X_train.shape)


In [None]:
# Normalize and reshape all the images
X_test = (X_test_all.astype(np.float32) - 127.5) / 127.5
X_test = X_test.reshape(-1, WIDTH, HEIGHT, 3)  # Keep the 3 color channels for RGB

print("Shape of X_test after reshaping:", X_test.shape)


In [None]:
import numpy as np

# Assuming X_train_all is your training data
# Replace this with your actual data loading/preprocessing code
X_train_all = np.random.randn(8412, 128, 128, 3)  # Example data

# Save the training data as a NumPy array file
np.save('X_train_all.npy', X_train_all)
print(X_train_all.shape)

In [None]:
import os
os.environ['TF_DISABLE_LAYOUT_OPTIMIZER'] = '1'  # Disable layout optimizer

from keras.models import Sequential, Model
from keras.layers import Dense, Flatten, Conv2D, Reshape, Input, Conv2DTranspose, UpSampling2D
from keras.layers import Activation, LeakyReLU, BatchNormalization, Dropout, Add
from keras.losses import BinaryCrossentropy
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras.optimizers.legacy import Adam
from tensorflow.keras.applications.inception_v3 import InceptionV3, preprocess_input
from scipy.linalg import sqrtm
from sklearn.metrics import precision_score, recall_score, f1_score

# Load the training data
X_train = np.load('X_train_all.npy')

# Normalize X_train to [-1, 1]
X_train_min, X_train_max = X_train.min(), X_train.max()
X_train = 2 * (X_train - X_train_min) / (X_train_max - X_train_min) - 1
print(f"X_train range after normalization: [{X_train.min():.2f}, {X_train.max():.2f}]")

# Set random seed for reproducibility
np.random.seed(40)
tf.random.set_seed(40)

# Constants
NOISE_DIM = 150
BATCH_SIZE = 64
STEPS_PER_EPOCH = 132
EPOCHS = 50
IMG_WIDTH, IMG_HEIGHT, CHANNELS = 128, 128, 3
G_LR = 0.0006
D_LR = 0.0001

# Load Inception V3 for FID
inception_model = InceptionV3(weights='imagenet', include_top=False, pooling='avg', input_shape=(299, 299, 3))

# FID Calculation (unchanged)
def calculate_fid(real_images, generated_images):
    real_images_resized = tf.image.resize(real_images, (299, 299))
    gen_images_resized = tf.image.resize(generated_images, (299, 299))
    real_images_proc = preprocess_input((real_images_resized + 1) * 127.5)
    gen_images_proc = preprocess_input((gen_images_resized + 1) * 127.5)
    real_features = inception_model.predict(real_images_proc, verbose=0)
    gen_features = inception_model.predict(gen_images_proc, verbose=0)
    mu1, sigma1 = np.mean(real_features, axis=0), np.cov(real_features, rowvar=False)
    mu2, sigma2 = np.mean(gen_features, axis=0), np.cov(gen_features, rowvar=False)
    diff = mu1 - mu2
    covmean = sqrtm(sigma1.dot(sigma2))
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    fid = diff.dot(diff) + np.trace(sigma1 + sigma2 - 2 * covmean)
    return fid

# Enhanced Generator
def build_generator():
    inputs = Input(shape=(NOISE_DIM,))
    
    # Initial dense layer with increased capacity
    x = Dense(8 * 8 * 2048, input_dim=NOISE_DIM)(inputs)  # Start at 8x8 with more filters
    x = Reshape((8, 8, 2048))(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    # Upsample to 16x16 with residual connection
    x = UpSampling2D()(x)  # 16x16
    shortcut = Conv2D(1024, (1,1), padding='same')(x)
    x = Conv2D(1024, (3,3), padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Add()([x, shortcut])
    
    # Upsample to 32x32 with transpose convolution
    x = Conv2DTranspose(512, (4,4), strides=2, padding='same')(x)  # 32x32
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Dropout(0.2)(x)
    
    # Upsample to 64x64
    x = Conv2DTranspose(256, (4,4), strides=2, padding='same')(x)  # 64x64
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    # Upsample to 128x128
    x = Conv2DTranspose(128, (4,4), strides=2, padding='same')(x)  # 128x128
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    # Refinement layer
    x = Conv2D(64, (3,3), padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    # Output layer
    outputs = Conv2D(CHANNELS, (3,3), padding='same', activation='tanh')(x)
    
    return Model(inputs, outputs)

# Discriminator (unchanged)
def build_discriminator():
    model = Sequential()
    model.add(tfa.layers.SpectralNormalization(Conv2D(32, (3,3), strides=2, padding='same',
                     input_shape=(IMG_WIDTH, IMG_HEIGHT, CHANNELS))))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.3))
    model.add(tfa.layers.SpectralNormalization(Conv2D(64, (3,3), strides=2, padding='same')))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.3))
    model.add(tfa.layers.SpectralNormalization(Conv2D(128, (3,3), strides=2, padding='same')))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.3))
    model.add(Flatten())
    model.add(Dense(1, activation='sigmoid'))
    return model

# Build and compile models
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', 
                     optimizer=Adam(learning_rate=D_LR, beta_1=0.5, beta_2=0.999), 
                     metrics=['accuracy'])

generator = build_generator()
z = Input(shape=(NOISE_DIM,))
img = generator(z)
discriminator.trainable = False
validity = discriminator(img)
gan = Model(z, validity)
gan.compile(loss='binary_crossentropy', 
            optimizer=Adam(learning_rate=G_LR, beta_1=0.5, beta_2=0.999))

# Training Function (increased generator updates)
def train_gan(epochs, batch_size, steps_per_epoch, X_train):
    real_label_smooth = np.random.uniform(0.8, 1.0, (batch_size, 1))
    fake_label_smooth = np.zeros((batch_size, 1)) + 0.1
    
    os.makedirs("gan_progress", exist_ok=True)
    d_losses, g_losses, acc_real, acc_fake, fid_scores, precisions, recalls = [], [], [], [], [], [], []
    
    for epoch in range(epochs):
        epoch_d_loss, epoch_g_loss, epoch_acc_real, epoch_acc_fake = [], [], [], []
        progress_bar = tqdm(range(steps_per_epoch), desc=f"Epoch {epoch+1}/{epochs}")
        
        for step in progress_bar:
            # Train Generator (3 updates for more power)
            for _ in range(3):
                noise = np.random.normal(0, 1, (batch_size, NOISE_DIM))
                g_loss = gan.train_on_batch(noise, real_label_smooth)
            
            # Train Discriminator
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            real_imgs = X_train[idx]
            fake_imgs = generator.predict(noise, verbose=0)
            
            real_preds = discriminator.predict(real_imgs, verbose=0)
            fake_preds = discriminator.predict(fake_imgs, verbose=0)
            acc_r = np.mean(real_preds)
            acc_f = np.mean(fake_preds)
            
            d_loss = 0
            if step % 3 == 0 or acc_r > 0.9:
                d_loss_real = discriminator.train_on_batch(real_imgs, real_label_smooth)[0]
                d_loss_fake = discriminator.train_on_batch(fake_imgs, fake_label_smooth)[0]
                d_loss = 0.5 * (d_loss_real + d_loss_fake)
            
            acc_r_display = acc_r * 100
            acc_f_display = (1 - acc_f) * 100
            
            epoch_d_loss.append(d_loss)
            epoch_g_loss.append(g_loss)
            epoch_acc_real.append(acc_r_display)
            epoch_acc_fake.append(acc_f_display)
            
            progress_bar.set_postfix({
                'D Loss': f"{d_loss:.4f}",
                'G Loss': f"{g_loss:.4f}",
                'Acc Real': f"{acc_r_display:.2f}%",
                'Acc Fake': f"{acc_f_display:.2f}%"
            })
        
        # Compute averages
        avg_d_loss = np.mean([x for x in epoch_d_loss if x != 0])
        avg_g_loss = np.mean(epoch_g_loss)
        avg_acc_r = np.mean(epoch_acc_real)
        avg_acc_f = np.mean(epoch_acc_fake)
        
        # Calculate FID
        noise = np.random.normal(0, 1, (batch_size, NOISE_DIM))
        gen_imgs = generator.predict(noise, verbose=0)
        fid = calculate_fid(real_imgs, gen_imgs)
        
        # Calculate Precision and Recall
        real_labels = np.ones(batch_size)
        fake_labels = np.zeros(batch_size)
        all_labels = np.concatenate([real_labels, fake_labels])
        all_preds = np.concatenate([real_preds, fake_preds]) > 0.5
        precision = precision_score(all_labels, all_preds)
        recall = recall_score(all_labels, all_preds)
        
        d_losses.append(avg_d_loss)
        g_losses.append(avg_g_loss)
        acc_real.append(avg_acc_r)
        acc_fake.append(avg_acc_f)
        fid_scores.append(fid)
        precisions.append(precision * 100)
        recalls.append(recall * 100)
        
        print(f"\nEpoch {epoch+1} Summary:")
        print(f"  Discriminator Loss: {avg_d_loss:.4f}")
        print(f"  Generator Loss: {avg_g_loss:.4f}")
        print(f"  Real Image Accuracy: {avg_acc_r:.2f}%")
        print(f"  Fake Image Accuracy: {avg_acc_f:.2f}%")
        print(f"  FID Score: {fid:.2f}")
        print(f"  Precision: {precision*100:.2f}%")
        print(f"  Recall: {recall*100:.2f}\n")
        
        generate_images(epoch)
        if epoch % 5 == 0:
            plot_metrics(d_losses, g_losses, acc_real, acc_fake, fid_scores, precisions, recalls)
    
    return d_losses, g_losses, acc_real, acc_fake, fid_scores, precisions, recalls

# Visualization Functions (unchanged)
def plot_metrics(d_losses, g_losses, acc_real, acc_fake, fid_scores, precisions, recalls):
    plt.figure(figsize=(15, 10))
    plt.subplot(2, 2, 1)
    plt.plot(d_losses, label='Discriminator Loss')
    plt.plot(g_losses, label='Generator Loss')
    plt.title("Training Losses")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.subplot(2, 2, 2)
    plt.plot(acc_real, label='Real Accuracy')
    plt.plot(acc_fake, label='Fake Accuracy')
    plt.title("Classification Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy (%)")
    plt.legend()
    plt.subplot(2, 2, 3)
    plt.plot(fid_scores, label='FID Score')
    plt.title("Fréchet Inception Distance")
    plt.xlabel("Epoch")
    plt.ylabel("FID")
    plt.legend()
    plt.subplot(2, 2, 4)
    plt.plot(precisions, label='Precision')
    plt.plot(recalls, label='Recall')
    plt.title("Precision and Recall")
    plt.xlabel("Epoch")
    plt.ylabel("Percentage (%)")
    plt.legend()
    plt.tight_layout()
    plt.show()

def generate_images(epoch, num_images=9):
    noise = np.random.normal(0, 1, (num_images, NOISE_DIM))
    gen_imgs = generator.predict(noise, verbose=0)
    gen_imgs = 0.5 * gen_imgs + 0.5  # Rescale to [0, 1] for visualization
    gen_imgs = (gen_imgs * 255).astype(np.uint8)
    plt.figure(figsize=(10, 10))
    for i in range(num_images):
        plt.subplot(3, 3, i+1)
        plt.imshow(gen_imgs[i])
        plt.axis('off')
    plt.suptitle(f"Epoch {epoch+1}", fontsize=20)
    plt.savefig(f"gan_progress/epoch_{epoch+1}.png")
    plt.show()
    plt.close()

# Training Execution
history = train_gan(epochs=EPOCHS, 
                    batch_size=BATCH_SIZE, 
                    steps_per_epoch=STEPS_PER_EPOCH, 
                    X_train=X_train)

# Save models
generator.save('generator_model.h5')
discriminator.save('discriminator_model.h5')

gan.save('gan_model.h5')

In [21]:
# Enhanced Generator
def build_generator():
    inputs = Input(shape=(NOISE_DIM,))
    
    # Initial dense layer with increased capacity
    x = Dense(8 * 8 * 2048, input_dim=NOISE_DIM)(inputs)  # Start at 8x8 with more filters
    x = Reshape((8, 8, 2048))(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    # Upsample to 16x16 with residual connection
    x = UpSampling2D()(x)  # 16x16
    shortcut = Conv2D(1024, (1,1), padding='same')(x)
    x = Conv2D(1024, (3,3), padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Add()([x, shortcut])
    
    # Upsample to 32x32 with transpose convolution
    x = Conv2DTranspose(512, (4,4), strides=2, padding='same')(x)  # 32x32
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Dropout(0.2)(x)
    
    # Upsample to 64x64
    x = Conv2DTranspose(256, (4,4), strides=2, padding='same')(x)  # 64x64
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    # Upsample to 128x128
    x = Conv2DTranspose(128, (4,4), strides=2, padding='same')(x)  # 128x128
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    # Refinement layer
    x = Conv2D(64, (3,3), padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    # Output layer
    outputs = Conv2D(CHANNELS, (3,3), padding='same', activation='tanh')(x)
    
    return Model(inputs, outputs)

# Discriminator (unchanged)
def build_discriminator():
    model = Sequential()
    model.add(tfa.layers.SpectralNormalization(Conv2D(32, (3,3), strides=2, padding='same',
                     input_shape=(IMG_WIDTH, IMG_HEIGHT, CHANNELS))))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.3))
    model.add(tfa.layers.SpectralNormalization(Conv2D(64, (3,3), strides=2, padding='same')))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.3))
    model.add(tfa.layers.SpectralNormalization(Conv2D(128, (3,3), strides=2, padding='same')))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.3))
    model.add(Flatten())
    model.add(Dense(1, activation='sigmoid'))
    return model

# Build and compile models
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', 
                     optimizer=Adam(learning_rate=D_LR, beta_1=0.5, beta_2=0.999), 
                     metrics=['accuracy'])

generator = build_generator()
z = Input(shape=(NOISE_DIM,))
img = generator(z)
discriminator.trainable = False
validity = discriminator(img)
gan = Model(z, validity)
gan.compile(loss='binary_crossentropy', 
            optimizer=Adam(learning_rate=G_LR, beta_1=0.5, beta_2=0.999))
gan.summary()


Model: "model_7"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_8 (InputLayer)        [(None, 150)]             0         
                                                                 
 model_6 (Functional)        (None, 128, 128, 3)       51868035  
                                                                 
 sequential_5 (Sequential)   (None, 1)                 126241    
                                                                 
Total params: 51,994,276
Trainable params: 51,859,971
Non-trainable params: 134,305
_________________________________________________________________


er porer ta optional just chaile run koraite paris 


In [None]:
import os
os.environ['TF_DISABLE_LAYOUT_OPTIMIZER'] = '1'  # Disable layout optimizer

from keras.models import Sequential, Model
from keras.layers import Dense, Flatten, Conv2D, Reshape, Input, Conv2DTranspose, UpSampling2D
from keras.layers import Activation, LeakyReLU, BatchNormalization, Dropout, Add
from keras.losses import BinaryCrossentropy
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras.optimizers.legacy import Adam
from tensorflow.keras.applications.inception_v3 import InceptionV3, preprocess_input
from scipy.linalg import sqrtm
from sklearn.metrics import precision_score, recall_score, f1_score

# Load the training data
X_train = np.load('X_train_all.npy')

# Normalize X_train to [-1, 1]
X_train_min, X_train_max = X_train.min(), X_train.max()  # Fixed syntax
X_train = 2 * (X_train - X_train_min) / (X_train_max - X_train_min) - 1
print(f"X_train range after normalization: [{X_train.min():.2f}, {X_train.max():.2f}]")

# Set random seed for reproducibility
np.random.seed(40)
tf.random.set_seed(40)

# Constants
NOISE_DIM = 150
BATCH_SIZE = 64
STEPS_PER_EPOCH = 50
EPOCHS = 50
IMG_WIDTH, IMG_HEIGHT, CHANNELS = 128, 128, 3
G_LR = 0.0006
D_LR = 0.00015
EARLY_STOPPING_PATIENCE = 10
EARLY_STOPPING_MIN_DELTA = 1.0

# Load Inception V3 for FID
inception_model = InceptionV3(weights='imagenet', include_top=False, pooling='avg', input_shape=(299, 299, 3))

# FID Calculation
def calculate_fid(real_images, generated_images):
    real_images_resized = tf.image.resize(real_images, (299, 299))
    gen_images_resized = tf.image.resize(generated_images, (299, 299))
    real_images_proc = preprocess_input((real_images_resized + 1) * 127.5)
    gen_images_proc = preprocess_input((gen_images_resized + 1) * 127.5)
    real_features = inception_model.predict(real_images_proc, verbose=0)
    gen_features = inception_model.predict(gen_images_proc, verbose=0)
    mu1, sigma1 = np.mean(real_features, axis=0), np.cov(real_features, rowvar=False)
    mu2, sigma2 = np.mean(gen_features, axis=0), np.cov(gen_features, rowvar=False)
    diff = mu1 - mu2
    covmean = sqrtm(sigma1.dot(sigma2))
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    fid = diff.dot(diff) + np.trace(sigma1 + sigma2 - 2 * covmean)
    return fid

# Simplified Generator
def build_generator():
    inputs = Input(shape=(NOISE_DIM,))
    
    x = Dense(8 * 8 * 2048)(inputs)
    x = Reshape((8, 8, 2048))(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = UpSampling2D()(x)  # 16x16
    shortcut = Conv2D(1024, (1,1), padding='same')(x)
    x = Conv2D(1024, (3,3), padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Add()([x, shortcut])
    
    x = Conv2DTranspose(512, (4,4), strides=2, padding='same')(x)  # 32x32
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Dropout(0.2)(x)
    
    x = Conv2DTranspose(256, (4,4), strides=2, padding='same')(x)  # 64x64
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = Conv2DTranspose(128, (4,4), strides=2, padding='same')(x)  # 128x128
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    outputs = Conv2D(CHANNELS, (3,3), padding='same', activation='tanh')(x)
    
    return Model(inputs, outputs)

# Discriminator
def build_discriminator():
    model = Sequential()
    model.add(tfa.layers.SpectralNormalization(Conv2D(32, (3,3), strides=2, padding='same',
                     input_shape=(IMG_WIDTH, IMG_HEIGHT, CHANNELS))))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.3))
    model.add(tfa.layers.SpectralNormalization(Conv2D(64, (3,3), strides=2, padding='same')))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.3))
    model.add(tfa.layers.SpectralNormalization(Conv2D(128, (3,3), strides=2, padding='same')))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.3))
    model.add(Flatten())
    model.add(Dense(1, activation='sigmoid'))
    return model

# Build and compile models
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', 
                     optimizer=Adam(learning_rate=D_LR, beta_1=0.5, beta_2=0.999), 
                     metrics=['accuracy'])

generator = build_generator()
z = Input(shape=(NOISE_DIM,))
img = generator(z)
discriminator.trainable = False
validity = discriminator(img)
gan = Model(z, validity)
gan.compile(loss='binary_crossentropy', 
            optimizer=Adam(learning_rate=G_LR, beta_1=0.5, beta_2=0.999))

# Training Function with Optimizations
def train_gan(epochs, batch_size, steps_per_epoch, X_train):
    real_label_smooth = np.random.uniform(0.9, 1.0, (batch_size, 1))
    fake_label_smooth = np.zeros((batch_size, 1)) + 0.1
    
    os.makedirs("gan_progress", exist_ok=True)
    d_losses, g_losses, acc_real, acc_fake, fid_scores, precisions, recalls = [], [], [], [], [], [], []
    
    best_fid = float('inf')
    patience_counter = 0
    best_generator_weights = None
    
    for epoch in range(epochs):
        epoch_d_loss, epoch_g_loss, epoch_acc_real, epoch_acc_fake = [], [], [], []
        progress_bar = tqdm(range(steps_per_epoch), desc=f"Epoch {epoch+1}/{epochs}")
        
        for step in progress_bar:
            # Train Generator (2 updates)
            for _ in range(2):
                noise = np.random.normal(0, 1, (batch_size, NOISE_DIM))
                g_loss = gan.train_on_batch(noise, real_label_smooth)
            
            # Train Discriminator
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            real_imgs = X_train[idx]
            fake_imgs = generator.predict(noise, verbose=0)
            
            real_preds = discriminator.predict(real_imgs, verbose=0)
            fake_preds = discriminator.predict(fake_imgs, verbose=0)
            acc_r = np.mean(real_preds)
            acc_f = np.mean(fake_preds)
            
            d_loss = 0
            if step % 2 == 0 or acc_r > 0.9:
                d_loss_real = discriminator.train_on_batch(real_imgs, real_label_smooth)[0]
                d_loss_fake = discriminator.train_on_batch(fake_imgs, fake_label_smooth)[0]
                d_loss = 0.5 * (d_loss_real + d_loss_fake)
            
            acc_r_display = acc_r * 100
            acc_f_display = (1 - acc_f) * 100
            
            epoch_d_loss.append(d_loss)
            epoch_g_loss.append(g_loss)
            epoch_acc_real.append(acc_r_display)
            epoch_acc_fake.append(acc_f_display)
            
            progress_bar.set_postfix({
                'D Loss': f"{d_loss:.4f}",
                'G Loss': f"{g_loss:.4f}",
                'Acc Real': f"{acc_r_display:.2f}%",
                'Acc Fake': f"{acc_f_display:.2f}%"
            })
        
        # Compute averages
        avg_d_loss = np.mean([x for x in epoch_d_loss if x != 0])
        avg_g_loss = np.mean(epoch_g_loss)
        avg_acc_r = np.mean(epoch_acc_real)
        avg_acc_f = np.mean(epoch_acc_fake)
        
        # Calculate FID every 5 epochs or at stopping
        fid = None
        if epoch % 5 == 0 or epoch == epochs - 1 or patience_counter == EARLY_STOPPING_PATIENCE - 1:
            noise = np.random.normal(0, 1, (batch_size, NOISE_DIM))
            gen_imgs = generator.predict(noise, verbose=0)
            fid = calculate_fid(real_imgs, gen_imgs)
            fid_scores.append(fid)
        else:
            fid_scores.append(fid_scores[-1] if fid_scores else None)
        
        # Calculate Precision and Recall
        real_labels = np.ones(batch_size)
        fake_labels = np.zeros(batch_size)
        all_labels = np.concatenate([real_labels, fake_labels])
        all_preds = np.concatenate([real_preds, fake_preds]) > 0.5
        precision = precision_score(all_labels, all_preds)
        recall = recall_score(all_labels, all_preds)
        
        d_losses.append(avg_d_loss)
        g_losses.append(avg_g_loss)
        acc_real.append(avg_acc_r)
        acc_fake.append(avg_acc_f)
        precisions.append(precision * 100)
        recalls.append(recall * 100)
        
        print(f"\nEpoch {epoch+1} Summary:")
        print(f"  Discriminator Loss: {avg_d_loss:.4f}")
        print(f"  Generator Loss: {avg_g_loss:.4f}")
        print(f"  Real Image Accuracy: {avg_acc_r:.2f}%")
        print(f"  Fake Image Accuracy: {avg_acc_f:.2f}%")
        print(f"  FID Score: {fid if fid is not None else 'N/A'}")
        print(f"  Precision: {precision*100:.2f}%")
        print(f"  Recall: {recall*100:.2f}")
        print(f"  Patience Counter: {patience_counter}/{EARLY_STOPPING_PATIENCE}\n")
        
        # Early stopping logic
        if fid is not None:
            if fid < best_fid - EARLY_STOPPING_MIN_DELTA:
                best_fid = fid
                patience_counter = 0
                best_generator_weights = generator.get_weights()
                print(f"New best FID: {best_fid:.2f}, weights saved.")
            else:
                patience_counter += 1
                print(f"No improvement in FID. Best FID remains {best_fid:.2f}.")
            
            if patience_counter >= EARLY_STOPPING_PATIENCE:
                print(f"Early stopping triggered after {epoch+1} epochs. Restoring best weights.")
                generator.set_weights(best_generator_weights)
                break
        
        generate_images(epoch)
        if epoch % 5 == 0:
            plot_metrics(d_losses, g_losses, acc_real, acc_fake, fid_scores, precisions, recalls)
    
    return d_losses, g_losses, acc_real, acc_fake, fid_scores, precisions, recalls

# Visualization Functions
def plot_metrics(d_losses, g_losses, acc_real, acc_fake, fid_scores, precisions, recalls):
    plt.figure(figsize=(15, 10))
    plt.subplot(2, 2, 1)
    plt.plot(d_losses, label='Discriminator Loss')
    plt.plot(g_losses, label='Generator Loss')
    plt.title("Training Losses")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.subplot(2, 2, 2)
    plt.plot(acc_real, label='Real Accuracy')
    plt.plot(acc_fake, label='Fake Accuracy')
    plt.title("Classification Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy (%)")
    plt.legend()
    plt.subplot(2, 2, 3)
    plt.plot([f for f in fid_scores if f is not None], label='FID Score')
    plt.title("Fréchet Inception Distance")
    plt.xlabel("Epoch (every 5)")
    plt.ylabel("FID")
    plt.legend()
    plt.subplot(2, 2, 4)
    plt.plot(precisions, label='Precision')
    plt.plot(recalls, label='Recall')
    plt.title("Precision and Recall")
    plt.xlabel("Epoch")
    plt.ylabel("Percentage (%)")
    plt.legend()
    plt.tight_layout()
    plt.show()

def generate_images(epoch, num_images=9):
    noise = np.random.normal(0, 1, (num_images, NOISE_DIM))
    gen_imgs = generator.predict(noise, verbose=0)
    gen_imgs = 0.5 * gen_imgs + 0.5
    gen_imgs = (gen_imgs * 255).astype(np.uint8)
    plt.figure(figsize=(10, 10))
    for i in range(num_images):
        plt.subplot(3, 3, i+1)
        plt.imshow(gen_imgs[i])
        plt.axis('off')
    plt.suptitle(f"Epoch {epoch+1}", fontsize=20)
    plt.savefig(f"gan_progress/epoch_{epoch+1}.png")
    plt.show()
    plt.close()

# Training Execution
history = train_gan(epochs=EPOCHS, 
                    batch_size=BATCH_SIZE, 
                    steps_per_epoch=STEPS_PER_EPOCH, 
                    X_train=X_train)

# Save final models
generator.save('generator_model_final.h5')
discriminator.save('discriminator_model_final.h5')
gan.save('gan_model_final.h5')

# Save best model (based on FID)
generator.save('generator_model_best.h5')