# Generative AI for Ophthalmological Image Synthesis

## Project Overview

This project aims to develop a generative AI model for creating synthetic ophthalmological images based on the Brazilian Multilabel Ophthalmological Dataset (BRSET). My goal is to generate high quality, diverse images that could potentially be used for augmenting datasets, improving model training, and advancing research in ophthalmology.

## Methodology

My approach is inspired by the paper "Using generative AI to investigate medical imagery models and datasets" (Lang et al., 2024). I will implement a multistep process:

1. **Image Classification**: Train a deep learning classifier on the BRSET dataset to predict various ophthalmological conditions.

2. **Generative Model**: Develop a StyleGAN2-based generative model, incorporating guidance from my trained classifier.

3. **Attribute Discovery**: Use the trained generator to identify and visualize key attributes that influence the classifier's predictions.

4. **Analysis and Interpretation**: Examine the generated images and attributes to gain insights into the model's understanding of ophthalmological features.

## Project Goals

- Create a high-performance classifier for ophthalmological conditions using the BRSET dataset.
- Implement a StyleGAN2 based generator capable of producing realistic eye images.
- Discover and visualize attributes that are important for classifying various eye conditions.
- Generate synthetic images that could potentially be used to augment existing datasets.

## Ethical Considerations

While this project aims to advance medical imaging research, we must be mindful of the ethical implications of generating synthetic medical data. All generated images should be clearly labeled as synthetic and not used for diagnostic purposes without extensive validation.

## Getting Started

This notebook will guide you through the implementation of each step in my methodology. Let's begin by setting up our environment and loading the BRSET dataset.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Google Drive base directory
BASE_DIR = '/content/drive/MyDrive/'

# Local Base Directory
# BASE_DIR = './'

In [None]:
# Standard libraries
import os
import random
import warnings
import numpy as np
from numpy.random import normal
import pandas as pd
from tqdm import tqdm

# Deep learning and image processing
import tensorflow as tf
from tensorflow import keras
import cv2
from PIL import Image

# TensorFlow and Keras modules
from tensorflow.keras import layers, models, optimizers, losses, metrics, backend, applications
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
from tensorflow.keras.mixed_precision import global_policy
from tensorflow.keras.callbacks import LearningRateScheduler
from tensorflow.keras.utils import custom_object_scope
from tensorflow.keras.models import load_model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ReduceLROnPlateau, ModelCheckpoint, EarlyStopping, TensorBoard, CSVLogger


# Scikit-learn for data splitting and evaluation metrics
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.utils import class_weight

# Stats libraries for statistical analysis
from scipy.stats import pearsonr

# Plotting libraries
import matplotlib.pyplot as plt
import seaborn as sns

# Suppress warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
SEED = 12
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

# Ensure TensorFlow is using GPU
physical_devices = tf.config.list_physical_devices('GPU')
if len(physical_devices) > 0:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)

# Enable eager execution
tf.compat.v1.enable_eager_execution()

# Display all outputs in a cell
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

# Print environment information
print("Num GPUs Available:", len(tf.config.experimental.list_physical_devices('GPU')))
print("TensorFlow version:", tf.__version__)
print("Keras version:", tf.keras.__version__)
print("Eager execution enabled:", tf.executing_eagerly())

### Load in the labels dataset

In [None]:
# Import the data labels 

# Local Labels path
#label_path = '/Volumes/Extreme SSD/a-brazilian-multilabel-ophthalmological-dataset-brset-1.0.0/labels.csv'

# Google Drive Labels path 
label_path = '/content/drive/MyDrive/a-brazilian-multilabel-ophthalmological-dataset-brset-1.0.0/labels.csv'

labels = pd.read_csv(label_path)

# Display the first few rows of the data labels
labels.head()

### Load in images and inspect the data

In [None]:
# Define the path to the fundus photos

# Local image path
#IMAGE_PATH = '/Volumes/Extreme SSD/a-brazilian-multilabel-ophthalmological-dataset-brset-1.0.0/fundus_photos/'

# Google drive image path 
IMAGE_PATH = '/content/drive/MyDrive/a-brazilian-multilabel-ophthalmological-dataset-brset-1.0.0/fundus_photos/'

def load_and_preprocess_image(image_id, target_size=(224, 224)):
    """
    Load and preprocess a fundus photo given its image_id.
    
    Args:
    image_id (str): The ID of the image to load.
    target_size (tuple): The target size to resize the image to.
    
    Returns:
    numpy.array: The preprocessed image as a numpy array.
    """
    # Construct the full path to the image
    image_path = os.path.join(IMAGE_PATH, f"{image_id}.jpg")
    
    # Load the image
    img = load_img(image_path, target_size=target_size)
    
    # Convert the image to a numpy array
    img_array = img_to_array(img)
    
    # Normalize the image
    img_array = img_array / 255.0
    
    return img_array

def load_batch_of_images(image_ids, batch_size=32):
    """
    Load and preprocess a batch of fundus photos.
    
    Args:
    image_ids (list): List of image IDs to load.
    batch_size (int): Number of images to load at once.
    
    Returns:
    numpy.array: A batch of preprocessed images.
    """
    images = []
    for i in range(0, len(image_ids), batch_size):
        batch_ids = image_ids[i:i+batch_size]
        batch_images = [load_and_preprocess_image(id) for id in batch_ids]
        images.extend(batch_images)
    return np.array(images)

# Load the first 100 images 
first_100_image_ids = labels['image_id'].iloc[:100].tolist()
batch_of_images = load_batch_of_images(first_100_image_ids)

print(f"Batch of images shape: {batch_of_images.shape}")

# Display a grid of the first 16 images
plt.figure(figsize=(20, 20))
for i in range(16):
    plt.subplot(4, 4, i+1);
    plt.imshow(batch_of_images[i]);
    plt.axis('off');
    plt.title(f"Image ID: {first_100_image_ids[i]}");
