In [2]:
from crops import *
import torch

In [3]:
train_domains = np.loadtxt('../../data/our_input/train_domains.csv', dtype='O')

#X = torch.load(f'../../data/our_input/tensors/{train_domains[0]}_X.pt')
#Y = torch.load(f'../../data/our_input/Y_tensors/{train_domains[0]}_Y.pt')

X = torch.load(f'../../data/our_input/tensors/4nb5B02_X.pt')
Y = torch.load(f'../../data/our_input/Y_tensors/4nb5B02_Y.pt')
X.shape

torch.Size([569, 64, 64])

In [18]:
def pad_1d(input_1d, crop_size=64, random_state=1):
    """
    Pads 1D input
    
    Input:
        input_1d    : 1D torch tensor
        random_state: int
        
    Output:
        padded1D: list of tuples of padded 1D (i, j) inputs for each crop 
    """
    L = len(input_1d)
    crop_indices = make_crop_indices(L, random_state=random_state)
    padded = []
    
    if L < 64:
        np.random.seed(random_state)
        offset_range = np.arange(crop_size - L + 1)
        i_offset, j_offset = np.random.choice(offset_range), np.random.choice(offset_range)
        
        cropped_i = torch.cat((
            torch.zeros(i_offset, dtype=torch.long),
            input_1d.to(torch.long),
            torch.zeros(crop_size - L - i_offset, dtype=torch.long)
        ))
        
        cropped_j = torch.cat((
            torch.zeros(j_offset, dtype=torch.long),
            input_1d.to(torch.long),
            torch.zeros(crop_size - L - j_offset, dtype=torch.long)
        ))
        padded.append((cropped_i.view(1, 64), cropped_j.view(1, 64)))
    else:
        for ci in crop_indices:
            i0, imax = ci[0][0], ci[1][0]
            j0, jmax = ci[0][1], ci[1][1]
            padding = ci[2]
            
            # crop
            cropped_i, cropped_j = input_1d[i0:imax], input_1d[j0:jmax]
            
            # pad
            
            i_padding, j_padding = crop_size - (imax - i0), crop_size - (jmax - j0)
            
            if padding == 'topleft':
                # pad both from left
                cropped_i = torch.cat((torch.zeros(i_padding, dtype=torch.long), cropped_i.to(torch.long)))
                cropped_j = torch.cat((torch.zeros(j_padding, dtype=torch.long), cropped_j.to(torch.long)))
                
            elif padding == 'top':
                # pad only i from left
                cropped_i = torch.cat((torch.zeros(i_padding, dtype=torch.long), cropped_i.to(torch.long)))
                cropped_j = cropped_j.to(torch.long)
            elif padding == 'topright':
                # pad both i from left and j from right
                cropped_i = torch.cat((torch.zeros(i_padding, dtype=torch.long), cropped_i.to(torch.long)))
                cropped_j = torch.cat((cropped_j.to(torch.long), torch.zeros(j_padding, dtype=torch.long)))
                
            elif padding == 'left':
                # pad only j from left
                cropped_i = cropped_i.to(torch.long)
                cropped_j = torch.cat((torch.zeros(j_padding, dtype=torch.long), cropped_j.to(torch.long)))
                
            elif padding is None:
                cropped_i = cropped_i.to(torch.long)
                cropped_j = cropped_j.to(torch.long)
            
            elif padding == 'right':
                # pad only j from right
                cropped_i = cropped_i.to(torch.long)
                cropped_j = torch.cat((cropped_j.to(torch.long), torch.zeros(j_padding, dtype=torch.long)))
                
            elif padding == 'bottomleft':
                # i from right and j from left
                cropped_i = torch.cat((cropped_i.to(torch.long), torch.zeros(i_padding, dtype=torch.long)))
                cropped_j = torch.cat((torch.zeros(j_padding, dtype=torch.long), cropped_j.to(torch.long)))
                
            elif padding == 'bottom':
                # pad only i from right
                cropped_i = torch.cat((cropped_i.to(torch.long), torch.zeros(i_padding, dtype=torch.long)))
                cropped_j = cropped_j.to(torch.long)
            elif padding == 'bottomright':
                # pad both from right
                cropped_i = torch.cat((cropped_i.to(torch.long), torch.zeros(i_padding, dtype=torch.long)))
                cropped_j = torch.cat((cropped_j.to(torch.long), torch.zeros(j_padding, dtype=torch.long)))
                
            padded.append((cropped_i.view(1, 64), cropped_j.view(1, 64)))
    return padded   

