In [None]:
import os
import time
import random
from collections import namedtuple

from tqdm import tnrange
import numpy as np
import cv2
from parse import parse

import torch
from torch.utils import data
from torch.optim.lr_scheduler import StepLR
from torchvision import transforms

from sdcdup.utils import get_datetime_now
from sdcdup.utils import channel_shift
from sdcdup.utils import to_hls
from sdcdup.utils import to_bgr
from sdcdup.utils import create_dataset_from_tiles_and_truth
from sdcdup.utils import even_split

from test_friend_circles import SDCImageContainer
from datasets import create_dataset_from_truth
from dupnet import create_loss_and_optimizer
from dupnet import save_checkpoint
from dupnet import DupCNN

%matplotlib inline
%reload_ext autoreload
%autoreload 2

ship_dir = "data/input/"
train_768_dir = os.path.join("data", "train_768")
train_256_dir = os.path.join(ship_dir, "train_256")
train_image_dir = os.path.join(ship_dir, "train_768")
image_md5hash_grids_file = os.path.join("data", "image_md5hash_grids.pkl")
image_bm0hash_grids_file = os.path.join("data", "image_bm0hash_grids.pkl")
image_cm0hash_grids_file = os.path.join("data", "image_cm0hash_grids.pkl")
image_greycop_grids_file = os.path.join("data", "image_greycop_grids.pkl")
image_entropy_grids_file = os.path.join("data", "image_entropy_grids.pkl")
image_issolid_grids_file = os.path.join("data", "image_issolid_grids.pkl")
duplicate_truth_file = os.path.join("data", "duplicate_truth.txt")
os.makedirs(train_256_dir, exist_ok=True)

In [None]:
class RandomHorizontalFlip:
    """Horizontally flip the given numpy array randomly with a given probability.

    Args:
        p (float): probability of the image being flipped. Default value is 0.5
    """

    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, img):
        """
        Args:
            img (Image): Image to be flipped.

        Returns:
            Image: Randomly flipped image.
        """
        if np.random.random() < self.p:
            return cv2.flip(img, 1)
        return img

    def __repr__(self):
        return self.__class__.__name__ + '(p={})'.format(self.p)


class RandomTransformC4:
    """Rotate a n-D tensor by 90 degrees in the H x W plane.

    Args:
        with_identity (bool): whether or not to include 0 degrees as a probable rotation
    """

    def __init__(self, with_identity=True):
        self.with_identity = with_identity
        self.n90s = (0, 1, 2, 3) if self.with_identity else (1, 2, 3)

    def __call__(self, img):
        """
        Args:
            img (Image): Image to be rotated.

        Returns:
            Image: Randomly rotated image but in 90 degree increments.
        """
        k = random.choice(self.n90s)
        return torch.rot90(img, k, (1, 2))

    def __repr__(self):
        return self.__class__.__name__ + '(with_identity={})'.format(self.with_identity)


In [None]:
idx_chan_map = {0: 'H', 1: 'L', 2: 'S'}

ImgAugs = namedtuple('ImgAugs', 'idx3 chan gain')

