## Load model weight

In [4]:
import torch
import os
from evodiff.pretrained import OA_DM_640M

checkpoint = OA_DM_640M()
model, collater, tokenizer, scheme = checkpoint

model.load_state_dict(torch.load('models/epoch1_lr0.0001_accum8_warmup20_weight-decay1e-05/model_epoch0_accu36.5882.pth'))


from evodiff.generate import generate_oaardm

seq_len = 100
tokeinzed_sample, generated_sequence = generate_oaardm(model, tokenizer, seq_len, batch_size=1, device='cpu')
print("Generated sequence:", generated_sequence)


100%|██████████| 100/100 [00:27<00:00,  3.60it/s]

Generated sequence: ['MTMAHFVEKYEAGTKPAICSNMFFVRLVDAGVPKDQIEYLGERLGIYPPLQVIQEAIQKEDTIEDAVGKTVSIEGGKLKGNQYKKTMESLVEIASYIKKS']





## Dataset

In [2]:
import os
import numpy as np
from Bio import SeqIO

root = '../../data'
data_dir = os.path.join(root,'ec_species')
a3m_dir = os.path.join(data_dir, 'scaffolding-msas')
files = [file for file in os.listdir(a3m_dir) if file.endswith('.a3m')]

# the length of sequenc in .a3m
detergent_depths = np.array([], dtype=int)
#TODO I'm not sure but it may be the maximum ammount of gap in .a3m
detergent_gap_depths = np.array([], dtype=int)
# the ammount of sequences in .a3m
detergent_lengths = np.array([], dtype=int)


for filename in files:

    input_file = os.path.join(a3m_dir, filename)
    records = list(SeqIO.parse(input_file, 'fasta-2line'))
    length = len(records[0].seq)
    depth = len(records)        
    gap_depth = max([len(record.seq) - record.seq.count('-') for record in records])


    detergent_depths = np.append(detergent_depths,depth)
    detergent_gap_depths = np.append(detergent_gap_depths,gap_depth)
    detergent_lengths = np.append(detergent_lengths,length)
    

np.savez(os.path.join(a3m_dir,'detergent_depths.npz'), array=detergent_depths)
np.savez(os.path.join(a3m_dir,'detergent_gap_depths.npz'), array=detergent_gap_depths)
np.savez(os.path.join(a3m_dir,'detergent_lengths.npz'), array=detergent_lengths)


# print(f'detergent_gap_depths: {np.load(os.path.join(a3m_dir,"detergent_gap_depths.npz"))["array"]}')
# print(f'detergent_depths: {np.load(os.path.join(data_dir,"detergent_depths.npz"))["array"]}')
# print(f'detergent_lengths: {np.load(os.path.join(data_dir,"detergent_lengths.npz"))["array"]}')

