In [2]:
from google.colab import drive

drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
ISICDATA = '/content/drive/MyDrive/ISIC'

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as T
import torch.nn.functional as F
import numpy as np
import pandas as pd
import os
from skimage import io
from PIL import Image
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

DataSet Implementation

In [5]:
class FamilyHistoryDataSet(torch.utils.data.Dataset):
  def __init__(self, csv_file, root_dir, transforms=None):
    self.annotations = pd.read_csv(os.path.join(root_dir, csv_file))
    self.root_dir = root_dir
    self.transforms = transforms

  def __len__(self):
    return len(self.annotations)

  def __getitem__(self, index):
    img_path = os.path.join(self.root_dir, self.annotations.iloc[index, 0])
    image = Image.open(img_path)
    y_label = torch.tensor(int(self.annotations.iloc[index, 1]))

    if self.transforms:
      image = self.transforms(image)

    return (image, y_label)

  def get_splits(self, splits=[0.8, 0.2]):
    train_split = round(len(self.annotations)*splits[0])
    test_split = len(self.annotations) - train_split
    return (train_split, test_split)

Hyperparameters, Dataloader & Split

In [14]:
learning_rate = 1e-3
batch_size = 64
num_epochs = 1
img_crop_size = 85
# model params
n_classes = 2
in_features = 3
# data
ISIC_MEAN = [1.2721, 0.3341, -0.0479]
ISIC_STD = [0.2508, 0.2654, 0.3213]

dataset = FamilyHistoryDataSet(csv_file= 'family_history.csv', root_dir = ISICDATA,
                               transforms=T.Compose(
                                  [T.CenterCrop(img_crop_size),
                                   T.ToTensor(),
                                   T.Normalize(ISIC_MEAN, ISIC_STD)]))

# Test on 1/5 of data
train_split, test_split = dataset.get_splits()
train_set, test_set = torch.utils.data.random_split(dataset, [train_split, test_split])
train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True, pin_memory=True)
test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=batch_size, shuffle=True, pin_memory=True)


Model

In [15]:
class CNN(nn.Module):
    def __init__(self, n_classes, in_features):
        super().__init__()
        self.conv1 = nn.Conv2d(in_features, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(6480, 32)
        self.fc2 = nn.Linear(32, n_classes)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=True)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


model = CNN(n_classes, in_features)
model.to(device)

CNN(
  (conv1): Conv2d(3, 10, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2_drop): Dropout2d(p=0.5, inplace=False)
  (fc1): Linear(in_features=6480, out_features=32, bias=True)
  (fc2): Linear(in_features=32, out_features=2, bias=True)
)

Loss + Optimization Algorithm

Training

In [18]:
def basic_training_loop(loader, model, loss_func, optimizer, device):
    for batch_idx, (data, labels) in tqdm(enumerate(loader), total=len(loader), leave=False):
        data = data.to(device)
        labels = labels.to(device)

        scores = model(data)
        loss = loss_func(scores, labels)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if batch_idx > 20:
          break

Evaluation

In [13]:
def basic_validation(loader, model, loss_func, device):
    num_correct, num_samples = 0, 0
    test_loss = 0
    model.eval()

    with torch.no_grad():
        for batch_idx, (x, y) in tqdm(enumerate(loader), total=len(loader), leave=False):
            print(x)
            print(y)
            x = x.to(device=device)
            y = y.to(device=device)

            pred = model(x)
            test_loss += loss_func(pred, y).item()
            _, predictions = pred.max(1)
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
            # if test_loss != 0:
            #     break
        print(
            f"Got {num_correct} / {num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}"
        )

    model.train()

In [17]:
class OptimizationLoop:
    def __init__(self, params) -> None:
        self.n_epochs = params['n_epochs']
        self.training = params['train_loop']
        self.validation = params['validation_loop']
        
        self.model = params['model']
        self.train_loader = params['train_loader']
        self.test_loader = params['test_loader']
        self.loss_func = params['loss']
        self.optimizer = params['optim']
        self.device = params['device']

    def optimize(self) -> None:
        for _ in range(self.n_epochs):
            self.training(
               self.train_loader, self.model,
               self.loss_func, self.optimizer, self.device)
            self.validation(
                self.test_loader, self.model,
                self.loss_func, self.device)

