In [1]:
#######Project#####
# Node_feature 1) one hot aa 2) backbone dihedral
# Edge_feature 1) Ca_Ca distance 2) 6D feature (trRosetta)

NVIDIA-SMI has failed because it couldn't communicate with the NVIDIA driver. Make sure that the latest NVIDIA driver is installed and running.



In [1]:
# Import python libraries (pytorch, numpy, etc -- whatever you need to import)
import torch 
import torch.nn as nn
import numpy as np
import re
import scipy

In [2]:
# Hyper parameters

torch.set_num_threads(4) # make it use upto 4 cpus -- to avoid utilizing all CPUs on master node when you run this script on master

BATCH_SIZE = 16 # number of batches
CROP_SIZE = 64 # If input protein is longer than CROP_SIZE, it will be trimed to have CROP_SIZE residues (i_start:i_start+CROP_SIZE)
LR = 0.001 # learning rate (Q. what would happen if you have too high/too small learning rates?)
NUM_EPOCHS = 20
device="cuda:0" if torch.cuda.is_available() else "cpu" # which device to use?

In [5]:
!ls /public_data/ml/CATH40/CATH40-20JUN08/

ccm  get_CCMraw.py  out  split_train_test.py  task_s  train_s
err  msa	    pdb  tar_s		      test_s


