# Lab 3: CATH Dataset & DataLoader

- Complete the following tasks
- Save and submit your Jupytor notebook

## Problem Definition
In this lab, we will have a series of labs to solve one open problem using various deep learning models. The open problem is to classify a protein into its CATH super-family, given its sequence (or structure). To comprehend the relevant terminologies and their biological significance, I recommend learning about the CATH database first, and then begin coding. Without a thorough understanding of the provided data, one might design a model that produces meaningless results.

The objective of this lab is to create a data flow for future experiments. To achieve this, you must first comprehend the provided data. Next, select data as the input of deep learning models and as the labels for model predictions. Note that you may not require all provided data to design a valid data flow. In the subsequent labs, you will design CNN, Transformer, and GNN models to address the CATH super-family classification problem.

## CATH reference:
- https://www.cathdb.info/wiki/doku/?id=faq

## H5PY reference:
- https://docs.h5py.org/en/stable/quick.html

## PyTorch reference:
- https://pytorch.org/tutorials/beginner/basics/data_tutorial.html

In [1]:
# load CATH database from the hdf5 file with h5py
import h5py
import numpy as np

def load(fn):
    with h5py.File(fn) as f:
        seq = f['node_seq'][()]
        idx = f['node_idx'][()]
        lab = f['label'][()]
        f.close()
    return seq, idx, lab

%time data = load('/data/cath/hdf5/seq1024.hdf5')

CPU times: user 138 ms, sys: 723 ms, total: 861 ms
Wall time: 293 ms


In [2]:
# check loaded data by showing some key properties
print('type:', type(data))
print('#seq:', data[0].shape, data[0].min(), data[0].max())
print('#idx:', data[1].shape, data[1].min(), data[1].max())
print('#lab:', data[2].shape, data[2].min(0), data[2].max(0))

type: <class 'tuple'>
#seq: (476621235,) 0 20
#idx: (3924784,) 0 476621235
#lab: (3924783, 5) [ 0  1 10  4  1] [ 6629     6   180  4200 12820]


# Lab Requirements

Implement the following cells to build a dataset for CATH database, split the dataset into training and validation subsets, collate a mini-batch with appropriate paddings (if necessary), and finally build dataloaders for training and validation datasets.

In [3]:
# implement the ProteinDataset for CATH database
from copy import deepcopy
from torch.utils.data import Dataset

class ProteinDataset(Dataset):
    def __init__(self, dataset, mapping=None):
        super(ProteinDataset, self).__init__()
        if isinstance(dataset,tuple): # raw data
            self.seq=dataset[0]
            self.idx=dataset[1]
            self.lab=dataset[2]
            self.map=np.arange(len(self.lab),dtype=np.int64) # 恒等映射
            assert len(self.seq)==self.idx[-1]
            assert len(self.lab)==len(self.idx)-1
        else: # structured data
            self.seq=dataset.seq
            self.idx=dataset.idx
            self.lab=dataset.lab
            self.map=mapping
            assert self.map is not None
            assert np.max(self.map)<len(self.lab)
    # self.map旨在维护一个data子集的映射，数据仍然是全部数据   
    def __getitem__(self, idx):
        idx_=self.map[idx]
        seq=self.seq[self.idx[idx_]:self.idx[idx_+1]]
        seq=np.concatenate([[21],seq,[22]]) # add start and end token
        lab=self.lab[idx_]
        return deepcopy(seq),deepcopy(lab)

    def __len__(self):
        return len(self.map) #子集的大小是map的大小

dataset = ProteinDataset(data)

In [9]:
# implement a few test cases for trainset and validset
print(dataset[365365])

(array([21, 14,  6, 10,  3, 14,  5,  3,  4,  1, 17, 15, 15,  4, 10, 14, 12,
        5, 10,  1, 14,  4, 14, 17,  9,  1,  1, 10, 14, 17, 14, 18,  7,  1,
        5, 17,  3, 15,  2, 19,  3, 10,  2,  8,  9,  6, 14, 13,  6,  1, 15,
        5, 16, 15,  6,  4,  4,  1,  2, 10, 17, 12,  2, 18,  3, 15,  5, 10,
        3, 16, 16, 10,  5,  8, 18,  9, 16, 10,  4,  4, 15,  9,  6,  6, 22]), array([592,   1,  10, 287, 810], dtype=int16))


In [4]:
# randomly split dataset into two parts: a training set and a validation set
from sklearn.model_selection import train_test_split

datamap=np.arange(len(dataset),dtype=np.int64) #32不够
trainmap,validmap=train_test_split(datamap,test_size=1024*75,random_state=7)
trainset=ProteinDataset(dataset,trainmap)
validset=ProteinDataset(dataset,validmap)

In [5]:
max_len=0
for i in validset:
    if len(i[0]) > max_len:
        max_len=len(i[0])
for i in trainset:
    if len(i[0]) > max_len:
        max_len=len(i[0])
max_len

401

