In [1]:
import numpy as np
from pathlib import Path
from esm import pretrained, FastaBatchedDataset
from torch.utils.data import  DataLoader

## Create fasta batched dataset from entire train fasta file

In [43]:
data = FastaBatchedDataset.from_file(Path('../dev/data/full_data_curation/PSALM_1b_Tr_ALL.fasta')) # all train seqs
tokens_per_batch = 4096 # length limit and num_tokens are same in our case currently
esm_model_name =  'esm2_t33_650M_UR50D' 

_, alphabet = pretrained.load_model_and_alphabet(esm_model_name) # Load esm, ignore model

batches = data.get_batch_indices(tokens_per_batch, extra_toks_per_seq=1) # Get batch indices based on length sorting

## Find the seq lengths in each batch

In [42]:
total = [] # Will have same dimension as the variable 'batches': num_batches x num_seqs_per_batch

for i in range(len(batches)):
    current = []
    for j in batches[i]:
        current.append(len(data[j][1]))
        ctr += 1
    total.append(current)

## Difference between longest and shortest sequence in each batch

In [32]:
diffs = []
for t in total:
    diffs.append(np.max(t) - np.min(t))

## Flatten lengths, indices to find optimal splitting strategy given gpus

In [39]:
flat_lengths = []
flat_indices = []
for b in range(len(batches)):
    flat_lengths.extend(total[b])
    flat_indices.extend(batches[b])

In [46]:
len(flat_lengths)

1099995

In [47]:
len(total)

133292

## Find frequency of different sized batches

In [48]:
from collections import defaultdict
freqs = defaultdict(int)

for t in total:
    freqs[len(t)] += 1

## Test Merging hmmscan dict

In [3]:
import hmmscan_utils as hu
from timeit import default_timer as timer

full_dict = {}
start = timer()
for i in range(1,100):
    hmm_dict = hu.parse_hmmscan_results(f'../dev/data/tr_scan_PSALM_1b/shard_{i}_tr.txt',score_threshold=0)
    full_dict = full_dict | hmm_dict
end = timer()
print(f'Total time is {end-start:.3f}')

Total time is 1448.970


In [5]:
import pickle

with open(f'../dev/data/full_data_curation/all_scans2.pkl' , 'wb') as f:
    pickle.dump(full_dict, f)

In [6]:
import pickle

with open(f'../dev/data/full_data_curation/all_scans2.pkl' , 'rb') as f:
    all_scans = pickle.load(f)

print(len(list(all_scans.keys())))

1088060


## FastaBatch testing

In [2]:
data = FastaBatchedDataset.from_file(Path('../dev/data/full_data_curation/PSALM_1b_Tr_ALL.fasta')) # all train seqs

In [3]:
def filter_batches(data, keys):

    bad_idxs = []

    for idx, seq_name in enumerate(data.sequence_labels):
        seq_id = seq_name.split()[0]
        if seq_id not in keys:
            bad_idxs.append(idx)

    data.sequence_strs = [x for i,x in enumerate(data.sequence_strs) if i not in bad_idxs]
    data.sequence_labels = [x for i,x in enumerate(data.sequence_labels) if i not in bad_idxs]

    return data

In [4]:
import pickle

with open(f'../dev/data/full_data_curation/all_scans.pkl' , 'rb') as f:
    hmm_dict = pickle.load(f)

In [5]:
data = filter_batches(data, hmm_dict.keys())
tokens_per_batch = 4096 # length limit and num_tokens are same in our case currently
esm_model_name =  'esm2_t6_8M_UR50D' 

_, alphabet = pretrained.load_model_and_alphabet(esm_model_name) # Load esm, ignore model

In [6]:
from ml_utils import DistributedBatchSampler

def get_full_dataloader(data, rank, num_gpus):
    
    seq_lengths = min(max(len(seq) for seq in data.sequence_strs), 4096) # At most 'limit' length
    tokens_per_batch = 4096
    batches = data.get_batch_indices(tokens_per_batch, extra_toks_per_seq=1)

    distributed_batch_sampler = DistributedBatchSampler(batches, rank, num_gpus)

    data_loader = DataLoader(
        data,
        collate_fn=alphabet.get_batch_converter(seq_lengths),
        batch_sampler=distributed_batch_sampler,
        pin_memory=False,
        shuffle=False
    )
    
    return data_loader

