In [3]:
import torch
import numpy as np
from misc import data, c
from torch import optim
from scipy.stats import spearmanr
from vae2 import VAE # import the last version

def get_cor_ensamble(batch, mutants_values, model, ensambles = 512, rand = True):
    model.eval()

    mt_elbos, wt_elbos = 0, 0

    for i in range(ensambles):
        if i and (i % 2 == 0):
            print(f"\tReached {i}/rand={rand}", " "*32, end="\r")

        elbos     = model.logp(batch, rand=rand).detach().cpu()
        wt_elbos += elbos[0]
        mt_elbos += elbos[1:]

    print()

    diffs  = (mt_elbos / ensambles) - (wt_elbos / ensambles)
    cor, _ = spearmanr(mutants_values, diffs)
    
    return cor

ModuleNotFoundError: No module named 'torch'

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataloader, df, mutants_tensor, mutants_df, neff = data(batch_size = 64, device=device)

wildtype   = dataloader.dataset[0] # one-hot-encoded wildtype 
eval_batch = torch.cat([wildtype.unsqueeze(0), mutants_tensor.to(device)])

args = {
    'alphabet_len': dataloader.dataset[0].shape[0],
    'seq_len':      dataloader.dataset[0].shape[1],
    'neff':         neff
}

vae   = VAE(**args).to(device)
opt   = optim.Adam(vae.parameters())

stats = {
    'rl': [],  # rl  = Reconstruction loss
    'klz': [], # kl  = Kullback-Leibler divergence loss
    'klp': [],  # KL divergence loss for parameters
    'cor': []  # cor = Spearman correlation to experimentally measured 
    }          # protein fitness according to eq.1 from paper

Parsing fasta 'data/BLAT_ECOLX_hmmerbit_plmc_n5_m30_f50_t0.2_r24-286_id100_b105.a2m'
Parsing labels 'data/BLAT_ECOLX_hmmerbit_plmc_n5_m30_f50_t0.2_r24-286_id100_b105_LABELS.a2m'
Generating 8403 1-hot encodings
Generating 8403 1-hot encodings. Took 0.793s torch.Size([8403, 23, 253])
Generating 4807 1-hot encodings
Generating 4807 1-hot encodings. Took 0.421s torch.Size([4807, 23, 253])


In [None]:
vae

VAE(
  (fc1): Linear(in_features=5819, out_features=64, bias=True)
  (fc1h): Linear(in_features=64, out_features=30, bias=True)
  (fc21): Linear(in_features=30, out_features=2, bias=True)
  (fc22): Linear(in_features=30, out_features=2, bias=True)
  (fc3): Linear(in_features=2, out_features=32, bias=False)
  (fc3h): Linear(in_features=32, out_features=64, bias=False)
  (W): Linear(in_features=16, out_features=16192, bias=False)
  (C): Linear(in_features=23, out_features=16, bias=False)
  (S): Linear(in_features=253, out_features=8, bias=False)
)

In [None]:
for epoch in range(30):
    # Unsupervised training on the MSA sequences.
    # https://stackoverflow.com/questions/51433378/what-does-model-train-do-in-pytorch
    vae.train()
    
    epoch_losses = { 'rl': [], 'klp': [], 'klz': [] }
    for batch in dataloader:
        # https://discuss.pytorch.org/t/what-step-backward-and-zero-grad-do/33301/2
        opt.zero_grad()
        x_hat, mu, logvar  = vae(batch)
        loss, rl, klz, klp = vae.loss(x_hat, batch, mu, logvar)
        loss.mean().backward(retain_graph=True) # Stefan: Do we need 'retain_graph'? - ask yevgen 
        opt.step()
        epoch_losses['rl'].append(rl.mean().item())
        epoch_losses['klp'].append(klp.mean().item())
        epoch_losses['klz'].append(klz.item())

    # Evaluation on mutants
    vae.eval()
    # x_hat_eval, mu, logvar = vae(eval_batch, rep=False)
    # elbos, _, _, _ = vae.loss(x_hat_eval, eval_batch, mu, logvar)
    # elbos = vae.logp(eval_batch)
    # diffs       = elbos[1:] - elbos[0] # log-ratio (first equation in the paper)
    # cor, _      = spearmanr(mutants_df.value, diffs.detach().to('cpu'))
    cor = get_cor_ensamble(eval_batch, mutants_df.value, vae, ensambles=16, rand=True)
    
    # Populate statistics 
    stats['rl'].append(np.mean(epoch_losses['rl']))
    stats['klz'].append(np.mean(epoch_losses['klz']))
    stats['klp'].append(np.mean(epoch_losses['klp']))
    stats['cor'].append(np.abs(cor))

    to_print = [
        f"{c.HEADER}EPOCH %03d"          % epoch,
        f"{c.OKBLUE}RL=%4.4f"            % stats['rl'][-1], # reconstrution loss
        f"{c.OKGREEN}KLZ=%4.4f"          % stats['klz'][-1], # KL loss
        f"{c.OKCYAN}|rho|=%4.4f{c.ENDC}" % stats['cor'][-1] # correlation (we want to max this value)
    ]
    print(" ".join(to_print))

	Reached 14/rand=True                                 
[95mEPOCH 000 [94mRL=778.7021 [92mKLZ=0.1985 [96m|rho|=0.4288[0m

[95mEPOCH 001 [94mRL=748.3698 [92mKLZ=0.0257 [96m|rho|=0.4620[0m

[95mEPOCH 002 [94mRL=712.2460 [92mKLZ=0.0307 [96m|rho|=0.5206[0m

[95mEPOCH 003 [94mRL=670.7138 [92mKLZ=0.0251 [96m|rho|=0.5455[0m

[95mEPOCH 004 [94mRL=628.1255 [92mKLZ=0.0167 [96m|rho|=0.5637[0m

[95mEPOCH 005 [94mRL=589.5910 [92mKLZ=0.0115 [96m|rho|=0.5766[0m

[95mEPOCH 006 [94mRL=558.5648 [92mKLZ=0.0081 [96m|rho|=0.5817[0m

[95mEPOCH 007 [94mRL=538.0320 [92mKLZ=0.0061 [96m|rho|=0.5864[0m

[95mEPOCH 008 [94mRL=527.0403 [92mKLZ=0.0043 [96m|rho|=0.5911[0m

[95mEPOCH 009 [94mRL=519.9695 [92mKLZ=0.0040 [96m|rho|=0.5948[0m

[95mEPOCH 010 [94mRL=514.7305 [92mKLZ=0.0039 [96m|rho|=0.5973[0m

[95mEPOCH 011 [94mRL=511.7554 [92mKLZ=0.0042 [96m|rho|=0.5998[0m

[95mEPOCH 012 [94mRL=512.0990 [92mKLZ=0.0040 [96m|rho|=0.6011[0m

[95mEPOCH 013 [94mRL

In [None]:
torch.save({
    'state_dict': vae.state_dict(), 
    'stats':      stats,
    'args':       args,
}, "models/full_paper.model.pth")

<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=35838f82-2ce6-4453-9bd2-2d87a43af151' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>