Matrix structure:
```
++++   ++++   ++++   ++++  ++++
+X0+   +X1+   +X2+   +X3+  +X4+
++++   ++++   ++++   ++++  ++++

X0 - books x word emb
X1 - books x googleNet cnn emb
X2 - books x inceptionNet cnn emb
X3 - books x resnet cnn emb
X4 - books x VGG cnn emb

E0 - books
E1 - word emb
E2 - googleNet cnn emb
E3 - inceptionNet cnn emb
E4 - resnet cnn emb
E5 - VGG cnn emb
```

In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.functional as F
import pickle

In [2]:
dataset_folder = '../dataset/'
X0_file = dataset_folder + "word2vec_emb_tensor.pkl"
X1_file = dataset_folder + "googleNet_cnn_emb_tensor.pkl"
X2_file = dataset_folder + "inceptionNet_cnn_emb_tensor.pkl"
X3_file = dataset_folder + "cnn_resnet_emb_tensor.pkl"
X4_file = dataset_folder + "vgg_cnn_emb_tensor.pkl"

In [3]:
X0 = torch.load(X0_file)
print(X0.size())
X1 = torch.load(X1_file)
print(X1.size())
X2 = torch.load(X2_file)
print(X2.size())
X3 = torch.load(X3_file)
print(X3.size())
X4 = torch.load(X4_file)
print(X4.size())

torch.Size([5000, 100])
torch.Size([5000, 1000])
torch.Size([5000, 1000])
torch.Size([5000, 1000])
torch.Size([5000, 1000])


In [4]:
class Autoencoder(nn.Module):
    def __init__(self, input_dim, embedding_dim):
        super().__init__()
        self.enc_linear1 = nn.Linear(input_dim, 128)
        self.enc_linear2 = nn.Linear(128, embedding_dim)
        self.dec_linear1 = nn.Linear(embedding_dim, 128)
        self.dec_linear2 = nn.Linear(128, input_dim)
        self.emb = None
        
    def forward(self, x):
        x = self.enc_linear1(x)
        x = torch.relu(x)
        x = self.enc_linear2(x)
        self.emb = x # return embedding from encoder
        x = torch.relu(x)
        x = self.dec_linear1(x)
        x = torch.relu(x)
        x = self.dec_linear2(x)
        return x # use x for training
    

In [5]:
class matrix_factorization():
    def __init__(self, matrices, entity_list, matrix_entity_mapping, emb_dim):
        self.matrices = matrices
        self.entity_list = entity_list
        self.matrix_entity_mapping = matrix_entity_mapping # {"E0": ["X0", "X1"], "E1": ["X0"], "E2":["X1"]}
        self.emb_dim = emb_dim
        self.autoencoders = {} # {"E0": E0_autoencoder, "E1": E1_ae, ...}
        self.reconstructed_matrices = {} # {"X0": recon_X0, "X1": recon_X1, ...}
        self.embeddings = {} # {"E0": E0_emb, "E1": E1_emb, ...}
        self.concatenated_matrices = []
        self.optim = None
        self.criterion = nn.MSELoss()
        self.batch_size = 50
        self.convergence_threshold = 1e-4
        self.learning_rate = 0.00001
        self.epoch_count = 500
        
    def init_autoencoders(self):
        # initialize autoencoder - one for each entity
        for entity, matrices in matrix_entity_mapping.items():
            if entity == "E0":
                C_E0 = matrices[0]
                for i in range(1, len(matrices)):
                    C_E0 = torch.cat((C_E0, matrices[i]), dim = 1)
                print(C_E0.size())
                E0_aec = Autoencoder(C_E0.size(1), self.emb_dim)
            elif entity == "E1":
                C_E1 = torch.transpose(matrices, 0, 1)
                print(C_E1.size())
                E1_aec = Autoencoder(C_E1.size(1), self.emb_dim)
            elif entity == "E2":
                C_E2 = torch.transpose(matrices, 0, 1)
                print(C_E2.size())
                E2_aec = Autoencoder(C_E2.size(1), self.emb_dim)
            elif entity == "E3":
                C_E3 = torch.transpose(matrices, 0, 1)
                print(C_E3.size())
                E3_aec = Autoencoder(C_E3.size(1), self.emb_dim)
            elif entity == "E4":
                C_E4 = torch.transpose(matrices, 0, 1)
                print(C_E4.size())
                E4_aec = Autoencoder(C_E4.size(1), self.emb_dim)
            elif entity == "E5":
                C_E5 = torch.transpose(matrices, 0, 1)
                print(C_E5.size())
                E5_aec = Autoencoder(C_E5.size(1), self.emb_dim)
                
        self.concatenated_matrices = {"E0": C_E0, "E1": C_E1, "E2": C_E2, "E3": C_E3, "E4": C_E4, "E5": C_E5}
        self.autoencoders = {"E0": E0_aec, "E1": E1_aec, "E2": E2_aec, "E3": E3_aec, "E4": E4_aec, "E5": E5_aec}
        self.optim = torch.optim.SGD(list(E0_aec.parameters()) + list(E1_aec.parameters()) + \
                                     list(E2_aec.parameters())+ list(E3_aec.parameters()) + \
                                     list(E4_aec.parameters()) +list(E5_aec.parameters()), lr = self.learning_rate)
    
    def train_autoencoder(self):
        # training
        prev_losses = []
        for epoch in range(0,self.epoch_count):
            shuffled_indices = {}
            avg_loss = {}
            ent_emb = {}
            for e in self.autoencoders.keys():
                shuffled_indices[e] = torch.randperm(self.concatenated_matrices[e].size(0))
                ent_emb[e] = torch.zeros(self.concatenated_matrices[e].size(0), self.emb_dim)
            
            for e in self.concatenated_matrices.keys():
                total_loss = 0
                num_batches = 0
                for count in range(0, self.concatenated_matrices[e].size(0), self.batch_size):
                    indices = shuffled_indices[e][count:count+self.batch_size] 
                    minibatch = self.concatenated_matrices[e][indices]
                    output = self.autoencoders[e](minibatch)
                    ent_emb[e][indices] = self.autoencoders[e].emb # assign emb of the mini batch to entity
