In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from dataloader import dataset
from unet import UNet
import os
import time

In [2]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

In [3]:
# Generate training, validation and test datasets
# random split
train_set_size = int(len(dataset)*0.6)
valid_set_size = int(len(dataset)*0.2)
test_set_size = len(dataset)-train_set_size-valid_set_size
train_set, valid_set, test_set = torch.utils.data.random_split(dataset, [train_set_size, valid_set_size, test_set_size])

# start to load from customized DataSet object
# bring batch size, iterations and epochs together
batch_size = 5
train_loader = DataLoader(train_set,batch_size,shuffle=False,drop_last=True,pin_memory=False)
valid_loader = DataLoader(valid_set,batch_size,shuffle=False,drop_last=True,pin_memory=False)
test_loader = DataLoader(test_set,batch_size,shuffle=False,drop_last=True,pin_memory=False)

In [None]:
print(train_set_size, valid_set_size, test_set_size)

In [4]:
for x,y in train_set:
    print(x.size())
    print(y.size())
    break

torch.Size([1024, 1024, 3])
torch.Size([512, 512, 3])


In [5]:
model = UNet(in_channels=3,
            out_channels=3,
            n_blocks=4,
            start_filters=32,
            activation='relu',
            normalization='batch',
            conv_mode='same',
            dim=2)
# x = torch.randn(size=(3, 3, 1024,1024), dtype=torch.float32)
# with torch.no_grad():
#     out = model(x)
    
# print(f'Out: {out.shape}')

# outpushape: [batch, channel, H, W]
# inputshape: [batch, channel, 1024, 1024]

In [None]:
# x = torch.randn(size=(3, 3, 1024,1024), dtype=torch.float32)

print(x.view(x.shape[0]*x.shape[2]*x.shape[3], x.shape[1]).shape)

In [None]:
x = torch.randn(size=(5, 3, 1024,1024), dtype=torch.float32)
x = x.reshape(5*3*1024, 1024)
print(x.shape)

In [6]:
batch_size = 5
n_iters = int(train_set_size / batch_size)
epochs = 10
lr = 0.001
iterations = epochs * n_iters
step_size = 2*n_iters
model_name ="{}epochs_lr{}_step{}".format(epochs, lr, step_size)
save_PATH = './model_name'
if not os.path.exists(save_PATH):
    os.mkdir(save_PATH)

In [7]:
model = model.float()
loss_fn = nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr)

In [8]:
def acc_fn(outputs, y):
    num_samples = len(y)
    equality_matrix = torch.eq(outputs, y)
    num_corr_pred = equality_matrix.sum()
    acc = num_corr_pred/num_samples
    return acc    

In [11]:
def train(model, train_dl, valid_dl, loss_fn, optimizer, acc_fn, epochs=1):
    start = time.time()
    # model.cuda()

    train_loss, valid_loss = [], []

    best_acc = 0.0

    for epoch in range(epochs):
        print('Epoch {}/{}'.format(epoch, epochs - 1))
        print('-' * 10)

        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train(True)  # Set trainind mode = true
                dataloader = train_dl
            else:
                model.train(False)  # Set model to evaluate mode
                dataloader = valid_dl

            running_loss = 0.0
            running_acc = 0.0

            step = 0

            # iterate over data
            for x, y in dataloader:
                # x = x.cuda()
                # y = y.cuda()
                x = torch.permute(x, (0, 3, 2, 1))
                step += 1

                # forward pass
                if phase == 'train':
                    # zero the gradients
                    optimizer.zero_grad()
                    outputs = model(x) 
                    outputs = torch.permute(outputs, (0,2,3,1)) #true y: [512,512,3]  outputs: [batch,512,512,3]
                    batch, h, w, channel = outputs.shape
                    outputs_loss = outputs.reshape(batch, h*w, channel)
                    y = y.reshape(batch, h*w, channel)
                    # outputs = torch.permute(outputs, (0, 2, 3, 1)).contiguous().view(-1, inp.size(1))
                    
                    # y = torch.permute(y, (0,3,2,1))
                    y = y.to(torch.float64)
                    outputs = outputs.to(torch.float64)
                    # print(outputs.shape)
                    # print(y.shape)
                    
                    print('outputs', outputs, outputs.shape)
                    #plt.imshow(outputs[0].detach().numpy())
                    #plt.imshow(outputs[1].detach().numpy())
                    #plt.imshow(outputs[2].detach().numpy())
                    #plt.imshow(outputs[3].detach().numpy())
                    #plt.imshow(outputs[4].detach().numpy())
                    
                    loss = loss_fn(outputs_loss, y)

                    # the backward pass frees the graph memory, so there is no 
                    # need for torch.no_grad in this training pass
                    loss.backward()
                    optimizer.step()
                    # scheduler.step()

                else:
                    with torch.no_grad():
                        outputs = model(x)
                        loss = loss_fn(outputs, y.long())

                # stats - whatever is the phase
                acc = acc_fn(outputs_loss, y)

                running_acc  += acc*dataloader.batch_size
                running_loss += loss*dataloader.batch_size 

                if step % 10 == 0:
                    # clear_output(wait=True)
                    print('Current step: {}  Loss: {}  Acc: {}  AllocMem (Mb): {}'.format(step, loss, acc, torch.cuda.memory_allocated()/1024/1024))
                    # print(torch.cuda.memory_summary())

            epoch_loss = running_loss / len(dataloader.dataset)
            epoch_acc = running_acc / len(dataloader.dataset)

            print('{} Loss: {:.4f} Acc: {}'.format(phase, epoch_loss, epoch_acc))
            
            return train_loss, valid_loss 

