# Adversarial Autoencoder


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import pandas as pd


In [2]:
# random seed for reproducibility
torch.manual_seed(42)

<torch._C.Generator at 0x77fd5a6acf10>

In [None]:
# Encoder network
class Encoder(nn.Module):
    def __init__(self, input_dim=768, latent_dim=256):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, 512)
        self.bn1 = nn.BatchNorm1d(512)

        self.fc2 = nn.Linear(512, 384)
        self.bn2 = nn.BatchNorm1d(384)

        self.fc3 = nn.Linear(384, latent_dim)

    def forward(self, x):
        h = F.leaky_relu(self.bn1(self.fc1(x)), 0.2)
        h = F.leaky_relu(self.bn2(self.fc2(h)), 0.2)
        latent = self.fc3(h)
        return latent

In [None]:
# Decoder network
class Decoder(nn.Module):
    def __init__(self, latent_dim=256, output_dim=768):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(latent_dim, 384)
        self.bn1 = nn.BatchNorm1d(384)

        self.fc2 = nn.Linear(384, 512)
        self.bn2 = nn.BatchNorm1d(512)

        self.fc3 = nn.Linear(512, output_dim)

    def forward(self, z):
        h = F.leaky_relu(self.bn1(self.fc1(z)), 0.2)
        h = F.leaky_relu(self.bn2(self.fc2(h)), 0.2)
        reconstructed = torch.sigmoid(self.fc3(h))
        return reconstructed

In [None]:
class Discriminator(nn.Module):
    def __init__(self, latent_dim=256):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(latent_dim, 256)
        self.bn1 = nn.BatchNorm1d(256)

        self.fc2 = nn.Linear(256, 128)
        self.bn2 = nn.BatchNorm1d(128)

        self.fc3 = nn.Linear(128, 1)

    def forward(self, z):
        h = F.leaky_relu(self.bn1(self.fc1(z)), 0.2)
        h = F.leaky_relu(self.bn2(self.fc2(h)), 0.2)
        logits = self.fc3(h)
        return logits