In [3]:
class CustomDataset(torch.utils.data.Dataset): 
    def __init__(self, list_name):
    #데이터셋의 전처리를 해주는 부분
        self.root = "/public_data/ml/CATH40/CATH40-20JUN08/" # home directory for dataset
        self.root2 = "."
        self.domIDs = [line.strip() for line in open("%s/%s"%(self.root2, list_name))] # items in dataset
        
        self.msa_dir = "%s/msa"%self.root # where can I find msa?
        self.ccm_dir = "%s/ccm"%self.root # where can I find raw CCM data (coevolution analysis)?
        self.pdb_dir = "%s/pdb"%self.root # where can I find ground-truth structure?
        
    def __len__(self):
        #데이터셋의 길이. 즉, 총 샘플의 수를 적어주는 부분
        return len(self.domIDs)
    
    def read_MSA(self, msa_fn):
        # Expected outputs:
        # - seq_onehot: one-hot encoded query sequence (Length, 20 standard aa + 1 unknown + 1 gap)
        # - seq_profile: sequence profile calculated from given MSA (Length, 20 standard aa +1 unknown + 1gap)
        a3m_lines = []
        with open(msa_fn, "r") as f:
            lines = f.readlines()
            for idx, line in enumerate(lines):
                if idx == 0:
                    continue
                if line.startswith(">"):
                    index = line[1:].split()[0]
                if not line.startswith(">"):
                    line = re.sub('[a-z]', '', line)
                    #ignore lower case, \n
                    a3m_lines.append(line.upper()[:-1])

        AA_order = "ARNDCQEGHILKMFPSTWYVX-"
        # convert letters into numbers
        alphabet = np.array(list(AA_order), dtype='|S1').view(np.uint8) # (22)
        msa = np.array([list(s) for s in a3m_lines], dtype='|S1').view(np.uint8) # (Nseq, L)
        for i in range(alphabet.shape[0]):
            msa[msa == alphabet[i]] = i

        # treat all unknown characters as gaps
        msa[msa > 22] = 20 # (Nseq, Length)

        msa = np.eye(22)[msa] # one-hot encoded msa (Nseq, Length, 22)
        Nseq = msa.shape[0] # number of sequences in MSA

        seq_onehot = msa[0] # one-hot encoded query sequence
        seq_prof = msa.sum(axis=0) / float(Nseq) # sequence profile

        return seq_onehot, seq_prof
        
    def read_pdb(self, pdb_fn, L):
        # Inputs:
        # - pdb_fn: input PDB file to parse
        # - L: Length of query protein
        #
        # Expected outputs:
        # - Cb_contact: Cb-Cb contact map (Ca for Gly or residues having missing Cb) (Length, Length)
        # - mask: indicates valid pairs or not
        #    - In PDB, it can have a missing region (exists in sequence, but not in structure).
        #    - You have to consider it during loss calculation
        xyz = list()
        mask = np.ones([L, L]).astype(float)
        cb_contact_list = np.zeros([L, L]).astype(float)
        prevNumber = 1
        residues = []
        with open(pdb_fn) as fp:
            for line in fp:
                if not line.startswith("ATOM"):
                    continue
                resName = line[17:20].strip()
                atmName = line[12:16].strip()
                if resName == "GLY":
                    if atmName == "CA":
                        xyz.append([float(line[30:38]), float(line[38:46]), float(line[46:54])])
                        residue_id = int(line[22:26].strip())
                        chain_id = line[21]
                        residues.append((residue_id, chain_id))

                else:
                    if atmName == "CB":
                        xyz.append([float(line[30:38]), float(line[38:46]), float(line[46:54])])
                        residue_id = int(line[22:26].strip())
                        chain_id = line[21]
                        residues.append((residue_id, chain_id))
        xyz = np.array(xyz)
        dist_map = scipy.spatial.distance.cdist(xyz, xyz)
        cb_contact = (dist_map < 8.0).astype(float)
        len_residue = len(residues)
        
        if L < len_residue:
            print(f"Fault: Residue length in a3m file is smaller than residue length in pdb {pdb_fn}")
        elif L > len_residue:
            mask[len_residue:] = mask[:,len_residue:] = 0
        cb_contact_list[0,:len_residue], cb_contact_list[:len_residue,0] = cb_contact[0,:], cb_contact[:,0]
        first_id, _ = residues[0]
        for i in range(len(residues)-1):
            current_residue_id, current_chain_id = residues[i]
            next_residue_id, next_chain_id = residues[i + 1]
            if (next_chain_id == current_chain_id) and (next_residue_id - current_residue_id > 1):
                mask[current_residue_id+1:next_residue_id,:] = mask[:,current_residue_id+1:next_residue_id] = 0
            cb_contact_list[next_residue_id-first_id,:len_residue], cb_contact_list[:len_residue,next_residue_id-first_id] = \
                                                                    cb_contact[i+1,:], cb_contact[:,i+1]
        return cb_contact_list, mask
        
    def __getitem__(self, idx): 
        #데이터셋에서 특정 1개의 샘플을 가져오는 함수
        domain = self.domIDs[idx]
        msa_fn = "%s/%s.a3m"%(self.msa_dir, domain) # msa file for selected domain
        ccm_fn = "%s/%s.npy"%(self.ccm_dir, domain) # raw ccm file for selected domain
        pdb_fn = "%s/%s.pdb"%(self.pdb_dir, domain) # raw pdb file for selected domain
         
        # 1. read MSA & get query sequence + sequence profile & concatenate them to get a 1D feature
        # size of feat_1d : (Length, 22+22)
        seq_onehot, seq_prof = self.read_MSA(msa_fn)
        feat_1d = np.concatenate([seq_onehot, seq_prof], axis=1)

        # 2. tile them to make it as 2D feature
        L = seq_onehot.shape[0]
        tile_x = np.tile(feat_1d[:,np.newaxis,:], (1,L,1))
        tile_y = np.tile(feat_1d[:,np.newaxis,:], (1,L,1))
        
        # 3. read raw CCM (coevolution) data & get 2D input features by concatenating CCM features & tiled 1D features
        ccm = np.load(ccm_fn)
        feat_2d = torch.from_numpy(np.concatenate((ccm, tile_x, tile_y), axis=-1)) # (L, L, 441+44+44)
        
        # 3. read pdb & get contacts between Cb atoms (for Gly use Ca instead of Cb)
        Cb_contact, mask = self.read_pdb(pdb_fn, L)
        Cb_contact, mask = torch.from_numpy(Cb_contact), torch.from_numpy(mask)
        # 4. return input features, labels (contact map), and mask information (to calculate loss on valid pairs only)
        return feat_2d, Cb_contact, mask