detergent_gap_depths: [ 382  382  382 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645
 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645  968  968  968
  968  968  968  968  968  968  968  968  968  968  968  968  968  968
  968  968  968  968  968  968  968  968  968  968  968  968  968  968
  968  968  968  968  968  968  968  968  968  968  968  847  847  847
  847 2015 2015 2015 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645
 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645
 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645
 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645
 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645
 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645
 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645
 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645
 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645 1645

In [3]:
from sequence_models.constants import PROTEIN_ALPHABET, PAD, GAP
from sequence_models.utils import parse_fasta
from torch.utils.data import Dataset
from evodiff.utils import Tokenizer
import pandas as pd
import numpy as np

from scipy.spatial.distance import hamming, cdist

In [4]:
class A3MDataset(Dataset):
    """Build dataset for A3M data: MSA Absorbing Diffusion model"""

    def __init__(self, selection_type, n_sequences, max_seq_len, data_dir=None, min_depth=None):
        """
        Args:
            selection_type: str,
                MSA selection strategy of random or MaxHamming
            n_sequences: int,
                number of sequences to subsample down to
            max_seq_len: int,
                maximum MSA sequence length
            data_dir: str,
                if you have a specified data directory
        """
        alphabet = PROTEIN_ALPHABET
        self.tokenizer = Tokenizer(alphabet)
        self.alpha = np.array(list(alphabet))
        self.gap_idx = self.tokenizer.alphabet.index(GAP)

        # Get npz_data dir
        if data_dir is not None:
            self.data_dir = data_dir
        else:
            raise FileNotFoundError(data_dir)
        
        [print("Excluding", x) for x in os.listdir(self.data_dir) if x.endswith('.npz')]
        all_files = [x for x in os.listdir(self.data_dir) if x.endswith('.a3m')]
        all_files = sorted(all_files)
        print("unfiltered length", len(all_files))

        print(len(np.load(os.path.join(data_dir,'detergent_lengths.npz'))['array']))
        print(len(np.load(os.path.join(data_dir,'detergent_depths.npz'))['array']))
        print(len(np.load(os.path.join(data_dir,'detergent_depths.npz'))['array']))

        ## Filter based on depth (keep > 64 seqs/MSA)
        if not os.path.exists(os.path.join(data_dir,'detergent_lengths.npz')):
            raise Exception(f"Missing detergent_lengths.npz in {data_dir}")
        if not os.path.exists(os.path.join(data_dir,'detergent_depths.npz')):
            #get_msa_depth_openfold(data_dir, sorted(all_files), 'openfold_depths.npz')
            raise Exception(f"Missing detergent_depths.npz in {data_dir}")
        if min_depth is not None: # reindex, filtering out MSAs < min_depth
            _depths = np.load(os.path.join(data_dir,'detergent_depths.npz'))['array']
            depths = pd.DataFrame(_depths, columns=['depth'])
            depths = depths[depths['depth'] >= min_depth]
            keep_idx = depths.index

            _lengths = np.load(os.path.join(data_dir,'detergent_lengths.npz'))['array']
            lengths = np.array(_lengths)[keep_idx]
            all_files = np.array(all_files)[keep_idx]
            print(f"filter MSA depth > {min_depth}", len(all_files))


        # Re-filter based on high gap-contining rows
        if not os.path.exists(os.path.join(data_dir,'detergent_gap_depths.npz')):
            #get_sliced_gap_depth_openfold(data_dir, all_files, 'openfold_gap_depths.npz', max_seq_len=max_seq_len)
            raise Exception(f"Missing detergent_gap_depths.npz in {data_dir}")
        _gap_depths = np.load(os.path.join(data_dir,'detergent_gap_depths.npz'))['array']
        gap_depths = pd.DataFrame(_gap_depths, columns=['gapdepth'])
        gap_depths = gap_depths[gap_depths['gapdepth'] >= min_depth]
        filter_gaps_idx = gap_depths.index
        lengths = np.array(lengths)[[idx for idx in filter_gaps_idx if idx < (len(lengths) - 1)]]
        all_files = np.array(all_files)[[idx for idx in filter_gaps_idx if idx < (len(all_files)-1)]]
        print(f"filter rows with GAPs > {min_depth}", len(all_files))


        self.filenames = all_files  # IDs of samples to include
        self.lengths = lengths # pass to batch sampler
        self.n_sequences = n_sequences
        self.max_seq_len = max_seq_len
        self.selection_type = selection_type

    def __len__(self):
        return len(self.filenames)
    
    def __getitem__(self, idx):
        filename = self.filenames[idx]

        def read_files(data_dir, filename):
            """
            inputs:
                data_dir : path to directory with data
                filename: MSA name

            outputs:
                path: path to .a3m file
            """
            if os.path.exists(os.path.join(data_dir, filename)):
                path = os.path.join(data_dir, filename)
            else:
                raise Exception("Missing filepaths")
            return path


        path = read_files(self.data_dir, filename)
        parsed_msa = parse_fasta(path)

        aligned_msa = [[char for char in seq if (char.isupper() or char == '-') and not char == '.'] for seq in parsed_msa]
        aligned_msa = [''.join(seq) for seq in aligned_msa]

        tokenized_msa = [self.tokenizer.tokenizeMSA(seq) for seq in aligned_msa]
        tokenized_msa = np.array([l.tolist() for l in tokenized_msa])
        msa_seq_len = len(tokenized_msa[0])

        if msa_seq_len > self.max_seq_len:
            slice_start = np.random.choice(msa_seq_len - self.max_seq_len + 1)
            seq_len = self.max_seq_len
        else:
            slice_start = 0
            seq_len = msa_seq_len

        # Slice to 512
        sliced_msa_seq = tokenized_msa[:, slice_start: slice_start + self.max_seq_len]
        anchor_seq = sliced_msa_seq[0]  # This is the query sequence in MSA

        # slice out all-gap rows
        sliced_msa = [seq for seq in sliced_msa_seq if (list(set(seq)) != [self.gap_idx])]
        msa_num_seqs = len(sliced_msa)




        if msa_num_seqs < self.n_sequences:
            print("before for len", len(sliced_msa_seq))
            print("msa_num_seqs < self.n_sequences should not be called")
            print("tokenized msa shape", tokenized_msa.shape)
            print("tokenized msa depth", len(tokenized_msa))
            print("sliced msa depth", msa_num_seqs)
            print("used to set slice")
            print("msa_seq_len", msa_seq_len)
            print("self max seq len", self.max_seq_len)
            print(slice_start)
            import pdb; pdb.set_trace()
            output = np.full(shape=(self.n_sequences, seq_len), fill_value=self.tokenizer.pad_id)
            output[:msa_num_seqs] = sliced_msa
            print("msa num_seqs < self.n_sequences, indicates dataset not filtered properly")
            # raise Exception("msa num_seqs < self.n_sequences, indicates dataset not filtered properly")
        elif msa_num_seqs > self.n_sequences:
            if self.selection_type == 'random':
                random_idx = np.random.choice(msa_num_seqs - 1, size=self.n_sequences - 1, replace=False) + 1
                anchor_seq = np.expand_dims(anchor_seq, axis=0)
                output = np.concatenate((anchor_seq, np.array(sliced_msa)[random_idx.astype(int)]), axis=0)
            elif self.selection_type == "MaxHamming":
                output = [list(anchor_seq)]
                msa_subset = sliced_msa[1:]
                msa_ind = np.arange(msa_num_seqs)[1:]
                random_ind = np.random.choice(msa_ind)
                random_seq = sliced_msa[random_ind]
                output.append(list(random_seq))
                random_seq = np.expand_dims(random_seq, axis=0)
                msa_subset = np.delete(msa_subset, (random_ind - 1), axis=0)
                m = len(msa_ind) - 1
                distance_matrix = np.ones((self.n_sequences - 2, m))

                for i in range(self.n_sequences - 2):
                    curr_dist = cdist(random_seq, msa_subset, metric='hamming')
                    curr_dist = np.expand_dims(np.array(curr_dist), axis=0)  # shape is now (1,msa_num_seqs)
                    distance_matrix[i] = curr_dist
                    col_min = np.min(distance_matrix, axis=0)  # (1,num_choices)
                    max_ind = np.argmax(col_min)
                    random_ind = max_ind
                    random_seq = msa_subset[random_ind]
                    output.append(list(random_seq))
                    random_seq = np.expand_dims(random_seq, axis=0)
                    msa_subset = np.delete(msa_subset, random_ind, axis=0)
                    distance_matrix = np.delete(distance_matrix, random_ind, axis=1)
        else:
            output = sliced_msa

        output = [''.join(seq) for seq in self.alpha[output]]
        return output

In [5]:
from evodiff.utils import Tokenizer
import os

root = '../../data'
data_dir = os.path.join(root,'ec_species/scaffolding-msas')
n_sequences = 64
min_depth = 64
selection_type = 'MaxHamming'
max_seq_len = 1024

dataset =A3MDataset(selection_type, n_sequences, max_seq_len, data_dir=data_dir, min_depth=min_depth)
train_size = len(dataset)

random_ind = np.random.choice(train_size, size=(train_size - 10000 if train_size>1000 else train_size), replace=False)
# print("TRAIN SIZE:", train_size, random_ind)

print(dataset.__len__())
# print(dataset.__getitem__(0))

Excluding detergent_depths.npz
Excluding detergent_gap_depths.npz
Excluding detergent_lengths.npz
unfiltered length 934
934
934
934
filter MSA depth > 64 930
filter rows with GAPs > 64 929
929


## Collater

In [6]:
from sequence_models.collaters import MSAAbsorbingCollater
from torch.utils.data import Subset
from sequence_models.constants import MSA_ALPHABET

# ds_train = Subset(dataset, random_ind)


collater = MSAAbsorbingCollater(alphabet=MSA_ALPHABET, num_seqs=64)

data = collater.__call__(dataset)


before for len 5
msa_num_seqs < self.n_sequences should not be called
tokenized msa shape (5, 383)
tokenized msa depth 5
sliced msa depth 5
used to set slice
msa_seq_len 383
self max seq len 1024
0
> [0;32m/tmp/ipykernel_85550/3851817186.py[0m(134)[0;36m__getitem__[0;34m()[0m
[0;32m    132 [0;31m            [0mprint[0m[0;34m([0m[0mslice_start[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    133 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 134 [0;31m            [0moutput[0m [0;34m=[0m [0mnp[0m[0;34m.[0m[0mfull[0m[0;34m([0m[0mshape[0m[0;34m=[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mn_sequences[0m[0;34m,[0m [0mseq_len[0m[0;34m)[0m[0;34m,[0m [0mfill_value[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mtokenizer[0m[0;34m.[0m[0mpad_id[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    135 [0;31m            [0moutput