In [19]:
params = {
    'n_epochs': num_epochs,
    'train_loop': basic_training_loop,
    'validation_loop': basic_validation,
    'model': model,
    'train_loader': train_loader,
    'test_loader': test_loader,
    'loss': nn.CrossEntropyLoss(),
    'optim': optim.SGD(model.parameters(), lr=learning_rate),
    'device': device
}
optim_loop = OptimizationLoop(params)
optim_loop.optimize()

  6%|▌         | 1/17 [01:08<18:18, 68.63s/it]

tensor([[[[-1.1631, -1.1475, -1.2257,  ..., -1.0849, -1.0849, -1.0849],
          [-1.1006, -1.0849, -1.2100,  ..., -1.0849, -1.0849, -1.0849],
          [-1.1631, -1.1006, -1.1787,  ..., -1.2100, -1.2100, -1.2257],
          ...,
          [-1.2100, -1.2100, -1.2413,  ..., -1.1944, -1.1787, -1.1475],
          [-1.2100, -1.2257, -1.2413,  ..., -1.1944, -1.1631, -1.1318],
          [-1.2257, -1.2413, -1.2569,  ..., -1.1944, -1.1631, -1.1162]],

         [[ 1.3417,  1.3565,  1.3417,  ...,  1.5929,  1.6373,  1.6077],
          [ 1.3565,  1.3713,  1.3565,  ...,  1.5043,  1.5634,  1.5338],
          [ 1.2974,  1.3565,  1.3861,  ...,  1.3861,  1.3861,  1.3713],
          ...,
          [ 1.3861,  1.3861,  1.3565,  ...,  1.4008,  1.4156,  1.4452],
          [ 1.3861,  1.3713,  1.3565,  ...,  1.4008,  1.4304,  1.4599],
          [ 1.3713,  1.3565,  1.3417,  ...,  1.4008,  1.4304,  1.4747]],

         [[ 2.1874,  2.1996,  2.1508,  ...,  2.3460,  2.3827,  2.3582],
          [ 2.2118,  2.2240,  

 12%|█▏        | 2/17 [02:13<16:35, 66.36s/it]

tensor([[[[-1.2726, -1.2569, -1.2882,  ..., -1.3195, -1.4446, -1.5697],
          [-1.2726, -1.2569, -1.2569,  ..., -1.4446, -1.4289, -1.4133],
          [-1.2882, -1.3038, -1.2569,  ..., -1.4758, -1.4602, -1.4289],
          ...,
          [-1.3507, -1.3507, -1.3195,  ..., -1.2726, -1.3038, -1.3351],
          [-1.3351, -1.3507, -1.3195,  ..., -1.3195, -1.3507, -1.3664],
          [-1.3195, -1.3351, -1.3195,  ..., -1.3820, -1.3977, -1.3977]],

         [[ 1.4304,  1.4452,  1.4156,  ...,  1.3270,  1.2087,  1.0905],
          [ 1.4304,  1.4452,  1.4008,  ...,  1.1940,  1.2087,  1.2235],
          [ 1.4156,  1.4008,  1.4156,  ...,  1.1644,  1.1792,  1.2087],
          ...,
          [ 1.2974,  1.2974,  1.2974,  ...,  1.3565,  1.3270,  1.2974],
          [ 1.3122,  1.2974,  1.2974,  ...,  1.3122,  1.2826,  1.2679],
          [ 1.3270,  1.3122,  1.2974,  ...,  1.2531,  1.2383,  1.2383]],

         [[ 2.0287,  2.0409,  2.0043,  ...,  1.8944,  1.7968,  1.6992],
          [ 2.0165,  2.0287,  

 18%|█▊        | 3/17 [03:18<15:18, 65.63s/it]

tensor([[[[-1.9293, -1.9449, -1.9449,  ..., -1.7416, -1.7729, -1.8042],
          [-1.9449, -1.9606, -1.9606,  ..., -1.7573, -1.7886, -1.8042],
          [-1.8824, -1.8511, -1.9449,  ..., -1.7260, -1.7573, -1.7573],
          ...,
          [-1.7573, -1.7729, -1.7416,  ..., -1.6166, -1.6009, -1.6009],
          [-1.8042, -1.8198, -1.7729,  ..., -1.6478, -1.6322, -1.6166],
          [-1.8198, -1.8355, -1.8042,  ..., -1.6947, -1.6635, -1.6322]],

         [[ 0.4995,  0.4847,  0.5143,  ...,  0.8098,  0.7802,  0.7507],
          [ 0.4847,  0.4699,  0.4995,  ...,  0.7950,  0.7655,  0.7507],
          [ 0.5438,  0.5734,  0.5143,  ...,  0.8246,  0.7950,  0.7950],
          ...,
          [ 0.8098,  0.7950,  0.8098,  ...,  0.9280,  0.9428,  0.9428],
          [ 0.7655,  0.7507,  0.7802,  ...,  0.8985,  0.9132,  0.9280],
          [ 0.7507,  0.7359,  0.7507,  ...,  0.8541,  0.8837,  0.9132]],

         [[ 0.8692,  0.8570,  0.8692,  ...,  1.2354,  1.2109,  1.1865],
          [ 0.8570,  0.8448,  

 24%|██▎       | 4/17 [04:22<14:07, 65.16s/it]

tensor([[[[-1.4758, -1.5071, -1.5384,  ..., -1.3820, -1.3977, -1.4446],
          [-1.5071, -1.5071, -1.5384,  ..., -1.3664, -1.3820, -1.4133],
          [-1.4602, -1.5540, -1.4915,  ..., -1.3977, -1.4133, -1.4289],
          ...,
          [-1.4602, -1.4446, -1.4289,  ..., -1.4602, -1.4758, -1.4915],
          [-1.3977, -1.3977, -1.3664,  ..., -1.4602, -1.4915, -1.5227],
          [-1.3820, -1.3977, -1.3664,  ..., -1.4758, -1.5071, -1.5227]],

         [[ 1.0019,  1.0167,  0.9871,  ...,  1.2235,  1.2087,  1.1644],
          [ 0.9723,  0.9723,  0.9871,  ...,  1.2383,  1.2235,  1.1940],
          [ 1.0167,  0.9280,  1.0314,  ...,  1.2087,  1.1940,  1.1792],
          ...,
          [ 1.2531,  1.2679,  1.2531,  ...,  1.1940,  1.1792,  1.1644],
          [ 1.3122,  1.3122,  1.2679,  ...,  1.1940,  1.1644,  1.1349],
          [ 1.3270,  1.3122,  1.2679,  ...,  1.1792,  1.1496,  1.1201]],

         [[ 2.2362,  2.2362,  2.1874,  ...,  2.3338,  2.3216,  2.2850],
          [ 2.2118,  2.2118,  

 29%|██▉       | 5/17 [05:26<12:56, 64.70s/it]

tensor([[[[-1.0849, -1.0849, -1.0849,  ..., -1.0849, -1.0849, -1.0849],
          [-1.0849, -1.0849, -1.0849,  ..., -1.0849, -1.0849, -1.0849],
          [-1.0849, -1.0849, -1.0849,  ..., -1.0849, -1.0849, -1.0849],
          ...,
          [-1.0849, -1.0849, -1.0849,  ..., -1.1162, -1.1006, -1.0849],
          [-1.0849, -1.0849, -1.0849,  ..., -1.1006, -1.1006, -1.0849],
          [-1.0849, -1.0849, -1.0849,  ..., -1.0849, -1.0849, -1.0849]],

         [[ 1.8441,  1.8441,  1.8441,  ...,  1.8146,  1.8146,  1.8146],
          [ 1.8441,  1.8441,  1.8441,  ...,  1.8146,  1.8146,  1.8146],
          [ 1.8293,  1.8293,  1.8441,  ...,  1.7998,  1.7998,  1.7998],
          ...,
          [ 1.8441,  1.8441,  1.8293,  ...,  1.7407,  1.7555,  1.7702],
          [ 1.8441,  1.8441,  1.8293,  ...,  1.7555,  1.7555,  1.7702],
          [ 1.8441,  1.8441,  1.8146,  ...,  1.7702,  1.7702,  1.7702]],

         [[ 2.5169,  2.5169,  2.5169,  ...,  2.4437,  2.4437,  2.4437],
          [ 2.5169,  2.5169,  

 35%|███▌      | 6/17 [06:30<11:50, 64.58s/it]

tensor([[[[-3.5398, -3.5398, -3.5085,  ..., -3.6805, -3.6962, -3.7118],
          [-3.5398, -3.5242, -3.5555,  ..., -3.6962, -3.7275, -3.7587],
          [-3.6336, -3.6180, -3.5711,  ..., -3.7431, -3.7431, -3.7275],
          ...,
          [-4.2122, -4.2122, -4.2591,  ..., -3.9933, -3.9933, -3.9933],
          [-4.2122, -4.1965, -4.2434,  ..., -3.9933, -3.9933, -3.9933],
          [-4.2122, -4.1965, -4.2278,  ..., -3.9933, -3.9933, -3.9776]],

         [[-0.4166, -0.4166, -0.4905,  ..., -0.6383, -0.6530, -0.6678],
          [-0.4166, -0.4018, -0.5348,  ..., -0.6530, -0.6826, -0.7121],
          [-0.5053, -0.4905, -0.5644,  ..., -0.6826, -0.6826, -0.6678],
          ...,
          [-0.8303, -0.8303, -0.8747,  ..., -0.6974, -0.6974, -0.6974],
          [-0.8303, -0.8156, -0.8599,  ..., -0.6974, -0.6974, -0.6974],
          [-0.8303, -0.8156, -0.8451,  ..., -0.6974, -0.6974, -0.6826]],

         [[ 1.3574,  1.3574,  1.2964,  ...,  1.1011,  1.0889,  1.0767],
          [ 1.3574,  1.3696,  

 41%|████      | 7/17 [07:36<10:49, 64.91s/it]

tensor([[[[-2.1013, -2.0544, -2.0387,  ..., -2.0231, -2.0544, -2.0856],
          [-2.0544, -2.0231, -1.8667,  ..., -1.9918, -2.0075, -2.0544],
          [-1.9136, -1.9606, -2.0700,  ..., -2.0700, -2.0856, -2.1169],
          ...,
          [-1.9136, -1.9606, -1.8824,  ..., -1.7729, -1.8198, -1.8980],
          [-1.9449, -2.0075, -1.8824,  ..., -1.8198, -1.8667, -1.9136],
          [-1.9762, -2.0231, -1.9606,  ..., -1.8667, -1.8980, -1.9293]],

         [[ 0.4552,  0.4995,  0.4404,  ...,  0.3370,  0.3074,  0.2779],
          [ 0.4552,  0.4847,  0.6029,  ...,  0.3665,  0.3517,  0.3074],
          [ 0.5586,  0.5438,  0.4108,  ...,  0.3074,  0.2926,  0.2631],
          ...,
          [ 0.4995,  0.4552,  0.3517,  ...,  0.5882,  0.5438,  0.4699],
          [ 0.4995,  0.4404,  0.3517,  ...,  0.5438,  0.4995,  0.4552],
          [ 0.4699,  0.4256,  0.3074,  ...,  0.4995,  0.4699,  0.4404]],

         [[ 0.6983,  0.7349,  0.7349,  ...,  0.5030,  0.4786,  0.4542],
          [ 0.7349,  0.7593,  

 47%|████▋     | 8/17 [08:39<09:39, 64.41s/it]

tensor([[[[-1.1162, -1.1318, -1.0849,  ..., -1.1006, -1.1318, -1.1631],
          [-1.0849, -1.0849, -1.0849,  ..., -1.0849, -1.1318, -1.1631],
          [-1.1006, -1.1162, -1.0849,  ..., -1.1318, -1.1475, -1.1475],
          ...,
          [-1.1631, -1.1631, -1.0849,  ..., -1.1006, -1.1162, -1.1475],
          [-1.1631, -1.1631, -1.0849,  ..., -1.1006, -1.1162, -1.1162],
          [-1.1787, -1.1787, -1.1006,  ..., -1.0849, -1.0849, -1.1162]],

         [[ 1.7850,  1.7702,  1.7850,  ...,  1.7998,  1.7702,  1.7111],
          [ 1.7850,  1.7850,  1.7850,  ...,  1.8146,  1.7702,  1.7111],
          [ 1.7702,  1.7555,  1.7850,  ...,  1.7702,  1.7555,  1.7259],
          ...,
          [ 1.7555,  1.7555,  1.7850,  ...,  1.9180,  1.9032,  1.8737],
          [ 1.7555,  1.7555,  1.7702,  ...,  1.9180,  1.9032,  1.8589],
          [ 1.7259,  1.7259,  1.7555,  ...,  1.9180,  1.9032,  1.8737]],

         [[ 2.4681,  2.4559,  2.4803,  ...,  2.4803,  2.4559,  2.3949],
          [ 2.4803,  2.4803,  

 53%|█████▎    | 9/17 [09:43<08:33, 64.15s/it]

tensor([[[[-1.2413, -1.2257, -1.2882,  ..., -1.1631, -1.1787, -1.1787],
          [-1.2257, -1.2100, -1.2726,  ..., -1.1944, -1.1944, -1.1944],
          [-1.3195, -1.3351, -1.2100,  ..., -1.2100, -1.2100, -1.2257],
          ...,
          [-1.2569, -1.2569, -1.3351,  ..., -1.4133, -1.5071, -1.6166],
          [-1.2569, -1.2569, -1.3351,  ..., -1.3507, -1.4289, -1.5071],
          [-1.2569, -1.2569, -1.2882,  ..., -1.4133, -1.4289, -1.4446]],

         [[ 1.6668,  1.6816,  1.6373,  ...,  1.7702,  1.7555,  1.7555],
          [ 1.6520,  1.6668,  1.6520,  ...,  1.7407,  1.7407,  1.7407],
          [ 1.5634,  1.5486,  1.6964,  ...,  1.7407,  1.7407,  1.7259],
          ...,
          [ 1.5634,  1.5634,  1.5338,  ...,  1.4452,  1.3565,  1.2531],
          [ 1.5634,  1.5634,  1.5338,  ...,  1.5043,  1.4304,  1.3565],
          [ 1.5486,  1.5486,  1.5338,  ...,  1.4747,  1.4599,  1.4452]],

         [[ 2.4925,  2.5047,  2.3827,  ...,  2.6146,  2.6024,  2.6024],
          [ 2.5047,  2.5169,  

 59%|█████▉    | 10/17 [10:45<07:24, 63.53s/it]

tensor([[[[-1.5071, -1.4758, -1.5697,  ..., -1.4915, -1.4915, -1.4758],
          [-1.5071, -1.4758, -1.5384,  ..., -1.4915, -1.4915, -1.4915],
          [-1.5227, -1.5384, -1.5384,  ..., -1.4602, -1.4289, -1.3977],
          ...,
          [-1.6791, -1.6791, -1.6478,  ..., -1.5384, -1.5384, -1.5384],
          [-1.6635, -1.6635, -1.6947,  ..., -1.5384, -1.5384, -1.5384],
          [-1.6635, -1.6635, -1.6947,  ..., -1.5384, -1.5384, -1.5384]],

         [[ 1.5486,  1.5782,  1.4895,  ...,  1.6225,  1.6225,  1.6373],
          [ 1.5486,  1.5782,  1.5190,  ...,  1.6225,  1.6225,  1.6225],
          [ 1.4452,  1.4304,  1.4304,  ...,  1.6225,  1.6520,  1.6816],
          ...,
          [ 0.9871,  0.9871,  1.0167,  ...,  1.3713,  1.3713,  1.3713],
          [ 1.0019,  1.0019,  0.9723,  ...,  1.3713,  1.3713,  1.3713],
          [ 1.0019,  1.0019,  0.9723,  ...,  1.3713,  1.3713,  1.3713]],

         [[ 2.3582,  2.3827,  2.3094,  ...,  2.4071,  2.4071,  2.4193],
          [ 2.3582,  2.3827,  

 65%|██████▍   | 11/17 [11:48<06:19, 63.26s/it]

tensor([[[[-1.4133, -1.4289, -1.4133,  ..., -1.2882, -1.3038, -1.3195],
          [-1.4446, -1.4602, -1.4446,  ..., -1.3038, -1.3038, -1.3195],
          [-1.4758, -1.5227, -1.4133,  ..., -1.3820, -1.3820, -1.3820],
          ...,
          [-1.3038, -1.3038, -1.3351,  ..., -1.3195, -1.3351, -1.3664],
          [-1.2882, -1.2882, -1.3038,  ..., -1.3351, -1.3664, -1.3820],
          [-1.3038, -1.3038, -1.3038,  ..., -1.3664, -1.3820, -1.3977]],

         [[ 1.3417,  1.3270,  1.2974,  ...,  1.5043,  1.4895,  1.4747],
          [ 1.2679,  1.2531,  1.2679,  ...,  1.4895,  1.4895,  1.4747],
          [ 1.2383,  1.1940,  1.2826,  ...,  1.4156,  1.4156,  1.4156],
          ...,
          [ 1.4156,  1.4156,  1.4008,  ...,  1.4304,  1.4156,  1.3861],
          [ 1.4008,  1.4008,  1.4304,  ...,  1.4156,  1.3861,  1.3713],
          [ 1.3861,  1.3861,  1.4304,  ...,  1.3861,  1.3713,  1.3565]],

         [[ 2.0531,  2.0409,  2.0043,  ...,  2.3460,  2.3338,  2.3216],
          [ 2.0043,  1.9921,  

 71%|███████   | 12/17 [12:54<05:20, 64.13s/it]

tensor([[[[-2.5391e+00, -2.5704e+00, -2.4766e+00,  ..., -2.8206e+00,
           -2.8049e+00, -2.7893e+00],
          [-2.5235e+00, -2.5547e+00, -2.5235e+00,  ..., -2.8206e+00,
           -2.8049e+00, -2.7893e+00],
          [-2.4609e+00, -2.4609e+00, -2.4766e+00,  ..., -2.7893e+00,
           -2.8049e+00, -2.8049e+00],
          ...,
          [-1.6791e+00, -1.6478e+00, -1.6947e+00,  ..., -2.4296e+00,
           -2.4453e+00, -2.4453e+00],
          [-1.7416e+00, -1.7260e+00, -1.6791e+00,  ..., -2.3046e+00,
           -2.3358e+00, -2.3827e+00],
          [-1.7886e+00, -1.8042e+00, -1.6947e+00,  ..., -2.2420e+00,
           -2.2733e+00, -2.3202e+00]],

         [[-1.5065e-01, -1.8020e-01, -1.2110e-01,  ..., -3.4274e-01,
           -3.2796e-01, -3.1319e-01],
          [-9.1545e-02, -1.2110e-01, -1.6543e-01,  ..., -3.4274e-01,
           -3.2796e-01, -3.1319e-01],
          [-3.2441e-02, -3.2441e-02, -1.2110e-01,  ..., -3.1319e-01,
           -3.2796e-01, -3.2796e-01],
          ...,
     

 76%|███████▋  | 13/17 [14:02<04:21, 65.25s/it]

tensor([[[[-1.4133, -1.3820, -1.4289,  ..., -1.4602, -1.4602, -1.4602],
          [-1.3977, -1.3820, -1.5540,  ..., -1.4446, -1.4602, -1.4602],
          [-1.4289, -1.4446, -1.5227,  ..., -1.4289, -1.3977, -1.3664],
          ...,
          [-1.4758, -1.5384, -1.4602,  ..., -1.4446, -1.4133, -1.3977],
          [-1.4289, -1.4915, -1.3977,  ..., -1.5071, -1.4758, -1.4602],
          [-1.4289, -1.4758, -1.3507,  ..., -1.5540, -1.5227, -1.5071]],

         [[ 1.1792,  1.2087,  1.2383,  ...,  1.0610,  1.0610,  1.0610],
          [ 1.1940,  1.2087,  1.1644,  ...,  1.0758,  1.0610,  1.0610],
          [ 1.1644,  1.1496,  1.1940,  ...,  1.0905,  1.1201,  1.1496],
          ...,
          [ 1.1792,  1.1201,  1.0462,  ...,  1.0610,  1.0905,  1.1053],
          [ 1.2087,  1.1496,  1.1053,  ...,  1.0019,  1.0314,  1.0462],
          [ 1.1792,  1.1349,  1.1496,  ...,  0.9871,  1.0167,  1.0314]],

         [[ 1.7846,  1.8090,  1.8822,  ...,  1.6137,  1.6137,  1.6137],
          [ 1.7968,  1.8090,  

 82%|████████▏ | 14/17 [15:05<03:14, 64.75s/it]

tensor([[[[-1.6166, -1.6166, -1.5853,  ..., -1.6322, -1.6322, -1.5853],
          [-1.5853, -1.6009, -1.5853,  ..., -1.7104, -1.6947, -1.6478],
          [-1.5853, -1.5697, -1.5853,  ..., -1.6478, -1.5853, -1.5540],
          ...,
          [-1.6635, -1.6947, -1.6791,  ..., -1.6478, -1.6166, -1.6009],
          [-1.6635, -1.6947, -1.6791,  ..., -1.6478, -1.6322, -1.6009],
          [-1.6478, -1.6947, -1.6947,  ..., -1.6791, -1.6478, -1.6009]],

         [[ 0.9723,  0.9723,  0.9132,  ...,  0.8837,  0.8985,  0.9428],
          [ 0.9576,  0.9428,  0.9132,  ...,  0.8098,  0.8393,  0.8837],
          [ 0.9723,  0.9871,  0.9132,  ...,  0.8689,  0.9280,  0.9723],
          ...,
          [ 0.8541,  0.8246,  0.8393,  ...,  0.9280,  0.9576,  0.9723],
          [ 0.8541,  0.8246,  0.8393,  ...,  0.9280,  0.9428,  0.9723],
          [ 0.8541,  0.8098,  0.8098,  ...,  0.8985,  0.9280,  0.9723]],

         [[ 1.5771,  1.5771,  1.5527,  ...,  1.4306,  1.4551,  1.4917],
          [ 1.5527,  1.5405,  

 88%|████████▊ | 15/17 [16:06<02:07, 63.56s/it]

tensor([[[[-1.3977, -1.3664, -1.4133,  ..., -1.7104, -1.7260, -1.7104],
          [-1.4915, -1.3664, -1.4133,  ..., -1.8667, -1.7573, -1.6635],
          [-1.4133, -1.4289, -1.3820,  ..., -1.7104, -1.7416, -1.7886],
          ...,
          [-1.5071, -1.5071, -1.4289,  ..., -1.4602, -1.4602, -1.4758],
          [-1.4915, -1.4915, -1.3351,  ..., -1.4133, -1.4289, -1.4446],
          [-1.4602, -1.4602, -1.3038,  ..., -1.4602, -1.4758, -1.4602]],

         [[ 1.1940,  1.2235,  1.1496,  ...,  0.9576,  0.9428,  0.9576],
          [ 1.1053,  1.1792,  1.1644,  ...,  0.8098,  0.9132,  1.0019],
          [ 1.1792,  1.1201,  1.1940,  ...,  0.9576,  0.9280,  0.8837],
          ...,
          [ 1.2235,  1.2235,  1.2679,  ...,  1.2974,  1.2974,  1.2826],
          [ 1.2383,  1.2383,  1.3565,  ...,  1.3417,  1.3270,  1.3122],
          [ 1.2383,  1.2383,  1.3713,  ...,  1.2974,  1.2826,  1.2974]],

         [[ 1.9921,  2.0165,  1.8700,  ...,  1.7602,  1.7480,  1.7602],
          [ 1.9189,  1.9799,  

 94%|█████████▍| 16/17 [17:09<01:03, 63.32s/it]

tensor([[[[-2.9456, -2.9769, -2.9144,  ..., -2.0544, -1.8511, -2.2889],
          [-2.9613, -2.9925, -2.8987,  ..., -2.1795, -1.7260, -2.2107],
          [-2.9144, -2.8987, -2.8049,  ..., -1.8824, -2.0387, -2.3827],
          ...,
          [-2.5391, -2.4453, -2.6798,  ..., -2.3515, -2.3202, -2.2889],
          [-2.5704, -2.4922, -2.5547,  ..., -2.3671, -2.3358, -2.2889],
          [-2.5860, -2.5391, -2.5704,  ..., -2.3827, -2.3515, -2.2889]],

         [[-0.4757, -0.5053, -0.5644,  ...,  0.3517,  0.5586,  0.1005],
          [-0.5348, -0.5644, -0.5496,  ...,  0.2483,  0.6325,  0.1449],
          [-0.4905, -0.4757, -0.4905,  ...,  0.4847,  0.3370, -0.0324],
          ...,
          [-0.3280, -0.2393, -0.3427,  ..., -0.1950, -0.1654, -0.1359],
          [-0.3132, -0.2393, -0.2245,  ..., -0.2098, -0.1802, -0.1359],
          [-0.3132, -0.2689, -0.1950,  ..., -0.1950, -0.1654, -0.1359]],

         [[ 0.6251,  0.6007,  0.5030,  ...,  1.3452,  1.4795,  1.0889],
          [ 0.5885,  0.5641,  

                                               

tensor([[[[-1.9293, -1.9449, -1.9918,  ..., -2.2733, -2.3046, -2.3358],
          [-1.9293, -1.9449, -1.9918,  ..., -2.2733, -2.2889, -2.3202],
          [-1.8980, -1.8980, -1.9762,  ..., -2.3358, -2.2576, -2.2264],
          ...,
          [-1.9918, -1.9449, -2.0075,  ..., -1.9449, -1.8980, -1.8667],
          [-2.0075, -1.9449, -2.0075,  ..., -1.9449, -1.9136, -1.8824],
          [-2.0075, -1.9762, -2.0075,  ..., -1.9606, -1.9293, -1.9136]],

         [[ 0.2926,  0.2779,  0.2335,  ..., -0.1654, -0.1950, -0.2245],
          [ 0.2926,  0.2779,  0.2335,  ..., -0.1506, -0.1654, -0.1950],
          [ 0.3222,  0.3222,  0.2483,  ..., -0.2098, -0.1359, -0.1063],
          ...,
          [ 0.1892,  0.2335,  0.2483,  ...,  0.2188,  0.2631,  0.2926],
          [ 0.1744,  0.2335,  0.2483,  ...,  0.2188,  0.2483,  0.2779],
          [ 0.1892,  0.2188,  0.2483,  ...,  0.1892,  0.2188,  0.2483]],

         [[ 0.6861,  0.6739,  0.6373,  ...,  0.4664,  0.4420,  0.4176],
          [ 0.6861,  0.6739,  

