Load in data from data folder and spit it out in data loader

In [5]:
import os
import torch
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

from torch.utils.data.dataloader import DataLoader
from foldingdiff import modelling
from foldingdiff import datasets as dsets
import numpy as np

# instance variables 'trim_strategy', 'pad', 'min_length', 'pdbs_src', 'structures', 'cache_dir', 'rng', 'means', 'all_lengths', '_length_rng', 'feature_idx'

clean_dset = dsets.CathCanonicalAnglesOnlyDataset(pad=128, trim_strategy='randomcrop')


print("instance variables in clean_dset", clean_dset.__dict__.keys())
print("all lengths", clean_dset.all_lengths)

max_length = 0 #need to find maximum length for padding purposes
for length in clean_dset.all_lengths:
    if length > max_length:
        max_length = length 

dset_angles = [] #list of matrices, each matrix represents one protein
dset_fnames = [] #list of file names, aligned to dset_angles
for i in range(len(clean_dset.structures)):
    structure = clean_dset.structures[i]
    structure_np = structure["angles"].to_numpy()
    np.nan_to_num(structure_np)
    rows = structure_np.shape[0]
    rows_to_add = max_length - rows
    structure_np_padded = np.pad(structure_np, pad_width=((0,rows_to_add), (0,0)))

    # we need to pad this to have the shape (max_length, 9)
    dset_angles.append(structure_np_padded)
    dset_fnames.append(structure["fname"])

dl = DataLoader(dset_angles, batch_size=32, shuffle=False)
for (batch_idx, data) in enumerate(dl):
    print(data.size())


instance variables in clean_dset dict_keys(['trim_strategy', 'pad', 'min_length', 'pdbs_src', 'fnames', 'structures', 'cache_dir', 'rng', 'means', 'all_lengths', '_length_rng', 'feature_idx'])
all lengths [135, 192, 241, 157, 65, 230, 146, 197, 85, 63, 123, 139, 248, 173, 244, 101, 104, 157, 154, 130, 346, 79, 81, 307, 101, 85, 176, 123, 302, 383, 95, 201, 62, 116, 101, 113, 81, 121, 141, 126, 206, 73, 345, 226, 75, 154, 119, 99, 112, 150, 115, 97, 93, 210, 118, 78, 173, 69, 91, 52, 119, 171, 74, 86, 101, 138, 127, 113, 168, 141, 221, 307, 219, 155, 67, 125, 87, 216, 391, 144, 142, 216, 52, 88, 256, 102, 179, 205, 164, 488, 291, 103, 87, 74, 100, 85, 173, 135, 288, 126, 87, 196, 241, 126, 138, 165, 100, 520, 112, 175, 106, 139, 449, 106, 113, 147, 54, 128, 259, 98, 168, 41, 176, 80, 188, 65, 146, 205, 141, 48, 103, 193, 44, 152, 271, 208, 189, 201, 165, 221, 125, 108, 176, 141, 174, 145, 148, 129, 47, 121, 81, 193, 131, 123, 84, 105, 96, 97, 268, 142, 269, 94, 280, 480, 155, 80, 119, 1

In [4]:
structure_one = clean_dset.structures[0]

print(structure_one.keys())

print(structure_one['coords'])


dict_keys(['angles', 'coords', 'fname'])
[[ 7.5550e+00 -3.0420e+00  4.0292e+01]
 [ 9.3150e+00  2.5600e-01  3.9361e+01]
 [ 8.0100e+00  3.8300e+00  3.9430e+01]
 [ 9.0360e+00  4.8810e+00  3.5888e+01]
 [ 6.3950e+00  2.9440e+00  3.3935e+01]
 [ 3.7850e+00  3.5820e+00  3.6677e+01]
 [ 4.3480e+00  7.3710e+00  3.6455e+01]
 [ 3.6120e+00  7.1230e+00  3.2731e+01]
 [ 5.2900e-01  4.9350e+00  3.3477e+01]
 [-9.3300e-01  7.5650e+00  3.5809e+01]
 [-1.1600e-01  1.0432e+01  3.3452e+01]
 [-1.8110e+00  8.5870e+00  3.0643e+01]
 [-4.9200e+00  8.0360e+00  3.2766e+01]
 [-4.9510e+00  1.1670e+01  3.3968e+01]
 [-4.7530e+00  1.2968e+01  3.0407e+01]
 [-7.5150e+00  1.0671e+01  2.9127e+01]
 [-1.1072e+01  1.1900e+01  2.8571e+01]
 [-1.2132e+01  8.8870e+00  2.6488e+01]
 [-1.1349e+01  5.2860e+00  2.7377e+01]
 [-1.3988e+01 -3.6700e-01  2.7174e+01]
 [-1.3871e+01 -4.0250e+00  2.6446e+01]
 [-1.6795e+01 -6.1990e+00  2.7547e+01]
 [-1.7795e+01 -9.7980e+00  2.7007e+01]
 [-2.0331e+01 -9.7790e+00  2.4229e+01]
 [-2.2876e+01 -1.2040e+