In [19]:
def _make_batches(X, Y, c=64, random_state=1):
    """Function should return input and output of shapes:
    input:  (crops, 675, 64, 64)
    output: (crops, 64, 64)
    """
    d_map, secondary, phi, psi = Y
    
    Ch, L = X.shape[0], X.shape[1]
    d_map = d_map.reshape((1, L, L))
    
    padded_sec = pad_1d(secondary, random_state=random_state)
    padded_phi = pad_1d(psi, random_state=random_state)
    padded_psi = pad_1d(psi, random_state=random_state)
    
    if L < c:
        output_batches = np.empty((1, c + 6, c))
        
        input_batches, output_batches0 = pad_crop(X, 'all', d_map, random_state=random_state)
               
        # add auxiliary losses to the output
        output_batches[:, :64, :] = output_batches0
        output_batches[:, 64:, :] = torch.cat((padded_sec[0][0], padded_sec[0][1], 
                              padded_phi[0][0], padded_phi[0][1],
                              padded_psi[0][0], padded_psi[0][1]
                             ))
        
        return torch.from_numpy(input_batches).to(torch.float32), torch.from_numpy(output_batches).to(torch.long)
    
    else:
        crop_indices = make_crop_indices(L, c=c, random_state=random_state)

        input_batches = np.empty((len(crop_indices), Ch, c, c))
        output_batches = np.empty((len(crop_indices), c + 6, c))

        for m in range(len(crop_indices)):
            (i0, j0), (i, j), padding = crop_indices[m]
            input_batches[m, :, :, :] = pad_crop(X[:, i0:i, j0:j], padding, c)
            output_batches[m, :64, :] = pad_crop(d_map[:, i0:i, j0:j], padding, c)
            
            # add auxiliary losses to the output 
            output_batches[m, 64:, :] = torch.cat((
                padded_sec[m][0], padded_sec[m][1], 
                padded_phi[m][0], padded_phi[m][1],
                padded_psi[m][0], padded_psi[m][1]
            ))      
        
        return torch.from_numpy(input_batches).to(torch.float32), torch.from_numpy(output_batches).to(torch.long)

In [20]:
d_map, secondary, phi, psi = Y

In [21]:
i, o = _make_batches(X, Y, random_state=0)

In [27]:
o[0, 69]

tensor([ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 23, 15, 16, 14,
        15, 14, 15, 15, 15, 15, 14, 14, 14, 15, 14, 15, 15, 15, 15, 15, 15, 15,
        14, 15, 15, 15, 16, 20, 15, 20, 35, 33])

In [29]:
o[0, 69:]

tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 23, 15, 16, 14,
         15, 14, 15, 15, 15, 15, 14, 14, 14, 15, 14, 15, 15, 15, 15, 15, 15, 15,
         14, 15, 15, 15, 16, 20, 15, 20, 35, 33]])

In [175]:
%debug

> [0;32m<ipython-input-172-515c50a32455>[0m(21)[0;36m_make_batches[0;34m()[0m
[0;32m     19 [0;31m[0;34m[0m[0m
[0m[0;32m     20 [0;31m        [0;31m# add auxiliary losses to the output[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 21 [0;31m        padded_Y = torch.cat((torch.from_numpy(output_batches).to(torch.long),
[0m[0;32m     22 [0;31m                              [0mpadded_sec[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m,[0m [0mpadded_sec[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[[0m[0;36m1[0m[0;34m][0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     23 [0;31m                              [0mpadded_phi[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m,[0m [0mpadded_phi[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[[0m[0;36m1[0m[0;34m][0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  padded_sec[0][0].shape


torch.Size([1, 1, 64])


ipdb>  c


In [13]:
o[0]

tensor([[ 0,  0,  0,  ...,  0,  0,  0],
        [ 0,  0,  0,  ...,  0,  0,  0],
        [ 0,  0,  0,  ...,  0,  0,  0],
        ...,
        [ 0,  0,  0,  ..., 20, 35, 33],
        [ 0,  0,  0,  ..., 20, 35, 33],
        [ 0,  0,  0,  ..., 20, 35, 33]])

In [104]:
X.shape

torch.Size([569, 136, 136])

In [192]:
s = pad_1d(secondary)

In [193]:
torch.cat((s[0][0], s[0][1])).shape

torch.Size([2, 64])

In [119]:
make_crop_indices(130)

array([[(0, 0), (33, 32), 'topleft'],
       [(0, 32), (33, 96), 'top'],
       [(0, 96), (33, 130), 'topright'],
       [(33, 0), (97, 32), 'left'],
       [(33, 32), (97, 96), None],
       [(33, 96), (97, 130), 'right'],
       [(97, 0), (130, 32), 'bottomleft'],
       [(97, 32), (130, 96), 'bottom'],
       [(97, 96), (130, 130), 'bottomright']], dtype=object)

In [113]:
i, o = make_batches(X, d_map)

In [125]:
o[0]

tensor([[ 0,  0,  0,  ...,  0,  0,  0],
        [ 0,  0,  0,  ...,  0,  0,  0],
        [ 0,  0,  0,  ...,  0,  0,  0],
        ...,
        [ 0,  0,  0,  ..., 14, 16, 10],
        [ 0,  0,  0,  ..., 10, 14, 11],
        [ 0,  0,  0,  ...,  7,  6,  5]])

In [143]:
torch.cat((o[0], s[0][0]))

tensor([[ 0,  0,  0,  ...,  0,  0,  0],
        [ 0,  0,  0,  ...,  0,  0,  0],
        [ 0,  0,  0,  ...,  0,  0,  0],
        ...,
        [ 0,  0,  0,  ..., 10, 14, 11],
        [ 0,  0,  0,  ...,  7,  6,  5],
        [ 0,  0,  0,  ...,  1,  1,  8]])

In [1]:
s[0][0].shape

NameError: name 's' is not defined

In [46]:
torch.cat((torch.zeros(8, dtype=torch.long), aa.to(torch.long)))

tensor([0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 4, 4, 4, 8, 7, 1, 1, 1, 1, 1, 1, 1, 1,
        6, 6, 8, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])