plt.tight_layout();
plt.show();

## EDA and dataset preparation/cleaning

### Labels data EDA

In [None]:
labels.info()

In [None]:
# List of binary diagnosis columns
binary_diagnoses = ['diabetic_retinopathy', 'macular_edema', 'scar', 'nevus', 'amd',
                    'vascular_occlusion', 'hypertensive_retinopathy', 'drusens',
                    'hemorrhage', 'retinal_detachment', 'myopic_fundus', 'increased_cup_disc']

# Calculate the counts and percentages in total patient population
diagnosis_counts = labels[binary_diagnoses].sum().sort_values(ascending=False)
total_patients = len(labels)
diagnosis_percentages = (diagnosis_counts / total_patients) * 100

plt.figure(figsize=(15, 8))
ax = diagnosis_counts.plot(kind='bar')

# Add count and percentage labels on top of each bar
for i, (count, percentage) in enumerate(zip(diagnosis_counts, diagnosis_percentages)):
    ax.text(i, count, f'N={count}\n({percentage:.1f}%)', 
            ha='center', va='bottom')

plt.title('Distribution of Diagnoses')
plt.xlabel('Diagnosis (N = count), percentage of total patients (%)')
plt.ylabel('Count')
plt.yticks(range(0, 3001, 500))
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show();

The highest diagnosis frequency in the dataset is increased cup disc ratio (CDR), followed by drusens, and diabetic retinopathy. 

In [None]:
# Correlation matrix of numerical columns and binary diagnoses
numerical_columns = ['patient_age', 'patient_sex', 'exam_eye', 'DR_SDRG', 'DR_ICDR', 
                     'focus', 'iluminaton', 'image_field', 'artifacts']
correlation_matrix = labels[numerical_columns + binary_diagnoses].corr()
plt.figure(figsize=(20, 16))
sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', linewidths=0.5, fmt='.2f')
plt.title('Correlation Matrix')
plt.tight_layout()
plt.show();

In [None]:
# Print out top correlations in the correlation matrix
def print_top_correlations(correlation_matrix, n=20):
    # Unstack the correlation matrix
    correlations = correlation_matrix.unstack()

    # Sort correlations in descending order of absolute value
    correlations = correlations.abs().sort_values(ascending=False)

    # Remove self correlations
    correlations = correlations[correlations != 1.0]

    # Create a set to keep track of pairs 
    seen_pairs = set()

    print(f"Top {n} Correlation Pairs:")
    count = 0
    for (var1, var2), correlation in correlations.items():
        pair = frozenset([var1, var2])

        if pair not in seen_pairs:
            print(f"{var1} - {var2}: {correlation_matrix.loc[var1, var2]:.4f}")
            seen_pairs.add(pair)
            count += 1

            if count == n:
                break

correlation_matrix = labels[numerical_columns + binary_diagnoses].corr()
print_top_correlations(correlation_matrix, n=10)

The correlation analysis of the Brazilian Retinal Image Dataset (BRSET) reveals several interesting relationships between various ophthalmological parameters and diagnoses.

1. **DR_SDRG - DR_ICDR (0.9853)**: 
   This extremely high correlation is expected as both are classification systems for diabetic retinopathy (DR). The Scottish Diabetic Retinopathy Grading Scheme (SDRG) and the International Clinical Diabetic Retinopathy (ICDR) scale are closely aligned in their assessment of DR severity.

2. **DR_ICDR - diabetic_retinopathy (0.9173)** and **diabetic_retinopathy - DR_SDRG (0.9103)**:
   These strong correlations indicate that both grading systems (ICDR and SDRG) are highly predictive of the presence of diabetic retinopathy. This validates the consistency between the binary classification (presence/absence) and the more detailed grading scales.

3. **diabetic_retinopathy - macular_edema (0.5611)**:
   This moderate positive correlation suggests that patients with diabetic retinopathy are more likely to also have macular edema. This is clinically significant as macular edema is a common complication of diabetic retinopathy.

4. **macular_edema - DR_SDRG (0.5406)** and **macular_edema - DR_ICDR (0.5337)**:
   These correlations further support the relationship between the severity of diabetic retinopathy (as measured by both scales) and the presence of macular edema. As the severity of DR increases, the likelihood of macular edema also increases.

5. **patient_age - drusens (0.2179)**:
   This weak positive correlation suggests that the presence of drusens (small yellow or white accumulations of extracellular material in the retina) is more common in older patients. This aligns with clinical knowledge, as drusens are often associated with age-related macular degeneration (AMD).

6. **vascular_occlusion - hemorrhage (0.1816)**:
   This weak positive correlation indicates a relationship between vascular occlusions and hemorrhages in the retina. This makes clinical sense, as occlusions can lead to bleeding in the affected blood vessels.

7. **patient_age - amd (0.1284)**:
   The weak positive correlation between age and age-related macular degeneration (AMD) is expected, as AMD is more prevalent in older populations.

8. **drusens - DR_SDRG (-0.0976)**:
   This very weak negative correlation might suggest a slight inverse relationship between the presence of drusens and the severity of diabetic retinopathy. However, given the low correlation coefficient, this relationship is likely not clinically significant and would require further investigation to determine if it's meaningful.

These correlations provide valuable insights into the relationships between various ophthalmological conditions and patient characteristics in the BRSET dataset. They highlight the interconnected nature of diabetic retinopathy, macular edema, and age-related eye conditions. These findings can inform feature selection for machine learning models and guide further clinical research into the progression and comorbidities of retinal diseases.

In [None]:
# Distribution of camera types
plt.figure(figsize=(10, 6))
labels['camera'].value_counts().plot(kind='bar')
plt.title('Distribution of Camera Types')
plt.xlabel('Camera')
plt.ylabel('Count')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show();

