In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
data_path = '/content/gdrive/MyDrive/data/BLAT_ECOLX_hmmerbit_plmc_n5_m30_f50_t0.2_r24-286_id100_b105.a2m'
labels_path = '/content/gdrive/MyDrive/data/BLAT_ECOLX_hmmerbit_plmc_n5_m30_f50_t0.2_r24-286_id100_b105_LABELS.a2m'
mutations_path = '/content/gdrive/MyDrive/data/BLAT_ECOLX_Ranganathan2015.csv'

In [None]:
data_path_weights = '/content/gdrive/MyDrive/AML/WEIGHTS2.txt'
with open(data_path_weights, 'r') as testwritefile:
    new_weights=testwritefile.read()
content_list = open(data_path_weights, 'r').readlines()
my_list = content_list[0].split(",")
weights = [float(my_list[i]) for i in range(len(my_list)-1)]
Neff = 2780.0254087972826

# misc.py

In [None]:
# This module loads and prepares the data

import torch, time, sys, re
import pandas as pd
from torch.nn import functional as F
from torch.utils.data import DataLoader
import numpy as np

ALPHABET = 'ACDEFGHIKLMNPQRSTVWXYZ-'
SEQ2IDX = dict(map(reversed, enumerate(ALPHABET)))


def fasta(file_path):
    """This function parses a subset of the FASTA format
    https://en.wikipedia.org/wiki/FASTA_format"""

    print(f"Parsing fasta '{file_path}'")
    data = {
        'ur_up_': [], 'accession': [],
        'entry_name': [], 'offset': [],
        'taxonomy': [], 'sequence': []
    }

    with open(file_path, 'r') as f:
        for i, line in enumerate(f):
            line = line.strip()

            if line[0] == '>':
                key = line[1:]

                if i == 0:
                    name, offset = key.split("/")
                    ur_up_, acc = None, None
                else:
                    ur_up_, acc, name_offset = key.split("|")
                    name, offset = name_offset.split('/')

                data['ur_up_'].append(ur_up_)
                data['accession'].append(acc)
                data['entry_name'].append(name)
                data['offset'].append(offset)
                data['sequence'].append('')
                data['taxonomy'].append(name.split('_')[1])
            else:
                data['sequence'][-1] += line

            if i and (i % 50000 == 0):
                print(f"Reached: {i}")

    return pd.DataFrame(data=data)


def labels(labels_file, labels=[]):
    """Parses the labels file"""

    print(f"Parsing labels '{labels_file}'")
    with open(labels_file, 'r') as f:
        for i, line in enumerate(f):
            labels.append(line.split(':')[-1].strip())
    return pd.Series(labels)


def trim(full_sequences, focus_columns, sequences=[]):
    """Trims the sequences according to the focus columns"""

    for seq in full_sequences:
        seq = seq.replace('.', '-')
        trimmed = [seq[idx].upper() for idx in focus_columns]
        sequences.append(''.join(trimmed))
    return pd.Series(sequences)


def encode(sequences):
    t0 = time.time()
    print(f"Generating {len(sequences)} 1-hot encodings")
    tensors, l = [], len(ALPHABET)
    for seq in sequences:
        idxseq = [SEQ2IDX[s] for s in seq]
        tensor = F.one_hot(torch.tensor(idxseq), l).t().float()
        tensors.append(tensor)
    r = torch.stack(tensors)
    print(f"Generating {len(sequences)} 1-hot encodings. Took {round(time.time() - t0, 3)}s", r.shape)
    return r


