# Lab 6: CATH Dataset & DataLoader (with structure data)

- Complete the following tasks
- Save and submit your Jupytor notebook

## Problem Definition
In this lab, we will have a series of labs to solve one open problem using various deep learning models. The open problem is to classify a protein into its CATH super-family, given its sequence (or structure). To comprehend the relevant terminologies and their biological significance, I recommend learning about the CATH database first, and then begin coding. Without a thorough understanding of the provided data, one might design a model that produces meaningless results.

The objective of this lab is to create a data flow for future experiments. To achieve this, you must first comprehend the provided data. Next, select data as the iptut of deep learning models and as the labels for model predictions. Note that you may not require all provided data to design a valid data flow. In the subsequent labs, you will design CNN, Transformer, and GNN models to address the CATH super-family classification problem.

## CATH reference:
- https://www.cathdb.info/wiki/doku/?id=faq

## H5PY reference:
- https://docs.h5py.org/en/stable/quick.html

## PyTorch reference:
- https://pytorch.org/tutorials/beginner/basics/data_tutorial.html

In [1]:
# load CATH database from the hdf5 file with h5py
import h5py

def load(fn):
    with h5py.File(fn) as f:
        seq = f['node_seq'][()]
        node_pos = f['node_pos'][()]
        node_idx = f['node_idx'][()]
        edge_nho = f['edge_nho'][()]
        edge_idx = f['edge_idx'][()]
        lab = f['label'][()]
        print(f.keys())
    return seq, node_pos, node_idx,edge_nho, edge_idx, lab

%time data = load('../../cath/hdf5/struct256.hdf5')

<KeysViewHDF5 ['edge_idx', 'edge_nho', 'label', 'node_idx', 'node_pos', 'node_seq']>
CPU times: user 109 ms, sys: 983 ms, total: 1.09 s
Wall time: 586 ms


In [3]:
# check loaded data by showing some key properties
print('#seq:', data[0].shape, data[0].min(), data[0].max())
print('#node_pos:', data[1].shape, data[1].min(), data[1].max())
print('#node_idx:', data[2].shape, data[2].min(), data[2].max())
print('#edge_nho:', data[3].shape, data[3].min(), data[3].max())
print('#edge_idx:', data[4].shape, data[4].min(), data[4].max())
print('#lab:', data[5].shape, data[5].min(0), data[5].max(0))

#seq: (30922943,) 1 20
#node_pos: (30922943, 5, 3) -965.0 1041.0
#node_idx: (215800,) 0 30922943
#edge_nho: (2, 22016065) 0 399
#edge_idx: (215800,) 0 22016065
#lab: (215799, 5) [ 0  1 10  4  1] [ 6629     6   170  4200 12820]


In [2]:
seq=data[0][0:data[2][1]]
node_pos=data[1][0:data[2][1]]
x=data[3][0][0:data[4][1]]
y=data[3][1][0:data[4][1]]
print(seq.shape)
print(seq)
print(node_pos.shape)
print(node_pos)
print(x.shape)
print(x)
print(y)

(42,)
[16 16 13  4 20 17  9  9  8  4 12 10  2  1 11  6  5  3 15 12  1 18  8 18
  1 10 16 16  9 16 19  3 18  4 17  1 17  4 10 10 10 16]