FileNotFoundError: [Errno 2] No such file or directory: './train_s'

In [43]:
# You need to define a rule about how to collate batch when proteins have different length
# Below is the example. feel free to modify it or use it as-is
def collate_batch(batch):
    # Input: batch = (input_feat, label, mask)
    # Output: Cropped input_feat, label, mask
    L_min = min([CROP_SIZE, min([input_feat.shape[0] for (input_feat,_,_) in batch])])

    b_input = list()
    b_label = list()
    b_mask = list()
    # crop examples having length > L_min
    for input_feat, label, mask in batch:
        L = input_feat.shape[0]
        if L > L_min:
            end_idx_1 = np.random.randint(L_min, L)
            end_idx_2 = np.random.randint(L_min, L)
        else:
            end_idx_1 = L_min
            end_idx_2 = L_min
        b_input.append(input_feat[end_idx_1-L_min:end_idx_1,end_idx_2-L_min:end_idx_2])
        b_label.append(label[end_idx_1-L_min:end_idx_1,end_idx_2-L_min:end_idx_2])
        b_mask.append(mask[end_idx_1-L_min:end_idx_1,end_idx_2-L_min:end_idx_2])
    a=torch.stack(b_input)
    b=torch.stack(b_label)
    c= torch.stack(b_mask)
    return torch.stack(b_input), torch.stack(b_label), torch.stack(b_mask)

In [45]:
# define dataloader for training set, validation set
trainset = CustomDataset("train_s")
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, num_workers=4, pin_memory=True, shuffle=True, collate_fn=collate_batch)

validset = CustomDataset("val_s")
validloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, num_workers=4, pin_memory=True, shuffle=False, collate_fn=collate_batch)



In [47]:
# 1. Training data -- Prepared
# 2. Data loader -- process input features & labels
# 3. Deep learning model
# 4. Optimizer
# 5. Loss function

LR = 0.001

# Get AI model to be trained
# You need to change **...** part!
model = ContactPredictor().to(device)

# define loss function to use (which loss you need to use?)
criterion =  nn.CrossEntropyLoss() # please check https://pytorch.org/docs/stable/nn.html#loss-functions

# define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

for i_epoch in range(NUM_EPOCHS):
    avg_loss = 0.0
    model.train()
    for i_batch, (inputs, labels, masks) in enumerate(trainloader):
        inputs, labels, masks = inputs.type(torch.FloatTensor), labels.type(torch.FloatTensor), masks.type(torch.FloatTensor)
        
        inputs = inputs.permute(0,3,1,2).to(device, non_blocking=True) # for Convolution layer in pytorch, input should have (Batch, feature dimension, H, W)
        labels = labels.to(device, non_blocking=True)
        masks = masks.to(device, non_blocking=True)
        outputs = model(inputs)
        loss = (criterion(outputs, labels)*masks).sum() / masks.sum() # use mask to ignore missing regions
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        avg_loss += loss.detach()
    avg_loss = avg_loss / len(trainloader)
    print ("Train", i_epoch, avg_loss.item())
    
    # Check validation loss
    model.eval()
    avg_loss = 0.0
    with torch.no_grad(): # you don't need to calculate gradient for validation
        for i_batch, (inputs, labels, masks) in enumerate(validloader):
            inputs = inputs.permute(0,3,1,2).to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            masks = masks.to(device, non_blocking=True)

            outputs = model(inputs)
            loss = (criterion(outputs, labels)*masks).sum() / masks.sum()

            avg_loss += loss.detach()
    avg_loss = avg_loss / len(validloader)
    print ("valid", i_epoch, avg_loss.item())

    torch.save({
                'epoch': i_epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
                }, '.')

    