def mutants(df):
    global mdf, offset, wt_full

    col = '2500'  # name of the column of our interest.
    mdf = pd.read_csv(mutations_path)
    mdf = pd.DataFrame(data={'value': mdf[col].values}, index=mdf['mutant'].values)
    wt_row = df.iloc[0]  # wildtype row in df
    wt_off = wt_row['offset']  # wildtype offset (24-286)
    offset = int(wt_off.split('-')[0])  # left-side offset: 24
    wt_full = wt_row['sequence']
    focus_columns = [idx for idx, char in enumerate(wt_full) if char.isupper()]

    reg_co = re.compile("([a-zA-Z]+)([0-9]+)([a-zA-Z]+)")
    mutants = {'mutation': [], 'sequence': [], 'value': []}

    for i, (k, v) in enumerate(mdf.iterrows()):
        v = v['value']
        _from, _index, _to = reg_co.match(k).groups()
        _index = int(_index) - offset

        if wt_full[_index].islower():
            continue  # we skip the lowercase residues

        if wt_full[_index] != _from:
            print("WARNING: Mutation sequence mismatch:", k, "full wt index:", _index)

        mutant = wt_full[:_index] + _to + wt_full[_index + 1:]
        mutant_trimmed = [mutant[idx] for idx in focus_columns]

        mutants['mutation'].append(k)
        mutants['sequence'].append(''.join(mutant_trimmed))
        mutants['value'].append(v)
    return pd.DataFrame(data=mutants)


def hamming_distance(a, b):
    result = 0
    for x, (i, j) in enumerate(zip(a, b)):
        if i != j:
            #print(f'char not math{i, j}in {x}')
            result += 1
    return result

def normalize(v): 
  norm = np.linalg.norm(v) 
  if norm == 0: 
    return v 

  return v / norm

def min_max(v):
  norm = (v-np.min(v))/(np.max(v)-np.min(v))

  return norm

def stand(v):
  return (v - np.average(v)) / (np.std(v))

def seq_weights(df, theta):
  weights = []

  for i in range(df.shape[0]):
      hamming_dist = []
      for j in range(df.shape[0]):
          hamming_dist.append(hamming_distance(df['trimmed'][i], df['trimmed'][j]))

      norm_dist = min_max(hamming_dist) #stand(hamming_dist) #normalize(hamming_dist)

      weights.append(1/sum([1 for norm in norm_dist if norm < theta]))

  n_eff = sum(weights)
  p_s = [w/n_eff for w in weights]

  return p_s

def data(batch_size=128, device='cpu'):
    df = fasta(data_path)
    df['label'] = labels(labels_path)

    # First sequence in the dataframe/fasta file is our wildtype.
    wildtype_seq = df.sequence[0]

    # What wildtype column-positions are we confident about (uppercased chars)
    focus_columns = [idx for idx, char in enumerate(wildtype_seq) if char.isupper()]

    # Trim the full sequences according to the columns we are confident at
    df['trimmed'] = trim(df.sequence, focus_columns)

    # Unique aminoacids are are:
    # ''.join(set(''.join(df.trimmed.to_list())))

    dataset = encode(df.trimmed).to(device)

    #weights = seq_weights(df, theta=0.2)
    
    sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights)) 
    dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)

    mutants_df = mutants(df)
    mutants_tensor = encode(mutants_df.sequence)

    return dataloader, df, mutants_tensor, mutants_df


# nice colors for the terminal
class c:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKCYAN = '\033[96m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'


#if __name__ == "__main__":
#    dataloader, df, mutants_tensor, mutants_df = data()

# vae.py

In [None]:
# This VAE is as vanilla as it can be.
import torch

class VAE(torch.nn.Module):
    def __init__(self, **kwargs):
        super(VAE, self).__init__()
        self.hidden_size   = 64
        self.latent_size   = 2
        self.alphabet_size = kwargs['alphabet_size']
        self.seq_len       = kwargs['seq_len']
        self.input_size    = self.alphabet_size * self.seq_len

        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(self.input_size, self.hidden_size),
            torch.nn.ReLU(),
        )

        # Latent space `mu` and `var`
        self.fc21 = torch.nn.Linear(self.hidden_size, self.latent_size)
        self.fc22 = torch.nn.Linear(self.hidden_size, self.latent_size)

        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(self.latent_size, self.hidden_size),
            torch.nn.ReLU(),
            torch.nn.Linear(self.hidden_size, self.input_size),
        )

    def forward(self, x, rep=True):
        x = x.view(-1, self.input_size)                    # flatten
        x = self.encoder(x)                                # encode
        mu, logvar = self.fc21(x), self.fc22(x)            # branch mu, var

        if rep:                                            # reparameterize
            x = mu + torch.randn_like(mu) * (0.5*logvar).exp() 
        else:                                              # or don't 
            x = mu                                         

        x = self.decoder(x)                                # decode
        x = x.view(-1, self.alphabet_size, self.seq_len)   # squeeze back
        x = x.log_softmax(dim=1)                           # softmax

        return x, mu, logvar
    
    def loss(self, x_hat, true_x, mu, logvar, beta=0.5):
        RL = -(x_hat*true_x).sum(-1).sum(-1)                    # reconst. loss
        KL = -0.5 * (1 + logvar - mu**2 - logvar.exp()).sum(-1) # KL loss
        return RL + beta*KL, RL, KL

