<a href="https://colab.research.google.com/github/Yikes23/Brandix-MarketHub/blob/main/POC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [26]:
import torch
import collections
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.distributions import Normal
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split

np.random.seed(0)

### Loading and preprocessing MNIST dataset

In [2]:
def fetch_mnist():
  trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
  dataset_train = datasets.MNIST('./data/mnist/', train=True, download=True, transform=trans_mnist)
  dataset_test = datasets.MNIST('./data/mnist/', train=False, download=True, transform=trans_mnist)
  return dataset_train, dataset_test

### VAE components

In [74]:
def create_encoder(input_dim, latent_dim):
    encoder = nn.Sequential(
        nn.Linear(input_dim, 128),
        nn.ReLU(inplace=False),
        nn.Linear(128, 64),
        nn.ReLU(inplace=False)
    )
    mean = nn.Linear(64, latent_dim)
    logvar = nn.Linear(64, latent_dim)
    return encoder, mean, logvar

def create_decoder(latent_dim, output_dim):
    return nn.Sequential(
        nn.Linear(latent_dim, 64),
        nn.ReLU(inplace=False),
        nn.Linear(64, 128),
        nn.ReLU(inplace=False),
        nn.Linear(128, output_dim)
    )

# Reparameterization trick
def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std

### GAN Components

In [4]:
def create_generator(noise_dim, output_dim):
    return nn.Sequential(
        nn.Linear(noise_dim, 128),
        nn.ReLU(),
        nn.Linear(128, 256),
        nn.ReLU(),
        nn.Linear(256, output_dim),
        nn.Tanh()
    )

def create_discriminator(input_dim):
    return nn.Sequential(
        nn.Linear(input_dim, 256),
        nn.ReLU(),
        nn.Linear(256, 128),
        nn.ReLU(),
        nn.Linear(128, 1),
        nn.Sigmoid()
    )

## CNN Model

In [77]:
# # Local CNN model
# class CNN(nn.Module):
#   # Designed for MNIST dataset (grayscale images)
#   # Input channels: 1 (grayscale)
#     def __init__(self, config):
#         super(CNN, self).__init__()
#         self.conv1 = nn.Conv2d(config['num_channels'], 10, kernel_size=5)
#         self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
#         self.conv2_drop = nn.Dropout2d()
#         self.fc1 = nn.Linear(320, 50)
#         self.fc2 = nn.Linear(50, config['num_classes'])

#     def forward(self, x):
#         x = nn.functional.relu(nn.functional.max_pool2d(self.conv1(x), 2))
#         x = nn.functional.relu(nn.functional.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
#         x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3])
#         x = nn.functional.relu(self.fc1(x))
#         x = nn.functional.dropout(x, training=self.training)
#         x = self.fc2(x)
#         return nn.functional.log_softmax(x, dim=1)

