# Assess VAE model
Perform simple assessment of the VAE model, compare accuracies between data splits, visualize metrics in one plot, etc.

In [1]:
import os
from vae import BaseVAE
import specvae.dataset as dt
import specvae.utils as utils

In [2]:
# Parameters
dataset = "MoNA"
model_name = "joint_vae_30-15-3-15-30 (17-11-2021_15-22-06)"
model_dir = "d:\\Workspace\\SpecVAE\\.model\\MoNA\\joint_sigmoid\\joint_vae_30-15-3-15-30 (17-11-2021_15-22-06)"
csv_filepath = "d:\\Workspace\\SpecVAE\\.model\\MoNA"


## Load model

In [3]:
device, cpu = utils.device(use_cuda=True)

GPU device count: 1
Device in use:  cuda:0


In [4]:
print("Load model: %s..." % model_name)
model_path = os.path.join(model_dir, 'model.pth')
model = BaseVAE.load(model_path, device)
model.eval()

Load model: joint_vae_30-15-3-15-30 (17-11-2021_15-22-06)...


JointVAESigmoid(
  (encoder_): Sequential(
    (en_lin_1): Linear(in_features=30, out_features=15, bias=True)
    (en_lin_batchnorm_1): BatchNorm1d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (en_act_1): ReLU()
  )
  (en_mu): Linear(in_features=15, out_features=3, bias=True)
  (en_mu_batchnorm): BatchNorm1d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (en_log_var): Linear(in_features=15, out_features=3, bias=True)
  (en_log_var_batchnorm): BatchNorm1d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (decoder): Sequential(
    (de_lin_1): Linear(in_features=3, out_features=15, bias=True)
    (de_lin_batchnorm_1): BatchNorm1d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (de_act_1): ReLU()
    (de_lin_2): Linear(in_features=15, out_features=30, bias=True)
    (de_act_2): Sigmoid()
  )
  (loss): JointVAESigmoidCriterium(
    (gnll): GaussianNLLLoss()
  )
)

In [5]:
model.config

{'layer_config': array([[30, 15,  3],
        [ 3, 15, 30]]),
 'limit': 2500,
 'latent_spec': {'cont': 3},
 'temperature': 0.5,
 'cont_capacity': [0.0, 3.0, 10000, 5.0],
 'disc_capacity': [0.0, 1.0, 10000, 10],
 'dropout': 0.0,
 'transform': Compose(
     <dataset.SplitSpectrum object at 0x0000017DE8976908>
     <dataset.TopNPeaks object at 0x0000017DE8976988>
     <dataset.FilterPeaks object at 0x0000017DE8976A08>
     <dataset.Normalize object at 0x0000017DE868EE88>
     <dataset.ToMZIntConcatAlt object at 0x0000017DE6C36848>
 ),
 'input_columns': ['spectrum'],
 'types': [torch.float32],
 'dataset': 'MoNA',
 'max_mz': 2500,
 'min_intensity': 0.1,
 'max_num_peaks': 15,
 'normalize_intensity': True,
 'normalize_mass': True,
 'n_samples': -1}

## Load dataset for evaluation

In [6]:
input_columns = model.config['input_columns']
types = model.config['types']

train_loader, valid_loader, test_loader, metadata = dt.load_data(
    dataset, model.transform, -1, int(1e7), True, device, input_columns, types)

Load train data
Load and transform...
Progress: 5%
Progress: 10%
Progress: 15%
Progress: 20%
Progress: 25%
Progress: 30%
Progress: 35%
Progress: 40%
Progress: 45%
Progress: 50%
Progress: 55%
Progress: 60%
Progress: 65%
Progress: 70%
Progress: 75%
Progress: 80%
Progress: 85%
Progress: 90%
Progress: 95%
Convert data to pytorch tensors...
Load valid data
Load and transform...
Progress: 5%
Progress: 10%
Progress: 15%
Progress: 20%
Progress: 25%
Progress: 30%
Progress: 35%
Progress: 40%
Progress: 45%
Progress: 50%
Progress: 55%
Progress: 60%
Progress: 65%
Progress: 70%
Progress: 75%
Progress: 80%
Progress: 85%
Progress: 90%
Progress: 95%
Convert data to pytorch tensors...
Load test data
Load and transform...
Progress: 5%
Progress: 10%
Progress: 15%
Progress: 20%
Progress: 25%
Progress: 30%
Progress: 35%
Progress: 40%
Progress: 45%
Progress: 50%
Progress: 55%
Progress: 60%
Progress: 65%
Progress: 70%
Progress: 75%
Progress: 80%
Progress: 85%
Progress: 90%
Progress: 95%
Convert data to pytorc

## Evaluate model

In [7]:
def compute_simple_metrics(X, Xr, z, latent_dist):
    import specvae.metrics as mcs
    memory_error = True
    while memory_error:
        try:
            metrics = {}
            loss, recon, kldiv = model.loss.forward_(Xr, X, latent_dist)
            X_, Xr_ = X.data.cpu().numpy(), Xr.data.cpu().numpy()
            metrics['loss'] = loss.item()
            metrics['recon'] = recon.item()
            metrics['kldiv'] = kldiv.item()
            metrics['cos_sim'] = mcs.cos_sim(X_, Xr_)
            metrics['dist_eu'] = mcs.euclidean_distance(X_, Xr_)
            memory_error = False
            return metrics
        except MemoryError as me:
            n = int(0.9 * X.shape[0])
            print("Memory error, try again with a smaller set n=", n)
            X, Xr = X[:n,:], Xr[:n,:]
            memory_error = True

### Train subset

In [8]:
X_train, ids_train = next(iter(train_loader))
X_recon_train, z_train, latent_dist_train = model.forward_(X_train)
train_metrics = compute_simple_metrics(X_train, X_recon_train, z_train, latent_dist_train)
train_metrics

Memory error, try again with smaller set n= 98246
Memory error, try again with smaller set n= 88421
Memory error, try again with smaller set n= 79578
Memory error, try again with smaller set n= 71620
Memory error, try again with smaller set n= 64458


{'loss': 6.296037673950195,
 'recon': 6.140259265899658,
 'kldiv': 2.968844175338745,
 'cos_sim': 0.9873850147576717,
 'dist_eu': 0.19230925}

### Test subset

In [9]:
X_test, ids_test = next(iter(test_loader))
X_recon_test, z_test, latent_dist_test = model.forward_(X_test)
test_metrics = compute_simple_metrics(X_test, X_recon_test, z_test, latent_dist_test)
test_metrics

{'loss': 6.725614070892334,
 'recon': 5.001116752624512,
 'kldiv': 2.6551005840301514,
 'cos_sim': 0.9911310901380566,
 'dist_eu': 0.15641002}

### Compare Train and Test metrics

In [10]:
def compare_models_metrics(metrics1, metrics2):
    import plotly.graph_objects as go
    for mname, mvalue in metrics1.items():
        fig = go.Figure(data=[
            go.Bar(name="Test", x=['test'], y=[metrics1[mname]]),
            go.Bar(name="Train", x=['train'], y=[metrics2[mname]]),
        ])
        fig.update_layout(autosize=False, width=500, height=500, title="" + mname + " [Test vs. Train]")
        fig.show()


In [11]:
compare_models_metrics(test_metrics, train_metrics)