139 64 102 113 torch.Size([139, 139]) torch.Size([64, 64])
257 64 114 168 torch.Size([257, 257]) torch.Size([64, 64])
170 64 149 110 torch.Size([170, 170]) torch.Size([64, 64])
153 64 93 113 torch.Size([153, 153]) torch.Size([64, 64])
371 64 101 234 torch.Size([371, 371]) torch.Size([64, 64])
133 64 131 125 torch.Size([133, 133]) torch.Size([64, 64])
132 64 123 70 torch.Size([132, 132]) torch.Size([64, 64])
362 64 183 117 torch.Size([362, 362]) torch.Size([64, 64])
182 64 146 105 torch.Size([182, 182]) torch.Size([64, 64])
213 64 204 66 torch.Size([213, 213]) torch.Size([64, 64])
161 64 68 119 torch.Size([161, 161]) torch.Size([64, 64])
143 64 111 67 torch.Size([143, 143]) torch.Size([64, 64])
163 64 114 109 torch.Size([163, 163]) torch.Size([64, 64])
201 64 124 130 torch.Size([201, 201]) torch.Size([64, 64])
202 64 143 177 torch.Size([202, 202]) torch.Size([64, 64])
160 64 103 87 torch.Size([160, 160]) torch.Size([64, 64])
251 64 203 210 torch.Size([251, 251]) torch.Size([64, 64])
329

Exception ignored in: 

torch.Size([64, 64])

<function _MultiProcessingDataLoaderIter.__del__ at 0x7ff066169d30>







182

Traceback (most recent call last):


 

  File "/opt/ohpc/pub/apps/anaconda3/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1466, in __del__


64

    

 

self._shutdown_workers()

138




 162

  File "/opt/ohpc/pub/apps/anaconda3/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1449, in _shutdown_workers
    

 

if w.is_alive():

torch.Size([182, 182])




 

  File "/opt/ohpc/pub/apps/anaconda3/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    

torch.Size([64, 64])


assert self._parent_pid == os.getpid(), 'can only test a child process'


147

AssertionError

 64

: 

 

can only test a child process

74 




96 torch.Size([147, 147])

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ff066169d30>

 


Traceback (most recent call last):


torch.Size([64, 64])


  File "/opt/ohpc/pub/apps/anaconda3/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1466, in __del__
    

166 

self._shutdown_workers()

64




 

  File "/opt/ohpc/pub/apps/anaconda3/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1449, in _shutdown_workers


152

    

 

if w.is_alive():


120 torch.Size([166, 166])

  File "/opt/ohpc/pub/apps/anaconda3/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    

 

assert self._parent_pid == os.getpid(), 'can only test a child process'


torch.Size([64, 64])

AssertionError




: 

262 

can only test a child process

64




 91 

Exception ignored in: 

113

<function _MultiProcessingDataLoaderIter.__del__ at 0x7ff066169d30>

 


Traceback (most recent call last):


torch.Size([262, 262])

  File "/opt/ohpc/pub/apps/anaconda3/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1466, in __del__


 torch.Size([64, 64])

    




self._shutdown_workers()

141


  File "/opt/ohpc/pub/apps/anaconda3/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1449, in _shutdown_workers


 64

    

 

if w.is_alive():


131 

  File "/opt/ohpc/pub/apps/anaconda3/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    

137

assert self._parent_pid == os.getpid(), 'can only test a child process'

 torch.Size([141, 141]) 




torch.Size([64, 64])

AssertionError: 




can only test a child process

162




 64 137

Exception ignored in: 

 

<function _MultiProcessingDataLoaderIter.__del__ at 0x7ff066169d30>

160




 

Traceback (most recent call last):


torch.Size([162, 162])

  File "/opt/ohpc/pub/apps/anaconda3/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1466, in __del__


 

    self._shutdown_workers()

torch.Size([64, 64])





146

  File "/opt/ohpc/pub/apps/anaconda3/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1449, in _shutdown_workers


 

    