(42, 5, 3)
[[[ 1.2359e+01  4.9727e+00 -1.3703e+01]
  [ 1.2766e+01  4.2422e+00 -1.4906e+01]
  [ 1.2109e+01  4.7266e+00 -1.6203e+01]
  [ 1.0875e+01  4.7656e+00 -1.6312e+01]
  [ 1.2461e+01  2.7441e+00 -1.4719e+01]]

 [[ 1.2945e+01  5.1016e+00 -1.7172e+01]
  [ 1.2500e+01  5.5312e+00 -1.8516e+01]
  [ 1.1977e+01  6.9375e+00 -1.8734e+01]
  [ 1.0766e+01  7.1562e+00 -1.8766e+01]
  [ 1.1469e+01  4.5508e+00 -1.9062e+01]]

 [[ 1.2883e+01  7.9062e+00 -1.8938e+01]
  [ 1.2477e+01  9.2969e+00 -1.9188e+01]
  [ 1.1617e+01  9.4062e+00 -2.0438e+01]
  [ 1.0719e+01  1.0242e+01 -2.0516e+01]
  [ 1.3812e+01  1.0016e+01 -1.9328e+01]]

 [[ 1.1891e+01  8.5625e+00 -2.1438e+01]
  [ 1.1141e+01  8.5547e+00 -2.2688e+01]
  [ 9.6875e+00  8.1406e+00 -2.2469e+01]
  [ 8.7656e+00  8.7656e+00 -2.3000e+01]
  [ 1.1758e+01  7.5898e+00 -2.3703e+01]]

 [[ 9.5000e+00  7.0742e+00 -2.1688e+01]
  [ 8.

# Lab Requirements

Implement the following cells to build a dataset for CATH database, split the dataset into training and validation subsets, collate a mini-batch with appropriate paddings (if necessary), and finally build dataloaders for training and validation datasets.

NEW: The 3D coordinates of C_alpha atoms (i.e. representitives of amino acids) are provied in this lab, and it is required to construct a contact graph for each protein strucutre. Specifically, each node represents an amino acid and an edge is added between each pair of amino acids with a distance smaller than 8. The data structure should be compatible with torch_geometric.loader.dataloader.


In [6]:
# implement the ProteinDataset for CATH database
import torch as pt
from torch_geometric.data import Data
from torch_geometric.data import Dataset


class ProteinDataset(Dataset):
    def __init__(self, dataset, mapping=None):
        super(ProteinDataset, self).__init__()
        if isinstance(dataset, tuple): # raw data
            self.seq = pt.tensor(dataset[0], dtype=pt.int32)
            self.node_pos = pt.tensor(dataset[1], dtype=pt.float32)
            self.node_idx = pt.tensor(dataset[2])
            self.edge_nho = pt.tensor(dataset[3], dtype=pt.int32)
            self.edge_idx = pt.tensor(dataset[4])
            self.lab = pt.tensor(dataset[5])
            self.map = pt.arange(len(self.lab), dtype=pt.int32) # 恒等映射 
            assert len(self.seq) == self.node_idx[-1]
            assert len(self.lab) == len(self.node_idx) - 1
            assert len(self.lab) == len(self.edge_idx) - 1
        else: # structured data
            self.seq = dataset.seq
            self.node_pos = dataset.node_pos
            self.node_idx = dataset.node_idx
            self.edge_nho = dataset.edge_nho
            self.edge_idx = dataset.edge_idx
            self.lab = dataset.lab
            self.map = mapping
            assert self.map is not None
            assert pt.max(self.map) < len(self.lab)
    # self.map旨在维护一个data子集的映射，数据仍然是全部数据   
    
    def get(self, idx):
        idx_ = self.map[idx]
        seq = self.seq[self.node_idx[idx_] : self.node_idx[idx_+1]]
        len_seq = len(seq)
        seq_ = pt.zeros(len_seq, 21, dtype=pt.float32)
        seq_[pt.arange(len_seq) , seq[:]-1] = 1.0
        # shape:(len_seq,5,3) 0:N, 1:α, 2:C, 3:O, 4:β 
        node_pos = self.node_pos[self.node_idx[idx_] : self.node_idx[idx_+1]] 
        # N-α-β这个夹角反映了氨基酸的空间结构，在化学键确定的前提下，N和β之间的距离就能反应角度 
        seq_[:, 20] = pt.sqrt(pt.sum((node_pos[:,0] - node_pos[:,4])**2, dim=1))
        
        # 连接关系(氢键) 
        edge_nho = pt.stack((self.edge_nho[0][self.edge_idx[idx_] : self.edge_idx[idx_+1]],
                             self.edge_nho[1][self.edge_idx[idx_] : self.edge_idx[idx_+1]]), dim=0)
        # 连接关系(肽键)
        edge_tai = pt.stack((pt.arange(0, len_seq-1), pt.arange(1, len_seq)), dim=0)
        # 连接关系
        edge_idx = pt.cat((edge_nho, edge_tai), dim=1)
        # 边长 
        nho_attr = pt.empty(edge_nho.shape[1])
        # 氢键：氨基的氢和羧基的氧之间吸引产生，用氨基的氮坐标代替氢坐标 
        nho_attr[:] = pt.sqrt(pt.sum((node_pos[edge_nho[0][:], 0] - node_pos[edge_nho[1][:], 3])**2, dim=1))
        tai_attr = pt.empty(edge_tai.shape[1])
        # 羧基碳接氨基氮
        tai_attr[:] = pt.sqrt(pt.sum((node_pos[edge_tai[0][:], 2] - node_pos[edge_tai[1][:], 0])**2, dim=1))
        edge_attr = pt.cat((nho_attr, tai_attr))[:, None]
        
        lab = self.lab[idx_]
        one_hot = pt.zeros(6630, dtype=pt.float32)
        one_hot[lab[0]] = 1.0
        graph = Data(x=seq_, edge_index=edge_idx, edge_attr=edge_attr, y=one_hot, pos=node_pos[:,1])

        return graph
        

    def len(self):  
        return len(self.map) #子集的大小是map的大小


dataset = ProteinDataset(data)

In [7]:
# implement a few test cases for dataset
print(dataset[0])
print(dataset[0].edge_index)

torch.Size([33])
Data(x=[42, 21], edge_index=[2, 74], edge_attr=[74, 1], y=[6630], pos=[42, 3])
torch.Size([33])
tensor([[ 5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 16, 15, 21, 22, 23, 24, 25, 26,
         27, 28, 28, 30, 29, 31, 32, 35, 36, 37, 38, 39, 40, 41, 41,  0,  1,  2,
          3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
         21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
         39, 40],
        [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 11, 11, 12, 17, 18, 19, 20, 21, 22,
         23, 24, 25, 25, 26, 28, 30, 31, 32, 33, 34, 35, 37, 37, 38,  1,  2,  3,
          4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
         22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39,
         40, 41]])


In [29]:
# randomly split dataset into two parts: a training set and a validation set
from sklearn.model_selection import train_test_split

datamap=pt.arange(len(dataset),dtype=pt.int32) #32不够
trainmap,validmap=train_test_split(datamap,test_size=1024*30,random_state=7)
validmap, testmap=train_test_split(validmap,test_size=1024*10,random_state=7)
trainset=ProteinDataset(dataset,trainmap)
validset=ProteinDataset(dataset,validmap)
testset=ProteinDataset(dataset,testmap)

In [30]:
# design a data structure for model training and collate a mini-batch into the data structure
# hint: what is the iptut of the model? what is the label to be predicted?
from torch_geometric.loader import DataLoader

batchsize = 1024
trainloader = DataLoader(trainset, batch_size=batchsize, shuffle=True, drop_last=True, num_workers=6)
validloader = DataLoader(validset, batch_size=batchsize, shuffle=False, drop_last=False, num_workers=6)
testloader = DataLoader(testset, batch_size=batchsize, shuffle=False, drop_last=False, num_workers=6)

In [31]:
# implement a few test cases for trainloader and validloader

# for batch in trainloader: pass
# for batch in validloader: pass
for batch in testloader: 
    print(batch)

DataBatch(x=[145774, 21], edge_index=[2, 248452], edge_attr=[248452, 1], y=[6789120], pos=[145774, 3], batch=[145774], ptr=[1025])
DataBatch(x=[145318, 21], edge_index=[2, 247709], edge_attr=[247709, 1], y=[6789120], pos=[145318, 3], batch=[145318], ptr=[1025])
DataBatch(x=[149623, 21], edge_index=[2, 255091], edge_attr=[255091, 1], y=[6789120], pos=[149623, 3], batch=[149623], ptr=[1025])
DataBatch(x=[144300, 21], edge_index=[2, 245462], edge_attr=[245462, 1], y=[6789120], pos=[144300, 3], batch=[144300], ptr=[1025])
DataBatch(x=[148747, 21], edge_index=[2, 253535], edge_attr=[253535, 1], y=[6789120], pos=[148747, 3], batch=[148747], ptr=[1025])
DataBatch(x=[150065, 21], edge_index=[2, 256058], edge_attr=[256058, 1], y=[6789120], pos=[150065, 3], batch=[150065], ptr=[1025])
DataBatch(x=[150580, 21], edge_index=[2, 256914], edge_attr=[256914, 1], y=[6789120], pos=[150580, 3], batch=[150580], ptr=[1025])
DataBatch(x=[143313, 21], edge_index=[2, 244619], edge_attr=[244619, 1], y=[6789120