# PML Project

Data: 2022-01-15

Author: Jiajun He

Content: IWAE DeepSequence

In [1]:
!wget https://sid.erda.dk/share_redirect/a5PTfl88w0/BLAT_ECOLX_1_b0.5_labeled.fasta

--2022-01-17 06:12:23--  https://sid.erda.dk/share_redirect/a5PTfl88w0/BLAT_ECOLX_1_b0.5_labeled.fasta
Resolving sid.erda.dk (sid.erda.dk)... 130.225.104.13
Connecting to sid.erda.dk (sid.erda.dk)|130.225.104.13|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2441075 (2.3M)
Saving to: ‘BLAT_ECOLX_1_b0.5_labeled.fasta’


2022-01-17 06:12:24 (3.05 MB/s) - ‘BLAT_ECOLX_1_b0.5_labeled.fasta’ saved [2441075/2441075]



In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
# parsing the FASTA file, codes from https://colab.research.google.com/github/wouterboomsma/pml_vae_project/blob/main/protein_vae_data_processing.ipynb
import os
import re
import numpy as np
import torch
import torch.nn.functional as F
import pandas as pd

# FASTA parser requires Biopython
try:
    from Bio import SeqIO
except:
    !pip install biopython
    from Bio import SeqIO
    
# Retrieve protein alignment file
if not os.path.exists('BLAT_ECOLX_1_b0.5_labeled.fasta'):
    !wget https://sid.erda.dk/share_redirect/a5PTfl88w0/BLAT_ECOLX_1_b0.5_labeled.fasta
        
# Retrieve file with experimental measurements
if not os.path.exists('BLAT_ECOLX_Ranganathan2015.csv'):
    !wget https://sid.erda.dk/share_redirect/a5PTfl88w0/BLAT_ECOLX_Ranganathan2015.csv
        
# Options
batch_size = 16

# Mapping from amino acids to integers
aa1_to_index = {'A': 0, 'C': 1, 'D': 2, 'E': 3, 'F': 4, 'G': 5, 'H': 6,
                'I': 7, 'K': 8, 'L': 9, 'M': 10, 'N': 11, 'P': 12,
                'Q': 13, 'R': 14, 'S': 15, 'T': 16, 'V': 17, 'W': 18,
                'Y': 19, 'X':20, 'Z': 21, '-': 22}
aa1 = "ACDEFGHIKLMNPQRSTVWYXZ-"

phyla = ['Acidobacteria', 'Actinobacteria', 'Bacteroidetes',
         'Chloroflexi', 'Cyanobacteria', 'Deinococcus-Thermus',
         'Firmicutes', 'Fusobacteria', 'Proteobacteria', 'Other']