In [None]:
# Distribution of image quality
plt.figure(figsize=(10, 6))
labels['quality'].value_counts().plot(kind='bar')
plt.title('Distribution of Image Quality')
plt.xlabel('Quality')
plt.ylabel('Count')
plt.xticks(rotation=0)
plt.tight_layout()
plt.show();

Based on the plot above we can see there is a subset of images labeled as inadequate. We will remove these images from the dataset as they are not useful for training our classifier.

In [None]:
# Shape before dropping inadequate images
labels.shape

In [None]:
# Drop the inadequate quality images
labels = labels[labels['quality'] != 'Inadequate']

In [None]:
# Shape after dropping inadequate quality images
labels.shape

In [None]:
# Distribution of DR severity (DR_ICDR)
plt.figure(figsize=(10, 6))
labels['DR_ICDR'].value_counts().sort_index().plot(kind='bar')
plt.title('Distribution of Diabetic Retinopathy Severity (ICDR scale)')
plt.xlabel('Severity')
plt.ylabel('Count')
plt.xticks(rotation=0)
plt.tight_layout()
plt.show();

In [None]:
# Age distribution
plt.figure(figsize=(10, 6))
sns.histplot(data=labels, x='patient_age', kde=True)
plt.title('Distribution of Patient Age')
plt.xlabel('Age')
plt.ylabel('Count')
plt.tight_layout()
plt.show();

In [None]:
# Diabetes duration distribution
plt.figure(figsize=(10, 6))
labels['diabetes_time_y'] = pd.to_numeric(labels['diabetes_time_y'], errors='coerce')
sns.histplot(data=labels, x='diabetes_time_y', kde=True)
plt.title('Distribution of Diabetes Duration')
plt.xlabel('Years')
plt.ylabel('Count')
plt.tight_layout()
plt.show();

In [None]:
# Relationship between age and diabetes duration
plt.figure(figsize=(10, 6))
sns.scatterplot(data=labels, x='patient_age', y='diabetes_time_y')
plt.title('Relationship between Patient Age and Diabetes Duration')
plt.xlabel('Patient Age')
plt.ylabel('Diabetes Duration (years)')
plt.tight_layout()
plt.show();

In [None]:
# Distribution of patient sex
plt.figure(figsize=(8, 6))
labels['patient_sex'].map({1: 'Male', 2: 'Female'}).value_counts().plot(kind='bar')
plt.title('Distribution of Patient Sex')
plt.xlabel('Sex')
plt.ylabel('Count')
plt.xticks(rotation=0)
plt.tight_layout()
plt.show();

In [None]:
# Distribution of examined eye
plt.figure(figsize=(8, 6))
labels['exam_eye'].map({1: 'Right', 2: 'Left'}).value_counts().plot(kind='bar')
plt.title('Distribution of Examined Eye')
plt.xlabel('Eye')
plt.ylabel('Count')
plt.xticks(rotation=0)
plt.tight_layout()
plt.show();

In [None]:
print("Summary Statistics:")
labels[numerical_columns + binary_diagnoses].describe()

In [None]:
print("\nMissing Values:")
print(labels.isnull().sum())

As seen above there are a lot of missing values in the diabetes_time_y and the insulin columns. Also comorbidity column has a lot of missing values that should be inspected. Patient age has quite a few missing values but we will impute them with the median. 

The diabetes_time_y and insulin columns have a lot of missing values so we will drop them.

In [None]:
# Impute the missing ages with the median age
median_age = labels['patient_age'].median()
labels['patient_age'] = labels['patient_age'].fillna(median_age)

# Drop the diabetes_time_y column and insulin column 
labels = labels.drop(columns=['diabetes_time_y', 'insuline'])

In [None]:
labels['comorbidities'].value_counts()

Based on the comobrbidiy column, it looks like the large values are not meaningful so we will drop the column.

In [None]:
# Drop the comorbidities column
labels = labels.drop(columns=['comorbidities'])

In [None]:
labels.isna().sum()

In [None]:
labels.shape

### Images EDA 

In [None]:
# Calculate mean, std, min, and max pixel values across initial 100 images
mean_pixel_value = np.mean(batch_of_images)
std_pixel_value = np.std(batch_of_images)
min_pixel_value = np.min(batch_of_images)
max_pixel_value = np.max(batch_of_images)

print(f"Mean pixel value: {mean_pixel_value:.4f}")
print(f"Std dev of pixel values: {std_pixel_value:.4f}")
print(f"Min pixel value: {min_pixel_value:.4f}")
print(f"Max pixel value: {max_pixel_value:.4f}")

In [None]:
# Plot histogram of pixel intensities
plt.figure(figsize=(10, 6))
plt.hist(batch_of_images.ravel(), bins=50, range=(0, 1))
plt.title("Histogram of Pixel Intensities")
plt.xlabel("Pixel Intensity")
plt.ylabel("Frequency")
plt.show();

In [None]:
# Separate color channels
red_channel = batch_of_images[:, :, :, 0]
green_channel = batch_of_images[:, :, :, 1]
blue_channel = batch_of_images[:, :, :, 2]

# Plot histograms for each channel
plt.figure(figsize=(15, 5))

plt.subplot(131)
plt.hist(red_channel.ravel(), bins=50, color='red', alpha=0.7)
plt.title("Red Channel")

plt.subplot(132)
plt.hist(green_channel.ravel(), bins=50, color='green', alpha=0.7)
plt.title("Green Channel")

plt.subplot(133)
plt.hist(blue_channel.ravel(), bins=50, color='blue', alpha=0.7)
plt.title("Blue Channel")

plt.tight_layout()
plt.show();

