In [18]:
%pip install biopython


Note: you may need to restart the kernel to use updated packages.


In [20]:
import os 
import numpy as np
import pandas as pd
import scipy
import sklearn.metrics as skmetrics

# plotting
import matplotlib.pyplot as plt
import seaborn as sns

# Pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import lightning as L

import torchmetrics
from torchmetrics.regression import PearsonCorrCoef

In [24]:
aa_alphabet = 'ACDEFGHIKLMNPQRSTVWY' # amino acid alphabet
aa_to_int = {aa: i for i, aa in enumerate(aa_alphabet)} # mapping from amino acid to number

def one_hot_encode(sequence):
    # initialize a zero matrix of shape (len(sequence), len(amino_acids))
    one_hot = torch.zeros(len(sequence), len(aa_alphabet))
    if sequence != "0":
        for i, aa in enumerate(sequence):
            # set the column corresponding to the amino acid to 1
            one_hot[i].scatter_(0, torch.tensor([aa_to_int[aa]]), 1)
    return one_hot
# sequence data, comes already batched, so treat accordingly in dataloader (batch_size=1)
class SequenceData(Dataset):
    def __init__(self, csv_file, label_col="ddG_ML"):
        """
        Initializes the dataset. 
        input:
            csv_file: path to the relevant data file, eg. "/home/data/mega_train.csv"
        """
        
        self.min_size = 72
        
        self.df = pd.read_csv(csv_file, sep=",")
        self.label_col = label_col
        # only have mutation rows
        self.df = self.df[self.df.mut_type!="wt"]
        # process the mutation row
        self.df["mutation_pos"] = self.df["mut_type"].apply(lambda x: int(x[1:-1])-1) # make position start at zero
        self.df["mutation_to"] = self.df["mut_type"].apply(lambda x: aa_to_int[x[-1]]) # give numerical label to mutation

        # group by wild type
        self.df = self.df.groupby("WT_name").agg(list)
        # get wild type names
        self.wt_names = self.df.index.values
        # precompute one-hot encoding for faster training
        self.encoded_seqs = {}
        for wt_name in self.wt_names:
            # get the correct row
            mut_row = self.df.loc[wt_name]
            seq = mut_row["wt_seq"][0]
            self.encoded_seqs[wt_name] = one_hot_encode(seq)
            for i in range(self.min_size-len(self.encoded_seqs[wt_name])):
                #print(wt_name, i)
                self.encoded_seqs[wt_name] = torch.cat((self.encoded_seqs[wt_name], one_hot_encode("0")),0)
                

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # get the wild type name
        wt_name = self.wt_names[idx]
        # get the correct row
        mut_row = self.df.loc[wt_name]
        # get the wt sequence in one hot encoding
        sequence_encoding = self.encoded_seqs[wt_name]

        # create mask and target tensors
        mask = torch.zeros((1, len(sequence_encoding),20)) # will be 1 where we have a measurement
        target = torch.zeros((1, len(sequence_encoding),20)) # ddg values
        # all mutations from df
        positions = torch.tensor(mut_row["mutation_pos"])
        amino_acids = torch.tensor(mut_row["mutation_to"])
        # get the labels
        labels = torch.tensor(mut_row[self.label_col])

        for i in range(len(sequence_encoding)):
            mask[0,i,amino_acids[positions==i]] = 1 # one where we have data
            target[0,i,amino_acids[positions==i]] = labels[positions==i] # fill with ddG values
        
        # returns encoded sequence, mask and target sequence 
     
        
        
        return {"sequence": sequence_encoding[None,:,:].float(), "mask": mask, "labels": target}

In [25]:
# usage
dataset_train = SequenceData('project_data/project_data/mega_train.csv')
dataset_val= SequenceData('project_data/project_data/mega_val.csv')
dataset_test = SequenceData('project_data/project_data/mega_test.csv')

# use batch_size=1 bc we treat each sequence as one batch
dataloader_val = DataLoader(dataset_val, batch_size=1, shuffle=False)
dataloader_train = DataLoader(dataset_train, batch_size=1, shuffle=True)
dataloader_test = DataLoader(dataset_test, batch_size=1, shuffle=False)

