In [6]:
import sys
sys.path.append('../../../')
from models.unsupervised.vae.model import VAE, Decoder, Encoder
from sklearn.feature_selection import mutual_info_regression
import numpy as np
import pickle as pkl
import torch

def compute_mig(true_latents, learned_latents):
    num_latents = true_latents.shape[1]
    mi_matrix = np.zeros((num_latents, num_latents))
    
    for i in range(num_latents):
        for j in range(num_latents):
            mi_matrix[i, j] = mutual_info_regression(true_latents[:, i].reshape(-1, 1), learned_latents[:, j])[0]
    
    mi_sorted = np.sort(mi_matrix, axis=1)
    gaps = mi_sorted[:, -1] - mi_sorted[:, -2]
    mig_score = np.mean(gaps)
    
    return mig_score
from sklearn.ensemble import RandomForestClassifier

def compute_factorvae_score(true_latents, learned_latents):
    num_latents = true_latents.shape[1]
    accuracy_scores = []
    
    for i in range(num_latents):
        classifier = RandomForestClassifier()
        classifier.fit(learned_latents, true_latents[:, i])
        accuracy = classifier.score(learned_latents, true_latents[:, i])
        accuracy_scores.append(accuracy)
        
    factorvae_score = np.mean(accuracy_scores)
    
    return factorvae_score



In [8]:
model_name = "vae"
size = "disentangled_3"
epoch = 1800


models_path = f"../../../models/unsupervised/{model_name}/saved_models"
res_q_25, res_med, res_q_75 = [], [], []
with open(f'{models_path}/{size}/dataset.pkl', 'rb') as f:
	dataset = pkl.load(f)

features = [64, 32, 16, 8 ,4]
encoder = Encoder(in_features=32, features=features, out_features=2)
decoder = Decoder(in_features=2, features=list(reversed(features)), out_features=32)
vae = VAE(encoder, decoder)

vae.load_state_dict(torch.load(f"{models_path}/{size}/model_{epoch}.pth"))

<All keys matched successfully>

In [17]:
true_latents = dataset.y.detach().numpy()
mu, log_var =  vae.encoder(dataset.X)
learned_latents = vae.reparameterize(mu, log_var).detach().numpy()

In [18]:

# Compute MIG and FactorVAE Score
mig_score = compute_mig(true_latents, learned_latents)
factorvae_score = compute_factorvae_score(true_latents, learned_latents)

print("MIG Score:", mig_score)
print("FactorVAE Score:", factorvae_score)


ValueError: Unknown label type: 'continuous'