In [1]:
import esm
import torch
from time import time

from typing import List, Tuple, Optional, Dict, NamedTuple, Union, Callable
import itertools
import os
import string
from pathlib import Path

import numpy as np
from Bio import SeqIO, Phylo
import pandas as pd
from scipy.spatial.distance import squareform, pdist, cdist
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from torch.utils.data import Dataset, DataLoader

torch.set_grad_enabled(True)

<torch.autograd.grad_mode.set_grad_enabled at 0x1ea35724700>

In [2]:
deletekeys = dict.fromkeys(string.ascii_lowercase)
deletekeys["."] = None
deletekeys["*"] = None
translation = str.maketrans(deletekeys)

def read_sequence(filename: str) -> Tuple[str, str]:
    """ Reads the first (reference) sequences from a fasta or MSA file."""
    record = next(SeqIO.parse(filename, "fasta"))
    return record.description, str(record.seq)

def remove_insertions(sequence: str) -> str:
    """ Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """
    return sequence.translate(translation)

def read_msa(filename: str) -> List[Tuple[str, str]]:
    """ Reads the sequences from an MSA file, automatically removes insertions."""
    return [(record.description, remove_insertions(str(record.seq))) for record in SeqIO.parse(filename, "fasta")]

def greedy_select(msa: List[Tuple[str, str]], num_seqs: int, mode: str = "max") -> List[Tuple[str, str]]:
    assert mode in ("max", "min")
    if len(msa) <= num_seqs:
        return msa
    
    array = np.array([list(seq) for _, seq in msa], dtype=np.bytes_).view(np.uint8)

    optfunc = np.argmax if mode == "max" else np.argmin
    all_indices = np.arange(len(msa))
    indices = [0]
    pairwise_distances = np.zeros((0, len(msa)))
    for _ in range(num_seqs - 1):
        dist = cdist(array[indices[-1:]], array, "hamming")
        pairwise_distances = np.concatenate([pairwise_distances, dist])
        shifted_distance = np.delete(pairwise_distances, indices, axis=1).mean(0)
        shifted_index = optfunc(shifted_distance)
        index = np.delete(all_indices, indices)[shifted_index]
        indices.append(index)
    indices = sorted(indices)
    return [msa[idx] for idx in indices]

def Seq_tuples_to_fasta(sequences, file_path, export_type = "fasta"):
    MSA_SeqRecords = [SeqRecord(Seq(record[1]), id = record[0], name= record[0], description= record[0]) for record in sequences]
    with open(f"{file_path}", "w") as output_handle:
        SeqIO.write(MSA_SeqRecords, output_handle, export_type)

In [3]:
MSA_filename = "./data/protein-families-msa-full/PF00004_full.fasta"
all_seqs = [(record.description, remove_insertions(str(record.seq))) for record in SeqIO.parse(MSA_filename, "fasta")]


In [5]:
def create_training_set(seqs_per_MSA, n_sampled_MSAs, p_mask, mask_idx, seed):

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    masked_MSAs = []
    true_MSAs = []
    
    np.random.seed(seed)
    
    for i in range(n_sampled_MSAs):
    
        sampled_ids = list(np.random.choice(range(len(all_seqs)), seqs_per_MSA))
        sampled_MSA = [all_seqs[i] for i in sampled_ids]
    
        _,_,batch_tokens = batch_converter([sampled_MSA])
    
        starting_tokens = batch_tokens[0,:,:1]
        batch_tokens = batch_tokens[0,:,1:]
    
        mask = ((torch.rand(batch_tokens.shape) > p_mask).type(
                    torch.uint8))
    
        masked_batch_tokens = batch_tokens * mask + mask_idx * (1 - mask)
    
        batch_tokens = torch.cat((starting_tokens, batch_tokens), dim = -1)
        masked_batch_tokens = torch.cat((starting_tokens, masked_batch_tokens), dim = -1)
    
        masked_MSAs.append(masked_batch_tokens)
        true_MSAs.append(batch_tokens)
    
    masked_MSAs = torch.stack(masked_MSAs, dim=0).to(device)
    true_MSAs = torch.stack(true_MSAs, dim=0).to(device)
    
    return masked_MSAs, true_MSAs

    

    

In [6]:
class MSADataset(Dataset):
    def __init__(self, masked_tokens, true_tokens):
        self.masked_tokens = masked_tokens
        self.true_tokens = true_tokens


    def __len__(self):
        return self.masked_tokens.shape[0]

    def __getitem__(self, idx):

        masked_MSA, true_MSA = self.masked_tokens[idx], self.true_tokens[idx]

        return masked_MSA, true_MSA

In [7]:
import torch
import torch.nn as nn

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters 
num_epochs = 5
batch_size = 1
learning_rate = 1e-4
weight_decay = 1e-4
# lr_scheduler: str = "warmup_linear"
# warmup_steps: int = 16000
adam_betas = (0.9, 0.999)
max_steps: int = 1000000

seqs_per_MSA = 100
n_sampled_MSAs = 1000
p_mask = 0.1

torch.cuda.empty_cache()

device = "cuda" if torch.cuda.is_available() else "cpu"
model, alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
model = model.to(device)
batch_converter = alphabet.get_batch_converter()
mask_idx = alphabet.tok_to_idx["<mask>"]

for name, param in model.named_parameters():
    if not name.startswith("lm_head"):
        param.requires_grad = False

model.lm_head.weight.requires_grad = True

masked_MSAs, true_MSAs = create_training_set(seqs_per_MSA=seqs_per_MSA, n_sampled_MSAs= n_sampled_MSAs, p_mask=p_mask, mask_idx=mask_idx, seed=42)

MSAs_Dataset = MSADataset(masked_MSAs,true_MSAs)
train_dataloader = DataLoader(MSAs_Dataset, batch_size = batch_size, shuffle = True)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=adam_betas)  

# Train the model
n_total_steps = len(train_dataloader)
for epoch in range(num_epochs):
    for i, (masked_MSAs, true_MSAs) in enumerate(train_dataloader):  

        logits = model(masked_MSAs)["logits"]
        masked_pos = masked_MSAs == alphabet.tok_to_idx["<mask>"]

        logits = logits[masked_pos].to(device)
        true_MSAs = true_MSAs[masked_pos].to(device)

        loss = criterion(logits, true_MSAs)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 1 == 0:
            print (f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.4f}')

        del masked_MSAs, true_MSAs, logits

# with torch.no_grad():
#     n_correct = 0
#     n_samples = 0
#     for images, labels in test_loader:
#         images = images.reshape(-1, 28*28).to(device)
#         labels = labels.to(device)
#         outputs = model(images)
#         # max returns (value ,index)
#         _, predicted = torch.max(outputs.data, 1)
#         n_samples += labels.size(0)
#         n_correct += (predicted == labels).sum().item()

#     acc = 100.0 * n_correct / n_samples
#     print(f'Accuracy of the network on the 10000 test images: {acc} %')

Epoch [1/5], Step [1/1000], Loss: 1.0809
Epoch [1/5], Step [2/1000], Loss: 1.0954
Epoch [1/5], Step [3/1000], Loss: 1.0653
Epoch [1/5], Step [4/1000], Loss: 1.0506
Epoch [1/5], Step [5/1000], Loss: 1.0405
Epoch [1/5], Step [6/1000], Loss: 1.0230
Epoch [1/5], Step [7/1000], Loss: 0.9821
Epoch [1/5], Step [8/1000], Loss: 1.0436


KeyboardInterrupt: 