In [None]:
# Define the full Adversarial Autoencoder
class AdversarialAutoencoder:
    def __init__(self, input_dim=768, latent_dim=256, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.device = device

        # Initialize networks
        self.encoder = Encoder(input_dim, latent_dim).to(device)
        self.decoder = Decoder(latent_dim, input_dim).to(device)
        self.discriminator = Discriminator(latent_dim).to(device)

        # Initialize optimizers
        self.encoder_optimizer = optim.Adam(self.encoder.parameters(), lr=0.001)
        self.decoder_optimizer = optim.Adam(self.decoder.parameters(), lr=0.001)
        self.discriminator_optimizer = optim.Adam(self.discriminator.parameters(), lr=0.0001)

        # Loss functions
        self.reconstruction_loss = nn.MSELoss()
        self.adversarial_loss = nn.BCEWithLogitsLoss()

    def train_step(self, x_batch):
        batch_size = x_batch.size(0)
        x_batch = x_batch.to(self.device)

        # Target tensors
        real_target = torch.ones(batch_size, 1).to(self.device)
        l = torch.zeros(batch_size, 1).to(self.device)

        # Train Autoencoder
        self.encoder_optimizer.zero_grad()
        self.decoder_optimizer.zero_grad()

        # Encode and decode the input
        z = self.encoder(x_batch)
        x_reconstructed = self.decoder(z)

        # Compute reconstruction loss
        recon_loss = self.reconstruction_loss(x_reconstructed, x_batch)

        # Compute adversarial loss for the generator (encoder)
        gen_loss = self.adversarial_loss(self.discriminator(z), real_target)

        # Total autoencoder loss
        ae_loss = recon_loss + gen_loss

        # Backpropagate and update parameters
        ae_loss.backward()
        self.encoder_optimizer.step()
        self.decoder_optimizer.step()

        # Train Discriminator
        self.discriminator_optimizer.zero_grad()

        # Generate latent vectors from the prior distribution (Gaussian in this case)
        z_prior = torch.randn(batch_size, self.latent_dim).to(self.device)

        # Get encoded samples
        z_encoded = self.encoder(x_batch).detach()  # Detach to avoid training the encoder again

        # Compute discriminator loss
        real_loss = self.adversarial_loss(self.discriminator(z_prior), real_target)
        fake_loss = self.adversarial_loss(self.discriminator(z_encoded), fake_target)
        d_loss = (real_loss + fake_loss) / 2

        # Backpropagate and update parameters
        d_loss.backward()
        self.discriminator_optimizer.step()

        return {
            'reconstruction_loss': recon_loss.item(),
            'generator_loss': gen_loss.item(),
            'discriminator_loss': d_loss.item()
        }

    def train(self, data_loader, epochs=100):
        self.encoder.train()
        self.decoder.train()
        self.discriminator.train()

        training_history = []

        for epoch in range(epochs):
            epoch_losses = {'reconstruction_loss': 0, 'generator_loss': 0, 'discriminator_loss': 0}
            batch_count = 0

            for batch_idx, (x_batch, _) in enumerate(data_loader):
                step_losses = self.train_step(x_batch)

                for key in epoch_losses:
                    epoch_losses[key] += step_losses[key]
                batch_count += 1

            # Calculate average losses for the epoch
            for key in epoch_losses:
                epoch_losses[key] /= batch_count

            training_history.append(epoch_losses)

            print(f"Epoch [{epoch+1}/{epochs}] - "
                  f"Recon Loss: {epoch_losses['reconstruction_loss']:.4f}, "
                  f"Gen Loss: {epoch_losses['generator_loss']:.4f}, "
                  f"Disc Loss: {epoch_losses['discriminator_loss']:.4f}")

        return training_history

    def encode(self, x):
        self.encoder.eval()
        with torch.no_grad():
            x = x.to(self.device)
            z = self.encoder(x)
        return z

    def decode(self, z):
        self.decoder.eval()
        with torch.no_grad():
            z = z.to(self.device)
            x_reconstructed = self.decoder(z)
        return x_reconstructed

    def reconstruct(self, x):
        self.encoder.eval()
        self.decoder.eval()
        with torch.no_grad():
            x = x.to(self.device)
            z = self.encoder(x)
            x_reconstructed = self.decoder(z)
        return x_reconstructed

    def save_model(self, path):
        torch.save({
            'encoder_state_dict': self.encoder.state_dict(),
            'decoder_state_dict': self.decoder.state_dict(),
            'discriminator_state_dict': self.discriminator.state_dict()
        }, path)

    def load_model(self, path):
        checkpoint = torch.load(path)
        self.encoder.load_state_dict(checkpoint['encoder_state_dict'])
        self.decoder.load_state_dict(checkpoint['decoder_state_dict'])
        self.discriminator.load_state_dict(checkpoint['discriminator_state_dict'])


In [None]:
movies_df = pd.read_parquet(os.path.join(root, 'movies_embeddings.parquet'), engine='pyarrow')
songs_df = pd.read_parquet(os.path.join(root, 'song_embeddings.parquet'), engine='pyarrow')


In [None]:
movies_df.head()

In [None]:
songs_df.head()

In [None]:
#'label' column to movies_df with value 1
movies_df['label'] = 1

# 'label' column to songs_df with value 0
songs_df['label'] = 0

# Keep only the 'embedding' and 'label' columns
movies_df = movies_df[['embedding', 'label']]
songs_df = songs_df[['embedding', 'label']]


In [None]:
# Vertically stack (concatenate) the two DataFrames
combined_data = pd.concat([movies_data, songs_data], axis=0, ignore_index=True)

# Shuffle the combined DataFrame
combined_data = combined_data.sample(frac=1, random_state=42).reset_index(drop=True)

# Extract combined embeddings and labels
# Assuming that the embedding column stores a list of numbers for each row.
combined_embeddings = np.array(combined_data['embedding'].tolist())
combined_labels = np.array(combined_data['label'].tolist())

combined_embeddings.shape, combined_labels.shape

In [None]:
combined_embeddings = torch.tensor(combined_embeddings, dtype=torch.float32)
combined_labels = torch.tensor(combined_labels, dtype=torch.float32)

combined_dataset = TensorDataset(combined_embeddings, combined_labels)
combined_dataloader = DataLoader(combined_dataset, batch_size=64, shuffle=True)