In [None]:
# Image channel correlations
r_g_corr = pearsonr(red_channel.ravel(), green_channel.ravel())[0]
r_b_corr = pearsonr(red_channel.ravel(), blue_channel.ravel())[0]
g_b_corr = pearsonr(green_channel.ravel(), blue_channel.ravel())[0]

print(f"Correlation between Red and Green channels: {r_g_corr:.4f}")
print(f"Correlation between Red and Blue channels: {r_b_corr:.4f}")
print(f"Correlation between Green and Blue channels: {g_b_corr:.4f}")

In [None]:
# Calculate brightness 
brightness = np.mean(batch_of_images, axis=3)

plt.figure(figsize=(10, 6))
plt.hist(brightness.ravel(), bins=50)
plt.title("Histogram of Image Brightness")
plt.xlabel("Brightness")
plt.ylabel("Frequency")
plt.show();

In [None]:
# Calculate contrast
contrast = np.std(batch_of_images, axis=3)

plt.figure(figsize=(10, 6))
plt.hist(contrast.ravel(), bins=50)
plt.title("Histogram of Image Contrast")
plt.xlabel("Contrast")
plt.ylabel("Frequency")
plt.show();

# Model Building 

The model will focus on predicting diabetic retinopathy. Diabetic retinopathy is a common complication of diabetes and a leading cause of blindness in adults. Early detection and treatment are crucial for preventing vision loss. The model will be trained on the BRSET dataset, which contains retinal fundus images labeled with various ophthalmological conditions, including diabetic retinopathy.

## Data Pipeine

In [None]:
CLASSIFIER = 'diabetic_retinopathy'

In [None]:
labels[CLASSIFIER].value_counts()

We will sample randomly the same number of images as there are positive cases of diabetic retinopathy. This will help us balance the dataset. We will only sample images that have a diagnosis of diabetic retinopathy and images that have no other diagnosis for the other conditions in this dataset. 

In [None]:
# Get all rows where drusens == 1
classifier_positive = labels[labels[CLASSIFIER] == 1]

# Get the count of positive cases
condition_positive = len(classifier_positive)

# Create a mask for normal images (all 0 in the binary diagnosis columns)
normal_mask = (
    (labels['diabetic_retinopathy'] == 0) &
    (labels['macular_edema'] == 0) &
    (labels['scar'] == 0) &
    (labels['nevus'] == 0) &
    (labels['amd'] == 0) &
    (labels['vascular_occlusion'] == 0) &
    (labels['hypertensive_retinopathy'] == 0) &
    (labels['drusens'] == 0) &
    (labels['hemorrhage'] == 0) &
    (labels['retinal_detachment'] == 0) &
    (labels['myopic_fundus'] == 0) &
    (labels['increased_cup_disc'] == 0) &
    (labels['other'] == 0)
)

# Get all normal images
normal_images = labels[normal_mask]

# Randomly sample the same number of normal images as drusens positive cases
sampled_normal = normal_images.sample(n=condition_positive, random_state=12)

# Combine drusens positive cases and sampled normal images
subset_df = pd.concat([classifier_positive, sampled_normal])

# Shuffle the combined dataframe
subset_df = subset_df.sample(frac=1, random_state=12).reset_index(drop=True)

print(f"Total number of normal (no diagnosis) images {len(normal_images)}")
print(f"Sampled normal: {len(sampled_normal)}")
print(f"{CLASSIFIER} positive: {len(classifier_positive)}")
print(f"Total samples in combined subset: {len(subset_df)}")

In [None]:
subset_df.head()

In [None]:
subset_df = subset_df.drop(columns = ['camera', 'nationality', 'other', 'quality', 'patient_id'], errors = 'ignore')

In [None]:
subset_df.columns

In [None]:
subset_df[CLASSIFIER].value_counts()

In [None]:
subset_df[CLASSIFIER].value_counts(normalize=True)

In [None]:
binary_diagnoses = [CLASSIFIER]

IMAGE_SIZE = (224, 224)

def parse_image(filename, labels):
    image = tf.io.read_file(filename)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, IMAGE_SIZE)  # Resize to 224x224
    image = tf.cast(image, tf.float32) / 255.0  # Normalize 0-1
    return image, labels

def augment_image(image, labels):
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_brightness(image, max_delta=0.1)
    image = tf.image.random_contrast(image, lower=0.9, upper=1.1)
    return image, labels

def create_dataset(labels_df, batch_size=32, shuffle=True, augment=False):
    
    filenames = labels_df['image_id'].apply(lambda x: os.path.join(IMAGE_PATH, x + '.jpg')).tolist()
    
    # Get labels
    labels = labels_df[binary_diagnoses].values.astype(np.float32).tolist()

    dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
    
    # Parse images and labels
    dataset = dataset.map(parse_image, num_parallel_calls=tf.data.AUTOTUNE)
    
    if augment:
        dataset = dataset.map(augment_image, num_parallel_calls=tf.data.AUTOTUNE)
    
    if shuffle:
        dataset = dataset.shuffle(buffer_size=1000)
    
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
    
    return dataset

In [None]:
train_df, val_df = train_test_split(subset_df, test_size=0.2, random_state=12)

BATCH_SIZE = 16

# Create datasets
train_dataset = create_dataset(train_df, batch_size=BATCH_SIZE, shuffle=True, augment=True)
val_dataset = create_dataset(val_df, batch_size=BATCH_SIZE, shuffle=False, augment=False)

# Inspect the number of batches in the training and validation datasets
print(f"\nNumber of batches in training dataset: {tf.data.experimental.cardinality(train_dataset)}")
print(f"Number of batches in validation dataset: {tf.data.experimental.cardinality(val_dataset)}")

# Inspect the first batch of the training dataset
for images, labels_batch in train_dataset.take(1):
    print(f"\nShape of the image batch: {images.shape}")
    print(f"Shape of the labels batch: {labels_batch.shape}")
    print(f"Sample labels from the first image: {labels_batch[0]}")

