In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/aae/pytorch/default/1/best_aae_model.pt
/kaggle/input/aae/pytorch/default/1/final_aae_model.pt
/kaggle/input/moviesandsongs/song_embeddings.parquet
/kaggle/input/moviesandsongs/movies_embeddings.parquet


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from torch.utils.data import DataLoader, TensorDataset, random_split
from sklearn.metrics import mean_squared_error
from sklearn.metrics.pairwise import cosine_similarity

In [5]:
class Encoder(nn.Module):
    def __init__(self, input_dim=768, latent_dim=256):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, 512)
        self.bn1 = nn.BatchNorm1d(512)

        self.fc2 = nn.Linear(512, 384)
        self.bn2 = nn.BatchNorm1d(384)

        self.fc3 = nn.Linear(384, latent_dim)

    def forward(self, x):
        h = F.leaky_relu(self.bn1(self.fc1(x)), 0.2)
        h = F.leaky_relu(self.bn2(self.fc2(h)), 0.2)
        latent = self.fc3(h)
        return latent

In [7]:
class Decoder(nn.Module):
    def __init__(self, latent_dim=256, output_dim=768):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(latent_dim, 384)
        self.bn1 = nn.BatchNorm1d(384)

        self.fc2 = nn.Linear(384, 512)
        self.bn2 = nn.BatchNorm1d(512)

        self.fc3 = nn.Linear(512, output_dim)

    def forward(self, z):
        h = F.leaky_relu(self.bn1(self.fc1(z)), 0.2)
        h = F.leaky_relu(self.bn2(self.fc2(h)), 0.2)
        reconstructed = torch.sigmoid(self.fc3(h))
        return reconstructed

In [8]:
class Discriminator(nn.Module):
    def __init__(self, latent_dim=256):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(latent_dim, 256)
        self.bn1 = nn.BatchNorm1d(256)

        self.fc2 = nn.Linear(256, 128)
        self.bn2 = nn.BatchNorm1d(128)

        self.fc3 = nn.Linear(128, 1)

    def forward(self, z):
        h = F.leaky_relu(self.bn1(self.fc1(z)), 0.2)
        h = F.leaky_relu(self.bn2(self.fc2(h)), 0.2)
        logits = self.fc3(h)
        return logits