#                     print(ent_emb[e][indices[0]])
                    loss = self.criterion(minibatch, output)
                    num_batches += 1
                    total_loss += loss
                avg_loss[e] = total_loss/num_batches
    
#             print(ent_emb['E0'][0])
            aec_loss = 0
    
            for v in avg_loss.values():
                aec_loss += v
#             print(f"Aec {aec_loss}")
            self.reconstructed_matrices['X0'] = torch.matmul(ent_emb['E0'], torch.transpose(ent_emb['E1'], 0, 1))
            self.reconstructed_matrices['X1'] = torch.matmul(ent_emb['E0'], torch.transpose(ent_emb['E2'], 0, 1))
            self.reconstructed_matrices['X2'] = torch.matmul(ent_emb['E0'], torch.transpose(ent_emb['E3'], 0, 1))
            self.reconstructed_matrices['X3'] = torch.matmul(ent_emb['E0'], torch.transpose(ent_emb['E4'], 0, 1))
            self.reconstructed_matrices['X4'] = torch.matmul(ent_emb['E0'], torch.transpose(ent_emb['E5'], 0, 1))
            recon_loss = self.criterion(self.reconstructed_matrices['X0'], self.matrix_entity_mapping["E1"]) + \
                        self.criterion(self.reconstructed_matrices['X1'], self.matrix_entity_mapping["E2"]) + \
                        self.criterion(self.reconstructed_matrices['X2'], self.matrix_entity_mapping["E3"]) + \
                        self.criterion(self.reconstructed_matrices['X3'], self.matrix_entity_mapping["E4"]) + \
                        self.criterion(self.reconstructed_matrices['X4'], self.matrix_entity_mapping["E5"])
#             print(f"recon loss {recon_loss}")
            aec_loss += recon_loss
#             print(f"Total {aec_loss}")
            self.optim.zero_grad()
            aec_loss.requires_grad_(True)
            aec_loss.backward()
            self.optim.step()

            if epoch % 10 == 0:
                print(f"Average loss for epoch {epoch} = {aec_loss}")
            if  (epoch > 100) and (len(prev_losses) > 0) and (prev_losses[-1] - aec_loss < self.convergence_threshold):
                print('Convergence!')
                break
            prev_losses.append(aec_loss)
        
    def get_embeddings(self):
        for e in self.matrix_entity_mapping.keys():
            out = self.autoencoders[e](self.concatenated_matrices[e])
            self.embeddings[e] = self.autoencoders[e].emb
        return self.embeddings
 

In [None]:
matrices = ["X0", "X1", "X2", "X3", "X4"]
entity_list = ["E0", "E1", "E2", "E3", "E4", "E5"]
matrix_entity_mapping = {"E0": (X0, X1, X2, X3, X4), "E1": (X0), "E2":(X1), "E3":(X2), "E4":(X3), "E5":(X4)}
emb_dim = 50

model = matrix_factorization(matrices, entity_list, matrix_entity_mapping, emb_dim)
model.init_autoencoders()
model.train_autoencoder()
embeddings = model.get_embeddings()

torch.Size([5000, 4100])
torch.Size([100, 5000])
torch.Size([1000, 5000])
torch.Size([1000, 5000])
torch.Size([1000, 5000])
torch.Size([1000, 5000])
Average loss for epoch 0 = 25.892372131347656
Average loss for epoch 10 = 24.21759796142578
Average loss for epoch 20 = 23.493083953857422
Average loss for epoch 30 = 23.044540405273438
Average loss for epoch 40 = 22.731101989746094
Average loss for epoch 50 = 22.499862670898438
Average loss for epoch 60 = 22.32384490966797
Average loss for epoch 70 = 22.183021545410156
Average loss for epoch 80 = 22.06642723083496
Average loss for epoch 90 = 21.967226028442383
Average loss for epoch 100 = 21.881023406982422
Average loss for epoch 110 = 21.80487823486328
Average loss for epoch 120 = 21.73670196533203
Average loss for epoch 130 = 21.67487335205078
Average loss for epoch 140 = 21.618167877197266
Average loss for epoch 150 = 21.56584930419922
Average loss for epoch 160 = 21.517314910888672


In [None]:
torch.save(embeddings, "all_cnn_embeddings.pkl")

In [None]:
# with open('all_cnn_embeddings.pkl', 'wb') as handle:
#     pickle.dump(embeddings, handle, protocol=pickle.HIGHEST_PROTOCOL)