In [44]:


structure = downloadStructure(dataset_train.df.index[0][:4])
chain = removeHeteroatoms(structure)

nClosestAAs = 10

aaDistances = np.zeros((len(chain), nClosestAAs, len(aa_alphabet)+1))
# Iterate over each residue in the chain
for index, residue in enumerate(chain):
    distances = getNClosestAAs(chain, nClosestAAs = nClosestAAs)
    aaDistances[index] = [np.append(one_hot_encode(aa[0]),aa[1]) for aa in distances]
    
print(distances)
print(aaDistances.shape)

Structure exists: '/home/course/ProteinStabilityNN/a3/1a32.cif' 
Structure exists: '/home/course/ProteinStabilityNN/a3/1a32.cif' 
LTQERKREIIEQFKVHENDTGSPEVQIAILTEQINNLNEHLRVHKKDHHSRRGLLKMVGKRRRLLAYLRNKDVARYREIVEKLGL
[('L', 0.0), ('G', 3.7983356), ('V', 5.2198067), ('L', 5.580389), ('E', 6.4865713), ('K', 7.9120116), ('I', 8.589744), ('T', 8.869416), ('R', 9.62728), ('A', 10.435477)]
(85, 10, 21)




In [37]:
from Bio.PDB import PDBList, PDBParser, MMCIFParser

def downloadStructure(name):
    pdb_id = name
    PDBList().retrieve_pdb_file(pdb_id)

    pdb_file = PDBList().retrieve_pdb_file(pdb_id)

    return MMCIFParser().get_structure(pdb_id, pdb_file)


In [40]:
from Bio.PDB import NeighborSearch
from Bio.Data.IUPACData import protein_letters_3to1

def removeHeteroatoms(structure):
    model = structure[0]
    chains = model.get_chains()
    chain = next(iter(chains))
    residues = chain.get_residues()
    #print(next(iter(residues)))

    #remove all heteroatoms from chain
    heteroatoms = []
    for residue in chain:
        if(residue.id[0] != ' '):
            #remove residue from chain
            heteroatoms.append(residue.id)

    for heteroatom in heteroatoms:
        chain.detach_child(heteroatom)

    names = []
    #print all residues
    for residue in chain:
        names.append(residue.get_resname())

    string = [protein_letters_3to1[aa.get_resname()[0]+aa.get_resname()[1:].lower()] for aa in chain]

    print(''.join(string))
    
    return chain


'1A32.pdb'

In [27]:
dataset_train.df["wt_seq"][0][0]

'SPEVQIAILTEQINNLNEHLRVHKKDHHSRRGLLKMVGKRRRLLAYLRNKDVARYREIVEKLG'

In [42]:
import numpy as np
from Bio.Data.IUPACData import protein_letters_3to1

def getNClosestAAs(chain, nClosestAAs = 10, maxRadius = 20.0):
    #distance matix
    # Create a list of all atoms in the structure
    atoms = [atom for atom in chain.get_atoms()]
    
    # Create a NeighborSearch object
    ns = NeighborSearch(atoms)


    # Calculate the distance between the residues
    neighbors = ns.search(residue["CA"].coord, level="R", radius=maxRadius)
    distances = [(protein_letters_3to1[aa.get_resname()[0]+aa.get_resname()[1:].lower()], residue["CA"]-aa["CA"]) for aa in neighbors]
    #sort by distance
    distances.sort(key=lambda x: x[1])
    # Ensure distance is always nClosestAAs large, if too short, fill with 0
    if len(distances) < nClosestAAs:
        distances.extend([("0", 0)] * (nClosestAAs - len(distances)))
    else:
        distances = distances[:nClosestAAs]

    return distances




In [93]:
torch.from_numpy(aaDistances)

tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  3.8177],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  6.2736],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000, 10.3375],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  3.8177],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  3.8201],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  9.7444],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000, 10.5362],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0