In [1]:
import pandas as pd
import yaml
import numpy as np
import os
import time

with open('../../hyperparams.yml') as f:
    configs=yaml.safe_load(f)
    
with open('../../data/dataset_config.yaml') as f:
    dataset_configs=yaml.safe_load(f)

data_dir=configs['data_dir']
raw_files_dir='raw/clusters/'
max_length=dataset_configs['sequence_length']
aa_vocabulary=dataset_configs['aa_vocabulary']


destination_dir='dataset/unsupervised_large_clusters/'

In [2]:
import h5py
hdf5_file = h5py.File(data_dir+destination_dir+"dataset.hdf5", "w")
sequences_encoded = hdf5_file.create_dataset("sequences", (5000000, 512), chunks=(512, 512), dtype='i8')

# Soft Dataset

In [3]:
clusters_refs=pd.DataFrame()
sequences_ids=pd.DataFrame()

pointer=0
for filename in os.listdir(data_dir+raw_files_dir):
    print('\nProcessing', filename)
    
    data=pd.read_csv(data_dir+raw_files_dir+filename)
    
    #compute length and remove too long seqs
    data=data[data['sequence'].str.len()<(max_length-2)]
    
    data=data.groupby('cluster_ref').first().reset_index()

    batch_size=len(data)
    print(batch_size)
    
    clusters_refs=pd.concat([clusters_refs, data['cluster_ref']], ignore_index=True)
    sequences_ids=pd.concat([sequences_ids, data['entry_id']], ignore_index=True)
    
    
    batch_encoded=np.zeros((len(data), max_length), dtype=np.int8)
    
    seq_idx=0
    for row in data.itertuples():
        aa_seq=row.sequence
        
        seq_encoded=np.zeros((max_length), dtype=np.int8)
        
        seq_encoded[0]=aa_vocabulary['<BOS>']

        for aa_idx, aa in enumerate(aa_seq):
            if aa not in aa_vocabulary:
                aa_token=aa_vocabulary['X']
            else:
                aa_token=aa_vocabulary[aa]
                
            seq_encoded[aa_idx+1]=aa_token
        
        seq_encoded[aa_idx+2]=aa_vocabulary['<EOS>']
        
        batch_encoded[seq_idx]=seq_encoded
        seq_idx += 1
        
    sequences_encoded[pointer:pointer+batch_size]=batch_encoded

    pointer+=batch_size

print('Total number of records:', pointer)
print('Total number of cluster references:', len(clusters_refs))
print('Total number of sequence references:', len(sequences_ids))


Processing shard_92.csv
4362

Processing shard_71.csv
4588

Processing shard_15.csv
4289

Processing shard_59.csv
4589

Processing shard_30.csv
4634

Processing shard_75.csv
4610

Processing shard_54.csv
4504

Processing shard_67.csv
4519

Processing shard_79.csv
3913

Processing shard_99.csv
4606

Processing shard_51.csv
4434

Processing shard_38.csv
4820

Processing shard_53.csv
3809

Processing shard_98.csv
4805

Processing shard_42.csv
4833

Processing shard_60.csv
4469

Processing shard_64.csv
4660

Processing shard_76.csv
4636

Processing shard_39.csv
4580

Processing shard_41.csv
4544

Processing shard_65.csv
4361

Processing shard_19.csv
7775

Processing shard_34.csv
3014

Processing shard_17.csv
4526

Processing shard_106.csv
4667

Processing shard_102.csv
4392

Processing shard_112.csv
4648

Processing shard_8.csv
4672

Processing shard_24.csv
3911

Processing shard_40.csv
4664

Processing shard_47.csv
4286

Processing shard_27.csv
3970

Processing shard_83.csv
4241

Process

# Complete dataset

In [None]:
clusters_refs=pd.DataFrame()
sequences_ids=pd.DataFrame()

