# Lattice Latent Variable Model (Lattice-LVM) Example

## Loading data

The functions below load gerbil data and create a dataloader for use in the model. I know gerbils are not birds, but I named the functions/file before I knew we'd be running this on gerbils and it just happens to work for gerbils too.

In [1]:
exp = 112

In [2]:
!uv pip install torch torchvision --index-url https://download.pytorch.org/whl/cu129


[2mUsing Python 3.11.12 environment at: C:\Users\gg3065\code\vocalizations\.venv[0m
[2mAudited [1m2 packages[0m [2min 9ms[0m[0m


In [6]:
from data.bird_data import load_gerbils,bird_data
from torch.utils.data import DataLoader
import numpy as np
import os
import torch

# Becasue I'm using Windows, can't use Miles's lambdas
def spec_to_tensor(x: np.ndarray) -> torch.Tensor:
    # x shape: H x W numpy array
    return torch.from_numpy(x).to(torch.float32).unsqueeze(0)

data_path = fr"\\sanesstorage.cns.nyu.edu\archive\ginosar\Processed_data\Audio\{exp}" #'/mnt/home/mmartinez/ceph/data/gerbil/gily' # this directory contains .hdf5 files with spectrograms from your audio
n_workers = max(os.cpu_count()-1,1) #len(os.sched_getaffinity(0))
print(n_workers)

(train_fns,test_fns),(train_ids,test_ids),specs_per_file = load_gerbils(data_path,specs_per_file=100,families=[1],test_size=0.2,seed=92,check=True)
### specs_per_file is how many spectrograms are in each .hdf5 file, families is the family number we're trying to load (I just set the data you sent to family 1),
### test_size is the portion of the data that will remain unseen in training, seed is used to maintain reproducibility, check determines whether we check to see if 
### all files have 100 vocalization each

train_dataset = bird_data(train_fns, train_ids,specs_per_file=specs_per_file,transform=spec_to_tensor,conditional=False)
# *GILY*: this is Mile's version, but I can't use lambda on multiprocessing in Windows
# train_dataset = bird_data(train_fns,train_ids,specs_per_file=specs_per_file,transform=lambda x: torch.from_numpy(x).to(torch.float32).unsqueeze(0), conditional=False)

### Unfortunately, transform has to be a little weird because of how I saved the spectrograms. This performs these operations on each spectrogram before returning them
### Conditional determines if we want to condition our model on other variables (fm, entropy, syllable length, etc)

test_dataset = bird_data(test_fns, test_ids,specs_per_file=specs_per_file,transform=spec_to_tensor, conditional=False)
# *GILY*: this is Mile's version, but I can't use lambda on multiprocessing in Windows
# test_dataset = bird_data(test_fns,test_ids,specs_per_file=specs_per_file,transform=lambda x: torch.from_numpy(x).to(torch.float32).unsqueeze(0), conditional=False)

train_loader = DataLoader(train_dataset,batch_size=64,num_workers=n_workers,shuffle=True)
test_loader = DataLoader(test_dataset,batch_size=64,num_workers=n_workers,shuffle=False)

31
loading family1


100%|██████████| 2081/2081 [00:51<00:00, 40.22it/s]


## Defining a model

Now we can set up our model for training. This is done by selecting a grid over our latent space and a model architecture:

In [8]:
from models.sampling import gen_fib_basis,gen_korobov_basis
from models.utils import get_decoder_arch
from models.qmc_base import QMCLVM
import torch

latent_dim=2 # sets our latent dimension
### If we use two dimensions, we should use gen_fib_basis for our grid over the latent space
### If more than two dimensions, we should use gen_korobov_basis. This requires additional arguments, 
### if you want to use this see help(gen_korobov_basis) for good argument values

latent_grid = gen_fib_basis(m=15) # m determines both the size of our grid and spacing of points
## if you want to plot this, you will need to plot (latent_grid % 1) instead of latent_grid


dataset = 'gerbil_ava' # used for getting a pre-selected decoder architecture
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') # use gpu if possible

decoder = get_decoder_arch(dataset_name=dataset,latent_dim=latent_dim) # get_decoder_arch has a set of fixed architectures -- 
### if you want to play around with your own, you can make one using nn.Sequential (strings together layers). That's all that the 
### decoders are -- nn.Sequential instances

In [4]:
from train.losses import binary_evidence,binary_lp,gaussian_evidence,gaussian_lp
model = QMCLVM(latent_dim=latent_dim,device=device,decoder=decoder)

## binary evidence
qmc_loss_func = binary_evidence # I used this for training models, but we can also use gaussian (what the VAE uses)
qmc_lp = binary_lp

## gaussian evidence
# qmc_loss_func = lambda samples,data: gaussian_evidence(samples,data,var=0.01) # var is the inverse of precision from the VAE
# qmc_lp = lambda samples,data: gaussian_lp(samples,data,var=0.01)

## Training the model

In [7]:
from train.train import train_loop
nEpochs=10
#### to speed up training, you can decrease grid size (decrease m) at the expense of model performance,
#### or increase batch size  
model,opt,losses = train_loop(model,train_loader,latent_grid.to(device),qmc_loss_func,\
                                                                    nEpochs=nEpochs,verbose=False,
                                                                    conditional=False)
        

  0%|          | 0/10 [00:03<?, ?it/s]


KeyboardInterrupt: 

In [9]:
from train.train import test_epoch
model.eval()
with torch.no_grad():
    test_losses =  test_epoch(model,test_loader,latent_grid.to(device),qmc_loss_func,conditional=False)


  0%|          | 0/652 [00:09<?, ?it/s]


RuntimeError: DataLoader worker (pid(s) 34660, 49844, 44176, 45944, 56260, 21356) exited unexpectedly

## Basic visualization

In [10]:
import numpy as np
import matplotlib.pyplot as plt
from plotting.visualize import format_plot_axis
def qmc_train_plot(qmc_train_losses,qmc_test_losses,save_fn='',show=False):

    qmc_train_losses,qmc_test_losses = np.array(qmc_train_losses),np.array(qmc_test_losses)

    ax = plt.gca()
    N = len(qmc_train_losses)

    ax.plot(-qmc_train_losses,label = 'Model evidence',color='tab:blue')
    xax = list(ax.get_xticks()) 
    if xax[-1] >= N + N//5:
        xax = xax[:-1]
    xticklabels = xax + ['Test']
    xax += [N + N//5] 
    ax.errorbar(N + N//5,np.nanmean(-qmc_test_losses),yerr =np.nanstd(-qmc_test_losses),capsize=6,linestyle='',color='tab:blue')
    ax =  format_plot_axis(ax,ylabel='Model evidence',xlabel='update number',xticks=xax,xticklabels=xticklabels)
    if show:
        plt.show()
    else:
        plt.savefig(save_fn)

    plt.close()

In [11]:
from plotting.visualize import model_grid_plot,qmc_train_plot
qmc_train_plot(losses,test_losses,show=True)
model_grid_plot(model,n_samples_dim=10,origin='lower',cm='inferno') # increase n_samples_dim for a denser plot that takes longer to show

NameError: name 'losses' is not defined

## Embedding



In [None]:
test_embeddings,test_labels = model.embed_data(latent_grid.to(device),test_loader,qmc_lp, embed_type='rqmc',n_samples=5)
## different embedding types are:
## argmax (maximum likelihood grid point for each sample in latent space)
## posterior (sum over grid points, weighted by likelihood)
## rqmc (average weighted sum over randomly shifted grid points, averaged over n_samples grid shifts)

### unfortunately, at this point test_labels are just family ID (so all 1, for this dataset) -- I haven't implemented an easy way to 
### label by location, file number, etc. yet. Will be doing that soon!

In [None]:
import matplotlib.pyplot as plt
from plotting.visualize import format_plot_axis
ax = plt.gca()
ax.scatter(test_embeddings[:,0],test_embeddings[:,1],s=1,alpha=0.5)
ax = format_plot_axis(ax,xlim=(0,1),ylim=(0,1),xlabel='Latent dim 1',ylabel='Latent dim 2',title='Latent embeddings of test dataset')
plt.show()
plt.close()