# RGB Autoencoder GAN Experiment
## Exploring Color Information in Filters
### Objective:
Investigate how color information propagates into the first convolutional layer of an RGB autoencoder compared to an autoencoder trained on a single channel (Red). Use a discriminator to differentiate between filters of the two encoders and set up a GAN-like training paradigm to fool the discriminator.


#### **Phase 1: Train RGB and Red Channel Autoencoders**

1. **Define Autoencoder Architectures:**
   - Basic convolutional autoencoder with `Conv2D` and `Conv2DTranspose` layers.
   - One model trained on the RGB input (3D: 256x256x3).
   - One model trained on the Red channel (2D: 256x256x1).

2. **Data Preparation:**
   - Use the CIFAR-10 dataset or another RGB dataset.
   - Extract Red channel as a grayscale image for the Red autoencoder.

3. **Training:**
   - Train both autoencoders separately.
   - Use MAE as the loss function and Adam optimizer.

4. **Save Filters:**
   - Extract and save the filters from the first `Conv2D` layer of each trained autoencoder for later use.

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, Flatten, Dense, Reshape
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import cifar10
import numpy as np

# Load CIFAR-10 dataset
(x_train, _), (x_test, _) = cifar10.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

# Extract Red channel for Red autoencoder
x_train_red = x_train[..., 0:1]
x_test_red = x_test[..., 0:1]

# Autoencoder Model
def build_autoencoder(input_shape):
    input_layer = Input(shape=input_shape)
    x = Conv2D(32, (3, 3), activation='relu', padding='same')(input_layer)
    x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
    x = Flatten()(x)
    x = Dense(256, activation='relu')(x)
    x = Dense(np.prod(input_shape), activation='sigmoid')(x)
    x = Reshape(input_shape)(x)
    return Model(input_layer, x)

# Train RGB Autoencoder
rgb_autoencoder = build_autoencoder((32, 32, 3))
rgb_autoencoder.compile(optimizer='adam', loss='mae')
rgb_autoencoder.fit(x_train, x_train, epochs=20, batch_size=128, validation_data=(x_test, x_test))

# Train Red Channel Autoencoder
red_autoencoder = build_autoencoder((32, 32, 1))
red_autoencoder.compile(optimizer='adam', loss='mae')
red_autoencoder.fit(x_train_red, x_train_red, epochs=20, batch_size=128, validation_data=(x_test_red, x_test_red))

# Extract Filters
rgb_filters = rgb_autoencoder.layers[1].get_weights()[0]
red_filters = red_autoencoder.layers[1].get_weights()[0]

# Save Filters
np.save('rgb_filters.npy', rgb_filters)
np.save('red_filters.npy', red_filters)

#### **Phase 2: Build and Train the Discriminator**

1. **Discriminator Architecture:**
   - Input: Flattened filters from the first layer of both autoencoders.
   - Layers: A simple feedforward neural network (MLP) with a few dense layers and ReLU activations.
   - Output: Binary classification (`RGB` or `R` filter).

2. **Dataset Creation:**
   - Label the filters: RGB filters as `1` and Red filters as `0`.
   - Shuffle and split into training and validation sets.

3. **Training:**
   - Train the discriminator to classify filters correctly.
   - Use binary cross-entropy loss and an Adam optimizer.

In [None]:
# Load Filters
rgb_filters = np.load('rgb_filters.npy')
red_filters = np.load('red_filters.npy')

# Prepare Dataset
filters = np.concatenate([rgb_filters.reshape(-1, 32 * 32), red_filters.reshape(-1, 32 * 32)])
labels = np.array([1] * len(rgb_filters) + [0] * len(red_filters))
indices = np.arange(len(labels))
np.random.shuffle(indices)
filters = filters[indices]
labels = labels[indices]

# Split Dataset
split_idx = int(0.8 * len(labels))
train_filters, val_filters = filters[:split_idx], filters[split_idx:]
train_labels, val_labels = labels[:split_idx], labels[split_idx:]

# Build Discriminator
def build_discriminator(input_shape):
    input_layer = Input(shape=input_shape)
    x = Dense(64, activation='relu')(input_layer)
    x = Dense(32, activation='relu')(x)
    x = Dense(1, activation='sigmoid')(x)
    return Model(input_layer, x)

discriminator = build_discriminator((32 * 32,))
discriminator.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
discriminator.fit(train_filters, train_labels, epochs=20, batch_size=64, validation_data=(val_filters, val_labels))

#### **Phase 3: Set Up GAN Paradigm**

1. **Generator Setup:**
   - The generator will use the RGB autoencoder and adjust its weights iteratively to fool the discriminator.

2. **Training Process:**
   - **Step 1:** Freeze the discriminator.
   - **Step 2:** Fine-tune the RGB autoencoder’s first convolutional layer to minimize the discriminator's ability to differentiate filters.
   - **Step 3:** Alternate training between the discriminator and generator.

3. **Loss Functions:**
   - Generator loss: Binary cross-entropy from the discriminator’s predictions.
   - Discriminator loss: Binary cross-entropy for distinguishing between RGB and Red filters.

In [None]:
# GAN Losses and Training Loop

# Freeze Discriminator for GAN training
discriminator.trainable = False

# Combine Generator and Discriminator
gan_input = rgb_autoencoder.input
gan_output = discriminator(rgb_autoencoder.layers[1].output)
gan = Model(gan_input, gan_output)
gan.compile(optimizer="adam", loss="binary_crossentropy")

# GAN Training Loop
for epoch in range(20):
    # Train Discriminator
    discriminator.trainable = True
    rgb_preds = discriminator.predict(rgb_filters.reshape(-1, 32 * 32))
    red_preds = discriminator.predict(red_filters.reshape(-1, 32 * 32))
    d_loss = discriminator.train_on_batch(filters, labels)

    # Train Generator
    discriminator.trainable = False
    g_loss = gan.train_on_batch(x_train, np.ones(len(x_train)))

    print(f"Epoch {epoch+1}, D Loss: {d_loss}, G Loss: {g_loss}")

### Next Steps:
- Visualize filter differences using heatmaps or PCA.
- Analyze discriminator accuracy and generator convergence.
- Document findings and adjust hyperparameters if needed.