In [None]:
# Inspect the first batch of the training dataset
for images, labels_batch in train_dataset.take(1):
    print(f"\nShape of the image batch: {images.shape}")
    print(f"Shape of the labels batch: {labels_batch.shape}")
    print("\nLabels for each image in the batch:")
    for i, labels in enumerate(labels_batch):
        print(f"Image {i+1}: {labels.numpy()}")

# Print unique label combinations
print("\nUnique label combinations in the batch:")
unique_labels = np.unique(labels_batch.numpy(), axis=0)
for label_combo in unique_labels:
    print(label_combo)

# Count of each label
print("\nCount of each label in the batch:")
label_counts = np.sum(labels_batch.numpy(), axis=0)
for i, count in enumerate(label_counts):
    print(f"{binary_diagnoses[i]}: {count}")

## Image classifier

First we will start with a simple classifier, then the second classifier will be a vgg16 model.

**Simple Classifier**

In [None]:
classification_model = models.Sequential([
    layers.Conv2D(32, 3, activation='relu', input_shape=(224, 224, 3)),
    layers.MaxPooling2D(),
    layers.Conv2D(64, 3, activation='relu'),
    layers.MaxPooling2D(),
    layers.Conv2D(128, 3, activation='relu'),
    layers.MaxPooling2D(),
    layers.Conv2D(128, 3, activation='relu'),
    layers.GlobalAveragePooling2D(),
    layers.Dense(128, activation='relu'),
    layers.Dropout(0.5),
    layers.Dense(64, activation='relu'),
    layers.Dense(1, activation='sigmoid')
])

# Optimizer for model 
optimizer = Adam(learning_rate=0.0001)
classification_model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])


# Define callbacks

# Learning rate scheduler
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.0000001)

# Early stopping
early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=15,
    restore_best_weights=True,
    verbose=1
)

# Combine all callbacks
callbacks = [
    reduce_lr,
    
]

# Train the model with all callbacks
history = classification_model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=150,
    callbacks=callbacks
)

classification_model.save(f'{BASE_DIR}classification_model.keras')

In [None]:
# Plot training & validation accuracy values
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')

# Plot training & validation loss values
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')

plt.tight_layout()
plt.show();

**VGG16 Model**

In [None]:
# Create a vgg16 model with sigmoid activation for binary classification
def create_vgg16(input_shape=(224, 224, 3), num_classes=1):
    model = models.Sequential([
        # Block 1
        layers.Conv2D(64, (3, 3), activation='relu', padding='same', input_shape=input_shape),
        layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
        layers.MaxPooling2D((2, 2), strides=(2, 2)),
        
        # Block 2
        layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
        layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
        layers.MaxPooling2D((2, 2), strides=(2, 2)),
        
        # Block 3
        layers.Conv2D(256, (3, 3), activation='relu', padding='same'),
        layers.Conv2D(256, (3, 3), activation='relu', padding='same'),
        layers.Conv2D(256, (3, 3), activation='relu', padding='same'),
        layers.MaxPooling2D((2, 2), strides=(2, 2)),
        
        # Block 4
        layers.Conv2D(512, (3, 3), activation='relu', padding='same'),
        layers.Conv2D(512, (3, 3), activation='relu', padding='same'),
        layers.Conv2D(512, (3, 3), activation='relu', padding='same'),
        layers.MaxPooling2D((2, 2), strides=(2, 2)),
        
        # Block 5
        layers.Conv2D(512, (3, 3), activation='relu', padding='same'),
        layers.Conv2D(512, (3, 3), activation='relu', padding='same'),
        layers.Conv2D(512, (3, 3), activation='relu', padding='same'),
        layers.MaxPooling2D((2, 2), strides=(2, 2)),
        
        # Classification block
        layers.Flatten(),
        layers.Dense(4096, activation='relu'),
        layers.Dropout(0.5),
        layers.Dense(4096, activation='relu'),
        layers.Dropout(0.5),
        layers.Dense(num_classes, activation='sigmoid' if num_classes == 1 else 'softmax')
    ])
    
    return model

# Create the model
vgg_model = create_vgg16()

# Compile the model
optimizer = Adam(learning_rate=0.0001)
vgg_model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])

# Define callbacks
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.0000001)
early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=15,
    restore_best_weights=True,
    verbose=1
)

callbacks = [reduce_lr]

# Train the model
history = vgg_model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=150,
    callbacks=callbacks
)

# Update saved weights
vgg_model.save(f'{BASE_DIR}vgg_model.keras')

In [None]:
# Plot training & validation accuracy values
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')

# Plot training & validation loss values
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')

plt.tight_layout()
plt.show();

## StylexGenerator 

In [None]:
class StylexGenerator(tf.keras.Model):
    def __init__(self, latent_dim, img_shape, **kwargs):
        super(StylexGenerator, self).__init__(**kwargs)
        self.latent_dim = latent_dim
        self.img_shape = img_shape
        self.model = models.Sequential([
            layers.Dense(256, input_dim=latent_dim),
            layers.LeakyReLU(alpha=0.2),
            layers.BatchNormalization(momentum=0.8),
            layers.Dense(512),
            layers.LeakyReLU(alpha=0.2),
            layers.BatchNormalization(momentum=0.8),
            layers.Dense(1024),
            layers.LeakyReLU(alpha=0.2),
            layers.BatchNormalization(momentum=0.8),
            layers.Dense(int(tf.math.reduce_prod(img_shape)), activation='tanh'), 
            layers.Reshape(img_shape)
        ])

    def call(self, z):
        return self.model(z)
    
    def get_config(self):
        config = super(StylexGenerator, self).get_config()
        config.update({
            "latent_dim": self.latent_dim,
            "img_shape": self.img_shape
        })
        return config
    
    @classmethod
    def from_config(cls, config):
        return cls(**config)