In [6]:
# design a data structure for model training and collate a mini-batch into the data structure
# hint: what is the input of the model? what is the label to be predicted?
# 当我们的样本数据具有不同长度的特征时，我们需要使用collate_fn来处理不同长度的样本数据，该函数将处理不同长度的文本序列，并将它们填充到相同长度。
# collate_fn函数接收一个batch的样本数据列表作为输入，并返回一个batch的样本数据。
# 它可以根据需要进行数据填充、截断或其他处理操作，以确保每个batch的样本数据具有相同的形状和长度。
from torch.utils.data import DataLoader

batchsize = 1024

def collate_fn(batch, bucketsize=batchsize*6):
    seq = [i[0] for i in batch] + [[0] * (bucketsize - 1)]
    seq = np.concatenate(seq)
    seq = seq[:len(seq)//bucketsize*bucketsize] #截取bucketsize的整数倍,一定不会去掉有用的数据
    # msk表示哪些是真实数据，哪些是填充数据
    msk = [[False] + [True] * (len(i[0]) - 2) + [False] for i in batch] + [[False] * (bucketsize - 1)]
    msk = np.concatenate(msk)[:len(seq)]
    # idx表示seq中的数据是第几个样本的
    idx = [[i] * len(j[0]) for i, j in enumerate(batch)] + [[len(batch)] * (bucketsize - 1)]
    idx = np.concatenate(idx)[:len(seq)]
    # ptr是每个样本序列（包括【21】【22】）长度的前缀和 
    ptr = [0] + [len(i[0]) for i in batch]
    ptr = np.append(np.cumsum(ptr), [len(seq)]) 
    # lab是batch中每个样本的标签，注意这里是i[1]，与数据结构中的getitem对应
    lab = [i[1] for i in batch]
    lab = np.array(lab)
    return seq.astype(np.int16), msk, idx.astype(np.int16), ptr.astype(np.int32), lab.astype(np.int16)

trainloader = DataLoader(trainset, batch_size=batchsize, shuffle=True, drop_last=True, collate_fn=collate_fn, num_workers=6)
validloader = DataLoader(validset, batch_size=batchsize, shuffle=False, drop_last=False, collate_fn=collate_fn, num_workers=6)

In [11]:
# implement a few test cases for trainloader and validloader
from tqdm.notebook import tqdm

for batch in tqdm(trainloader): 
    print('seq:',batch[0].shape)
    print('msk:',batch[1].shape)
    print('idx:',batch[2].shape)
    print('ptr:',batch[3].shape)
    print('lab:',batch[4].shape)
    print(batch)
    break

  0%|          | 0/3780 [00:00<?, ?it/s]

seq: (129024,)
msk: (129024,)
idx: (129024,)
ptr: (1026,)
lab: (1024, 5)
(array([21,  1,  6, ...,  0,  0,  0], dtype=int16), array([False,  True,  True, ..., False, False, False]), array([   0,    0,    0, ..., 1024, 1024, 1024], dtype=int16), array([     0,     48,    149, ..., 124369, 124482, 129024], dtype=int32), array([[ 6438,     6,    10,   250,  3100],
       [ 1383,     1,    20,    58,  1290],
       [ 4905,     3,    40,    50, 10890],
       ...,
       [ 4383,     3,    30,   890,    10],
       [ 4478,     3,    30,  1330,    90],
       [ 3609,     3,    20,   170,    20]], dtype=int16))


In [8]:
# implement a few test cases for trainloader and validloader
from tqdm.notebook import tqdm

for batch in tqdm(validloader): 
    for i in range(1,1026):
        print(batch[3][i]-batch[3][i-1])
    break

  0%|          | 0/75 [00:00<?, ?it/s]

88
81
132
189
173
148
94
133
41
114
181
140
69
292
75
256
169
98
266
364
71
157
189
146
197
172
98
204
95
253
96
141
104
51
70
173
126
85
54
75
85
115
172
117
177
74
132
137
109
46
281
161
183
135
317
97
42
71
113
83
202
100
391
118
133
111
225
173
127
221
163
112
62
96
48
188
301
83
61
195
74
113
174
204
228
133
65
263
122
93
110
54
179
83
102
108
112
99
102
105
116
62
166
151
69
147
77
95
46
275
51
64
84
226
117
237
127
75
155
85
53
142
123
114
80
195
137
99
193
109
100
82
248
70
142
150
200
311
238
60
175
76
99
43
135
76
113
119
46
66
72
257
52
235
55
49
124
130
329
100
41
186
325
234
89
371
55
73
181
66
108
188
54
119
68
50
122
270
61
118
181
150
71
50
44
121
119
51
116
112
64
101
204
92
66
129
72
83
189
62
74
105
77
147
393
107
95
44
72
139
286
63
51
206
71
66
93
125
127
97
99
70
69
115
237
88
114
177
79
141
283
100
98
57
90
55
109
100
81
182
226
78
165
103
294
83
140
47
99
119
132
112
70
86
90
204
261
146
200
106
91
89
138
42
66
74
195
161
51
79
99
49
144
124
71
72
102
131
99
175