# Evaluation

> Funtions to help with evaluating outputs of generative models.

In [None]:
#| hide
from nbdev.showdoc import *
from fastcore.utils import *

In [None]:
#| default_exp Evaluation

In [None]:
#| export
import torch
import scanpy as sc
import numpy as np
from fastcore.utils import *
import wandb
import scib




class Visualize:
    "Evaluation of the model"
    pass
    

In [None]:
#| export
@patch_to(Visualize)
def plot_embeddings(scdata, color_key_list, basis_list, device, show=False, log=True):
    for basis in basis_list:
        for color_key in color_key_list:
            plot_embeddings = sc.pl.embedding(scdata, basis=basis, color=color_key, wspace=0.3,  show=show)
            if log:
                wandb.log({"plot_embeddings_{}_{}".format(basis, color_key): wandb.Image(plot_embeddings)})

In [None]:
#| export
@patch_to(Visualize)
def plot_umaps(scdata, color_key_list, rep_list, device, show=False, log=True):
    for rep in rep_list:
        umap = sc.pp.neighbors(scdata, use_rep=rep)
        sc.tl.umap(scdata)
        for color_key in color_key_list:
            umap = sc.pl.umap(scdata, color=color_key, wspace=0.3, show = show)
            if log:
                wandb.log({"UMAP_{}_{}".format(rep, color_key): wandb.Image(umap)})

In [None]:
#| export
class Inferance:
   "Inferance of the model"
   pass

In [None]:
#| export
@patch_to(Inferance)
def get_embeddings_VAEGAN(VAEGAN, dataloader, device):
    with torch.no_grad():
        embeddings = []
        labels = []
        for batch, label in dataloader:
            batch = batch.to(device)
            x_hat, y_hat, mu, logvar = VAEGAN(batch)
            embeddings.append(mu.cpu().numpy())
            labels.append(label.cpu().numpy())
        embeddings = np.concatenate(embeddings)
        labels = np.concatenate(labels)
        return embeddings

In [None]:
#| export
@patch_to(Inferance)
def get_embeddings_VAEGAN_NEG_BI(VAEGAN, dataloader, device):
    with torch.no_grad():
        embeddings = []
        labels = []
        for batch, label in dataloader:
            batch = batch.to(device)
            x_hat, y_hat, mu, logvar, h_r, h_p = VAEGAN(batch)
            embeddings.append(mu.cpu().numpy())
            labels.append(label.cpu().numpy())
        embeddings = np.concatenate(embeddings)
        labels = np.concatenate(labels)
        return embeddings

In [None]:
#| export
@patch_to(Inferance)
def decode_embeddings_VAEGAN(VAEGAN, embeddings, device):
    with torch.no_grad():
        embeddings = torch.from_numpy(embeddings).to(device)
        for batch in embeddings:
            x_hat = VAEGAN.decoder(batch)
        x_hat = VAEGAN.decode()
        return x_hat.cpu().numpy()

In [None]:
#| export
@patch_to(Inferance)
def decode_embeddings_VAEGAN_NEG_BI(VAEGAN, embeddings, device):
    with torch.no_grad():
        embeddings = torch.from_numpy(embeddings).to(device)
        embeddings_list = []
        for batch in embeddings:
            batch = batch.to(device)
            x_hat = VAEGAN.decode(batch)
            embeddings_list.append(x_hat.cpu().numpy())
        embeddings_list = np.array(embeddings_list)
        return embeddings_list

In [None]:
class batch:
    pass

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()