class StylexDiscriminator(tf.keras.Model):
    def __init__(self, img_shape, **kwargs):
        super(StylexDiscriminator, self).__init__(**kwargs)
        self.img_shape = img_shape
        self.model = models.Sequential([
            layers.Flatten(input_shape=img_shape),
            layers.Dense(512),
            layers.LeakyReLU(alpha=0.2),
            layers.Dense(256),
            layers.LeakyReLU(alpha=0.2),
            layers.Dense(1, activation='sigmoid')
        ])

    def call(self, img):
        return self.model(img)
    
    def get_config(self):
        config = super(StylexDiscriminator, self).get_config()
        config.update({
            "img_shape": self.img_shape
        })
        return config
    
    @classmethod
    def from_config(cls, config):
        return cls(**config)

# Define loss functions
def generator_loss(fake_output):
    return tf.keras.losses.binary_crossentropy(tf.ones_like(fake_output), fake_output)

def discriminator_loss(real_output, fake_output):
    real_loss = tf.keras.losses.binary_crossentropy(tf.ones_like(real_output), real_output)
    fake_loss = tf.keras.losses.binary_crossentropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def classifier_loss(labels, cls_outputs):
    return tf.keras.losses.binary_crossentropy(labels, cls_outputs)

# Define the train_step function
@tf.function
def train_step(real_images, labels, generator, discriminator, classifier, gen_optimizer, disc_optimizer, threshold=0.5):
    batch_size = tf.shape(real_images)[0]
    noise = tf.random.normal([batch_size, latent_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

        real_output = discriminator(real_images, training=True)
        fake_output = discriminator(generated_images, training=True)

        gen_loss = tf.reduce_mean(generator_loss(fake_output))
        disc_loss = tf.reduce_mean(discriminator_loss(real_output, fake_output))

        # Classifier guidance loss
        cls_outputs = classifier(generated_images, training=False)
        cls_outputs_binary = tf.cast(cls_outputs > threshold, tf.float32)
        c_loss = tf.reduce_mean(classifier_loss(labels, cls_outputs_binary))

        gen_total_loss = gen_loss + c_loss

    gradients_of_generator = gen_tape.gradient(gen_total_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    gen_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    disc_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    return gen_loss, disc_loss, c_loss

# Setup model training
img_shape = (224, 224, 3)  
latent_dim = 100

generator = StylexGenerator(latent_dim, img_shape)
discriminator = StylexDiscriminator(img_shape)

gen_optimizer = tf.keras.optimizers.Adam(1e-4)
disc_optimizer = tf.keras.optimizers.Adam(1e-4)


user_model = tf.keras.models.load_model(f'{BASE_DIR}classification_model.keras')
dataset = train_dataset

def generate_and_save_images(generator, epoch, num_examples=16):
    # Generate noise for the input
    noise = tf.random.normal([num_examples, generator.latent_dim])
    
    # Generate images
    generated_images = generator(noise, training=False)

    # Rescale images to [0, 1] 
    generated_images = (generated_images + 1) / 2.0 if generated_images.numpy().min() < 0 else generated_images
    
    # Plot the generated images
    fig = plt.figure(figsize=(4, 4))

    for i in range(num_examples):
        plt.subplot(4, 4, i+1)
        plt.imshow(generated_images[i])
        plt.axis('off')

    plt.tight_layout()
    plt.savefig(f'{BASE_DIR}generated_images_epoch_{epoch}.png')
    plt.close(fig)

    print(f"Images saved for epoch {epoch}")

epochs = 200
checkpoint_interval = 50

for epoch in range(epochs):
    gen_losses = []
    disc_losses = []
    cls_losses = []

    
    for image_batch, label_batch in tqdm(dataset, desc=f'Epoch {epoch + 1}/{epochs}', unit='batch'):
        gen_loss, disc_loss, c_loss = train_step(image_batch, label_batch, generator, discriminator, user_model, gen_optimizer, disc_optimizer)
        gen_losses.append(gen_loss.numpy())
        disc_losses.append(disc_loss.numpy())
        cls_losses.append(c_loss.numpy())
    
    print(f'Epoch {epoch + 1}, Gen Loss: {np.mean(gen_losses):.4f}, Disc Loss: {np.mean(disc_losses):.4f}, Classifier Loss: {np.mean(cls_losses):.4f}')

    # Save checkpoint models
    if (epoch + 1) % checkpoint_interval == 0:
        generator.save(f'{BASE_DIR}stylex_generator_epoch_{epoch+1}.keras')
        discriminator.save(f'{BASE_DIR}stylex_discriminator_epoch_{epoch+1}.keras')

    # Generate and save sample images
    if (epoch + 1) % 10 == 0:
        generate_and_save_images(generator, epoch + 1)

# Save the final models
generator.save(f'{BASE_DIR}stylex_generator.keras')
discriminator.save(f'{BASE_DIR}stylex_discriminator.keras')

In [None]:
stylegan2_train, stylegan2_val = train_test_split(labels, test_size=0.2, random_state=12)

BATCH_SIZE = 16

# Create datasets
stylegan2_train_dataset = create_dataset(stylegan2_train, batch_size=BATCH_SIZE, shuffle=True, augment=True)
stylegan2_val_dataset = create_dataset(stylegan2_val, batch_size=BATCH_SIZE, shuffle=False, augment=False)

# Inspect the number of batches in the training and validation datasets
print(f"\nNumber of batches in training dataset: {tf.data.experimental.cardinality(stylegan2_train_dataset)}")
print(f"Number of batches in validation dataset: {tf.data.experimental.cardinality(stylegan2_val_dataset)}")

# Inspect the first batch of the training dataset
for images, labels_batch in stylegan2_train_dataset.take(1):
    print(f"\nShape of the image batch: {images.shape}")
    print(f"Shape of the labels batch: {labels_batch.shape}")
    print(f"Sample labels from the first image: {labels_batch[0]}")

train the StyleGAN2 model on entre BRSET dataset, not just the diabetic retinopathy images. This will aid in generating synthetic images that are representative of the entire dataset.

## StyleGAN2 Implementation

In [None]:
# # StyleGAN2 Generator and Discriminator
class AdaIN(layers.Layer):
    def __init__(self, **kwargs):
        super(AdaIN, self).__init__(**kwargs)

    def build(self, input_shape):
        content_shape, style_shape = input_shape
        self.channels = content_shape[-1]
        self.style_scale = self.add_weight(name="style_scale", shape=(style_shape[-1], self.channels), initializer="random_normal")
        self.style_bias = self.add_weight(name="style_bias", shape=(style_shape[-1], self.channels), initializer="random_normal")
        super(AdaIN, self).build(input_shape)

    def call(self, inputs):
        content, style = inputs
        mean, var = tf.nn.moments(content, axes=[1, 2], keepdims=True)
        normalized = (content - mean) / tf.sqrt(var + 1e-8)
        
        style = tf.expand_dims(style, axis=1)
        style = tf.expand_dims(style, axis=1)
        
        scale = tf.matmul(style, self.style_scale)
        bias = tf.matmul(style, self.style_bias)
        
        return scale * normalized + bias

    def get_config(self):
        config = super().get_config()
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)

class StyleBlock(layers.Layer):
    def __init__(self, filters, kernel_size, **kwargs):
        super(StyleBlock, self).__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.conv = layers.Conv2D(filters, kernel_size, padding="same", use_bias=False)
        self.adain = AdaIN()
        self.activation = layers.LeakyReLU(0.2)

    def call(self, inputs):
        x, w = inputs
        x = self.conv(x)
        x = self.adain([x, w])
        return self.activation(x)

    def get_config(self):
        config = super().get_config()
        config.update({
            "filters": self.filters,
            "kernel_size": self.kernel_size
        })
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)

class MappingNetwork(keras.Model):
    def __init__(self, latent_dim, n_layers=8, **kwargs):
        super(MappingNetwork, self).__init__(**kwargs)
        self.latent_dim = latent_dim
        self.n_layers = n_layers
        self.layers_list = []
        for _ in range(n_layers):
            self.layers_list.append(layers.Dense(latent_dim, activation='leaky_relu'))
        self.layers_list.append(layers.Dense(latent_dim))

    def call(self, inputs):
        x = inputs
        for layer in self.layers_list:
            x = layer(x)
        return x

    def get_config(self):
        config = super().get_config()
        config.update({
            "latent_dim": self.latent_dim,
            "n_layers": self.n_layers
        })
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)