# train.py

In [None]:
import torch
import numpy as np
#from misc import data, c
from torch import optim
from scipy.stats import spearmanr
#from vae import VAE

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataloader, df, mutants_tensor, mutants_df = data(batch_size = 64)

wildtype   = dataloader.dataset[0] # one-hot-encoded wildtype 
eval_batch = torch.cat([wildtype.unsqueeze(0), mutants_tensor])

args = {
    'alphabet_size': dataloader.dataset[0].shape[0],
    'seq_len':       dataloader.dataset[0].shape[1]
}

vae   = VAE(**args).to(device)
opt   = optim.Adam(vae.parameters())

Parsing fasta '/content/gdrive/MyDrive/data/BLAT_ECOLX_hmmerbit_plmc_n5_m30_f50_t0.2_r24-286_id100_b105.a2m'
Parsing labels '/content/gdrive/MyDrive/data/BLAT_ECOLX_hmmerbit_plmc_n5_m30_f50_t0.2_r24-286_id100_b105_LABELS.a2m'
Generating 8403 1-hot encodings
Generating 8403 1-hot encodings. Took 0.922s torch.Size([8403, 23, 253])
Generating 4807 1-hot encodings
Generating 4807 1-hot encodings. Took 0.446s torch.Size([4807, 23, 253])


In [None]:
# rl  = Reconstruction loss
# kl  = Kullback-Leibler divergence loss
# cor = Spearman correlation to experimentally measured 
#       protein fitness according to eq.1 from paper
stats = { 'rl': [], 'kl': [], 'cor': [] }

for epoch in range(1):
    # Unsupervised training on the MSA sequences.
    vae.train()
    
    epoch_losses = { 'rl': [], 'kl': [] }
    for batch in dataloader:
        batch = batch.to(device)
        opt.zero_grad()
        x_hat, mu, logvar = vae(batch)
        loss, rl, kl      = vae.loss(x_hat, batch, mu, logvar)
        loss.mean().backward()
        opt.step()
        epoch_losses['rl'].append(rl.mean().item())
        epoch_losses['kl'].append(kl.mean().item())

    # Evaluation on mutants
    vae.eval()
    eval_batch = eval_batch.to(device)
    x_hat_eval, mu, logvar = vae(eval_batch, rep=False)
    elbos, _, _ = vae.loss(x_hat_eval, eval_batch, mu, logvar)
    diffs       = elbos[1:] - elbos[0] # log-ratio (first equation in the paper)
    cor, _      = spearmanr(mutants_df.value, diffs.detach().cpu())
    
    # Populate statistics 
    stats['rl'].append(np.mean(epoch_losses['rl']))
    stats['kl'].append(np.mean(epoch_losses['kl']))
    stats['cor'].append(np.abs(cor))

    to_print = [
        f"{c.HEADER}EPOCH %03d"          % epoch,
        f"{c.OKBLUE}RL=%4.4f"            % stats['rl'][-1], 
        f"{c.OKGREEN}KL=%4.4f"           % stats['kl'][-1], 
        f"{c.OKCYAN}|rho|=%4.4f{c.ENDC}" % stats['cor'][-1]
    ]
    print(" ".join(to_print))