In [97]:
def training_simulation(config, dataset):
    device = torch.device(config['device'])
    input_dim = output_dim = config['input_dim']
    num_epochs = int(config['num_epochs'])

    train_loader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)

    # Initialize models
    encoder_base, mean_layer, logvar_layer = create_encoder(input_dim, config['latent_dim'])
    decoder_model = create_decoder(config['latent_dim'], output_dim)
    gen_model = create_generator(config['noise_dim'], output_dim)
    disc_model = create_discriminator(input_dim)

    # Move models to device
    models = {
        'encoder': encoder_base.to(device),
        'mean': mean_layer.to(device),
        'logvar': logvar_layer.to(device),
        'decoder': decoder_model.to(device),
        'generator': gen_model.to(device),
        'discriminator': disc_model.to(device)
    }

    ensemble_weights = nn.Parameter(torch.tensor([0.5, 0.5], device=device))

    optimizers = {
        'vae': optim.Adam(
            list(models['encoder'].parameters()) +
            list(models['decoder'].parameters()) +
            list(models['mean'].parameters()) +
            list(models['logvar'].parameters()),
            lr=config['lr']
        ),
        'gen': optim.Adam(models['generator'].parameters(), lr=config['lr']),
        'disc': optim.Adam(models['discriminator'].parameters(), lr=config['lr']),
        'weights': optim.Adam([ensemble_weights], lr=config['lr'])
    }

    def vae_loss(real_data, vae_output, mu, logvar):
        recon = nn.functional.mse_loss(vae_output, real_data, reduction='sum')
        kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return (recon + kl_div) / real_data.size(0)

    def compute_outputs(real_data, noise):
        # VAE outputs
        encoded = models['encoder'](real_data)
        mu = models['mean'](encoded)
        logvar = models['logvar'](encoded)
        z = reparameterize(mu, logvar)
        vae_output = models['decoder'](z)

        # GAN output - Remove .clone() as it's unnecessary
        gan_output = models['generator'](noise)

        # Ensemble output with softmax weights
        weights = nn.functional.softmax(ensemble_weights, dim=0)

        # Create ensemble output without detaching weights
        e_output = weights[0] * vae_output + weights[1] * gan_output

        return {
            'vae_output': vae_output,
            'gan_output': gan_output,
            'ensemble_output': e_output,
            'mu': mu,
            'logvar': logvar,
            'weights': weights
        }

    def compute_disc_loss(real_data, fake_data):
        real_preds = models['discriminator'](real_data)
        # Create a new tensor for fake data predictions instead of detaching
        fake_data_pred = models['discriminator'](fake_data)

        # Compute losses using mean directly
        real_loss = torch.mean(torch.log(real_preds + 1e-8))
        fake_loss = torch.mean(torch.log(1 - fake_data_pred + 1e-8))

        return -(real_loss + fake_loss)

    def compute_gen_loss(fake_data):
        fake_preds = models['discriminator'](fake_data)
        return -torch.mean(torch.log(fake_preds + 1e-8))

    # Training loop
    for epoch in range(num_epochs):
        for batch_idx, (data, _) in enumerate(train_loader):
            real_data = data.view(data.size(0), -1).to(device)
            noise = torch.randn(real_data.size(0), config['noise_dim']).to(device)

            # Get all outputs
            outputs = compute_outputs(real_data, noise)

            # VAE update
            optimizers['vae'].zero_grad()
            vae_loss_val = vae_loss(real_data, outputs['vae_output'], outputs['mu'], outputs['logvar'])
            vae_loss_val.backward(retain_graph = True)
            optimizers['vae'].step()

            # Discriminator update
            optimizers['disc'].zero_grad()
            d_loss = compute_disc_loss(real_data, outputs['gan_output'])
            d_loss.backward(retain_graph = True)
            optimizers['disc'].step()

            # Generator update
            optimizers['gen'].zero_grad()
            g_loss = compute_gen_loss(outputs['gan_output'])
            g_loss.backward(retain_graph = True)
            optimizers['gen'].step()

            # Ensemble weights update
            optimizers['weights'].zero_grad()

            # Recompute outputs for ensemble update to avoid gradient issues
            fresh_outputs = compute_outputs(real_data, noise)
            ensemble_reconstr_loss = nn.functional.mse_loss(fresh_outputs['ensemble_output'], real_data)
            weight_reg = torch.abs(fresh_outputs['weights'].sum() - 1.0)
            e_loss = ensemble_reconstr_loss + weight_reg

            e_loss.backward()
            optimizers['weights'].step()

            if batch_idx % 100 == 0:
                print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{batch_idx}]')
                print(f'VAE Loss: {vae_loss_val.item():.4f}, '
                      f'G Loss: {g_loss.item():.4f}, '
                      f'D Loss: {d_loss.item():.4f}')
                print(f'Ensemble Loss: {e_loss.item():.4f}')
                print(f'Weights: VAE={fresh_outputs["weights"][0].item():.2f}, '
                      f'GAN={fresh_outputs["weights"][1].item():.2f}\n')

    return models, ensemble_weights

In [98]:
train_dataset, test_dataset = fetch_mnist()

config = {

    'input_dim': train_dataset.data.shape[1] * train_dataset.data.shape[2],
    'latent_dim': 20,
    'noise_dim': 100,
    'lr': 0.0001,
    'num_channels': 1,
    'num_classes': len(train_dataset.classes),
    'img_size': 28,
    'num_epochs': 10,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu'
    }
# cnn = CNN(config=config).to(config['device'])
# cnn.train()

training_simulation(config, train_dataset)


Epoch [1/10], Step [0]
VAE Loss: 827.2452, G Loss: 0.6595, D Loss: 1.3672
Ensemble Loss: 1.0462
Weights: VAE=0.50, GAN=0.50

Epoch [1/10], Step [100]
VAE Loss: 631.7201, G Loss: 0.6033, D Loss: 1.2252
Ensemble Loss: 0.7998
Weights: VAE=0.50, GAN=0.50

Epoch [1/10], Step [200]
VAE Loss: 569.8990, G Loss: 1.1731, D Loss: 0.5698
Ensemble Loss: 0.7730
Weights: VAE=0.51, GAN=0.49

Epoch [1/10], Step [300]
VAE Loss: 517.7458, G Loss: 1.7815, D Loss: 0.3402
Ensemble Loss: 0.7036
Weights: VAE=0.52, GAN=0.48

Epoch [1/10], Step [400]
VAE Loss: 513.9258, G Loss: 1.8790, D Loss: 0.3258
Ensemble Loss: 0.6804
Weights: VAE=0.52, GAN=0.48

Epoch [1/10], Step [500]
VAE Loss: 458.2980, G Loss: 1.5891, D Loss: 0.5061
Ensemble Loss: 0.6221
Weights: VAE=0.53, GAN=0.47

Epoch [1/10], Step [600]
VAE Loss: 435.8460, G Loss: 1.7354, D Loss: 0.3581
Ensemble Loss: 0.6081
Weights: VAE=0.53, GAN=0.47

Epoch [1/10], Step [700]
VAE Loss: 391.4464, G Loss: 2.9015, D Loss: 0.0864
Ensemble Loss: 0.5540
Weights: VAE=0.

KeyboardInterrupt: 