class StyleGAN2Generator(keras.Model):
    def __init__(self, latent_dim, **kwargs):
        super(StyleGAN2Generator, self).__init__(**kwargs)
        self.latent_dim = latent_dim
        self.mapping = MappingNetwork(latent_dim)
        
        self.input_dense = layers.Dense(7 * 7 * 512)
        
        self.conv_blocks = [
            StyleBlock(512, 3),
            StyleBlock(256, 3),
            StyleBlock(128, 3),
            StyleBlock(64, 3),
            StyleBlock(32, 3),
        ]
        
        self.to_rgb = layers.Conv2D(3, 1, padding="same", activation="tanh")

    def call(self, inputs):
        w = self.mapping(inputs)
        
        x = self.input_dense(w)
        x = layers.Reshape((7, 7, 512))(x)
        
        for block in self.conv_blocks:
            x = block([x, w])
            x = layers.UpSampling2D()(x)
        
        return self.to_rgb(x)

    def get_config(self):
        config = super().get_config()
        config.update({
            "latent_dim": self.latent_dim,
        })
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)

class StyleGAN2Discriminator(keras.Model):
    def __init__(self, **kwargs):
        super(StyleGAN2Discriminator, self).__init__(**kwargs)
        self.conv_blocks = [
            layers.Conv2D(64, 3, strides=2, padding="same"),
            layers.Conv2D(128, 3, strides=2, padding="same"),
            layers.Conv2D(256, 3, strides=2, padding="same"),
            layers.Conv2D(512, 3, strides=2, padding="same"),
            layers.Conv2D(512, 3, strides=2, padding="same"),
        ]
        self.flatten = layers.Flatten()
        self.dense1 = layers.Dense(512, activation='leaky_relu')
        self.dense2 = layers.Dense(1)

    def call(self, inputs):
        if isinstance(inputs, tuple):
            x = inputs[0]  # Take only the image, ignore the label
        else:
            x = inputs
        
        for block in self.conv_blocks:
            x = block(x)
            x = layers.LeakyReLU(0.2)(x)
        x = self.flatten(x)
        x = self.dense1(x)
        return self.dense2(x)

    def get_config(self):
        config = super().get_config()
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)

# Loss functions
def generator_loss(fake_output):
    return tf.keras.losses.binary_crossentropy(tf.ones_like(fake_output), fake_output)

def discriminator_loss(real_output, fake_output):
    real_loss = tf.keras.losses.binary_crossentropy(tf.ones_like(real_output), real_output)
    fake_loss = tf.keras.losses.binary_crossentropy(tf.zeros_like(fake_output), fake_output)
    return real_loss + fake_loss