def get_data(data_filename, calc_weights=False, weights_similarity_threshold=0.8):
    '''Create dataset from FASTA filename'''
    ids = []
    labels = []
    seqs = []
    label_re = re.compile(r'\[([^\]]*)\]')
    for record in SeqIO.parse(data_filename, "fasta"):
        ids.append(record.id)       
        seqs.append(np.array([aa1_to_index[aa] for aa in str(record.seq).upper().replace('.', '-')]))
        
        label = label_re.search(record.description).group(1)
        # Only use most common classes
        if label not in phyla:
            label = 'Other'
        labels.append(label)
                
    seqs = torch.from_numpy(np.vstack(seqs))
    labels = np.array(labels)
    
    phyla_lookup_table, phyla_idx = np.unique(labels, return_inverse=True)

    dataset = torch.utils.data.TensorDataset(*[seqs, torch.from_numpy(phyla_idx)])
    
    
    weights = None
    if calc_weights is not False:

        # Experiencing memory issues on colab for this code because pytorch doesn't
        # allow one_hot directly to bool. Splitting in two and then merging.
        # one_hot = F.one_hot(seqs.long()).to('cuda' if torch.cuda.is_available() else 'cpu')
        one_hot1 = F.one_hot(seqs[:len(seqs)//2].long()).bool()
        one_hot2 = F.one_hot(seqs[len(seqs)//2:].long()).bool()
        one_hot = torch.cat([one_hot1, one_hot2]).to('cuda' if torch.cuda.is_available() else 'cpu')
        assert(len(seqs) == len(one_hot))
        del one_hot1
        del one_hot2
        one_hot[seqs>19] = 0
        flat_one_hot = one_hot.flatten(1)

        weights = []
        weight_batch_size = 1000
        flat_one_hot = flat_one_hot.float()
        for i in range(seqs.size(0) // weight_batch_size + 1):
            x = flat_one_hot[i * weight_batch_size : (i + 1) * weight_batch_size]
            similarities = torch.mm(x, flat_one_hot.T)
            lengths = (seqs[i * weight_batch_size : (i + 1) * weight_batch_size] <=19).sum(1).unsqueeze(-1).to('cuda' if torch.cuda.is_available() else 'cpu')
            w = 1.0 / (similarities / lengths).gt(weights_similarity_threshold).sum(1).float()
            weights.append(w)
            
        weights = torch.cat(weights)
        neff = weights.sum()

    return dataset, weights


dataset, weights = get_data('BLAT_ECOLX_1_b0.5_labeled.fasta', calc_weights=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
dataloader_weighted = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=torch.utils.data.sampler.WeightedRandomSampler(weights, num_samples=len(dataset)))

Collecting biopython
  Downloading biopython-1.79-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (2.3 MB)
[K     |████████████████████████████████| 2.3 MB 4.3 MB/s 
Installing collected packages: biopython
Successfully installed biopython-1.79
--2022-01-18 10:07:07--  https://sid.erda.dk/share_redirect/a5PTfl88w0/BLAT_ECOLX_1_b0.5_labeled.fasta
Resolving sid.erda.dk (sid.erda.dk)... 130.225.104.13
Connecting to sid.erda.dk (sid.erda.dk)|130.225.104.13|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2441075 (2.3M)
Saving to: ‘BLAT_ECOLX_1_b0.5_labeled.fasta’


2022-01-18 10:07:11 (1.24 MB/s) - ‘BLAT_ECOLX_1_b0.5_labeled.fasta’ saved [2441075/2441075]

--2022-01-18 10:07:11--  https://sid.erda.dk/share_redirect/a5PTfl88w0/BLAT_ECOLX_Ranganathan2015.csv
Resolving sid.erda.dk (sid.erda.dk)... 130.225.104.13
Connecting to sid.erda.dk (sid.erda.dk)|130.225.104.13|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1216640 (1.2M) [text/c

In [3]:
import torch
from torch.optim import Adam
from torch import nn
from tqdm import tqdm
import matplotlib.pyplot as plt
from scipy.stats import spearmanr

In [4]:
global H
H = 2000

In [5]:
class DeepSeq_Double(nn.Module):
    def __init__(self, input_size, latent_size, device):
        super().__init__()
        self.fc11 = nn.Linear(input_size, 1500) # 1500 is the number in the papar
        self.fc12 = nn.Linear(1500, 1500)
        self.fc131 = nn.Linear(1500, latent_size)
        self.fc132 = nn.Linear(1500, latent_size)

        self.fc21 = nn.Linear(latent_size, 100)
        self.fc22 = nn.Linear(100, H)
        

        self.lambd_tilde_loc = nn.Parameter(torch.tensor(0.))
        self.C_loc = nn.Parameter(torch.randn(23, 40))
        self.W_tilde_loc = nn.Parameter(torch.randn(input_size//23, 40, H)) # L * E * H
        self.S_tilde_loc = nn.Parameter(torch.randn(H//4, input_size//23) * 4 - 12)

        self._lambd_tilde_scale = nn.Parameter(torch.tensor(0.) - 4)
        self._C_scale = nn.Parameter(torch.zeros(23, 40) - 4)
        self._W_tilde_scale = nn.Parameter(torch.zeros(input_size//23, 40, H) - 4) # L * E * H
        self._S_tilde_scale = nn.Parameter(torch.zeros(H//4, input_size//23) - 4) # 
        
        self.b3 = nn.Parameter(torch.randn(input_size//23, 23))

        self.device = device
    
    def encoder(self, x):
        x = nn.functional.one_hot(x, num_classes=23).float().reshape(x.shape[0], -1)
        hidden = nn.ReLU()(self.fc11(x))
        hidden = nn.ReLU()(self.fc12(hidden))
        z_mu = self.fc131(hidden)
        z_sd = torch.exp(self.fc132(hidden))

        return z_mu, z_sd

    def decoder(self, z, sample=True, lambd_tilde=None, C=None, S_tilde=None, W_tilde=None):
        hidden = nn.ReLU()(self.fc21(z))
        hidden = nn.Sigmoid()(self.fc22(hidden)) 

        # lambd_tilde = self.lambd_tilde_loc + nn.Softplus()(self._lambd_tilde_scale) * torch.randn_like(self.lambd_tilde_loc)
        # C = self.C_loc + nn.Softplus()(self._C_scale) * torch.randn_like(self.C_loc)
        # S_tilde = self.S_tilde_loc + nn.Softplus()(self._S_tilde_scale) * torch.randn_like(self.S_tilde_loc)
        # W_tilde = self.W_tilde_loc + nn.Softplus()(self._W_tilde_scale) * torch.randn_like(self.W_tilde_loc)
        if not sample:
            lambd_tilde = self.lambd_tilde_loc
            C = self.C_loc
            S_tilde = self.S_tilde_loc
            W_tilde = self.W_tilde_loc

        S = torch.cat([1 / (torch.exp(-S_tilde.clone()) + 1) for i in range(4)], dim=0)
        hidden = hidden[:, :, np.newaxis].expand([z.shape[0], H, S.shape[-1]])
        Sh = torch.permute((hidden * S), (0, 2, 1))[:, :, :,np.newaxis] # B * L * H * 1
        hidden = torch.log(torch.exp(lambd_tilde) + 1) * C @ W_tilde @ Sh #B L q 1
        hidden = hidden[:, :, :, 0] + self.b3 #B * L * q

        return hidden

    # def forward(self, x, mc_samples=1, ann_factor=1., sample_params=True, calculate_global_KL=True):
    #     z_mu, z_sd = self.encoder(x)

    #     # prior distribution
    #     prior = torch.distributions.Normal(0., 1.)
    #     posterior = torch.distributions.Normal(z_mu, z_sd)

    #     # since both prior and posterior are diag-Gaussian, the KL divergence between q(z|x) and p(z) is just the sum of the KL divergence in each dimension
    #     KL_post_prior = torch.sum(torch.distributions.kl.kl_divergence(posterior, prior), dim=-1)

    #     # E_q[p(x|z)] (sum because p(x|z) = p(x1|z)p(x2|z)...)
    #     z = torch.randn_like(z_mu) * z_sd + z_mu
    #     x_logit = self.decoder(z, sample_params)
    #     Ep = torch.sum(torch.distributions.Categorical(logits=x_logit).log_prob(x), dim=-1)
    #     for sample in range(mc_samples-1):
    #         z = torch.randn_like(z_mu) * z_sd + z_mu
    #         x_logit = self.decoder(z, sample_params)
    #         Ep += torch.sum(torch.distributions.Categorical(logits=x_logit).log_prob(x), dim=-1)
    #     Ep = Ep / mc_samples

    #     ELBO = torch.sum(Ep - KL_post_prior)
        
    #     if calculate_global_KL:
    #         # lambda distribution
    #         prior_lambd_tilde = torch.distributions.Normal(0., 1.)
    #         posterior_lambda_tilde = torch.distributions.Normal(self.lambd_tilde_loc, nn.Softplus()(self._lambd_tilde_scale))
    #         KL_lambda_tilde = torch.sum(torch.distributions.kl.kl_divergence(posterior_lambda_tilde, prior_lambd_tilde))
            
    #         # C distribution
    #         prior_C = torch.distributions.Normal(0., 1.)
    #         posterior_C= torch.distributions.Normal(self.C_loc, nn.Softplus()(self._C_scale))
    #         KL_C = torch.sum(torch.distributions.kl.kl_divergence(posterior_C, prior_C))
            
    #         # W distribution
    #         prior_W_tilde = torch.distributions.Normal(0., 1.)
    #         posterior_W_tilde = torch.distributions.Normal(self.W_tilde_loc, nn.Softplus()(self._W_tilde_scale))
    #         KL_W_tilde = torch.sum(torch.distributions.kl.kl_divergence(posterior_W_tilde, prior_W_tilde))
            
    #         # S distribution
    #         prior_S_tilde = torch.distributions.Normal(-12.36, 4.)
    #         posterior_S_tilde = torch.distributions.Normal(self.S_tilde_loc, nn.Softplus()(self._S_tilde_scale))
    #         KL_S_tilde = torch.sum(torch.distributions.kl.kl_divergence(posterior_S_tilde, prior_S_tilde))
            
    #         ELBO = ELBO - (KL_lambda_tilde + KL_C + KL_W_tilde + KL_S_tilde) * ann_factor
        

    #     return ELBO, z_mu, z_sd
    # def forward_realIWAE(self, x, mc_samples_l=1, mc_samples_k=10, calculate_global_KL=True):

    #     z_mu, z_sd = self.encoder(x)
        
    #     # z distribution
    #     prior = torch.distributions.Normal(0., 1.)
    #     posterior = torch.distributions.Normal(z_mu, z_sd)

    #     # lambda distribution
    #     prior_lambd_tilde = torch.distributions.Normal(0., 1.)
    #     posterior_lambda_tilde = torch.distributions.Normal(self.lambd_tilde_loc, nn.Softplus()(self._lambd_tilde_scale))
        
    #     # C distribution
    #     prior_C = torch.distributions.Normal(0., 1.)
    #     posterior_C= torch.distributions.Normal(self.C_loc, nn.Softplus()(self._C_scale))
        
    #     # W distribution
    #     prior_W_tilde = torch.distributions.Normal(0., 1.)
    #     posterior_W_tilde = torch.distributions.Normal(self.W_tilde_loc, nn.Softplus()(self._W_tilde_scale))
        
    #     # S distribution
    #     prior_S_tilde = torch.distributions.Normal(-12.36, 4.)
    #     posterior_S_tilde = torch.distributions.Normal(self.S_tilde_loc, nn.Softplus()(self._S_tilde_scale))
        

    #     LOG_Ps = []
    #     for s in range(mc_samples_l):
    #         log_p = []
    #         for sample in range(mc_samples_k):
    #             z = torch.randn_like(z_mu) * z_sd + z_mu
                
    #             lambd_tilde = self.lambd_tilde_loc + nn.Softplus()(self._lambd_tilde_scale) * torch.randn_like(self.lambd_tilde_loc)
    #             C = self.C_loc + nn.Softplus()(self._C_scale) * torch.randn_like(self.C_loc)
    #             S_tilde = self.S_tilde_loc + nn.Softplus()(self._S_tilde_scale) * torch.randn_like(self.S_tilde_loc)
    #             W_tilde = self.W_tilde_loc + nn.Softplus()(self._W_tilde_scale) * torch.randn_like(self.W_tilde_loc)

    #             x_logit = self.decoder(z, sample=True, lambd_tilde=lambd_tilde, C=C, S_tilde=S_tilde, W_tilde=W_tilde)
    #             log_p.append(
    #                 torch.sum(torch.distributions.Categorical(logits=x_logit).log_prob(x), dim=-1, keepdim=True) + \
    #                 torch.sum(prior.log_prob(z) - posterior.log_prob(z), dim=-1, keepdim=True) + \
    #                 torch.sum(prior_lambd_tilde.log_prob(lambd_tilde) - posterior_lambda_tilde.log_prob(lambd_tilde)) - torch.log(torch.tensor(x.shape[0])) + \
    #                 torch.sum(prior_S_tilde.log_prob(S_tilde) - posterior_S_tilde.log_prob(S_tilde)) - torch.log(torch.tensor(x.shape[0])) + \
    #                 torch.sum(prior_W_tilde.log_prob(W_tilde) - posterior_W_tilde.log_prob(W_tilde)) - torch.log(torch.tensor(x.shape[0])) + \
    #                 torch.sum(prior_C.log_prob(C) - posterior_C.log_prob(C)) - torch.log(torch.tensor(x.shape[0]))
    #                 ) # - torch.log(torch.tensor(x.shape[0])) because the KL of global variable will be broadcasted, we need to divide it by the batch size
    #         log_p = torch.cat(log_p, dim=-1)

    #         # To avoid numerical unstable
    #         log_p_max = torch.max(log_p, dim=-1, keepdim=True)[0]
    #         LOG_Ps.append(torch.log(torch.sum(torch.exp(log_p - log_p_max), dim=-1, keepdim=True) / mc_samples_k) + log_p_max)
    #     LOG_Ps = torch.sum(torch.cat(LOG_Ps, dim=-1), dim=-1) / mc_samples_l
        
    #     L_IWAE = torch.sum(LOG_Ps)

    #     return L_IWAE, z_mu, z_sd      
    
    def forward(self, x, mc_samples_l=1, mc_samples_k=10, calculate_global_KL=True):

        z_mu, z_sd = self.encoder(x)
        
        # z distribution
        prior = torch.distributions.Normal(0., 1.)
        posterior = torch.distributions.Normal(z_mu, z_sd)

        LOG_Ps = []
        for s in range(mc_samples_l):
            log_p = []
            lambd_tilde = self.lambd_tilde_loc + nn.Softplus()(self._lambd_tilde_scale) * torch.randn_like(self.lambd_tilde_loc)
            C = self.C_loc + nn.Softplus()(self._C_scale) * torch.randn_like(self.C_loc)
            S_tilde = self.S_tilde_loc + nn.Softplus()(self._S_tilde_scale) * torch.randn_like(self.S_tilde_loc)
            W_tilde = self.W_tilde_loc + nn.Softplus()(self._W_tilde_scale) * torch.randn_like(self.W_tilde_loc)
            for sample in range(mc_samples_k):
                z = torch.randn_like(z_mu) * z_sd + z_mu
                x_logit = self.decoder(z, sample=True, lambd_tilde=lambd_tilde, C=C, S_tilde=S_tilde, W_tilde=W_tilde)
                log_p.append(
                    torch.sum(torch.distributions.Categorical(logits=x_logit).log_prob(x), dim=-1, keepdim=True) + \
                    torch.sum(prior.log_prob(z) - posterior.log_prob(z), dim=-1, keepdim=True)
                    ) 
            log_p = torch.cat(log_p, dim=-1)

            # To avoid numerical unstable
            log_p_max = torch.max(log_p, dim=-1, keepdim=True)[0]
            LOG_Ps.append(torch.log(torch.sum(torch.exp(log_p - log_p_max), dim=-1, keepdim=True) / mc_samples_k) + log_p_max)
        LOG_Ps = torch.sum(torch.cat(LOG_Ps, dim=-1), dim=-1) / mc_samples_l
        
        L_IWAE = torch.sum(LOG_Ps)

        if calculate_global_KL:
            # lambda distribution
            prior_lambd_tilde = torch.distributions.Normal(0., 1.)
            posterior_lambda_tilde = torch.distributions.Normal(self.lambd_tilde_loc, nn.Softplus()(self._lambd_tilde_scale))
            KL_lambda_tilde = torch.sum(torch.distributions.kl.kl_divergence(posterior_lambda_tilde, prior_lambd_tilde))
            
            # C distribution
            prior_C = torch.distributions.Normal(0., 1.)
            posterior_C= torch.distributions.Normal(self.C_loc, nn.Softplus()(self._C_scale))
            KL_C = torch.sum(torch.distributions.kl.kl_divergence(posterior_C, prior_C))
            
            # W distribution
            prior_W_tilde = torch.distributions.Normal(0., 1.)
            posterior_W_tilde = torch.distributions.Normal(self.W_tilde_loc, nn.Softplus()(self._W_tilde_scale))
            KL_W_tilde = torch.sum(torch.distributions.kl.kl_divergence(posterior_W_tilde, prior_W_tilde))
            
            # S distribution
            prior_S_tilde = torch.distributions.Normal(-12.36, 4.)
            posterior_S_tilde = torch.distributions.Normal(self.S_tilde_loc, nn.Softplus()(self._S_tilde_scale))
            KL_S_tilde = torch.sum(torch.distributions.kl.kl_divergence(posterior_S_tilde, prior_S_tilde))
            
            L_IWAE = L_IWAE - (KL_lambda_tilde + KL_C + KL_W_tilde + KL_S_tilde)
        
        return L_IWAE, z_mu, z_sd


In [6]:
def train(model, dl, optimizer, device):
    ELBOs = []
    for i in dl:
        optimizer.zero_grad()
        ELBO, _, _ = model(i[0].to(device), mc_samples_l=1, mc_samples_k=10)
        loss = -ELBO
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimizer.step()
        ELBOs.append(ELBO.item())
    return sum(ELBOs) / len(ELBOs)

In [7]:
input_size = next(iter(dataloader))[0].shape[1] * 23
DEVICE = "cuda"
deepseq = DeepSeq_Double(input_size, 30, device=DEVICE).to(DEVICE)

In [11]:
# train the model

global_params = [id(deepseq.get_parameter(i)) for i in ["_lambd_tilde_scale", "_C_scale", "_W_tilde_scale", "_S_tilde_scale"]]
base_params = filter(lambda p: id(p) not in global_params, deepseq.parameters())
optimizer = Adam([{'params': base_params}, 
          {'params': [deepseq.state_dict()[i] for i in ["_lambd_tilde_scale", "_C_scale", "_W_tilde_scale", "_S_tilde_scale"]], 'lr': 1e-7}], 
          lr=4e-4)

EPOCH = 201
ELBOs = []
for epoch in tqdm(range(EPOCH)):
    if epoch % 10 == 0:
        PATH = "/content/drive/MyDrive/PML-project/FinalModel/" + "DoubleVI-IWAE-%d.pkl"%epoch
        torch.save(deepseq.state_dict(),PATH)

        print(epoch, ":")
        # Check the reconstruct accuracy (make sure it learns normally)
        acc = []
        for i in range(len(dataset)):
            raw_sequence = dataset[i][0][np.newaxis, :].to(DEVICE)
            z_mu, _ = deepseq.encoder(raw_sequence)
            acc.append(torch.argmax(deepseq.decoder(z_mu, False), dim=-1) == raw_sequence)
        acc = np.mean(torch.cat(acc, dim=0).cpu().numpy())
        print("Reconstruct Accuracy:", acc)

        raw_sequence = dataset[0][0][np.newaxis, :].to(DEVICE)
        experiment_value = []
        predicted_value = []
        with torch.no_grad():
            log_x_wt_ELBO, _, _ = deepseq(raw_sequence, 10, calculate_global_KL=False)
            for (position, mutant_from), row in experimental_data.iterrows():
                assert aa1_to_index[mutant_from] == raw_sequence[0, position]
                for mutant_to, exp_value in row.iteritems():
                    if mutant_to != mutant_from:
                        new_sequence = raw_sequence.clone()
                        new_sequence[0, position] = aa1_to_index[mutant_to]
                        experiment_value.append(exp_value)
                        log_x_mt_ELBO, _, _ = deepseq(new_sequence, 10, calculate_global_KL=False)
                        predicted_value.append((log_x_mt_ELBO - log_x_wt_ELBO).item())
        print(spearmanr(experiment_value, predicted_value))
        

    ELBOs.append(train(deepseq, dataloader_weighted, optimizer, DEVICE))


  0%|          | 0/201 [00:00<?, ?it/s]

0 :
Reconstruct Accuracy: 0.04430064974221657
SpearmanrResult(correlation=-0.01223989736512748, pvalue=0.38701261127070485)


  5%|▍         | 10/201 [49:49<8:06:02, 152.68s/it]

10 :
Reconstruct Accuracy: 0.587126727847009
SpearmanrResult(correlation=0.6155575611658645, pvalue=0.0)


 10%|▉         | 20/201 [1:39:39<7:40:27, 152.64s/it]

20 :
Reconstruct Accuracy: 0.6191848459407108
SpearmanrResult(correlation=0.6596583367604157, pvalue=0.0)


 15%|█▍        | 30/201 [2:29:28<7:14:35, 152.49s/it]

30 :
Reconstruct Accuracy: 0.6385413859228336
SpearmanrResult(correlation=0.6823821412658243, pvalue=0.0)


 20%|█▉        | 40/201 [3:19:19<6:49:32, 152.62s/it]

40 :
Reconstruct Accuracy: 0.6572377133572341
SpearmanrResult(correlation=0.6890700633266295, pvalue=0.0)


 25%|██▍       | 50/201 [4:09:11<6:23:58, 152.57s/it]

50 :
Reconstruct Accuracy: 0.6666348355673272
SpearmanrResult(correlation=0.6733190554617611, pvalue=0.0)


 30%|██▉       | 60/201 [4:59:11<5:59:53, 153.15s/it]

60 :
Reconstruct Accuracy: 0.6732481100082793
SpearmanrResult(correlation=0.6925045110803982, pvalue=0.0)


 35%|███▍      | 70/201 [5:49:12<5:33:47, 152.88s/it]

70 :
Reconstruct Accuracy: 0.6805424407117499
SpearmanrResult(correlation=0.6981726537581576, pvalue=0.0)


 40%|███▉      | 80/201 [6:39:10<5:08:02, 152.75s/it]

80 :
Reconstruct Accuracy: 0.6868600252451318
SpearmanrResult(correlation=0.694470069196371, pvalue=0.0)


 45%|████▍     | 90/201 [7:29:02<4:42:24, 152.65s/it]

90 :
Reconstruct Accuracy: 0.6907655557128259
SpearmanrResult(correlation=0.6972006827273134, pvalue=0.0)


 50%|████▉     | 100/201 [8:18:56<4:17:00, 152.68s/it]

100 :
Reconstruct Accuracy: 0.6949604744998963
SpearmanrResult(correlation=0.6944062502305448, pvalue=0.0)


 55%|█████▍    | 110/201 [9:08:47<3:51:34, 152.69s/it]

110 :
Reconstruct Accuracy: 0.6988281954384258
SpearmanrResult(correlation=0.696524293578406, pvalue=0.0)


 60%|█████▉    | 120/201 [9:58:39<3:25:57, 152.56s/it]

120 :
Reconstruct Accuracy: 0.7002794996732868
SpearmanrResult(correlation=0.7032776054423837, pvalue=0.0)


 65%|██████▍   | 130/201 [10:48:29<3:00:36, 152.62s/it]

130 :
Reconstruct Accuracy: 0.7024918418669763
SpearmanrResult(correlation=0.6876431937684752, pvalue=0.0)


 70%|██████▉   | 140/201 [11:38:19<2:35:09, 152.61s/it]

140 :
Reconstruct Accuracy: 0.7043028213664557
SpearmanrResult(correlation=0.6960253615865504, pvalue=0.0)


 75%|███████▍  | 150/201 [12:28:21<2:10:22, 153.37s/it]

150 :
Reconstruct Accuracy: 0.7073954469571084
SpearmanrResult(correlation=0.6963093625593267, pvalue=0.0)


 80%|███████▉  | 160/201 [13:18:25<1:44:42, 153.24s/it]

160 :
Reconstruct Accuracy: 0.7080546900297241
SpearmanrResult(correlation=0.6951541992462055, pvalue=0.0)


 85%|████████▍ | 170/201 [14:08:33<1:19:11, 153.28s/it]

170 :
Reconstruct Accuracy: 0.7095995486123903
SpearmanrResult(correlation=0.6995613732107933, pvalue=0.0)


 90%|████████▉ | 180/201 [14:58:36<53:33, 153.02s/it]

180 :
Reconstruct Accuracy: 0.7100842861657841
SpearmanrResult(correlation=0.7007616145916438, pvalue=0.0)


 95%|█████████▍| 190/201 [15:48:40<28:05, 153.19s/it]

190 :
Reconstruct Accuracy: 0.7113712643700447
SpearmanrResult(correlation=0.6954939085152613, pvalue=0.0)


100%|█████████▉| 200/201 [16:38:48<02:33, 153.58s/it]

200 :
Reconstruct Accuracy: 0.7121279396908926
SpearmanrResult(correlation=0.7043106428675208, pvalue=0.0)


100%|██████████| 201/201 [17:09:16<00:00, 307.25s/it]


In [12]:
EPOCH = 261
for epoch in tqdm(range(201, EPOCH)):
    if epoch % 10 == 0:
        PATH = "/content/drive/MyDrive/PML-project/FinalModel/" + "DoubleVI-IWAE-%d.pkl"%epoch
        torch.save(deepseq.state_dict(),PATH)

        print(epoch, ":")
        # Check the reconstruct accuracy (make sure it learns normally)
        acc = []
        for i in range(len(dataset)):
            raw_sequence = dataset[i][0][np.newaxis, :].to(DEVICE)
            z_mu, _ = deepseq.encoder(raw_sequence)
            acc.append(torch.argmax(deepseq.decoder(z_mu, False), dim=-1) == raw_sequence)
        acc = np.mean(torch.cat(acc, dim=0).cpu().numpy())
        print("Reconstruct Accuracy:", acc)

        raw_sequence = dataset[0][0][np.newaxis, :].to(DEVICE)
        experiment_value = []
        predicted_value = []
        with torch.no_grad():
            log_x_wt_ELBO, _, _ = deepseq(raw_sequence, 10, calculate_global_KL=False)
            for (position, mutant_from), row in experimental_data.iterrows():
                assert aa1_to_index[mutant_from] == raw_sequence[0, position]
                for mutant_to, exp_value in row.iteritems():
                    if mutant_to != mutant_from:
                        new_sequence = raw_sequence.clone()
                        new_sequence[0, position] = aa1_to_index[mutant_to]
                        experiment_value.append(exp_value)
                        log_x_mt_ELBO, _, _ = deepseq(new_sequence, 10, calculate_global_KL=False)
                        predicted_value.append((log_x_mt_ELBO - log_x_wt_ELBO).item())
        print(spearmanr(experiment_value, predicted_value))
        

    ELBOs.append(train(deepseq, dataloader_weighted, optimizer, DEVICE))

 15%|█▌        | 9/60 [19:55<1:52:48, 132.71s/it]

210 :
Reconstruct Accuracy: 0.7135128348809388
SpearmanrResult(correlation=0.6987506692066767, pvalue=0.0)


 32%|███▏      | 19/60 [1:10:20<1:44:50, 153.43s/it]

220 :
Reconstruct Accuracy: 0.7149917691563433
SpearmanrResult(correlation=0.6966625104972404, pvalue=0.0)


 48%|████▊     | 29/60 [2:00:41<1:19:16, 153.42s/it]

230 :
Reconstruct Accuracy: 0.7157872234814626
SpearmanrResult(correlation=0.704826964215937, pvalue=0.0)


 65%|██████▌   | 39/60 [2:50:42<53:40, 153.35s/it]

240 :
Reconstruct Accuracy: 0.7169767694374911
SpearmanrResult(correlation=0.6949836911201619, pvalue=0.0)


 82%|████████▏ | 49/60 [3:40:58<28:13, 153.94s/it]

250 :
Reconstruct Accuracy: 0.718972918682367
SpearmanrResult(correlation=0.6979752224352139, pvalue=0.0)


 98%|█████████▊| 59/60 [4:31:07<02:33, 153.31s/it]

260 :
Reconstruct Accuracy: 0.717826999106144
SpearmanrResult(correlation=0.6975778061261392, pvalue=0.0)


100%|██████████| 60/60 [5:01:32<00:00, 301.54s/it]


In [8]:
PATH = "/content/drive/MyDrive/PML-project/FinalModel/" + "DoubleVI-IWAE-%d.pkl"%230
deepseq.load_state_dict(torch.load(PATH))

<All keys matched successfully>

In [None]:
# Check the reconstruct accuracy (make sure it learns normally)
acc = []
for i in range(len(dataset)):
    raw_sequence = dataset[i][0][np.newaxis, :].to(DEVICE)
    z_mu, _ = deepseq.encoder(raw_sequence)
    acc.append(torch.argmax(deepseq.decoder(z_mu, False), dim=-1) == raw_sequence)
acc = np.mean(torch.cat(acc, dim=0).cpu().numpy())
print("Reconstruct Accuracy:", acc)

In [9]:
! wget https://sid.erda.dk/share_redirect/a5PTfl88w0/BLAT_ECOLX_Ranganathan2015.csv

# Read in the experimental data, codes by Wooter from https://colab.research.google.com/github/wouterboomsma/pml_vae_project/blob/main/protein_vae_data_processing.ipynb

def read_experimental_data(filename, alignment_data, measurement_col_name = '2500', sequence_offset=0):
    '''Read experimental data from csv file, and check that amino acid match those 
       in the first sequence of the alignment.
       
       measurement_col_name specifies which column in the csv file contains the experimental 
       observation. In our case, this is the one called 2500.
       
       sequence_offset is used in case there is an overall offset between the
       indices in the two files.
       '''
    
    measurement_df = pd.read_csv(filename, delimiter=',', usecols=['mutant', measurement_col_name])
    
    wt_sequence, wt_label = alignment_data[0]
    
    zero_index = None
    
    experimental_data = {}
    for idx, entry in measurement_df.iterrows():
        mutant_from, position, mutant_to = entry['mutant'][:1],int(entry['mutant'][1:-1]),entry['mutant'][-1:]  
        
        # Use index of first entry as offset (keep track of this in case 
        # there are index gaps in experimental data)
        if zero_index is None:
            zero_index = position
            
        # Corresponding position in our alignment
        seq_position = position-zero_index+sequence_offset
            
        # Make sure that two two inputs agree on the indices: the 
        # amino acids in the first entry of the alignment should be 
        # identical to those in the experimental file.
        assert mutant_from == aa1[wt_sequence[seq_position]]  
        
        if seq_position not in experimental_data:
            experimental_data[seq_position] = {}
        
        # Check that there is only a single experimental value for mutant
        assert mutant_to not in experimental_data[seq_position]
        
        experimental_data[seq_position]['pos'] = seq_position
        experimental_data[seq_position]['WT'] = mutant_from
        experimental_data[seq_position][mutant_to] = entry[measurement_col_name]
    
    experimental_data = pd.DataFrame(experimental_data).transpose().set_index(['pos', 'WT'])
    return experimental_data
        
        
experimental_data = read_experimental_data("BLAT_ECOLX_Ranganathan2015.csv", dataset)
# For each of the entries in the dataframe above, you should calculate
# the corresponding difference in ELBO from your VAE, and then finally
# calculate a Spearman correlation between the two.

# # You can iterate over all experimental values like this:
# for (position, mutant_from), row in experimental_data.iterrows():
#     print(position, mutant_from)   # mutant from is the wild type (wt)
#     for mutant_to, exp_value in row.iteritems():
#         print("\t", mutant_to, exp_value) 

--2022-01-18 10:07:39--  https://sid.erda.dk/share_redirect/a5PTfl88w0/BLAT_ECOLX_Ranganathan2015.csv
Resolving sid.erda.dk (sid.erda.dk)... 130.225.104.13
Connecting to sid.erda.dk (sid.erda.dk)|130.225.104.13|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1216640 (1.2M) [text/csv]
Saving to: ‘BLAT_ECOLX_Ranganathan2015.csv.1’


2022-01-18 10:07:42 (737 KB/s) - ‘BLAT_ECOLX_Ranganathan2015.csv.1’ saved [1216640/1216640]



In [10]:
raw_sequence = dataset[0][0][np.newaxis, :].to(DEVICE)
experiment_value = []
predicted_value = []
with torch.no_grad():
    log_x_wt_ELBO, _, _ = deepseq(raw_sequence, 100, calculate_global_KL=False)
    for (position, mutant_from), row in tqdm(experimental_data.iterrows()):
        assert aa1_to_index[mutant_from] == raw_sequence[0, position]
        for mutant_to, exp_value in row.iteritems():
            if mutant_to != mutant_from:
                new_sequence = raw_sequence.clone()
                new_sequence[0, position] = aa1_to_index[mutant_to]
                experiment_value.append(exp_value)
                log_x_mt_ELBO, _, _ = deepseq(new_sequence, 100, calculate_global_KL=False)
                predicted_value.append((log_x_mt_ELBO - log_x_wt_ELBO).item())
print(spearmanr(experiment_value, predicted_value))

263it [4:32:19, 62.13s/it]

SpearmanrResult(correlation=0.7243124956759219, pvalue=0.0)