class Dataset(data.Dataset):
    """Characterizes a dataset for PyTorch"""
    def __init__(self, img_overlaps, train_768_dir, train_256_dir, train_or_valid, image_transform,
                 in_shape=(6, 256, 256), 
                 out_shape=(1,)):
        
        """Initialization"""
        self.img_overlaps = img_overlaps
        self.train_768_dir = train_768_dir
        self.train_256_dir = train_256_dir
        # TODO: handle case if train_or_valid == 'test'
        self.valid = train_or_valid == 'valid'
        self.image_transform = image_transform
        self.ij = ((0, 0), (0, 1), (0, 2),
                   (1, 0), (1, 1), (1, 2),
                   (2, 0), (2, 1), (2, 2))
        
        self.in_shape = in_shape
        self.out_shape = out_shape
        self.hls_limits = {'H': 5, 'L': 10, 'S': 10}
        if self.valid:
            self.img_augs = [self.get_random_augmentation() for _ in self.img_overlaps]

    def __len__(self):
        """Denotes the total number of samples"""
        return len(self.img_overlaps)

    def __getitem__(self, index):
        """Generates one sample of data"""
        if self.valid:
            img_aug = self.img_augs[index]
        else:
            img_aug = self.get_random_augmentation()
        X, y = self.get_data_pair(self.img_overlaps[index], img_aug)
        return X, y
    
    def get_random_augmentation(self):

        idx3 = np.random.choice(4, p=[0.4, 0.1, 0.1, 0.4])
        
        hls_idx = np.random.choice(3)
        hls_chan = idx_chan_map[hls_idx]
        hls_gain = np.random.choice(self.hls_limits[hls_chan]) + 1
        hls_gain = hls_gain if np.random.random() > 0.5 else -1 * hls_gain
        
        return ImgAugs(idx3, hls_chan, hls_gain)
    
    def color_shift(self, img, chan, gain):
        hls = to_hls(img)
        hls_shifted = channel_shift(hls, chan, gain)
        return to_bgr(hls_shifted)
    
    def get_tile(self, img, ij, sz=256):
        i, j = self.ij[ij]
        return img[i * sz:(i + 1) * sz, j * sz:(j + 1) * sz, :]
    
    def read_from_large(self, img_id, ij):
        img = cv2.imread(os.path.join(self.train_768_dir, img_id))
        return self.get_tile(img, ij)
    
    def read_from_small(self, img_id, ij):
        filebase, fileext = img_id.split('.')
        tile_id = f'{filebase}_{ij}.{fileext}'
        return cv2.imread(os.path.join(self.train_256_dir, tile_id))
    
    def get_data_pair(self, img_overlap, img_aug):

        # diff img_id (img1_id != img2_id), random tile from overlap, where is_dup == 1 (from duplicate_truth.txt)
            # img1_[i,j], img2_[k,l], 1, exact or fuzzy
            # img1_[i,j], tile2_kl, 1, exact or fuzzy
            # tile1_ij, img2_[k,l], 1, exact or fuzzy
            # tile1_ij, tile2_kl, 1, exact or fuzzy
        
        # same img_id (img1_id == img2_id), same tile (ij == kl)
            # img1_[i,j], img1_[i,j], 1, exact
            # img1_[i,j], tile1_ij, 1, fuzzy
            # tile1_ij, img1_[i,j], 1, fuzzy
            # tile1_ij, tile1_ij, 1, exact
            
        # same img_id (img1_id == img2_id), diff tile (ij != kl)
            # img1_[i,j], img1_[k,l], 0, similar but different
            # img1_[i,j], tile1_kl, 0, similar but different
            # tile1_ij, img1_[k,l], 0, similar but different
            # tile1_ij, tile1_kl, 0, similar but different
            
        # diff img_id (img1_id != img2_id), same tile (ij == kl)
            # img1_[i,j], img2_[i,j], 0, very different
            # img1_[i,j], tile2_ij, 0, very different
            # tile1_ij, img2_[i,j], 0, very different
            # tile1_ij, tile2_ij, 0, very different
            
        # diff img_id (img1_id != img2_id), diff tile (ij != kl)
            # img1_[i,j], img2_[k,l], 0, very different
            # img1_[i,j], tile2_kl, 0, very different
            # tile1_ij, img2_[k,l], 0, very different
            # tile1_ij, tile2_kl, 0, very different
        
        # use image_md5hash_grids.pkl for equal image id pairs (img1_id == img2_id)
        #--------------------------------------------------------------------
        # ij == kl? | tile1? | tile2? | shift? | is_dup?
        #--------------------------------------------------------------------
        #   yes     |  768   |  768   |   yes  |    yes agro color shift 
        #   yes     |  768   |  768   |    no  |    yes
        #   yes     |  768   |  256   |    no  |    yes 
        #   yes     |  256   |  768   |    no  |    yes 
        #   yes     |  256   |  256   |   yes  |    yes agro color shift 
        #   yes     |  256   |  256   |    no  |    yes 
        #    no     |  768   |  768   |   yes  |     no 
        #    no     |  768   |  768   |    no  |     no 
        #    no     |  256   |  256   |   yes  |     no 
        #    no     |  256   |  256   |    no  |     no 
        
        # use duplicate_truth.txt for unequal image id pairs (img1_id != img2_id)
        # NOTE: Be sure to use the overlap_map when comparing ij and kl
        #--------------------------------------------------------------------
        # ij == kl? | tile1? | tile2? | shift? | is_dup?
        #--------------------------------------------------------------------
        #   yes     |  768   |  768   |   yes  |    yes small color shift 
        #   yes     |  768   |  768   |    no  |    yes
        #   yes     |  768   |  256   |    no  |    yes
        #   yes     |  256   |  768   |    no  |    yes
        #   yes     |  256   |  256   |   yes  |    yes
        #   yes     |  256   |  256   |    no  |    yes
        #    no     |  768   |  768   |   yes  |     no
        #    no     |  768   |  768   |    no  |     no
        #    no     |  256   |  256   |   yes  |     no 
        #    no     |  256   |  256   |    no  |     no

        img1_id, img2_id, idx1, idx2, is_dup = img_overlap
        idx3, chan, gain = img_aug

        if img1_id == img2_id:
            if is_dup:  # ij == kl
                if idx3 == 0:
                    tile1 = self.read_from_large(img1_id, idx1)
                    tile2 = self.color_shift(tile1, chan, gain)
                elif idx3 == 1:
                    tile1 = self.read_from_large(img1_id, idx1)
                    tile2 = self.read_from_small(img2_id, idx2)
                elif idx3 == 2:
                    tile1 = self.read_from_small(img1_id, idx1)
                    tile2 = self.read_from_large(img2_id, idx2)
                elif idx3 == 3:
                    tile1 = self.read_from_small(img1_id, idx1)
                    tile2 = self.color_shift(tile1, chan, gain)
                else:
                    raise ValueError
            else:  # ij != kl
                # These 4 have pretty much the same effect.
                # The last one (2 256x256 tiles vs 1 768x768 tile) is the fastest.
                # idx3 = np.random.choice(4, p=[0.1, 0.1, 0.1, 1.7])
                # idx3 = 3
                if idx3 == 0: # fast
                    img = cv2.imread(os.path.join(self.train_768_dir, img1_id))
                    tile1 = self.get_tile(img, idx1)
                    tile2 = self.get_tile(img, idx2)
                elif idx3 == 1: # slowest
                    tile1 = self.read_from_large(img1_id, idx1)
                    tile2 = self.read_from_small(img2_id, idx2)
                elif idx3 == 2: # slowest
                    tile1 = self.read_from_small(img1_id, idx1)
                    tile2 = self.read_from_large(img2_id, idx2)
                elif idx3 == 3: # fastest
                    tile1 = self.read_from_small(img1_id, idx1)
                    tile2 = self.read_from_small(img2_id, idx2)
                else:
                    raise ValueError
        else: # img1_id != img2_id
            if is_dup:  # ij == kl
                if idx3 == 0: # slowest
                    tile1 = self.read_from_large(img1_id, idx1)
                    tile2 = self.read_from_large(img2_id, idx2)
                elif idx3 == 1: # slow
                    tile1 = self.read_from_large(img1_id, idx1)
                    tile2 = self.read_from_small(img2_id, idx2)
                elif idx3 == 2: # slow
                    tile1 = self.read_from_small(img1_id, idx1)
                    tile2 = self.read_from_large(img2_id, idx2)
                elif idx3 == 3: # fast
                    tile1 = self.read_from_small(img1_id, idx1)
                    tile2 = self.read_from_small(img2_id, idx2)
                else:
                    raise ValueError
            else:  # ij != kl
                tile1 = self.read_from_large(img1_id, idx1)
                tile2 = self.read_from_large(img2_id, idx2)