In [9]:
class AdversarialAutoencoder:
    def __init__(self, input_dim=768, latent_dim=256, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.device = device

        # Initialize networks
        self.encoder = Encoder(input_dim, latent_dim).to(device)
        self.decoder = Decoder(latent_dim, input_dim).to(device)
        self.discriminator = Discriminator(latent_dim).to(device)

        # Initialize optimizers
        self.encoder_optimizer = optim.Adam(self.encoder.parameters(), lr=0.001)
        self.decoder_optimizer = optim.Adam(self.decoder.parameters(), lr=0.001)
        self.discriminator_optimizer = optim.Adam(self.discriminator.parameters(), lr=0.0001)

        # Loss functions
        self.reconstruction_loss = nn.MSELoss()
        self.adversarial_loss = nn.BCEWithLogitsLoss()

    def train_step(self, x_batch):
        batch_size = x_batch.size(0)
        x_batch = x_batch.to(self.device)

        # Target tensors
        real_target = torch.ones(batch_size, 1).to(self.device)
        fake_target = torch.zeros(batch_size, 1).to(self.device)

        # Train Autoencoder
        self.encoder_optimizer.zero_grad()
        self.decoder_optimizer.zero_grad()

        # Encode and decode the input
        z = self.encoder(x_batch)
        x_reconstructed = self.decoder(z)

        # Compute reconstruction loss
        recon_loss = self.reconstruction_loss(x_reconstructed, x_batch)

        # Compute adversarial loss for the generator (encoder)
        gen_loss = self.adversarial_loss(self.discriminator(z), real_target)

        # Total autoencoder loss
        ae_loss = recon_loss + gen_loss

        # Backpropagate and update parameters
        ae_loss.backward()
        self.encoder_optimizer.step()
        self.decoder_optimizer.step()

        # Train Discriminator
        self.discriminator_optimizer.zero_grad()

        # Generate latent vectors from the prior distribution (Gaussian in this case)
        z_prior = torch.randn(batch_size, self.latent_dim).to(self.device)

        # Get encoded samples
        z_encoded = self.encoder(x_batch).detach()  # Detach to avoid training the encoder again

        # Compute discriminator loss
        real_loss = self.adversarial_loss(self.discriminator(z_prior), real_target)
        fake_loss = self.adversarial_loss(self.discriminator(z_encoded), fake_target)
        d_loss = (real_loss + fake_loss) / 2

        # Backpropagate and update parameters
        d_loss.backward()
        self.discriminator_optimizer.step()

        return {
            'reconstruction_loss': recon_loss.item(),
            'generator_loss': gen_loss.item(),
            'discriminator_loss': d_loss.item()
        }

    def train(self, data_loader, epochs=100):
        self.encoder.train()
        self.decoder.train()
        self.discriminator.train()

        training_history = []

        for epoch in range(epochs):
            epoch_losses = {'reconstruction_loss': 0, 'generator_loss': 0, 'discriminator_loss': 0}
            batch_count = 0

            for batch_idx, (x_batch, _) in enumerate(data_loader):
                step_losses = self.train_step(x_batch)

                for key in epoch_losses:
                    epoch_losses[key] += step_losses[key]
                batch_count += 1

            # Calculate average losses for the epoch
            for key in epoch_losses:
                epoch_losses[key] /= batch_count

            training_history.append(epoch_losses)

            print(f"Epoch [{epoch+1}/{epochs}] - "
                  f"Recon Loss: {epoch_losses['reconstruction_loss']:.4f}, "
                  f"Gen Loss: {epoch_losses['generator_loss']:.4f}, "
                  f"Disc Loss: {epoch_losses['discriminator_loss']:.4f}")

        return training_history

    def encode(self, x):
        self.encoder.eval()
        with torch.no_grad():
            x = x.to(self.device)
            z = self.encoder(x)
        return z

    def decode(self, z):
        self.decoder.eval()
        with torch.no_grad():
            z = z.to(self.device)
            x_reconstructed = self.decoder(z)
        return x_reconstructed

    def reconstruct(self, x):
        self.encoder.eval()
        self.decoder.eval()
        with torch.no_grad():
            x = x.to(self.device)
            z = self.encoder(x)
            x_reconstructed = self.decoder(z)
        return x_reconstructed

    def save_model(self, path):
        torch.save({
            'encoder_state_dict': self.encoder.state_dict(),
            'decoder_state_dict': self.decoder.state_dict(),
            'discriminator_state_dict': self.discriminator.state_dict()
        }, path)

    def load_model(self, path):
        checkpoint = torch.load(path,  weights_only=True)
        self.encoder.load_state_dict(checkpoint['encoder_state_dict'])
        self.decoder.load_state_dict(checkpoint['decoder_state_dict'])
        self.discriminator.load_state_dict(checkpoint['discriminator_state_dict'])

In [10]:
print(torch.cuda.is_available())
print(torch.cuda.current_device())
print(torch.cuda.get_device_name(0))

True
0
Tesla P100-PCIE-16GB


In [11]:
checkpoint = torch.load('/kaggle/input/aae/pytorch/default/1/best_aae_model.pt')
print(checkpoint.keys())

  checkpoint = torch.load('/kaggle/input/aae/pytorch/default/1/best_aae_model.pt')


dict_keys(['encoder_state_dict', 'decoder_state_dict', 'discriminator_state_dict'])


In [12]:
bestModel = AdversarialAutoencoder()
bestModel.load_model("/kaggle/input/aae/pytorch/default/1/best_aae_model.pt")

In [13]:
movies_df = pd.read_parquet("/kaggle/input/moviesandsongs/movies_embeddings.parquet")
songs_df = pd.read_parquet("/kaggle/input/moviesandsongs/song_embeddings.parquet")

In [14]:
movies_df.head()

Unnamed: 0,id,title,overview,genres,embedding
0,27205,Inception,Cobb a skilled thief who commits corporate esp...,"Action, Science Fiction, Adventure","[0.01589059643447399, 0.11273891478776932, -0...."
1,157336,Interstellar,The adventures of a group of explorers who mak...,"Adventure, Drama, Science Fiction","[0.037922028452157974, -0.005655079614371061, ..."
2,155,The Dark Knight,Batman raises the stakes in his war on crime W...,"Drama, Action, Crime, Thriller","[0.011266704648733139, 0.032755907624959946, -..."
3,19995,Avatar,In the 22nd century a paraplegic Marine is dis...,"Action, Adventure, Fantasy, Science Fiction","[0.01744804158806801, 0.03436880186200142, 0.0..."
4,24428,The Avengers,When an unexpected enemy emerges and threatens...,"Science Fiction, Action, Adventure","[0.027801260352134705, -0.019952325150370598, ..."


In [15]:
songs_df.head()

Unnamed: 0,title,tag,lyrics,embedding
0,Poor Poor Pitiful Me,country,Well I lay my head on the railroad track\nWait...,"[0.029062896966934204, 0.08223594725131989, -0..."
1,Cuckoos Nest,country,There is a thorn bush\nIn Outcolia\nThere is a...,"[0.0009067401406355202, -0.09515126794576645, ..."
2,Wedding Bells,country,I have the invitation that your sent me\nYou w...,"[0.04617173597216606, 0.013956493698060513, -0..."
3,Could Have Fooled Me,country,Im fading like the taillights\nOf a car that y...,"[0.015817370265722275, -0.0025993576273322105,..."
4,Shot of Glory,country,Its finally payday\nMeeting the boys at my pla...,"[-0.011555955745279789, 0.0511910654604435, 0...."


In [16]:
import matplotlib.pyplot as plt
import seaborn as sns
import random

<class 'torch.Tensor'>


In [17]:
song_embeddings = np.array(songs_df['embedding'].tolist())
song_embeddings = torch.tensor(song_embeddings, dtype=torch.float32)
movie_embeddings = np.array(movies_df['embedding'].tolist())
movie_embeddings = torch.tensor(movie_embeddings, dtype=torch.float32)

songs_df["codes"] = [embedding.tolist() for embedding in bestModel.encode(song_embeddings).cpu().detach().numpy()]
movies_df["codes"] = [embedding.tolist() for embedding in bestModel.encode(movie_embeddings).cpu().detach().numpy()]

In [20]:
songs_df.head()

Unnamed: 0,title,tag,lyrics,embedding,codes
0,Poor Poor Pitiful Me,country,Well I lay my head on the railroad track\nWait...,"[0.029062896966934204, 0.08223594725131989, -0...","[4983.107421875, -929.4005737304688, 28013.300..."
1,Cuckoos Nest,country,There is a thorn bush\nIn Outcolia\nThere is a...,"[0.0009067401406355202, -0.09515126794576645, ...","[6021.56787109375, 1232.96337890625, 18822.304..."
2,Wedding Bells,country,I have the invitation that your sent me\nYou w...,"[0.04617173597216606, 0.013956493698060513, -0...","[-64.68852996826172, -4480.0126953125, 24379.0..."
3,Could Have Fooled Me,country,Im fading like the taillights\nOf a car that y...,"[0.015817370265722275, -0.0025993576273322105,...","[4583.91845703125, -1101.9259033203125, 26836...."
4,Shot of Glory,country,Its finally payday\nMeeting the boys at my pla...,"[-0.011555955745279789, 0.0511910654604435, 0....","[8586.7255859375, -6141.3564453125, 30254.4980..."


In [19]:
movies_df.head()

Unnamed: 0,id,title,overview,genres,embedding,codes
0,27205,Inception,Cobb a skilled thief who commits corporate esp...,"Action, Science Fiction, Adventure","[0.01589059643447399, 0.11273891478776932, -0....","[6462.3095703125, 561.6226806640625, 18281.753..."
1,157336,Interstellar,The adventures of a group of explorers who mak...,"Adventure, Drama, Science Fiction","[0.037922028452157974, -0.005655079614371061, ...","[4852.48681640625, -3510.31396484375, 21318.25..."
2,155,The Dark Knight,Batman raises the stakes in his war on crime W...,"Drama, Action, Crime, Thriller","[0.011266704648733139, 0.032755907624959946, -...","[2426.921630859375, 1094.772216796875, 12682.5..."
3,19995,Avatar,In the 22nd century a paraplegic Marine is dis...,"Action, Adventure, Fantasy, Science Fiction","[0.01744804158806801, 0.03436880186200142, 0.0...","[-7045.265625, -5515.25537109375, 16004.128906..."
4,24428,The Avengers,When an unexpected enemy emerges and threatens...,"Science Fiction, Action, Adventure","[0.027801260352134705, -0.019952325150370598, ...","[5416.56982421875, -5810.353515625, 19996.8769..."


In [21]:
random.seed(42)
def recomend(songs_encoded, movies, k=5):
    movie_codes = np.vstack(movies["codes"].values)
    similarities = cosine_similarity([songs_encoded], movie_codes)[0]
    top_indices = np.argsort(similarities)[::-1][:k]
    recommended_movies = movies_df.iloc[top_indices].copy()
    recommended_movies['similarity'] = similarities[top_indices]
    return recommended_movies
    

In [24]:
random.seed(42)
random_song_indices = random.sample(range(len(songs_df)), 5)
random_song_codes = songs_df.iloc[random_song_indices]['codes'].values
recomended_movies_codes = []
for i in random_song_indices:
    random_song = songs_df.iloc[i]
    song_codes = random_song['codes']
    recommendations = recomend(song_codes, movies_df)
    recomended_movies_codes.append(recommendations['codes'].values[0])
    print(f"Recommendations for song '{random_song['title']}'':")
    print(recommendations[['title', 'similarity']])
    print("-" * 30)

Recommendations for song '$?'':
                                                title  similarity
484114  Queens of the Stone Age: Live at Pinkpop 2008    0.923329
109055          Busta Rhymes - Everything Remains Raw    0.916486
433125  Garbage - Montreux Jazz Festival - 2005-07-03    0.914472
344479                     Chocolate and Cracker Orgy    0.907997
108114      Beastie Boys: The $kill$ To Pay The Bill$    0.904499
------------------------------
Recommendations for song 'Am I The Only One Who Cares'':
                         title  similarity
321049  You Make My Body Shake    0.922339
98363               Warm Broth    0.921366
326097  As Long As You're Here    0.915221
295983    Darling Teen Sluts 5    0.911001
531752              Jelly Baby    0.910888
------------------------------
Recommendations for song 'Crying Steel Guitar Waltz'':
                    title  similarity
475570         Surrender!    0.917457
185616       Tainted Love    0.916066
329959  Ashes a Video art 