Generator(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (batchnorm1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (batchnorm2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(16, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (batchnorm3): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (upsample1): Upsample(size=(64, 64), mode='bilinear')
)


Generator(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (batchnorm1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (batchnorm2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(16, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (batchnorm3): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (upsample1): Upsample(size=(64, 64), mode='bilinear')
)

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from sklearn.model_selection import RandomizedSearchCV
from torch.utils.data import DataLoader, TensorDataset

class Encoder(nn.Module):
    def __init__(self, conv1_channels, conv2_channels, conv3_channels):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(3, conv1_channels, kernel_size=3, stride=1, padding=1)
        # Input: (3, 384, 64) -> Output: (conv1_channels, 384, 64)
        # This layer applies conv1_channels filters to extract basic features

        self.pool = nn.MaxPool2d(2, 2)
        # After pool: (conv1_channels, 192, 32)
        # Max pooling reduces spatial dimensions by half, keeping strongest features

        self.conv2 = nn.Conv2d(conv1_channels, conv2_channels, kernel_size=3, stride=1, padding=1)
        # After conv2: (conv2_channels, 192, 32)
        # This layer applies conv2_channels filters to extract more complex features

        self.conv3 = nn.Conv2d(conv2_channels, conv3_channels, kernel_size=3, stride=1, padding=1)
        # After conv3 and pool: (conv3_channels, 48, 8)
        # This layer applies conv3_channels filters for even more complex features

        self.fc = nn.Linear(conv3_channels * 24 * 4, 3 * 64 * 64)
        # Flattened input: (conv3_channels * 24 * 4) -> Output: (3 * 64 * 64)
        # This fully connected layer reshapes the data to the desired output size

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))  # (conv1_channels, 192, 32)
        x = self.pool(torch.relu(self.conv2(x)))  # (conv2_channels, 96, 16)
        hidden = x  # This will be used as input to the discriminator
        x = self.pool(torch.relu(self.conv3(x)))  # (conv3_channels, 48, 8)
        x = x.view(x.size(0), -1)  # Flatten: (batch_size, conv3_channels * 48 * 8)
        x = self.fc(x)
        return x.view(-1, 3, 64, 64), hidden  # Output: (3, 64, 64)

class Decoder(nn.Module):
    def __init__(self, conv3_channels):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(3 * 64 * 64, conv3_channels * 24 * 4)
        # Input: (3 * 64 * 64) -> Output: (conv3_channels * 24 * 4)
        # This layer reshapes the encoded representation back to a 3D tensor

        self.deconv1 = nn.ConvTranspose2d(conv3_channels, 64, kernel_size=4, stride=2, padding=1)
        # Input: (conv3_channels, 24, 4) -> Output: (64, 48, 8)
        # This transposed convolution upsamples the feature maps

        self.deconv2 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)
        # Input: (64, 48, 8) -> Output: (32, 96, 16)
        # Further upsampling

        self.deconv3 = nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1)
        # Input: (32, 96, 16) -> Output: (3, 192, 32)
        # Final upsampling to get back to the original number of channels

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten: (batch_size, 3 * 64 * 64)
        x = self.fc(x)
        x = x.view(-1, 192, 24, 4)  # Reshape: (batch_size, 192, 24, 4)
        x = torch.relu(self.deconv1(x))  # (64, 48, 8)
        x = torch.relu(self.deconv2(x))  # (32, 96, 16)
        x = torch.tanh(self.deconv3(x))  # (3, 192, 32)
        return x

class Discriminator(nn.Module):
    def __init__(self, input_channels):
        super(Discriminator, self).__init__()
        self.conv = nn.Conv2d(input_channels, 64, kernel_size=3, stride=1, padding=1)
        # Input: (input_channels, 96, 16) -> Output: (64, 96, 16)
        # This layer applies 64 filters to extract features for discrimination

        self.pool = nn.MaxPool2d(2, 2)
        # After pool: (64, 48, 8)
        # Reduces spatial dimensions, focusing on most important features

        self.fc1 = nn.Linear(64 * 48 * 8, 200)
        # Input: (64 * 48 * 8) -> Output: (200)
        # Flattens and reduces dimensionality

        self.fc2 = nn.Linear(200, 50)
        # Input: (200) -> Output: (50)
        # Further reduces dimensionality

        self.fc3 = nn.Linear(50, 1)
        # Input: (50) -> Output: (1)
        # Final layer for binary classification (real vs fake)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv(x)))  # (64, 48, 8)
        x = x.view(x.size(0), -1)  # Flatten: (batch_size, 64 * 48 * 8)
        x = torch.relu(self.fc1(x))  # (200)
        x = torch.relu(self.fc2(x))  # (50)
        x = torch.sigmoid(self.fc3(x))  # (1)
        return x

class AAE(nn.Module):
    def __init__(self, conv1_channels, conv2_channels, conv3_channels):
        super(AAE, self).__init__()
        self.encoder = Encoder(conv1_channels, conv2_channels, conv3_channels)
        self.decoder = Decoder(conv3_channels)
        self.discriminator = Discriminator(conv2_channels)

    def forward(self, x):
        encoded, hidden = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded, hidden

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def train(model, dataloader, epochs, lr_ae, lr_d):
    model = model.to(device)  # Move the model to the appropriate device
    optimizer_ae = optim.Adam(list(model.encoder.parameters()) + list(model.decoder.parameters()), lr=lr_ae)
    optimizer_d = optim.Adam(model.discriminator.parameters(), lr=lr_d)
    reconstruction_loss = nn.MSELoss()
    adversarial_loss = nn.BCELoss()

    for epoch in range(epochs):
        for batch in dataloader:
            real_data = batch[0].to(device)  # Move the data to the device

            # Train autoencoder
            optimizer_ae.zero_grad()
            encoded, hidden = model.encoder(real_data)
            decoded = model.decoder(encoded)
            real_labels = torch.ones(real_data.size(0), 1).to(device)
            rec_loss = reconstruction_loss(decoded, real_data)
            g_loss = adversarial_loss(model.discriminator(hidden), real_labels)
            ae_loss = rec_loss + g_loss
            ae_loss.backward()
            optimizer_ae.step()

            # Train discriminator
            optimizer_d.zero_grad()
            _, hidden = model.encoder(real_data)
            fake_labels = torch.zeros(real_data.size(0), 1).to(device)
            d_real_loss = adversarial_loss(model.discriminator(hidden.detach()), fake_labels)
            prior = torch.randn(real_data.size(0), hidden.size(1), hidden.size(2), hidden.size(3)).to(device)
            d_fake_loss = adversarial_loss(model.discriminator(prior), real_labels)
            d_loss = (d_real_loss + d_fake_loss) / 2
            d_loss.backward()
            optimizer_d.step()

        print(f'Epoch [{epoch+1}/{epochs}], AE Loss: {ae_loss.item():.4f}, D Loss: {d_loss.item():.4f}')

    return model

In [None]:
def objective(trial):
    conv1_channels = trial.suggest_int('conv1_channels', 16, 64)
    conv2_channels = trial.suggest_int('conv2_channels', 32, 128)
    conv3_channels = trial.suggest_int('conv3_channels', 64, 256)
    lr_ae = trial.suggest_loguniform('lr_ae', 1e-5, 1e-2)
    lr_d = trial.suggest_loguniform('lr_d', 1e-5, 1e-2)
    batch_size = trial.suggest_categorical('batch_size', [32, 64, 128])

    model = AAE(conv1_channels, conv2_channels, conv3_channels).to(device)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    trained_model = train(model, dataloader, epochs=10, lr_ae=lr_ae, lr_d=lr_d)

    # Evaluate the model (you need to implement this based on your criteria)
    score = evaluate_model(trained_model, test_data)

    return score

import optuna

study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=100)

best_params = study.best_params

In [None]:
best_model = AAE(best_params['conv1_channels'], 
                 best_params['conv2_channels'], 
                 best_params['conv3_channels']).to(device)

dataloader = DataLoader(dataset, batch_size=best_params['batch_size'], shuffle=True)

final_model = train(best_model, dataloader, epochs=50, 
                    lr_ae=best_params['lr_ae'], 
                    lr_d=best_params['lr_d'])