In [12]:
train_loss, valid_loss = train(model, train_loader, valid_loader, loss_fn, opt, acc_fn, epochs)

Epoch 0/9
----------
outputs tensor([[[[-5.8505e-01, -1.1378e+00,  1.4093e+00],
          [ 1.2273e+00,  1.9088e-01,  1.5955e+00],
          [-9.9683e-01,  1.5805e-01,  1.0955e+00],
          ...,
          [-9.9308e-01,  2.0188e+00,  5.5407e-01],
          [-3.1343e+00,  4.3586e+00,  1.2731e+00],
          [-9.0895e-01,  5.0278e+00, -5.2059e+00]],

         [[-8.3801e-01, -2.8667e-01,  5.6198e-01],
          [ 9.3627e-01,  1.1941e+00,  6.3337e-01],
          [ 3.9506e-02,  4.1331e-01,  9.6568e-01],
          ...,
          [-4.2324e-01,  5.8929e-01,  4.3264e-02],
          [-3.0517e+00, -1.4178e+00, -2.4595e+00],
          [ 1.8903e+00,  5.4882e-01,  2.8891e+00]],

         [[-1.2074e+00, -1.5344e+00,  1.3114e-02],
          [ 1.4168e-01,  1.5370e+00,  6.0560e-01],
          [ 5.5301e-01,  5.0754e-01,  9.0780e-01],
          ...,
          [ 2.9280e-01, -5.6453e-02,  5.6029e-01],
          [-1.2258e+00,  1.2263e+00, -2.5831e+00],
          [ 1.9738e+00, -2.1050e-01, -1.2759e+00]],

  

outputs tensor([[[[-4.3284e-02, -2.9173e-01,  1.2375e+00],
          [ 1.1514e-01, -4.6910e-01,  1.3611e+00],
          [-7.2576e-01,  1.2378e+00,  1.0118e+00],
          ...,
          [-1.6640e-01, -1.0759e-01,  1.2702e+00],
          [-5.2362e-01,  3.1889e-01,  1.3962e+00],
          [-1.2719e-01,  7.8630e-03,  5.0667e-01]],

         [[ 4.0836e-01, -1.5536e+00,  1.0312e+00],
          [ 1.3717e+00,  1.2009e-01,  1.3830e+00],
          [-2.7454e-01, -3.5555e-01,  8.9975e-01],
          ...,
          [ 1.5018e+00,  7.7916e-01,  7.5234e-01],
          [-6.1641e-01,  1.4460e+00,  3.1860e-01],
          [ 3.2103e+00,  9.1638e-01,  8.2042e-01]],

         [[ 5.1141e-02,  5.4269e-01,  6.0577e-02],
          [ 1.7658e-01,  1.4535e+00,  1.0018e+00],
          [ 1.0314e+00,  1.0282e+00,  7.6391e-01],
          ...,
          [ 8.0919e-01,  9.0320e-01,  1.4864e+00],
          [-9.8085e-01,  1.6679e-01,  1.2851e+00],
          [ 1.3613e+00,  9.0692e-02,  5.1384e-02]],

         ...,

        

outputs tensor([[[[-7.5157e-01, -2.2254e+00,  1.1690e+00],
          [-1.8992e+00,  4.9812e-01,  1.8153e+00],
          [-5.9745e-01,  7.0572e-01,  1.6837e+00],
          ...,
          [-8.8400e-01, -8.3291e-01,  9.8652e-01],
          [-1.3968e+00, -2.0997e-01,  1.2097e+00],
          [-4.9766e-01,  8.2164e-01, -1.4245e+00]],

         [[-6.2028e-01, -1.7856e+00,  2.4366e-01],
          [ 1.7031e+00, -7.4378e-01,  1.4144e+00],
          [-3.5323e-01, -6.6269e-01,  3.1606e-01],
          ...,
          [ 1.0699e+00, -7.9975e-01,  3.3898e-01],
          [-6.0093e-01, -1.6591e+00,  3.6844e-02],
          [ 2.2774e+00,  1.5588e-01,  1.9341e+00]],

         [[ 9.8410e-01,  1.0557e-02,  7.4741e-01],
          [ 1.2718e+00,  9.7967e-02,  5.6120e-01],
          [ 5.4894e-01, -1.9375e-01,  1.3535e+00],
          ...,
          [ 1.0003e+00, -6.6289e-01,  1.9790e+00],
          [ 4.6208e-01, -4.6404e-01,  4.4503e-01],
          [ 1.1806e+00,  1.1847e+00, -6.1336e-01]],

         ...,

        

outputs tensor([[[[-5.2026e-01, -9.1005e-01,  1.5757e+00],
          [ 6.9622e-01, -7.8411e-02,  9.5940e-01],
          [-9.8251e-01, -1.0984e+00,  9.6973e-01],
          ...,
          [-1.6827e+00, -9.5753e-01,  9.8759e-02],
          [-3.3406e-01, -1.0172e+00,  1.2824e+00],
          [-1.7263e+00,  3.4716e-01, -2.2699e+00]],

         [[-7.7410e-02, -7.6716e-01,  3.4278e-01],
          [-2.2523e-01, -1.0129e+00,  2.4633e-01],
          [ 2.1343e+00,  3.8324e-01,  5.0914e-01],
          ...,
          [-7.5993e-01, -8.7411e-01, -6.1805e-01],
          [-1.1287e+00, -1.2403e+00, -7.3124e-01],
          [ 6.5866e-01,  1.7975e+00,  1.2730e+00]],

         [[ 7.7157e-01, -1.6984e-01,  2.7727e-01],
          [ 1.9672e+00,  1.1515e-01, -8.7456e-01],
          [ 1.3287e-01,  1.5645e+00,  1.3267e+00],
          ...,
          [-9.0723e-01, -1.8799e+00,  4.0426e-01],
          [ 2.2619e-01, -2.1397e+00,  1.6799e-01],
          [ 2.1066e+00,  3.8282e-01, -9.3104e-01]],

         ...,

        

outputs tensor([[[[-2.8032e+00, -9.3115e+00, -2.2731e+00],
          [-3.6443e+00,  1.7423e+00,  1.2176e+00],
          [-5.0003e+00,  4.3075e-01,  1.8261e+00],
          ...,
          [-2.9134e+00, -1.9983e+00,  1.1232e+00],
          [-5.4696e+00, -1.6695e-01,  1.5066e+00],
          [-7.9989e+00, -1.5340e-01, -7.0900e+00]],

         [[ 9.8069e-02, -5.7148e+00, -3.6118e+00],
          [-1.7972e+00, -9.6751e+00, -1.0132e+01],
          [-3.9976e+00, -3.5400e+00, -2.3645e+00],
          ...,
          [-5.2817e+00, -6.7049e+00, -2.9338e+00],
          [-6.5624e+00, -8.1832e+00, -5.2671e+00],
          [-4.2054e+00, -1.6501e+00, -7.9063e-02]],

         [[ 3.5840e-01, -4.0754e+00, -3.2471e+00],
          [-1.0190e+01, -7.3998e+00, -1.0416e+01],
          [-2.4024e+00, -7.5130e+00, -2.8422e+00],
          ...,
          [-2.8968e+00, -4.3649e+00, -1.3433e+00],
          [-2.3552e+00, -9.8526e+00, -3.6941e+00],
          [-2.6787e+00, -2.7591e+00, -1.8203e+00]],

         ...,

        

outputs tensor([[[[-9.9646e-01, -4.0734e+00, -2.6037e-01],
          [-1.3291e+00,  2.6608e-01,  5.6271e-01],
          [-2.0373e+00, -8.1481e-01,  1.3730e+00],
          ...,
          [-2.6768e+00, -8.8204e-01,  1.0620e+00],
          [-2.4946e+00, -5.2452e-01,  1.2378e+00],
          [-3.0392e+00,  2.5253e-01, -3.4780e+00]],

         [[-9.5020e-02, -4.2589e+00, -2.3126e+00],
          [ 9.8942e-01, -2.7794e+00, -2.8756e+00],
          [-1.2935e+00,  4.7374e-01, -8.0156e-02],
          ...,
          [-2.8794e+00, -2.5166e+00, -1.1739e+00],
          [-3.7177e+00, -3.1441e+00, -3.3454e+00],
          [-1.7787e+00, -3.6209e-02,  1.2694e+00]],

         [[ 1.6143e+00, -2.0971e+00, -1.0137e+00],
          [-8.4333e-01, -1.0894e+00, -4.1902e+00],
          [-8.8146e-01, -1.0927e+00,  8.5397e-01],
          ...,
          [-2.0741e+00, -1.9837e+00, -1.4711e+00],
          [-7.1173e-01, -3.8887e+00,  4.5880e-01],
          [ 8.9842e-01, -1.9974e+00, -1.4961e+00]],

         ...,

        

outputs tensor([[[[-3.9927e+00, -1.1895e+01, -3.5175e+00],
          [-6.0651e+00,  5.2868e-01,  1.8704e+00],
          [-5.7281e+00, -1.5714e+00,  1.5605e+00],
          ...,
          [-4.1005e+00, -2.9570e+00,  2.3486e-01],
          [-5.9084e+00, -2.4525e+00,  1.2496e+00],
          [-8.3396e+00, -1.2031e+00, -4.2909e+00]],

         [[-3.1884e-01, -6.4662e+00, -5.8194e+00],
          [-1.7738e+00, -1.2941e+01, -1.5530e+01],
          [-5.5118e+00, -6.4045e+00, -4.7407e+00],
          ...,
          [-6.1121e+00, -6.7912e+00, -3.0403e+00],
          [-9.6501e+00, -7.6150e+00, -6.9377e+00],
          [-5.8674e+00, -1.6870e+00,  2.5375e+00]],

         [[-2.1158e+00, -5.4683e+00, -5.4169e+00],
          [-1.1229e+01, -1.0602e+01, -1.7380e+01],
          [-4.9407e+00, -9.2554e+00, -4.3498e+00],
          ...,
          [-3.3810e+00, -6.1326e+00, -4.7113e-01],
          [-3.6217e+00, -8.7985e+00,  1.5952e-01],
          [-2.1829e+00, -4.7891e+00, -1.4365e+00]],

         ...,

        

outputs tensor([[[[-4.6434e-01, -3.7082e+00, -3.4926e-01],
          [-1.5882e+00,  7.3947e-01,  2.8380e-01],
          [-3.2358e+00, -6.2499e-01,  7.4279e-01],
          ...,
          [ 3.7281e-01, -1.6209e-01,  1.0034e+00],
          [-6.4168e-01, -2.4224e-01,  4.3198e-01],
          [-3.4307e-04, -7.1980e-01, -4.6700e-03]],

         [[-8.6675e-01, -2.6065e+00, -2.5340e+00],
          [ 9.4246e-01, -2.5089e+00, -4.8264e+00],
          [-2.1505e+00, -8.4737e-01, -5.7347e-01],
          ...,
          [ 1.3611e+00,  1.0926e+00,  1.0736e+00],
          [ 3.4474e-01,  2.5327e-01,  9.3915e-01],
          [ 2.3221e+00,  9.9189e-01,  1.2616e+00]],

         [[ 1.9014e-01, -1.1111e+00, -1.5239e+00],
          [-2.5169e+00, -1.6408e+00, -6.1689e+00],
          [-2.3829e+00, -3.0913e+00, -1.6288e+00],
          ...,
          [-4.9227e-01,  1.2478e+00,  1.5013e-01],
          [ 1.8489e+00,  1.4774e+00,  6.5997e-01],
          [ 1.8697e+00,  1.1624e+00,  6.7041e-01]],

         ...,

        

outputs tensor([[[[-3.6378e-01, -1.4077e+00,  1.1711e+00],
          [ 6.4539e-01,  2.5299e-01,  9.3547e-01],
          [ 3.5280e-01,  3.0837e-01,  1.4409e+00],
          ...,
          [ 3.3706e-01,  2.6637e-01,  7.1249e-01],
          [ 8.7036e-02,  2.3496e-01,  1.0392e+00],
          [ 1.1591e+00,  2.8972e-01,  6.7855e-01]],

         [[-4.8885e-01, -4.3641e-01, -2.9532e-01],
          [ 1.4055e+00,  4.8818e-01,  1.1411e+00],
          [ 1.3232e+00,  1.0695e+00,  2.0730e+00],
          ...,
          [ 2.5602e+00,  1.3651e+00,  1.1389e+00],
          [ 1.7027e+00,  1.3845e+00,  1.4441e+00],
          [ 1.9065e+00,  1.2463e+00,  3.4794e-01]],

         [[-3.3751e-02, -3.5692e-01,  2.4275e-01],
          [ 5.9298e-01,  4.1956e-01,  9.2612e-01],
          [ 1.0109e+00,  9.1336e-01,  9.9144e-01],
          ...,
          [ 2.3807e+00,  1.7067e+00,  5.4886e-01],
          [ 8.8623e-01,  1.1757e+00,  6.8001e-01],
          [ 1.4854e+00,  2.3997e-01,  1.8351e-02]],

         ...,

        

outputs tensor([[[[-3.8271e-01, -2.2116e+00,  4.5771e-01],
          [-1.0017e+00, -8.4963e-02,  4.6972e-02],
          [-1.2446e+00, -3.4466e-01,  8.2070e-01],
          ...,
          [ 2.5816e-01, -1.3723e-01,  1.9843e-01],
          [-1.6042e+00, -8.3394e-01,  1.0443e-01],
          [-3.4165e-01, -1.0268e+00,  2.5218e-01]],

         [[-7.9596e-01, -1.7285e+00, -7.2501e-01],
          [ 1.0356e+00, -1.3918e+00, -2.4212e+00],
          [-5.0273e-01, -3.2476e-01, -3.6107e-01],
          ...,
          [ 9.4058e-01, -5.4449e-01,  1.4867e+00],
          [-1.0498e-01, -2.1972e-01,  9.7261e-01],
          [ 1.7544e+00, -1.1896e+00,  1.1664e+00]],

         [[-7.5289e-01, -3.1326e-01, -2.5814e-01],
          [-1.2154e+00,  7.1191e-01, -2.8250e+00],
          [-5.8395e-01, -5.8823e-01,  4.6894e-01],
          ...,
          [-8.8188e-01,  7.1354e-01, -7.3040e-01],
          [-1.6595e-01, -1.0175e+00,  5.4404e-01],
          [ 8.3249e-01, -1.1762e+00,  1.4393e+00]],

         ...,

        

Current step: 10  Loss: 1008042.7164358139  Acc: 0.0  AllocMem (Mb): 0.0
outputs tensor([[[[-3.3488e+00, -9.7410e+00, -1.2963e+00],
          [-3.7755e+00,  2.7221e-01,  1.1559e+00],
          [-3.1712e+00, -1.0527e+00,  2.2804e-01],
          ...,
          [-2.5067e+00, -1.0372e+00,  7.7753e-01],
          [-4.7462e+00, -1.3810e+00,  3.6900e-01],
          [-4.3763e+00, -1.6166e+00, -2.4807e+00]],

         [[-3.2464e+00, -6.5335e+00, -2.6949e+00],
          [-1.8744e+00, -5.4826e+00, -5.3767e+00],
          [-2.9449e+00, -3.8485e+00, -2.4440e+00],
          ...,
          [-1.0986e+00, -1.2638e+00, -8.4933e-01],
          [-2.1388e+00, -1.9019e+00, -2.0023e+00],
          [-5.1913e-01, -2.3761e-01,  2.3176e+00]],

         [[-1.4475e-01, -3.7463e+00, -1.6167e+00],
          [-4.1769e+00, -4.0536e+00, -9.9920e+00],
          [-3.1960e+00, -3.3235e+00,  1.0320e+00],
          ...,
          [-3.0471e-01, -7.8999e-01,  4.7404e-01],
          [-1.1919e+00, -1.2337e+00, -2.0401e-01],
   

outputs tensor([[[[-3.6089e+00, -8.1317e+00, -1.1129e+00],
          [-4.3493e+00, -3.2686e-01,  1.3627e+00],
          [-3.0361e+00, -1.8190e+00,  1.4039e+00],
          ...,
          [-6.6697e-01, -3.5359e-01,  1.1863e+00],
          [-1.6451e+00, -6.4367e-01,  2.6582e-01],
          [-1.0447e+00, -8.5828e-01, -3.0555e-01]],

         [[-2.3631e+00, -6.0536e+00, -1.2679e+00],
          [-1.7436e+00, -7.3621e+00, -7.4343e+00],
          [-3.6674e+00, -3.4249e+00, -8.5804e-01],
          ...,
          [ 1.2704e+00,  7.2555e-01,  1.0586e+00],
          [-1.0003e+00, -6.5445e-01, -1.2262e-01],
          [ 1.0363e+00,  5.4243e-02,  2.4493e+00]],

         [[-7.4607e-01, -3.4242e+00, -1.1922e+00],
          [-6.5909e+00, -5.5375e+00, -1.0577e+01],
          [-2.1380e+00, -4.9389e+00, -1.0205e+00],
          ...,
          [ 1.1898e-01,  5.8987e-01,  9.6102e-01],
          [ 4.3649e-01,  1.3527e-01,  8.8665e-01],
          [ 1.3942e+00,  5.6414e-01,  7.0636e-01]],

         ...,

        

outputs tensor([[[[-2.9809e+00, -6.6875e+00, -1.2459e+00],
          [-4.6329e+00, -5.7075e-01,  1.7375e+00],
          [-2.8137e+00, -1.3435e+00,  1.7276e+00],
          ...,
          [-5.1482e-01, -3.0109e-01,  1.6183e+00],
          [-7.9708e-01, -8.8568e-01,  1.0200e+00],
          [-3.4343e-01, -7.3069e-01, -2.0692e-01]],

         [[-3.4706e+00, -6.1748e+00, -4.7943e-02],
          [-1.6763e+00, -6.1519e+00, -4.3216e+00],
          [-2.8887e+00, -4.1932e+00,  4.0518e-02],
          ...,
          [ 1.2694e+00,  1.9586e-01,  8.5539e-01],
          [-7.4458e-01,  1.8240e-01,  5.9624e-01],
          [ 1.3182e+00,  1.2549e+00,  6.2376e-01]],

         [[-1.3959e+00, -2.2462e+00,  1.0458e+00],
          [-4.3548e+00, -6.1045e+00, -7.8903e+00],
          [-2.2691e+00, -4.8085e+00,  7.1599e-01],
          ...,
          [ 4.8096e-01,  1.0700e+00,  9.9497e-01],
          [ 5.9966e-01,  9.8696e-01, -3.9677e-01],
          [ 8.2970e-01,  5.9845e-01,  1.2377e-01]],

         ...,

        

outputs tensor([[[[-1.0939e+00, -3.5993e+00, -4.7479e-01],
          [-6.5242e-01, -2.9268e-01,  5.2842e-01],
          [-2.0204e+00, -1.4536e+00,  8.0757e-01],
          ...,
          [-2.9838e-01, -9.2746e-02,  1.6739e+00],
          [-9.0621e-01, -4.7034e-01,  1.0459e+00],
          [ 3.2884e-01, -2.8447e-01, -7.6597e-02]],

         [[-2.2713e+00, -1.9637e+00, -1.9856e+00],
          [ 3.8560e-01, -8.9903e-01, -2.9988e+00],
          [-1.5845e+00, -8.6694e-01, -4.0110e-01],
          ...,
          [ 1.0981e+00,  4.8032e-01,  7.7737e-01],
          [-1.2026e-01,  8.3708e-01, -3.4078e-02],
          [ 1.5372e+00,  9.4855e-01,  3.6874e-01]],

         [[-1.2679e+00, -1.6537e+00, -1.6171e+00],
          [-1.3712e+00, -1.8455e-01, -3.0485e+00],
          [-1.6189e+00, -2.5627e+00, -3.3842e-01],
          ...,
          [ 3.1644e-01,  1.0790e+00, -1.5686e-01],
          [ 6.1744e-01,  4.3824e-01,  9.5445e-01],
          [ 5.8242e-01,  7.5404e-01, -5.7033e-01]],

         ...,

        

KeyboardInterrupt: 