pointer=0
for filename in os.listdir(data_dir+raw_files_dir):
    print('\nProcessing', filename)
    
    data=pd.read_csv(data_dir+raw_files_dir+filename)
    
    #compute length and remove too long seqs
    data=data[data['sequence'].str.len()<(max_length-2)]
    
    batch_size=len(data)
    print(batch_size)
    
    clusters_refs=pd.concat([clusters_refs, data['cluster_ref']], ignore_index=True)
    sequences_ids=pd.concat([sequences_ids, data['entry_id']], ignore_index=True)
    
    
    batch_encoded=np.zeros((len(data), max_length), dtype=np.int8)
    
    seq_idx=0
    for row in data.itertuples():
        aa_seq=row.sequence
        
        seq_encoded=np.zeros((max_length), dtype=np.int8)
        
        seq_encoded[0]=aa_vocabulary['<BOS>']

        for aa_idx, aa in enumerate(aa_seq):
            if aa not in aa_vocabulary:
                aa_token=aa_vocabulary['X']
            else:
                aa_token=aa_vocabulary[aa]
                
            seq_encoded[aa_idx+1]=aa_token
        
        seq_encoded[aa_idx+2]=aa_vocabulary['<EOS>']
        
        batch_encoded[seq_idx]=seq_encoded
        seq_idx += 1
        
    sequences_encoded[pointer:pointer+batch_size]=batch_encoded

    pointer+=batch_size

print('Total number of records:', pointer)
print('Total number of cluster references:', len(clusters_refs))
print('Total number of sequence references:', len(sequences_ids))

# Finalize

In [4]:
sequences_encoded.resize((pointer, max_length))
print(sequences_encoded.shape)
hdf5_file.close()

(505603, 512)


In [5]:
clusters_refs.columns=['cluster_ref']
clusters_refs.to_csv(data_dir+destination_dir+'clusters_refs.csv', index=False)

In [6]:
sequences_ids.columns=['sequence_id']
sequences_ids.to_csv(data_dir+destination_dir+'sequences_ids.csv', index=False)

# Split in training and test

In [7]:
dataset=h5py.File(data_dir+destination_dir+"dataset.hdf5")['sequences']

In [28]:
idxs=np.arange(dataset.shape[0])

In [29]:
test_ratio=0.2
test_size=int(pointer*test_ratio)
print('Train dataset:', pointer-test_size, 'Test dataset:', test_size)

Train dataset: 404483 Test dataset: 101120


In [47]:
val_file=h5py.File(data_dir+destination_dir+"validation_dataset.hdf5", "w")
validation_dataset=val_file.create_dataset("sequences", (5000000, 512), chunks=(512,512), dtype='i8')

train_file=h5py.File(data_dir+destination_dir+"train_dataset.hdf5", "w")
train_dataset=train_file.create_dataset("sequences", (5000000, 512), chunks=(512,512), dtype='i8')

In [48]:
train_pointer=0
test_pointer=0
for b_start in range(0, len(idxs), 4096):
    b_end=b_start+4096
    
    #dataset batch
    dataset_batch=dataset[b_start:b_end]
    
    batch_idxs=idxs[b_start:b_end]
    np.random.shuffle(batch_idxs)
    batch_idxs -= np.min(batch_idxs)
    test_idxs=np.sort(batch_idxs[:int(4096*test_ratio)]).tolist()
    train_idxs=np.sort(batch_idxs[int(4096*test_ratio):]).tolist()
    
    shifter_test=len(test_idxs)
    shifter_train=len(train_idxs)
    
    validation_dataset[test_pointer:test_pointer+shifter_test]=dataset_batch[test_idxs]
    train_dataset[train_pointer:train_pointer+shifter_train]=dataset_batch[train_idxs]
    
    test_pointer+=shifter_test
    train_pointer+=shifter_train
    
print(test_pointer)
print(train_pointer)

101556
404047


In [51]:
validation_dataset.resize((test_pointer, max_length))
train_dataset.resize((train_pointer, max_length))

In [52]:
print(validation_dataset.shape)
print(train_dataset.shape)

(101556, 512)
(404047, 512)


In [53]:
val_file.close()
train_file.close()