64

if w.is_alive():

 




69

  File "/opt/ohpc/pub/apps/anaconda3/lib/python3.9/multiprocessing/process.py", line 160, in is_alive


 

    

73

assert self._parent_pid == os.getpid(), 'can only test a child process'


 

AssertionError

torch.Size([146, 146])

: 

 

can only test a child process

torch.Size([64, 64])





202 64 134 121 torch.Size([202, 202]) torch.Size([64, 64])
392 64 365 88 torch.Size([392, 392]) torch.Size([64, 64])
190 64 18964 80 torch.Size([190, 190]) torch.Size([64, 64])
208 64 162 81 torch.Size([208, 208]) torch.Size([64, 64])
254 64 90 230 torch.Size([254, 254]) torch.Size([64, 64]) 
14564  64138  64115  138torch.Size([189, 189])  torch.Size([145, 145])torch.Size([64, 64]) 
221torch.Size([64, 64]) 
64 209 155 torch.Size([221, 221]) torch.Size([64, 64])
180 64 68 154 torch.Size([180, 180]) torch.Size([64, 64])
286 64 110 185 torch.Size([286, 286]) torch.Size([64, 64])
172 64 98 149 torch.Size([172, 172]) torch.Size([64, 64])
191 64 101 87 torch.Size([191, 191]) torch.Size([64, 64])
147 64 65 105 torch.Size([147, 147]) torch.Size([64, 64])
138 64 102 73 torch.Size([138, 138]) torch.Size([64, 64])
180 64 173 169 torch.Size([180, 180]) torch.Size([64, 64])
397 64 181 160 torch.Size([397, 397]) torch.Size([64, 64])
172 64 128 114 torch.Size([172, 172]) torch.Size([64, 64])
167 64 

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ff066169d30>
Traceback (most recent call last):
  File "/opt/ohpc/pub/apps/anaconda3/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1466, in __del__
    self._shutdown_workers()
  File "/opt/ohpc/pub/apps/anaconda3/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1449, in _shutdown_workers
    if w.is_alive():
  File "/opt/ohpc/pub/apps/anaconda3/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ff066169d30>
Traceback (most recent call last):
  File "/opt/ohpc/pub/apps/anaconda3/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1466, in __del__
    self._shutdown_workers()
  File "/opt/ohpc/pub/apps/anaconda3/lib/python3.9/site-packages/torch/utils/

150 64 147 128 torch.Size([150, 150]) torch.Size([64, 64])
388 64 275 74 torch.Size([388, 388]) torch.Size([64, 64])
147 64 83 136 torch.Size([147, 147]) torch.Size([64, 64])
128 64 94 89 torch.Size([128, 128]) torch.Size([64, 64])
291 64 123 73 torch.Size([291, 291]) torch.Size([64, 64])
252 64 68 177 torch.Size([252, 252]) torch.Size([64, 64])
458 64 384 446 torch.Size([458, 458]) torch.Size([64, 64])
237 64 209 171 torch.Size([237, 237]) torch.Size([64, 64])
188 64 163 182 torch.Size([188, 188]) torch.Size([64, 64])
391 64 77 295 torch.Size([391, 391]) torch.Size([64, 64])
215 64 133 169 torch.Size([215, 215]) torch.Size([64, 64])
264 64 134 125 torch.Size([264, 264]) torch.Size([64, 64])
220 64 181 130 torch.Size([220, 220]) torch.Size([64, 64])
279 64 78 103 torch.Size([279, 279]) torch.Size([64, 64])
149 64 142 105 torch.Size([149, 149]) torch.Size([64, 64])
233 64 220 125 torch.Size([233, 233]) torch.Size([64, 64])


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ff066169d30>
Traceback (most recent call last):
  File "/opt/ohpc/pub/apps/anaconda3/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1466, in __del__
    self._shutdown_workers()
  File "/opt/ohpc/pub/apps/anaconda3/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1449, in _shutdown_workers
    if w.is_alive():
  File "/opt/ohpc/pub/apps/anaconda3/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ff066169d30>
