# Plot lattice

In [None]:
import os
import numpy as np
from specvae.model import BaseModel
import specvae.dataset as dt
import specvae.utils as utils

In [None]:
# Parameters
model_dirs = [
    # "d:\\Workspace\\SpecVAE\\.model\\HMDB\\betavae_capacity_nextron\\betavae_capacity_20-400-200-50-2-50-200-400-20_01 (24-12-2021_19-40-39)",
    # "d:\\Workspace\\SpecVAE\\.model\\HMDB\\betavae_capacity_nextron\\betavae_capacity_20-1600-2-1600-20_02 (24-12-2021_18-27-38)",
    # "d:\\Workspace\\SpecVAE\\.model\\HMDB\\betavae_capacity_nextron\\betavae_capacity_20-100-2-90-100-20_03 (24-12-2021_20-30-13)",
    # "d:\\Workspace\\SpecVAE\\.model\\HMDB\\betavae_capacity_nextron\\betavae_capacity_20-100-90-2-90-100-20_04 (24-12-2021_19-24-07)",
    # "d:\\Workspace\\SpecVAE\\.model\\HMDB\\betavae_capacity_nextron\\betavae_capacity_20-400-100-2-100-400-20_05 (24-12-2021_19-09-29)",
    # "d:\\Workspace\\SpecVAE\\.model\\HMDB\\betavae_capacity_nextron\\betavae_capacity_20-100-90-2-100-20_06 (24-12-2021_20-58-10)",

    # "d:\\Workspace\\SpecVAE\\.model\\MoNA\\betavae_capacity_nextron\\betavae_capacity_20-800-200-50-3-50-200-800-20_01 (24-12-2021_01-50-12)",
    # "d:\\Workspace\\SpecVAE\\.model\\MoNA\\betavae_capacity_nextron\\betavae_capacity_20-400-100-3-400-20_02 (24-12-2021_03-34-34)",
    "d:\\Workspace\\SpecVAE\\.model\\MoNA\\betavae_capacity_nextron\\best\\betavae_capacity_20-1600-3-1600-20_03 (24-12-2021_00-17-31)",
    # "d:\\Workspace\\SpecVAE\\.model\\MoNA\\betavae_capacity_nextron\\betavae_capacity_20-800-3-800-20_04 (24-12-2021_00-25-10)",
    # "d:\\Workspace\\SpecVAE\\.model\\MoNA\\betavae_capacity_nextron\\betavae_capacity_20-100-3-90-100-20_05 (24-12-2021_03-01-19)",
    # "d:\\Workspace\\SpecVAE\\.model\\MoNA\\betavae_capacity_nextron\\betavae_capacity_50-400-3-100-400-50_06 (24-12-2021_06-19-49)"
]


In [None]:
def load_model(path):
    device, cpu = utils.device(use_cuda=False)
    model_path = os.path.join(path, 'model.pth')
    model = BaseModel.load(model_path, device)
    model.eval()
    return model

In [None]:
import torch
def lattice(model, grid=(8, 8), zrange=(-3, 3), dim1=0, dim2=1, rdim_value=0.):
    with torch.no_grad():
        z = torch.zeros(1, model.latent_dim)
        z = z.repeat(grid[0] * grid[1], 1)
        x = torch.linspace(zrange[0], zrange[1], grid[0])
        y = torch.linspace(zrange[0], zrange[1], grid[1])
        x_grid, y_grid = torch.meshgrid(x, y)
        z[:, dim1] = x_grid.flatten()
        z[:, dim2] = y_grid.flatten()
        if model.latent_dim > 2:
            rdiml = list(range(model.latent_dim))
            rdiml.remove(dim1)
            rdiml.remove(dim2)
            rdim = rdiml[0]
            z[:, rdim] = rdim_value
        z = z.to(model.device)
        return model.decode(z)


In [None]:
import specvae.visualize as vis
import torchvision as tv
import itertools as it

resolution = 0.5
grid = (32, 32)
zrange = (-5., 5.)

In [None]:
for model_path in model_dirs:
    model = load_model(model_path)
    revtrans = tv.transforms.Compose([
        dt.ToMZIntDeConcatAlt(max_num_peaks=model.config['max_num_peaks']),
        dt.Denormalize(intensity=model.config['normalize_intensity'], mass=model.config['normalize_mass'], max_mz=model.config['max_mz']),
        dt.ToDenseSpectrum(resolution=resolution, max_mz=model.config['max_mz'])
    ])
    print(model_path)
    if model.config['dataset'] == 'MoNA':
        for dim1, dim2 in it.combinations([0,1,2], 2):
            for rvalue in np.linspace(-5., 5., num=11):
                print('dim1=%d, dim2=%d, rvalue=%f' % (dim1, dim2, rvalue))
                spectra = lattice(model, grid=grid, zrange=zrange, dim1=dim1, dim2=dim2, rdim_value=rvalue)
                fig, axs = vis.plot_spectra_(spectra, grid=grid, figsize=(68, 36), resolution=resolution, max_mz=model.config['max_mz'], transform=revtrans, dpi=300)
                fig.savefig(f'../.img/lattice/lattice-dim{dim1}-dim{dim2}-rvalue{rvalue}.svg')
                fig.savefig(f'../.img/lattice/lattice-dim{dim1}-dim{dim2}-rvalue{rvalue}.png')
    elif model.config['dataset'] == 'HMDB':
        dim1, dim2 = 0, 1
        print('dim1=%d, dim2=%d' % (dim1, dim2))
        spectra = lattice(model, grid=grid, zrange=zrange, dim1=dim1, dim2=dim2, rdim_value=0.)
        vis.plot_spectra_(spectra, grid=grid, figsize=(68, 36), resolution=resolution, max_mz=model.config['max_mz'], transform=revtrans)