torch.save({
    'state_dict': vae.state_dict(), 
    'stats':      stats,
    'args':       args,
}, "trained.model.pth")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
        [[ 0.3014,  0.8824, -0.9075,  ...,  2.2877,  1.9633,  0.3922],
         [-4.6456, -2.4753, -3.0742,  ..., -2.8906, -1.0768, -2.3474],
         [ 0.4227,  0.5659, -4.7076,  ..., -0.7510, -3.3283, -4.9070],
         ...,
         [-4.6729, -2.1956, -1.4500,  ..., -0.9200, -4.2350,  2.8232],
         [-5.9248, -3.4167, -3.5359,  ..., -4.4695, -4.7759, -5.2614],
         [ 2.1352,  2.5337,  2.4989,  ...,  1.5915,  1.3225,  2.1551]],

        ...,

        [[ 0.3401,  0.8350, -0.7765,  ...,  2.0729,  1.8131,  0.4141],
         [-4.1375, -2.1637, -2.7087,  ..., -2.6164, -1.0287, -2.0407],
         [ 0.3826,  0.4935, -4.1700,  ..., -0.6708, -2.9958, -4.2908],
         ...,
         [-4.1140, -1.9578, -1.2783,  ..., -0.8632, -3.7213,  2.4645],
         [-5.2467, -3.0334, -3.1240,  ..., -3.8952, -4.1911, -4.6036],
         [ 1.8626,  2.1663,  2.1243,  ...,  1.3545,  1.1030,  1.8661]],

        [[ 2.2309,  2.2732, -0.3840, 

In [None]:
import matplotlib.pyplot as plt
model_dict = torch.load('trained.model.pth')

plt.figure(figsize=(18,4))
plt.subplot(1,3,1)
plt.title("Loss statistics")
ax1 = plt.gca()
ax2 = ax1.twinx()
ax1.set_xlabel('EPOCH', c='C3')
ax1.tick_params(axis='x', labelcolor='C3')
ax1.set_ylabel('Reconstruction Loss (RL)', c='C0')
ax1.tick_params(axis='y', labelcolor='C0')
ax1.plot(model_dict['stats']['rl'], lw=2, c='C0')
ax2.set_ylabel('Kullback-Leibler divergence loss (KL)', c='C2')
ax2.tick_params(axis='y', labelcolor='C2')
ax2.plot(model_dict['stats']['kl'], lw=2, c='C2')
ax2.grid(False)

plt.subplot(1,3,2)
plt.title(r"$|Spearman\ \rho|$ correlation to experimental data")
plt.xlabel('EPOCH', c='C3')
plt.tick_params(axis='x', labelcolor='C3')
plt.plot(model_dict['stats']['cor'], lw=2, c='C9', label="Our result")
plt.tick_params(axis='y', labelcolor='C9')
plt.axhline(y=0.74388, c='C6', lw=2, label=f'Paper result (without ensambling) ' + rf'$|\rho|={round(0.74388, 4)}$')
plt.legend()

plt.subplot(1,3,3)
plt.title("Latent space")
mask = df['label'].isin(df['label'].value_counts()[:5].index) # We limit to top 5 classes only
vae = VAE(**model_dict['args'])
vae.load_state_dict(model_dict['state_dict'])
vae.eval()
_, mu, logvar = vae(dataloader.dataset[mask], rep=False)
columns = [str(i+1) for i in range(mu.shape[1])] + ['label']
dfp = pd.DataFrame(data=np.c_[mu.detach().numpy(), df[mask]['label']], columns=columns)
dfp = dfp.set_index('1').groupby('label')['2']
dfp.plot(style='.', ms=2, alpha=0.5, legend=True);
plt.tight_layout()
None

# vae in Pyro

In [None]:
!pip install pyro-ppl

Collecting pyro-ppl
[?25l  Downloading https://files.pythonhosted.org/packages/aa/7a/fbab572fd385154a0c07b0fa138683aa52e14603bb83d37b198e5f9269b1/pyro_ppl-1.6.0-py3-none-any.whl (634kB)
[K     |▌                               | 10kB 13.5MB/s eta 0:00:01[K     |█                               | 20kB 19.6MB/s eta 0:00:01[K     |█▌                              | 30kB 17.5MB/s eta 0:00:01[K     |██                              | 40kB 11.6MB/s eta 0:00:01[K     |██▋                             | 51kB 10.0MB/s eta 0:00:01[K     |███                             | 61kB 7.3MB/s eta 0:00:01[K     |███▋                            | 71kB 7.3MB/s eta 0:00:01[K     |████▏                           | 81kB 8.1MB/s eta 0:00:01[K     |████▋                           | 92kB 8.4MB/s eta 0:00:01[K     |█████▏                          | 102kB 9.1MB/s eta 0:00:01[K     |█████▊                          | 112kB 9.1MB/s eta 0:00:01[K     |██████▏                         | 122kB 9.1MB/s 

In [None]:
import pyro
import pyro.distributions as dist
import pyro.contrib.examples.util  # patches torchvision
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from torch.nn.functional import softplus

In [None]:
import torch

class Encoder(torch.nn.Module):
  def __init__(self, input_size, hidden_size, latent_size):
    super().__init__()
    self.input_size = input_size

    self.fc11 = torch.nn.Linear(input_size, hidden_size)
    self.relu = torch.nn.ReLU()

    # Latent space `mu` and `var`
    self.fc21 = torch.nn.Linear(hidden_size, latent_size)
    self.fc22 = torch.nn.Linear(hidden_size, latent_size)

    self.softplus = torch.nn.Softplus()

  def forward(self, x, rep=True):
    x = x.view(-1, self.input_size)                    # flatten

    x = self.fc11(x)                                
    x = self.relu(x)  

    # branch mu, var
    mu, logvar = self.fc21(x), self.fc22(x)    

    
    if rep:                                            # reparameterize
        x = mu + torch.randn_like(mu) * (0.5*logvar).exp() 
    else:                                              # or don't 
        x = mu
    
   
    return mu, logvar

In [None]:
class Decoder(torch.nn.Module):
  def __init__(self, input_size, hidden_size, latent_size, alphabet_size, seq_len):
    super().__init__()
    self.alphabet_size = alphabet_size
    self.seq_len = seq_len

    self.fc31 = torch.nn.Linear(latent_size, hidden_size)
    self.relu = torch.nn.ReLU()
    self.fc32 = torch.nn.Linear(hidden_size, input_size)

  def forward(self, z):
    x = self.fc31(z)                                
    x = self.relu(x)
    x = self.fc32(x)
    x = x.view(-1, self.alphabet_size, self.seq_len)   # squeeze back
    x = x.log_softmax(dim=1) 
    
    return x

In [None]:
# This VAE is as vanilla as it can be.

def loss(self, x_hat, true_x, mu, logvar, beta=0.5):
  RL = -(x_hat*true_x).sum(-1).sum(-1)                    # reconst. loss
  KL = -0.5 * (1 + logvar - mu**2 - logvar.exp()).sum(-1) # KL loss
  return RL + beta*KL, RL, KL

In [None]:
class VAE(torch.nn.Module):
  def __init__(self, **kwargs):
    super(VAE, self).__init__()
    self.hidden_size   = 64
    self.latent_size   = 2
    self.alphabet_size = kwargs['alphabet_size']
    self.seq_len       = kwargs['seq_len']
    self.input_size    = self.alphabet_size * self.seq_len

    # create the encoder and decoder networks
    self.encoder = Encoder(self.input_size, self.hidden_size, self.latent_size)
    self.decoder = Decoder(self.input_size, self.hidden_size, self.latent_size, self.alphabet_size, self.seq_len)

  def model(self, x):
    # Encoder
    priors_encoder = {} # Priors for the neural model
    for name, par in self.encoder.named_parameters():     # Loop over all neural network parameters
        if "weight" in name:
          mean = torch.normal(mean=torch.zeros(1), std=((2/(par.shape[0]+par.shape[1]))**(1/2))*torch.ones(1))
          logvar = -5
          priors_encoder[name] = dist.Normal(mean*torch.ones(*par.shape), (0.5*logvar*torch.ones(*par.shape)).exp()) 
        else:
          logvar = -10
          priors_encoder[name] = dist.Normal(torch.zeros(*par.shape), (0.5*logvar*torch.ones(*par.shape)).exp())

    bayesian_model_enc = pyro.random_module('bayesian_model_enc', self.encoder, priors_encoder) # Make this model and these priors a Pyro model
    sampled_model_enc = bayesian_model_enc()

    # Decoder
    priors_decoder = {} # Priors for the neural model
    for name, par in self.decoder.named_parameters():     # Loop over all neural network parameters
        if "weight" in name:
          mean = torch.normal(mean=torch.zeros(1), std=((2/(par.shape[0]+par.shape[1]))**(1/2))*torch.ones(1))
          logvar = -5
          priors_encoder[name] = dist.Normal(mean*torch.ones(*par.shape), (0.5*logvar*torch.ones(*par.shape)).exp()) #.to_event(0) 
        else:
          logvar = -10
          priors_encoder[name] = dist.Normal(torch.zeros(*par.shape), (0.5*logvar*torch.ones(*par.shape)).exp())
    
    bayesian_model_dec = pyro.random_module('bayesian_model_dec', self.decoder, priors_decoder) # Make this model and these priors a Pyro model
    sampled_model_dec = bayesian_model_dec()

    #with pyro.plate("data", self.input_size):
      # use the encoder to get the parameters used to define q(z|x)
    z_loc, logvar = sampled_model_enc(x)
    z_scale = softplus((0.5*logvar).exp())

    # sample the latent code z
    z = pyro.sample("latent", dist.Normal(z_loc, z_scale))

    # decode the latent code z
    loc = sampled_model_dec(z)

    pyro.sample("obs", dist.Categorical(logits=loc), obs=x) 

  def vr(name, *shape):
    return pyro.param(name, torch.autograd.Variable(torch.randn(*shape), requires_grad=True))

  def guide(self, x):
    wr1 = self.encoder.fc11.weight.shape[0]
    wc1 = self.encoder.fc11.weight.shape[1]
    wr2 = self.encoder.fc21.weight.shape[0]
    wc2 = self.encoder.fc21.weight.shape[1]

    priors_encoder = {
        'fc11.weight': dist.Normal(VAE.vr('fc11w_mu', wr1, wc1), F.softplus(VAE.vr("fc11w_sigma", wr1, wc1))),
        'fc11.bias': dist.Normal(VAE.vr('fc11b_mu', wr1), F.softplus(VAE.vr("fc11b_sigma", wr1))),
        'fc21.weight': dist.Normal(VAE.vr('fc21w_mu', wr2, wc2), F.softplus(VAE.vr("fc21w_sigma", wr2, wc2))),
        'fc21.bias': dist.Normal(VAE.vr('fc21b_mu', wr2), F.softplus(VAE.vr("fc21b_sigma", wr2))),
        'fc22.weight': dist.Normal(VAE.vr('fc22w_mu', wr2, wc2), F.softplus(VAE.vr("fc22w_sigma", wr2, wc2))),
        'fc22.bias': dist.Normal(VAE.vr('fc22b_mu', wr2), F.softplus(VAE.vr("fc21b_sigma", wr2)))
    }

    bayesian_model_enc = pyro.random_module('bayesian_model_enc', self.encoder, priors_encoder) # Make this model and these priors a Pyro model
    sampled_model_enc = bayesian_model_enc()

    wr1 = self.decoder.fc31.weight.shape[0]
    wc1 = self.decoder.fc31.weight.shape[1]
    wr2 = self.decoder.fc32.weight.shape[0]
    wc2 = self.decoder.fc32.weight.shape[1]

    priors_decoder = {
        'fc31.weight': dist.Normal(VAE.vr('fc31w_mu', wr1, wc1), F.softplus(VAE.vr("fc31w_sigma", wr1, wc1))),
        'fc31.bias': dist.Normal(VAE.vr('fc31b_mu', wr1), F.softplus(VAE.vr("fc31b_sigma", wr1))),
        'fc32.weight': dist.Normal(VAE.vr('fc32w_mu', wr2, wc2), F.softplus(VAE.vr("fc32w_sigma", wr2, wc2))),
        'fc32.bias': dist.Normal(VAE.vr('fc32b_mu', wr2), F.softplus(VAE.vr("fc32b_sigma", wr2)))
    }

    bayesian_model_dec = pyro.random_module('bayesian_model_dec', self.decoder, priors_decoder) # Make this model and these priors a Pyro model
    sampled_model_dec = bayesian_model_dec()

    #with pyro.plate("data", x.shape[0]):
      # use the encoder to get the parameters used to define q(z|x)
    z_loc, logvar = sampled_model_enc(x)
    z_scale = softplus((0.5*logvar).exp())
    
    # sample the latent code z
    pyro.sample("latent", dist.Normal(z_loc, z_scale))

    return sampled_model_enc, sampled_model_dec

# train in Pyro

In [None]:
import torch
import numpy as np
#from misc import data, c
from torch import optim
from scipy.stats import spearmanr
#from vae import VAE

#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataloader, df, mutants_tensor, mutants_df = data(batch_size = 64)

wildtype   = dataloader.dataset[0] # one-hot-encoded wildtype 
eval_batch = torch.cat([wildtype.unsqueeze(0), mutants_tensor])

args = {
    'alphabet_size': dataloader.dataset[0].shape[0],
    'seq_len':       dataloader.dataset[0].shape[1]
}

Parsing fasta '/content/gdrive/MyDrive/data/BLAT_ECOLX_hmmerbit_plmc_n5_m30_f50_t0.2_r24-286_id100_b105.a2m'
Parsing labels '/content/gdrive/MyDrive/data/BLAT_ECOLX_hmmerbit_plmc_n5_m30_f50_t0.2_r24-286_id100_b105_LABELS.a2m'
Generating 8403 1-hot encodings
Generating 8403 1-hot encodings. Took 0.989s torch.Size([8403, 23, 253])
Generating 4807 1-hot encodings
Generating 4807 1-hot encodings. Took 0.434s torch.Size([4807, 23, 253])


In [None]:
#from pyro.contrib.autoguide import AutoDiagonalNormal
vae   = VAE(**args)
opt   = Adam({"lr": 0.000000001})
svi = SVI(vae.model, vae.guide, opt, loss=Trace_ELBO())

In [None]:
def train(svi, train_loader):
    # initialize loss accumulator
    epoch_loss = 0.
    # do a training epoch over each mini-batch x returned
    # by the data loader
    for batch in train_loader:
        # do ELBO gradient and accumulate loss
        epoch_loss += svi.step(batch)

    # return epoch loss
    normalizer_train = len(train_loader.dataset)
    total_epoch_loss_train = epoch_loss / normalizer_train
    return total_epoch_loss_train

In [None]:
def evaluate(svi, eval_batch):
    # compute ELBO estimate and accumulate loss
    test_loss = svi.evaluate_loss(eval_batch)
    return test_loss

In [None]:
train_elbo = []
test_elbo = []
# training loop
for epoch in range(1):
    total_epoch_loss_train = train(svi, dataloader)
    train_elbo.append(-total_epoch_loss_train)

    # report test diagnostics
    total_epoch_loss_test = evaluate(svi, eval_batch)
    test_elbo.append(-total_epoch_loss_test)

    to_print = [
        f"{c.HEADER}EPOCH %03d" % epoch,
        f"{c.OKBLUE}TRAIN LOSS=%4.4f" % total_epoch_loss_train,
        f"{c.OKGREEN}TEST LOSS=%4.4f" % total_epoch_loss_test
    ]
    print(" ".join(to_print))



ValueError: ignored

In [None]:
class VAE(torch.nn.Module):
  def __init__(self, **kwargs):
    super(VAE, self).__init__()
    self.hidden_size   = 64
    self.latent_size   = 2
    self.alphabet_size = kwargs['alphabet_size']
    self.seq_len       = kwargs['seq_len']
    self.input_size    = self.alphabet_size * self.seq_len

    # create the encoder and decoder networks
    self.encoder = Encoder(self.input_size, self.hidden_size, self.latent_size)
    self.decoder = Decoder(self.input_size, self.hidden_size, self.latent_size, self.alphabet_size, self.seq_len)

  def model(self, x):
    priors_decoder = {} # Priors for the neural model
    for name, par in self.decoder.named_parameters():     # Loop over all neural network parameters
        priors_decoder[name] = dist.Normal(torch.zeros(*par.shape), torch.ones(*par.shape)).to_event() # Each parameter has a N(0, 1) prior
    
    bayesian_model_dec = pyro.random_module('bayesian_model_dec', self.decoder, priors_decoder) # Make this model and these priors a Pyro model
    sampled_model_dec = bayesian_model_dec()

    with pyro.plate("data", x.shape[0]):
      # setup hyperparameters for prior p(z)
      z_loc = x.new_zeros(torch.Size((x.shape[0], self.latent_size)))
      z_scale = x.new_ones(torch.Size((x.shape[0], self.latent_size)))
      # sample from prior (value will be sampled by guide when computing the ELBO)
      z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))

      # decode the latent code z
      loc = sampled_model_dec(z)

      pyro.sample("obs", dist.Bernoulli(loc.reshape(-1, self.input_size)).to_event(1), obs=x.reshape(-1, self.input_size))  

  def guide(self, x):
    priors_encoder = {} # Priors for the neural model

    # Encoder weight distribution priors
    fc11w_mu = torch.randn_like(self.encoder.fc11.weight)
    fc11w_sigma = torch.randn_like(self.encoder.fc11.weight)
    fc11w_mu_param = pyro.param("fc11w_mu", fc11w_mu)
    fc11w_sigma_param = softplus(pyro.param("fc11w_sigma", fc11w_sigma))
    fc11w_prior = dist.Normal(loc=fc11w_mu_param, scale=fc11w_sigma_param)
    # Encoder bias distribution priors
    fc11b_mu = torch.randn_like(self.encoder.fc11.bias)
    fc11b_sigma = torch.randn_like(self.encoder.fc11.bias)
    fc11b_mu_param = pyro.param("fc11b_mu", fc11b_mu)
    fc11b_sigma_param = softplus(pyro.param("fc11b_sigma", fc11b_sigma))
    fc11b_prior = dist.Normal(loc=fc11b_mu_param, scale=fc11b_sigma_param)

    # Latent state weight distribution priors
    ## 1
    fc21w_mu = torch.randn_like(self.encoder.fc21.weight)
    fc21w_sigma = torch.randn_like(self.encoder.fc21.weight)
    fc21w_mu_param = pyro.param("fc21w_mu", fc21w_mu)
    fc21w_sigma_param = softplus(pyro.param("fc21w_sigma", fc21w_sigma))
    fc21w_prior = dist.Normal(loc=fc21w_mu_param, scale=fc21w_sigma_param)
    ## 2
    fc22w_mu = torch.randn_like(self.encoder.fc22.weight)
    fc22w_sigma = torch.randn_like(self.encoder.fc22.weight)
    fc22w_mu_param = pyro.param("fc22w_mu", fc22w_mu)
    fc22w_sigma_param = softplus(pyro.param("fc22w_sigma", fc22w_sigma))
    fc22w_prior = dist.Normal(loc=fc22w_mu_param, scale=fc22w_sigma_param)
    # Latent state bias distribution priors
    ## 1
    fc21b_mu = torch.randn_like(self.encoder.fc21.bias)
    fc21b_sigma = torch.randn_like(self.encoder.fc21.bias)
    fc21b_mu_param = pyro.param("fc21b_mu", fc21b_mu)
    fc21b_sigma_param = softplus(pyro.param("fc21b_sigma", fc21b_sigma))
    fc21b_prior = dist.Normal(loc=fc21b_mu_param, scale=fc21b_sigma_param)
    ## 2
    fc22b_mu = torch.randn_like(self.encoder.fc22.bias)
    fc22b_sigma = torch.randn_like(self.encoder.fc22.bias)
    fc22b_mu_param = pyro.param("fc22b_mu", fc22b_mu)
    fc22b_sigma_param = softplus(pyro.param("fc22b_sigma", fc22b_sigma))
    fc22b_prior = dist.Normal(loc=fc22b_mu_param, scale=fc22b_sigma_param)

    priors_encoder = {'fc11.weight': fc11w_prior, 'fc11.bias': fc11b_prior, 'fc21.weight': fc21w_prior, 'fc21.bias': fc21b_prior, 'fc22.weight': fc22w_prior, 'fc22.bias': fc22b_prior}
    
    bayesian_model_enc = pyro.random_module('bayesian_model_enc', self.encoder, priors_encoder) # Make this model and these priors a Pyro model
    sampled_model_enc = bayesian_model_enc()
    
    with pyro.plate("data", x.shape[0]):
      # use the encoder to get the parameters used to define q(z|x)
      z_loc, logvar = sampled_model_enc(x)
      z_scale = torch.exp(logvar)
      l = torch.zeros((x.shape[0], self.latent_size))
      print("------------------------------------------------------")
      # sample the latent code z
      pyro.sample("latent", dist.Normal(l, z_scale).to_event(1))