In [1]:
import torch
import sidechainnet as scn

## Load CASP7 data as pytorch tensors

In [2]:
data = scn.load(casp_version=7, with_pytorch="dataloaders", 
                seq_as_onehot=True, aggregate_model_input=False,
               batch_size=16)

SidechainNet(7, 30) was not found in ./sidechainnet_data.
Downloading from https://pitt.box.com/shared/static/hjblmbwei2dkwhfjatttdmamznt1k9ef.pkl


Downloading file chunks (estimated): 19885chunk [01:04, 306.15chunk/s]                        


Downloaded SidechainNet to ./sidechainnet_data\sidechainnet_casp7_30.pkl.
SidechainNet was loaded from ./sidechainnet_data\sidechainnet_casp7_30.pkl.


## Creates features for a batch of sequences

In [3]:
def get_seq_features(batch):
    '''
    Take a batch of sequence info and return the sequence (one-hot),
    evolutionary info and (phi, psi, omega) angles per position, 
    as well as position mask.
    Also return the distance matrix, and distance mask.
    '''
    str_seqs = batch.str_seqs # seq in str format
    seqs = batch.seqs # seq in one-hot format
    int_seqs = batch.int_seqs # seq in int format
    masks = batch.msks # which positions are valid
    lengths = batch.lengths # seq length
    evos = batch.evos # PSSM / evolutionary info
    angs = batch.angs[:,:,0:2] # torsion angles: phi, psi
    
    # use coords to create distance matrix from c-beta
    # except use c-alpha for G
    # coords[:, 4, :] is c-beta, and coords[:, 1, :] is c-alpha
    coords = batch.crds # seq coord info (all-atom)
    batch_xyz = []
    for i in range(coords.shape[0]):
        xyz = []
        xyz = [coords[i][cpos+4,:] 
                if masks[i][cpos//14] and str_seqs[i][cpos//14] != 'G'
                else coords[i][cpos+1,:]
                for cpos in range(0, coords[i].shape[0]-1, 14)]
        batch_xyz.append(torch.stack(xyz))
    batch_xyz = torch.stack(batch_xyz)
    # now create pairwise distance matrix
    dmats = torch.cdist(batch_xyz, batch_xyz)
    # create matrix mask (0 means i,j invalid)
    dmat_masks = torch.einsum('bi,bj->bij', masks, masks)
    
    return seqs, evos, angs, masks, dmats, dmat_masks

## Now iterate through the train data

Do the same for data['test'] and data['valid-10']

In [4]:
for batch in data['train']:
    seqs, evos, angs, masks, dmats, dmat_masks = get_seq_features(batch)
    print(seqs.shape, evos.shape, angs.shape, masks.shape)
    print(dmats.shape, dmat_masks.shape)
    
    '''
    Now write code to create the crops from each protein in the batch
    to create the final batch of crops along with true "labels" 
    for the distances and angles.
    
    dmat should be discretized into 64 bins, and angs into 1296 bins, 
    as described in lecture, to create true labels/bins.
    
    The input to you model should only be the seqs, and evos features.
    Read sidechainnet github page for more info.
    '''
    
    break # remove this! does only one batch for illustration

AttributeError: Can't pickle local object 'get_collate_fn.<locals>.collate_fn'