# Training step
@tf.function
def train_step(real_images, generator, discriminator, gen_optimizer, disc_optimizer, batch_size, latent_dim):
    if isinstance(real_images, tuple):
        real_images = real_images[0]  # Take only the images, ignore the labels
    
    noise = tf.random.normal([batch_size, latent_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

        real_output = discriminator(real_images, training=True)
        fake_output = discriminator(generated_images, training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    gen_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    disc_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    return tf.reduce_mean(gen_loss), tf.reduce_mean(disc_loss)

# Setup and training
latent_dim = 100
batch_size = 16
img_shape = (224, 224, 3)
checkpoint_interval = 50

generator = StyleGAN2Generator(latent_dim)
discriminator = StyleGAN2Discriminator()

# Recompile the models
gen_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.0, beta_2=0.99, epsilon=1e-8)
disc_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.0, beta_2=0.99, epsilon=1e-8)


epochs = 500
for epoch in range(epochs):
    total_gen_loss = 0.0
    total_disc_loss = 0.0
    num_batches = 0
    
    progress_bar = tqdm(stylegan2_train, desc=f'Epoch {epoch + 1}/{epochs}')
    
    for batch in progress_bar:
        gen_loss, disc_loss = train_step(batch, generator, discriminator, gen_optimizer, disc_optimizer, batch_size, latent_dim)
        total_gen_loss += gen_loss.numpy()
        total_disc_loss += disc_loss.numpy()
        num_batches += 1
        
        # Update progress bar description with current losses
        progress_bar.set_postfix({
            'Gen Loss': f'{gen_loss.numpy():.4f}',
            'Disc Loss': f'{disc_loss.numpy():.4f}'
        })
    
    avg_gen_loss = total_gen_loss / num_batches
    avg_disc_loss = total_disc_loss / num_batches
    
    print(f'\nEpoch {epoch + 1}, Avg Gen Loss: {avg_gen_loss:.4f}, Avg Disc Loss: {avg_disc_loss:.4f}')

    # Save checkpoint models
    if (epoch + 1) % checkpoint_interval == 0:
        generator.save(f'{BASE_DIR}stylegan2_generator_epoch_{epoch+1}.keras')
        discriminator.save(f'{BASE_DIR}stylegan2_discriminator_epoch_{epoch+1}.keras')

    if (epoch + 1) % 10 == 0:
        # Generate and save sample images
        noise = tf.random.normal([1, latent_dim])
        generated_images = generator(noise, training=False)
        # Save the generated image
        plt.imshow(generated_images[0] * 0.5 + 0.5)  # Rescale from [-1, 1] to [0, 1]
        plt.axis('off')
        plt.savefig(f'{BASE_DIR}generated_image_epoch_{epoch+1}.png')
        plt.close()
        

generator.save(f'{BASE_DIR}stylegan2_generator.keras')
discriminator.save(f'{BASE_DIR}stylegan2_discriminator.keras')

## Extract Features

In [None]:
def extract_attributes(generator, classifier, num_samples, latent_dim, num_attributes, batch_size=100):
    attributes = []
    for i in range(latent_dim):
        noise = tf.random.normal([num_samples, latent_dim])
        base_imgs = generator(noise)
        base_preds = classifier(base_imgs)

        pred_diffs = []
        for start in range(0, num_samples, batch_size):
            end = min(start + batch_size, num_samples)
            noise_mod = noise[start:end].numpy()
            noise_mod[:, i] += 0.1  # Small perturbation
            mod_imgs = generator(noise_mod)
            mod_preds = classifier(mod_imgs)

            pred_diff = tf.reduce_mean(tf.abs(mod_preds - base_preds[start:end]))
            pred_diffs.append(pred_diff.numpy())

        avg_pred_diff = np.mean(pred_diffs)
        attributes.append((i, avg_pred_diff))
    
    attributes.sort(key=lambda x: x[1], reverse=True)
    return [attr[0] for attr in attributes[:num_attributes]]

def visualize_attribute(generator, attribute_idx, latent_dim):
    noise = tf.random.normal([1, latent_dim])
    base_img = generator(noise)
    
    noise_mod = noise.numpy()
    noise_mod[0, attribute_idx] += 1  # Increase attribute
    mod_img = generator(noise_mod)
    
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(base_img[0].numpy())
    plt.title("Original Image")
    plt.axis('off')
    
    plt.subplot(1, 2, 2)
    plt.imshow(mod_img[0].numpy())
    plt.title(f"Modified Image (Attribute {attribute_idx})")
    plt.axis('off')
    
    plt.show()

#### StyleGAN2 Features

In [None]:
# Load the models
custom_objects = {
    'StyleGAN2Generator': StyleGAN2Generator,
    'StyleGAN2Discriminator': StyleGAN2Discriminator,
    'AdaIN': AdaIN,
    'StyleBlock': StyleBlock,
    'MappingNetwork': MappingNetwork
}

with keras.utils.custom_object_scope(custom_objects):
    loaded_generator = keras.models.load_model(f'{BASE_DIR}stylegan2_generator.keras')
    loaded_discriminator = keras.models.load_model(f'{BASE_DIR}stylegan2_discriminator.keras')

print("Models loaded successfully!")


latent_dim = generator.latent_dim 
num_samples = 1000
num_attributes = 10
classifier = keras.models.load_model('{BASE_DIR}vgg_model.keras')
top_attributes = extract_attributes(generator, classifier, num_samples, latent_dim, num_attributes)

for attr in top_attributes:
    visualize_attribute(generator, attr, latent_dim)

#### Stylex features

In [None]:
# Extract from the Stylex model and show the top attributes
# Load the model using custom_object_scope
with custom_object_scope({'StylexGenerator': StylexGenerator, 'StylexDiscriminator': StylexDiscriminator}):
    generator = tf.keras.models.load_model(f'{BASE_DIR}stylex_generator.keras')
classifier = tf.keras.models.load_model(f'{BASE_DIR}classification_model.keras')

latent_dim = 100
num_samples = 1000
num_attributes = 10

top_attributes = extract_attributes(generator, classifier, num_samples, latent_dim, num_attributes)

for attr in top_attributes:
    visualize_attribute(generator, attr, latent_dim)

# Conclusion