#                 if idx3 == 0: # slowest
#                     tile1 = self.read_from_large(img1_id, i, j)
#                     tile2 = self.read_from_large(img2_id, k, l)
#                 elif idx3 == 1: # slow
#                     tile1 = self.read_from_large(img1_id, i, j)
#                     tile2 = self.read_from_small(img2_id, k, l)
#                 elif idx3 == 2: # slow
#                     tile1 = self.read_from_small(img1_id, i, j)
#                     tile2 = self.read_from_large(img2_id, k, l)
#                 elif idx3 == 3: # fast
#                     tile1 = self.read_from_small(img1_id, i, j)
#                     tile2 = self.read_from_small(img2_id, k, l)
#                 else:
#                     raise ValueError

        if is_dup == 0 and sdcic.tile_md5hash_grids[img1_id][idx1] == sdcic.tile_md5hash_grids[img2_id][idx2]:
            i, j = self.ij[idx1]
            k, l = self.ij[idx2]
            print(f'algo={idx3}; {img1_id} {idx1} -> ({i},{j}); {img2_id}, {idx2} -> ({k},{l}); correcting... {is_dup} -> 1')
            is_dup = 1
            
        tile1 = cv2.cvtColor(tile1, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.
        tile2 = cv2.cvtColor(tile2, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.
        
        X = np.dstack([tile1, tile2]) if np.random.random() < 0.5 else np.dstack([tile2, tile1])
        X = self.image_transform(X)
        y = np.array([is_dup], dtype=np.float32)
        return X, y

In [None]:
sdcic = SDCImageContainer(train_image_dir)
sdcic.preprocess_image_properties(
    image_md5hash_grids_file,
    image_bm0hash_grids_file,
    image_cm0hash_grids_file,
    image_greycop_grids_file,
    image_entropy_grids_file,
    image_issolid_grids_file)

In [None]:
# CUDA for PyTorch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark=True

idx_chan_map = {0: 'H', 1: 'L', 2: 'S'}

ImgAugs = namedtuple('ImgAugs', 'idx3 chan gain')

class Dataset(data.Dataset):
    """Characterizes a dataset for PyTorch"""
    def __init__(self, img_overlaps, train_or_valid, image_transform,
                 in_shape=(6, 256, 256), 
                 out_shape=(1,)):
        
        """Initialization"""
        self.img_overlaps = img_overlaps
        # TODO: handle case if train_or_valid == 'test'
        self.valid = train_or_valid == 'valid'
        self.image_transform = image_transform
        self.ij = ((0, 0), (0, 1), (0, 2),
                   (1, 0), (1, 1), (1, 2),
                   (2, 0), (2, 1), (2, 2))
        
        self.in_shape = in_shape
        self.out_shape = out_shape
        self.hls_limits = {'H': 5, 'L': 10, 'S': 10}
        if self.valid:
            self.img_augs = [self.get_random_augmentation() for _ in self.img_overlaps]

    def __len__(self):
        """Denotes the total number of samples"""
        return len(self.img_overlaps)

    def __getitem__(self, index):
        """Generates one sample of data"""
        if self.valid:
            img_aug = self.img_augs[index]
        else:
            img_aug = self.get_random_augmentation()
        X, y = self.get_data_pair(self.img_overlaps[index], img_aug)
        return X, y
    
    def get_random_augmentation(self):

#         p = [0.4, 0.1, 0.1, 0.4]
#         p = [0.1, 0.05, 0.05, 0.8]
        p = [0.3, 0.2, 0.2, 0.3]
        idx3 = np.random.choice(4, p=p)
        
        hls_idx = np.random.choice(3)
        hls_chan = idx_chan_map[hls_idx]
        hls_gain = np.random.choice(self.hls_limits[hls_chan]) + 1
        hls_gain = hls_gain if np.random.random() > 0.5 else -1 * hls_gain
        
        return ImgAugs(idx3, hls_chan, hls_gain)
    
    def color_shift(self, img, chan, gain):
        hls = to_hls(img)
        hls_shifted = channel_shift(hls, chan, gain)
        return to_bgr(hls_shifted)
    
    def get_tile(self, img, idx, sz=256):
        i, j = self.ij[idx]
        return img[i * sz:(i + 1) * sz, j * sz:(j + 1) * sz, :]
    
    def read_from_large(self, img_id, idx):
#         if not os.path.exists(img_id):
#             print(img_id, idx)
        img = cv2.imread(img_id)
        return self.get_tile(img, idx)
    
    def read_from_small(self, img_id, idx):
        dup_truth_path, img_filename = img_id.rsplit("/images_768/")
        row, col = parse('r{:3d}_c{:3d}.jpg', img_filename)
        i, j = self.ij[idx]
        tile_id = os.path.join(dup_truth_path, "images_256", f"r{row + i:03d}_c{col + j:03d}.jpg")
#         if not os.path.exists(tile_id):
#             print(img_id, idx)
#             print(tile_id)
        return cv2.imread(tile_id)
    
    def get_data_pair(self, img_overlap, img_aug):

        img1_id, img2_id, idx1, idx2, is_dup = img_overlap
        idx3, chan, gain = img_aug
        same_image = img1_id == img2_id
        
        if same_image:
            if is_dup:  # idx1 == idx2
                if idx3 == 0:
                    tile1 = self.read_from_large(img1_id, idx1)
                    tile2 = self.color_shift(tile1, chan, gain)
                elif idx3 == 1:
                    tile1 = self.read_from_large(img1_id, idx1)
                    tile2 = self.read_from_small(img2_id, idx2)
                elif idx3 == 2:
                    tile1 = self.read_from_small(img1_id, idx1)
                    tile2 = self.read_from_large(img2_id, idx2)
                elif idx3 == 3:
                    tile1 = self.read_from_small(img1_id, idx1)
                    tile2 = self.color_shift(tile1, chan, gain)
                else:
                    raise ValueError
            else:  # idx1 != idx2
                # idx3 = 3
                if idx3 == 0: # fast
                    img = cv2.imread(img1_id)
                    tile1 = self.get_tile(img, idx1)
                    tile2 = self.get_tile(img, idx2)
                elif idx3 == 1: # slowest
                    tile1 = self.read_from_large(img1_id, idx1)
                    tile2 = self.read_from_small(img2_id, idx2)
                elif idx3 == 2: # slowest
                    tile1 = self.read_from_small(img1_id, idx1)
                    tile2 = self.read_from_large(img2_id, idx2)
                elif idx3 == 3: # fastest
                    tile1 = self.read_from_small(img1_id, idx1)
                    tile2 = self.read_from_small(img2_id, idx2)
                else:
                    raise ValueError
        else: # img1_id != img2_id
            if is_dup:  # idx1 == idx2
                if idx3 == 0: # slowest
                    tile1 = self.read_from_large(img1_id, idx1)
                    tile2 = self.read_from_large(img2_id, idx2)
                elif idx3 == 1: # slow
                    tile1 = self.read_from_large(img1_id, idx1)
                    tile2 = self.read_from_small(img2_id, idx2)
                elif idx3 == 2: # slow
                    tile1 = self.read_from_small(img1_id, idx1)
                    tile2 = self.read_from_large(img2_id, idx2)
                elif idx3 == 3: # fast
                    # These end up being the same tiles.
                    tile1 = self.read_from_small(img1_id, idx1)
                    tile2 = self.color_shift(tile1, chan, gain)
                else:
                    raise ValueError
            else:  # idx1 != idx2
                tile1 = self.read_from_large(img1_id, idx1)
                tile2 = self.read_from_large(img2_id, idx2)

#         i, j = self.ij[idx1]
#         k, l = self.ij[idx2]
#         print(f'same_image, is_dup, idx3: {same_image*1}, {is_dup}, {idx3}\n{img1_id} {idx1} -> ({i},{j})\n{img2_id} {idx2} -> ({k},{l})\n')
        tile1 = cv2.cvtColor(tile1, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.
        tile2 = cv2.cvtColor(tile2, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.
        
        X = np.dstack([tile1, tile2]) if np.random.random() < 0.5 else np.dstack([tile2, tile1])
        X = self.image_transform(X)
        y = np.array([is_dup], dtype=np.float32)
        return X, y

In [None]:
input_shape = (6, 256, 256)
conv_layers = (16, 32, 64, 128, 256)
fc_layers = (128,)
output_size = 1

model_basename = 'dup_model'
datetime_now = get_datetime_now()

# Parameters
split = 90
batch_size = 256
max_epochs = 30
num_workers = 18
learning_rate = 0.0001
best_loss = 9999.0
max_datapoints = 100*2**14

loader_params = {
    'train': {'batch_size': batch_size, 'shuffle': True, 'num_workers': num_workers},
    'valid': {'batch_size': batch_size, 'shuffle': False, 'num_workers': num_workers}}

# Datasets
full_dataset = create_dataset_from_tiles_and_truth(sdcic)
full_dataset = create_dataset_from_truth('output/datasets')
print(len(full_dataset))
np.random.shuffle(full_dataset)
trainval_dataset = full_dataset[:max_datapoints] if max_datapoints < len(full_dataset) else full_dataset
print(len(trainval_dataset))
n_train, n_valid = even_split(len(trainval_dataset), batch_size, split)
print(n_train, n_valid)
partition = {'train': trainval_dataset[:n_train], 'valid': trainval_dataset[-n_valid:]}

image_transforms = {
    'train': transforms.Compose([
#         transforms.ColorJitter(),
        RandomHorizontalFlip(),
        transforms.ToTensor(),
        RandomTransformC4(with_identity=False),
    ]),
    'valid': transforms.Compose([
        RandomHorizontalFlip(),
        transforms.ToTensor(),
        RandomTransformC4(with_identity=False),
    ]),
}

# image_datasets = {x: Dataset(partition[x], train_768_dir, train_256_dir, x, image_transforms[x]) for x in ['train', 'valid']}
image_datasets = {x: Dataset(partition[x], x, image_transforms[x]) for x in ['train', 'valid']}

# Generators
generators = {x: data.DataLoader(image_datasets[x], **loader_params[x]) for x in ['train', 'valid']}
print(len(generators['train']), len(generators['valid']))

In [None]:
model = DupCNN(input_shape, output_size, conv_layers, fc_layers)
model.cuda()
model.to(device)

loss, optimizer = create_loss_and_optimizer(model, learning_rate)
scheduler = StepLR(optimizer, step_size=10, gamma=0.05)

In [None]:
# Loop over epochs
for epoch in range(max_epochs):
    
    start_time = time.time()
    scheduler.step()
    
    # Training
    total_train_loss = 0
    total_train_acc = 0
    model.train()
    t = tnrange(len(generators['train']))
    train_iterator = iter(generators['train'])
    for i in t:
        t.set_description(f'Epoch {epoch + 1:>02d}')
        # Get next batch and push to GPU
        inputs, labels = train_iterator.next()
        inputs, labels = inputs.to(device), labels.to(device)

        #Set the parameter gradients to zero
        optimizer.zero_grad()
        #Forward pass, backward pass, optimize
        outputs = model(inputs)
        train_loss = loss(outputs, labels)
        train_loss.backward()
        optimizer.step()
        
        #Print statistics
        total_train_loss += train_loss.data.item()
        y_pred = outputs > 0.5
        y_pred = y_pred.type_as(torch.cuda.FloatTensor())
        equality = labels == y_pred
        total_train_acc += equality.type_as(torch.FloatTensor()).mean()
        
        loss_str = f'{total_train_loss/(i + 1):.6f}'
        acc_str = f'{total_train_acc/(i + 1):.5f}'
        t.set_postfix(loss=loss_str, acc=acc_str)
    
    # Validation
    total_val_loss = 0
    total_val_acc = 0
    t = tnrange(len(generators['valid']))
    valid_iterator = iter(generators['valid'])
    with torch.set_grad_enabled(False):
        for i in t:
            t.set_description(f'Validation')
            # Get next batch and push to GPU
            inputs, labels = valid_iterator.next()
            inputs, labels = inputs.to(device), labels.to(device)

            #Forward pass
            val_outputs = model(inputs)
            val_loss = loss(val_outputs, labels)
            total_val_loss += val_loss.data.item()

            y_pred = val_outputs > 0.5
            y_pred = y_pred.type_as(torch.cuda.FloatTensor())
            equality = labels == y_pred
            total_val_acc += equality.type_as(torch.FloatTensor()).mean()
        
            loss_str = f'{total_val_loss/(i + 1):.6f}'
            acc_str = f'{total_val_acc/(i + 1):.5f}'
            t.set_postfix(loss=loss_str, acc=acc_str)

        val_loss = total_val_loss/(i + 1)
        if val_loss < best_loss:
            save_checkpoint(os.path.join("out", f"{model_basename}.{datetime_now}.{epoch + 1:02d}.{val_loss:.6f}.pth"), model)
            save_checkpoint(os.path.join("out", f"{model_basename}.{datetime_now}.last.pth"), model)
            save_checkpoint(os.path.join("out", f"{model_basename}.last.pth"), model)
#             save_checkpoint(''.join(['out/checkpoint_', acc_str, '.pth']), model)
            best_loss = val_loss