Traceback (most recent call last):
  File "/opt/ohpc/pub/apps/anaconda3/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1466, in __del__
    self._shutdown_workers()
  File "/opt/ohpc/pub/apps/anaconda3/lib/python3.9/site-packages/torch/utils/

173 64 159 107 torch.Size([173, 173]) torch.Size([64, 64])
160 64 146 151 torch.Size([160, 160]) torch.Size([64, 64])
197 64 73 174 torch.Size([197, 197]) torch.Size([64, 64])
140 64 88 107 torch.Size([140, 140]) torch.Size([64, 64])
181 64 160 162 torch.Size([181, 181]) torch.Size([64, 64])
151 64 134 129 torch.Size([151, 151]) torch.Size([64, 64])
141 64 125 140 torch.Size([141, 141]) torch.Size([64, 64])
200 64 139 185 torch.Size([200, 200]) torch.Size([64, 64])
161 64 147 129 torch.Size([161, 161]) torch.Size([64, 64])
134 64 65 100 torch.Size([134, 134]) torch.Size([64, 64])
532 64 202 512 torch.Size([532, 532]) torch.Size([64, 64])
139 64 79 97 torch.Size([139, 139]) torch.Size([64, 64])
153 64 106 114 torch.Size([153, 153]) torch.Size([64, 64])
137 64 104 70 torch.Size([137, 137]) torch.Size([64, 64])
152 64 68 140 torch.Size([152, 152]) torch.Size([64, 64])
142 64 102 68 torch.Size([142, 142]) torch.Size([64, 64])
228 64 103 194 torch.Size([228, 228]) torch.Size([64, 64])
133 6

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ff066169d30>
Traceback (most recent call last):
  File "/opt/ohpc/pub/apps/anaconda3/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1466, in __del__
    self._shutdown_workers()
  File "/opt/ohpc/pub/apps/anaconda3/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1449, in _shutdown_workers
    if w.is_alive():
  File "/opt/ohpc/pub/apps/anaconda3/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ff066169d30>
Traceback (most recent call last):
  File "/opt/ohpc/pub/apps/anaconda3/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1466, in __del__
    self._shutdown_workers()
  File "/opt/ohpc/pub/apps/anaconda3/lib/python3.9/site-packages/torch/utils/

346 64 279 345 torch.Size([346, 346]) torch.Size([64, 64])
196 64 143 158 torch.Size([196, 196]) torch.Size([64, 64])
315 64 65 212 torch.Size([315, 315]) torch.Size([64, 64])
151 64 111 145 torch.Size([151, 151]) torch.Size([64, 64])
309 64 235 86 torch.Size([309, 309]) torch.Size([64, 64])
152 64 142 81 torch.Size([152, 152]) torch.Size([64, 64])
137 64 86 73 torch.Size([137, 137]) torch.Size([64, 64])
212 64 138 144 torch.Size([212, 212]) torch.Size([64, 64])
160 64 153 144 torch.Size([160, 160]) torch.Size([64, 64])
236 64 184 67 torch.Size([236, 236]) torch.Size([64, 64])
136 64 132 124 torch.Size([136, 136]) torch.Size([64, 64])
162 64 148 67 torch.Size([162, 162]) torch.Size([64, 64])
188 64 98 112 torch.Size([188, 188]) torch.Size([64, 64])
216 64 94 156 torch.Size([216, 216]) torch.Size([64, 64])
132 64 114 127 torch.Size([132, 132]) torch.Size([64, 64])
138 64 70 108 torch.Size([138, 138]) torch.Size([64, 64])
206 64 105 138 torch.Size([206, 206]) torch.Size([64, 64])
367 64 

KeyboardInterrupt: 

In [None]:
masks.shape

In [None]:
labels.shape

In [None]:
inputs.shape

In [None]:
model