In [7]:
dl0 = get_full_dataloader(data, 0, 2)
dl1 = get_full_dataloader(data, 1, 2)

In [8]:
len(dl0)

65474

In [9]:
len(dl1)

65474

In [10]:
for i, stuff in enumerate(dl1):
    continue
stuff

(['G3HAC6.1 G3HAC6_CRIGR Titin {ECO:0000313|EMBL:EGW08606.1}'],
 ['MSCKESPHVGVSTTTVSSVAVQAGSSKIVIAIVKCGKWVQLQLAESQPNLLEIGSSQDETKKLLRDHELLLAKLKWQKQSCANRFKSQRLLNSHTVCANFFPMRKKASRKEERELNPGEDLPVAAIDYVPPELLERSLVLLNKSQQLTDFIEKFKCDGSNVNSELIQGAQSSCLKIDSLLELLQDRRRQLDKYLQQQRQELAQVLQLCLWDQQENQVTCWFQKTIRDLQEESLGSSLSDNKELIHKHEDLVVRAKEWNSTIEKLKSQALKILLSKDFTEKEHLQLSNQKLNRLQDEFGRLMVERKTWLTMANDFFNSANKALRELVKVEDELVEERRKTEIGRRKAFDVLGKVEAYLKLLKSEGLSLPVLAARHEELHREIKDATADALQKGGSLISQVDSCSSQVTGIHEMMECIQKRVDHLSEQCTAHKEFALKKQQLAASMEGSLRKVEMSIQEIRPVLSNTLDVGSSPSESEKILNKYLELDIQAKETAHILEAAARIMTEKNELELREVALLSFKAKWLEEELNILGRSISYRSQVLQTYVAFQKSSEEVEEQLQRLKAFYLTEIPQKDEDEAEVKYWSNSAERRWQLFLKKSFLTQDLGLECLNLINMAKENEILNVKNEMRIMKNIMEKQTTGREELSHLRMAWYLKATEGKPGRQQWEAFKEKLKKTTHNVKLLHEVLMPISALDLGGNLQTMSDLRRRWNAMKPQLQQLHDEVQHIMKEWEVLGGQGAPLKEKSEQLKDLIHLHERQRERIQDYEKLLYKTVQFHQVKEENCSNISAENLQQQLEVLELESRNWSANAKECERVLSCSLEYCSTRDEINELKESFKDIKKKFNNLKFNYSKKNEKSRNLKTLQYQIQQVDMYAEKIQALRKKMEKVNNKTSDSFLSYPSNKVNILLEAMEDLQKHVNDFEKVVIDYKMNLDLTE

In [21]:
len(stuff[1][0])

36026

In [15]:
stuff[2].shape

torch.Size([1, 4098])

## Checking train and validation fasta for duplicate labels

In [2]:
data = FastaBatchedDataset.from_file(Path('../dev/data/full_data_curation/PSALM_1b_train.fasta'))

In [3]:
data = FastaBatchedDataset.from_file(Path('../dev/data/full_data_curation/PSALM_1b_validation.fasta'))

AssertionError: Found duplicate sequence labels

In [5]:
from Bio import SeqIO
label_list = []
for record in SeqIO.parse('../dev/data/full_data_curation/PSALM_1b_validation.fasta', 'fasta'):
    label_list.append(record.id)

In [6]:
len(label_list)

4277

In [7]:
label_list = np.array(label_list)

In [8]:
label_list.shape

(4277,)

In [10]:
unique_list, counts = np.unique(label_list, return_counts=True)
# unique_list.shape
unique_list[counts>1]

array(['A0A0Q0VNJ1.1/15-284'], dtype='<U29')

'A0A0Q0VNJ1.1/15-284' is the duplicate sequence that has been removed from the validation set