In [1]:
!nvidia-smi

Mon Jan  4 20:31:24 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 410.48                 Driver Version: 410.48                    |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla P100-SXM2...  Off  | 00000000:05:00.0 Off |                    0 |
| N/A   30C    P0    42W / 300W |   8035MiB / 16280MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  Tesla P100-SXM2...  Off  | 00000000:06:00.0 Off |                    0 |
| N/A   29C    P0    57W / 300W |   6673MiB / 16280MiB |     26%      Default |
+-------------------------------+----------------------+----------------------+
|   2  Tesla P100-SXM2...  Off  | 00000000:84:00.0 Off |                    0 |
| N/A   

In [2]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [3]:
# Model parts

In [4]:

""" Parts of the U-Net model """

import torch
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None, batch_norm=True):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        if batch_norm:
            self.double_conv = nn.Sequential(
                nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(mid_channels),
                nn.ReLU(inplace=True),
                nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True),
            )
        else:
            self.double_conv = nn.Sequential(
                nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
            )

    def forward(self, x):
        return self.double_conv(x)


class Downscaler(nn.Module):
    """Double conv 3x3, then max pool 2x2 stride 2"""

    def __init__(self, in_channels, out_channels, batch_norm):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            DoubleConv(in_channels, out_channels, batch_norm=batch_norm),
            nn.MaxPool2d(2, stride=2)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Upscaler(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, batch_norm=True):
        super().__init__()
        self.upscale = nn.Sequential(
            DoubleConv(in_channels, in_channels, batch_norm=batch_norm),
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
#            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        return self.upscale(x)        


# Conditioning Branch
class OneOneConv(nn.Module):
    def __init__(self, in_channels, out_channels, batch_norm=True):
        super(OneOneConv, self).__init__()
        if batch_norm:
            self.conv = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            )
        else:
            self.conv = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1),
                nn.ReLU(inplace=True)
            )

    def forward(self, x):
        x = self.conv(x)
        return x

# class ConditioningConcat(nn.Module):
#     def __init__(self, tile_concat):
#         super(ConditioningConcat, self).__init__()
#         self.filtered_tile = OneOneConv(self.tile_concat)       # Here or in forward method?

#     def forward(self, x):
#         return torch.cat(x, self.filtered_tile, dim=-1)


class VGG16(nn.Module):
    def __init__(self, batch_norm=True, cfg='A', in_channels=3):
        super(VGG16, self).__init__()
        self.batch_norm = batch_norm

        self.cfgs = {
            'A': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M', 512, 512, 512, 'M'],     # 6 "evenly" distributed maxpools to reduce dims to 1x1x512 
            'B': [32, 32, 'M', 64, 64, 'M', 128, 'M', 256, 'M', 512, 'M', 'OutConv'], # From the paper's git
            'C': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M', 'M']     # 6 maxpools to reduce dims to 1x1x512
        }
        self.cfg = self.cfgs[cfg]

        self.layers = []
        for v in self.cfg:
            if v == 'M':
                self.layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            elif v == 'OutConv':
                outConv = nn.Conv2d(in_channels, 512, kernel_size=2, stride=1)
                if self.batch_norm:
                    self.layers += [outConv, nn.BatchNorm2d(512), nn.ReLU(inplace=True)]
                else:
                    self.layers += [outConv, nn.ReLU(inplace=True)]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
                if self.batch_norm:
                    self.layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
                else:
                    self.layers += [conv2d, nn.ReLU(inplace=True)]
                in_channels = v
        self.vgg16 = nn.Sequential(*self.layers)

    def forward(self, x):
        return self.vgg16(x)

In [5]:
# Model

In [6]:
import torch
from torch import nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
import numpy as np
import logging


class LogoDetection(nn.Module):
    def __init__(self,
                 n_channels: int = 3,
                 batch_norm=True,
                 vgg_cfg: str = 'A'):
        super(LogoDetection, self).__init__()
        self.n_channels = n_channels

        # Encoder steps
        self.input_layer = Downscaler(self.n_channels, 64, batch_norm)
        self.down_layer1 = Downscaler(64, 128, batch_norm)
        self.down_layer2 = Downscaler(128, 256, batch_norm)
        self.down_layer3 = Downscaler(256, 512, batch_norm)
        self.down_layer4 = Downscaler(512, 512, batch_norm)  # 1/(2^5)*(width x height) x 512

        # Conditioning Module
        self.latent_repr = VGG16(batch_norm, vgg_cfg)
        self.one_conv1 = OneOneConv(576, 64, batch_norm)  # 64+512
        self.one_conv2 = OneOneConv(640, 128, batch_norm)  # 128+512
        self.one_conv3 = OneOneConv(768, 256, batch_norm)  # 256+512
        self.one_conv4 = OneOneConv(1024, 512, batch_norm)  # 512+512

        # Decoder steps
        self.up1 = Upscaler(1024, 512, batch_norm)  # 512+512
        self.up2 = Upscaler(1024, 256, batch_norm)  # 512*2
        self.up3 = Upscaler(512, 128, batch_norm)  # 256*2
        self.up4 = Upscaler(256, 64, batch_norm)  # 128*2
        self.output_layer = Upscaler(128, 1, batch_norm)  # 64*2
        # self.output_layer = OutSoftmax()

        # with torch.no_grad():
        #     self.input_layer.weight = torch.nn.Parameter()

    def forward(self, query, target):
        # query = samples[:, 0]
        # target = samples[:, 1]
#         logging.info(f"query: {query}")
#        print(f"query: {query}")
#         logging.info(f"target: {target}")
#        print(f"target: {target}")
        z = self.latent_repr(query)
#         logging.info(f"z: {z}")
#         print(f"z: {z}")
        # print(z.shape)

        # Encoder + Conditioning
        x = self.input_layer(target)
#         logging.info(f"input_layer: {x}")

        tile = z.expand(z.shape[0], z.shape[1], 128, 128)
        # print(tile.shape)
#         logging.info(f"tile1: {tile}")
        x1 = torch.cat((x, tile), dim=1)
#         logging.info(f"x1: {x1}")
        x = self.down_layer1(x)
#         logging.info(f"down1: {x}")

        tile = z.expand(z.shape[0], z.shape[1], 64, 64)
        x2 = torch.cat((x, tile), dim=1)
#         logging.info(f"x2: {x2}")
        x = self.down_layer2(x)
#         logging.info(f"down2: {x}")

        tile = z.expand(z.shape[0], z.shape[1], 32, 32)
        x3 = torch.cat((x, tile), dim=1)
#         logging.info(f"x3: {x3}")
        x = self.down_layer3(x)
#         logging.info(f"down3: {x}")

        tile = z.expand(z.shape[0], z.shape[1], 16, 16)
        x4 = torch.cat((x, tile), dim=1)
#         logging.info(f"x4: {x4}")
        x = self.down_layer4(x)
#         logging.info(f"down4: {x}")

        tile = z.expand(z.shape[0], z.shape[1], 8, 8)
        x5 = torch.cat((x, tile), dim=1)
#         logging.info(f"x5: {x5}")
        # print(x.shape)

        # Decoder + Conditioning
#        x = torch.cat((x, x5), dim=1)
#        logging.info(f"cond5: {x}")
        # print(x.shape)
        x = self.up1(x5)
#         logging.info(f"up1: {x}")
        # print(x.shape)

        x4 = self.one_conv4(x4)
#         logging.info(f"cnv4: {x4}")
        x = torch.cat((x, x4), dim=1)
#         logging.info(f"cond4: {x}")
        # del x4
        x = self.up2(x)
#         logging.info(f"up2: {x}")

        x3 = self.one_conv3(x3)
#         logging.info(f"cnv3: {x3}")
        x = torch.cat((x, x3), dim=1)
#         logging.info(f"cond3: {x}")
        # del x3
        x = self.up3(x)
#         logging.info(f"up3: {x}")

        x2 = self.one_conv2(x2)
#         logging.info(f"cnv2: {x2}")
        x = torch.cat((x, x2), dim=1)
#         logging.info(f"cond2: {x}")
        # del x2
        x = self.up4(x)
#         logging.info(f"up4: {x}")

        x1 = self.one_conv1(x1)
#         logging.info(f"cnv1: {x1}")
        x = torch.cat((x, x1), dim=1)
#         logging.info(f"cond1: {x}")
        # del x1
#        output = self.up5(x)
#         logging.info(f"up5: {x}")

        output = self.output_layer(x)
        logging.info(f"output: {output}")
        return output

    def predict_mask(self, query, target):
        y = self.forward(query, target)
        output_layer = nn.Sequential(
            nn.batchNorm2d(1),
            nn.ReLU(),
            nn.Sigmoid()
            )
        return output_layer(y)


In [7]:
# eval

In [8]:
import functools, operator, collections
import torch
import torch.nn.functional as F
import numpy as np

from sklearn.metrics import jaccard_score as jsc
from sklearn.metrics import average_precision_score as avg_pr
from sklearn.cluster import DBSCAN
from sklearn import metrics

import matplotlib.pyplot as plt

from tqdm import tqdm

from skimage.measure import label, regionprops


def eval_net(model,
         loader, 
         device, 
         bbox: bool, 
         verbose: bool,
         iou_thr: int = 0.5
         ):
#     logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s: %(message)s", filename='oneshot_eval.log')
    model.eval()
    logging.info("Validating")

    # Number of batches
    n_val = len(loader)

#     logging.info(f"n_val: {n_val}")
    
    matches = 0

    # Number of bboxes
    max_matches = 0

    if bbox:
        truth_type = "bbox"
    else:
        truth_type = "mask"
    
#     logging.info(f"type of ground truth: {truth_type}")

    precisions, recalls, accuracies = [], [], []

    batch_results = []

    with tqdm(total=n_val, desc='Validation round', unit='samples', disable=not verbose) as bar:
        for batch in loader:
#             logging.info(f"New batch!")
            queries, targets, truth = batch['query'], batch['target'], batch[truth_type]
            queries = queries.to(device=device, dtype=torch.float32)
            targets = targets.to(device=device, dtype=torch.float32)
            # truth = truth.to(device=device, dtype=torch.float32)

            with torch.no_grad():
                pred = model(queries, targets)
#                 logging.info(f"Masks predicted")
#                 print(f"pred: {pred}")
                pred_masks = pred.cpu().numpy()
#                 print(f"pred_masks: {pred_masks}")
#                 logging.info(f"Masks to CPU")
#                 print(f"pred_masks_shape_0: {pred_masks.shape[0]}")
#                 print(f"pred_masks: {pred_masks}")
                # assunzione: gli indici della true e pred masks sono gli stessi
                for mask_index in range(pred_masks.shape[0]):
                    pred_mask = np.asarray(pred_masks[mask_index])
#                     print(f"pred_mask: {pred_mask}")
                    pred_mask = masks_as_image([rle_encode(pred_mask)])
#                     print(f"pred_mask.mask_as_image: {pred_mask}")
                    
                    # mask labeling and conversion to bboxes
                    pred_labels = label(pred_mask)
#                     print(f"pred_labels: {pred_labels}")
                    pred_bboxes_coords = list(map(lambda x: x.bbox, regionprops(pred_labels)))
                    pred_bboxes = calc_bboxes_from_coords(pred_bboxes_coords)
#                     print("pred_bboxes: ", pred_bboxes)

                    # computes truth bboxes in the same way as the pred
                    if bbox:
                        truth_bboxes = np.array(truth[mask_index])
                    else:
                        truth_mask = np.asarray(truth[mask_index])
                        truth_mask = masks_as_image([rle_encode(truth_mask)])
                        true_mask_labels = label(truth_mask)
                        truth_bboxes_coords = list(map(lambda x: x.bbox, regionprops(true_mask_labels)))
                        truth_bboxes = calc_bboxes_from_coords(truth_bboxes_coords)
#                         print("truth_bboxes: ", truth_bboxes)

                    max_matches += len(truth_bboxes)
    
                    logging.info(f"pred_bboxes: {pred_bboxes}")
                    logging.info(f"truth_bboxes: {truth_bboxes}")
                    
                    b_result = get_pred_results(truth_bboxes, pred_bboxes, iou_thr)
                    logging.info(f"b_result: {b_result}")
#                     print(f"b_result: {b_result}")
                    batch_results.append(b_result)
                bar.update(queries.shape[0])
                
#               logging.info(f"Batch finished. batch_results: {batch_results}")

#     print(f"Validation from eval completed")
    # Should not be here, since the eval method is used in both validation and test -> TODO: better handling of the flag.
    model.train()
    
#     print(f"Batch results: {batch_results}")
    
    result = dict(functools.reduce(operator.add, map(collections.Counter, batch_results)))
    logging.info(f"result: {result}")
#     print("result: ", str(result))
    # TODO: KeyError è troppo generico
    try:
        true_pos = result['true_pos']
    except KeyError:
        true_pos = 0
    try:
        false_pos = result['false_pos']
    except KeyError:
        false_pos = 0
    try:
        false_neg = result['false_neg']
    except KeyError:
        false_neg = 0

    precision = calc_precision(true_pos, false_pos)
    recall = calc_recall(true_pos, false_neg)
    accuracy = calc_accuracy(true_pos, false_pos, false_neg)
    output = f"Precision: {precision}    Recall: {recall}    Accuracy: {accuracy}"
    print(output)
    logging.info(output)

    return accuracy


def calc_bboxes_from_coords(bboxes_coords):
    """Calculate all bounding boxes from a set of bounding boxes coordinates"""
    bboxes = []
#     print(f"bboxes_coords: {bboxes_coords}")
    for coord_idx in range(len(bboxes_coords)):
#         print(f"coord_idx: {coord_idx}")
        coord = bboxes_coords[coord_idx]
#         print(f"coord: {coord}")
        bbox = (coord[1], coord[0], int(coord[4])-int(coord[1]), int(coord[3])-int(coord[0]))
        bboxes.append(bbox)
    return bboxes


def get_pred_results(truth_bboxes, pred_bboxes, iou_thr = 0.5):
    """Calculates true_pos, false_pos and false_neg from the input bounding boxes. """
    n_pred_idxs = range(len(pred_bboxes))
    n_truth_idxs = range(len(truth_bboxes))
    if len(n_pred_idxs) == 0:
#         print(f"n_pred_idxs = 0")
        tp = 0
        fp = 0
        fn = len(truth_bboxes)
        return {'true_pos': tp, 'false_pos': fp, 'false_neg': fn}
    if len(n_truth_idxs) == 0:
#         print(f"n_truth_idxs = 0")
        tp = 0
        fp = len(pred_bboxes)
        fn = 0
        return {'true_pos': tp, 'false_pos': fp, 'false_neg': fn}

    truth_idx_thr = []
    pred_idx_thr = []
    ious = []
    for pred_idx, pred_bbox in enumerate(pred_bboxes):
        for truth_idx, truth_bbox in enumerate(truth_bboxes):
            iou = get_jaccard(pred_bbox, truth_bbox)
            if iou > iou_thr:
                truth_idx_thr.append(truth_idx)
                pred_idx_thr.append(pred_idx)
                ious.append(iou)
#     print(f"ious: {ious}")
    # ::-1 reverses the list
    ious_desc = np.argsort(ious)[::-1]
    if len(ious_desc) == 0:
        # No matches
        tp = 0
        fp = len(pred_bboxes)
        fn = len(truth_bboxes)
    else:
        truth_match_idxes = []
        pred_match_idxes = []
        for idx in ious_desc:
            truth_idx = truth_idx_thr[idx]
            pred_idx = pred_idx_thr[idx]
            # If the bboxes are unmatched, add them to matches
            if (truth_idx not in truth_match_idxes) and (pred_idx not in pred_match_idxes):
                truth_match_idxes.append(truth_idx)
                pred_match_idxes.append(pred_match_idxes)
        tp = len(truth_match_idxes)
        fp = len(pred_bboxes) - len(pred_match_idxes)
        fn = len(truth_bboxes) - len(truth_match_idxes)
    return {'true_pos': tp, 'false_pos': fp, 'false_neg': fn}


def calc_precision(true_pos, false_neg):
    try:
        precision = true_pos / (true_pos + false_neg)
    except ZeroDivisionError:
        precision = 0.0
    return precision


def calc_recall(true_pos, false_pos):
    try:
        recall = true_pos / (true_pos + false_pos)
    except ZeroDivisionError:
        recall = 0.0
    return recall


def calc_accuracy(true_pos, false_pos, false_neg):
    try:
        accuracy = true_pos / (true_pos + false_pos + false_neg)
    except ZeroDivisionError:
        accuracy = 0.0
    return accuracy 


def calc_mavg_precision(precision_array):
    return 


def get_jaccard(pred_bbox, truth_bbox):
    pred_mask = get_mask_from_bbox(pred_bbox)
    truth_mask = get_mask_from_bbox(truth_bbox)
    return get_jaccard_from_mask(pred_mask, truth_mask)


def get_jaccard_from_mask(pred_mask, truth):
#     print(f"jsc truth: {truth}")
#     print(f"jsc pred_mask: {pred_mask}")
    return jsc(y_true=truth, y_pred=pred_mask, average='micro')


def get_mask_from_bbox(bbox):
    mask = np.zeros((256, 256))
    x = bbox[0]
    y = bbox[1]
    for width in range(bbox[2]):
        for height in range(bbox[3]):
            mask[int(x) + int(width), int(y) + int(height)] = 1
    return mask


def get_bbox_batch(img):
    bbox = np.empty((img.shape[0], 4))
    for i in range(img.shape[0]):
        bbox[i] = get_bbox(img[i])
    return bbox


def get_bbox(img):
    # img.shape = [batch_size, 1, 256, 256]
    rows = np.any(img, axis=-1)
    cols = np.any(img, axis=-2)
    rmin, rmax = np.where(rows)[0][[0, -1]]
    cmin, cmax = np.where(cols)[0][[0, -1]]
    # X, Y, Width, Height
    return [cmin, rmin, cmax-cmin, rmax-rmin]


def rle_encode(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    logging.info(f"pred_img: {img}")
    pixels_old = img.T.flatten()
    pixels = img.T.flatten()
    for x in range(len(pixels_old)):
        if pixels_old[x] > 0.5:
            pixels[x] = 1
        else:
            pixels[x] = 0
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)


def rle_decode(mask_rle, shape=(256, 256)):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background
    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape).T  # Needed to align to RLE direction


def masks_as_image(in_mask_list, all_masks=None):
    """
    Take the complete rle_encoded mask and create a mask array of the single masks
    """
    if all_masks is None:
        all_masks = np.zeros((256, 256), dtype=np.int16)
    # if isinstance(in_mask_list, list):
    for mask in in_mask_list:
        if isinstance(mask, str):
            all_masks += rle_decode(mask)
    return np.expand_dims(all_masks, -1)

  data = yaml.load(f.read()) or {}


In [9]:
# dataset_loader

In [10]:
import os
import numpy as np
from torch.utils.data import Dataset
import logging
from PIL import Image

import torch

import gzip
import shutil
import h5py
import tables


# TODO: Deve preprocessare anche le immagini di test
# TODO: Ha da funzionà co TorchVision, se hai tempo
# TODO: Visto che le maschere ci servono solo per estrapolare le bbox e confrontarle con quelle stimate, ha senso portarsi tutta la maschera e non le singole bbox?
class BasicDataset(Dataset):
    TARGET_IMAGE_PATH = "target_image_path"
    MASK_IMAGE_PATH = "mask_image_path"
    BBOX_PATH = "bbox_path"
    TARGET_IMAGE_BBOX_PATH = "target_image_bbox_path"

    # TODO: Check if the values are empty
    def __init__(self, imgs_dir: str, masks_dir: str, dataset_name: str, mask_image_dim: int = 256, query_dim: int = 64,
                 bbox_suffix: str = '.bboxes.txt', save_to_disk: bool = False, skip_bbox_lines: int = 0):
        self.imgs_dir = fix_input_dir(imgs_dir)
        self.masks_dir = fix_input_dir(masks_dir)
        self.processed_img_dir = str(self.imgs_dir[:self.imgs_dir.rindex(os.path.sep) + 1]) + "preprocessed"
        self.mask_img_dim = mask_image_dim
        self.query_dim = query_dim
        self.bbox_suffix = bbox_suffix
        self.save_to_disk = save_to_disk
        self.skip_bbox_lines = skip_bbox_lines
        assert mask_image_dim > 1, 'The dimension of mask and image must be higher than 1'
        assert query_dim > 1, 'The dimension of query image must be higher than 1'

        assert os.path.isdir(imgs_dir), f"Bad path for images directory: {imgs_dir}"

        assert os.path.isdir(masks_dir), f"Bad path for masks directory: {masks_dir}"

        if save_to_disk:
            # create processed image's directory, if not exists yet
            try:
                os.mkdir(self.processed_img_dir)
            except FileExistsError:
                # some previous instance generate this directory, no need to raise an exception
                pass

        # list of dict with the path of the images, which contains the paths for the following images:
        #       target, mask, bbox, target's bbox
        # every dict is defined by 4 str keys which have a str value
        # key = type of image
        # value = path of the image

        # List of dict. Every dict refers to an image with 4 keys:
        #       target, mask, bbox, query_target_bbox
        self.images_path = []

        # TODO: Fai in modo che preprocess calcoli sia la maschera che il bbox e poi, in base al dataset, togline uno
        if "FlickrLogos" in dataset_name:
            self.flickrlogos32_load()
        elif dataset_name == "TopLogos-10":
            self.toplogos10_load()

    def __len__(self):
        return len(self.images_path)

    def toplogos10_load(self):

        # get bbox path
        bbox_path = None
        for bbox_paths, _, bbox_list in os.walk(self.masks_dir):
            for bbox_file in bbox_list:
                if self.bbox_suffix in bbox_file:
                    bbox_path = get_class_file_path(bbox_paths, bbox_file)
                    break
            if bbox_path:
                break

        # get query image path
        query_full_image_path = f"{bbox_path[:bbox_path.index(self.bbox_suffix)]}"

        # get target images path and fill "self.images_path"
        for target_images_paths, _, target_images_list in os.walk(self.imgs_dir):
            for target_image_name in target_images_list:
                target_image_root_path, target_image_extension = os.path.splitext(
                    os.path.join(target_images_paths, target_image_name))
                if target_image_extension == ".jpg":
                    self.images_path.append(
                        {self.TARGET_IMAGE_PATH: get_class_file_path(target_images_paths, target_image_name),
                         self.MASK_IMAGE_PATH: None,
                         self.BBOX_PATH: bbox_path,
                         self.TARGET_IMAGE_BBOX_PATH: query_full_image_path})

    def flickrlogos32_load(self):

        # dict with merged masks path
        # key = target image file name
        # value = merged mask's file path
        masks_dict = {}

        # put stuff into masks_dict
        for masks_paths, _, masks_files in os.walk(self.masks_dir):
            for mask_file_name in masks_files:
                _, mask_extension = os.path.splitext(os.path.join(masks_paths, mask_file_name))
                if mask_extension == ".png" and "merged" in mask_file_name:
                    masks_dict[mask_file_name[:mask_file_name.rindex(".mask")]] = get_class_file_path(masks_paths,
                                                                                                      mask_file_name)

        # dict with every image of every class
        # key = class name
        # value = dict with images type and path
        #       key = type of image (target, query, mask)
        #       value = path of the file
        image_path_element = {}

        # put stuff into image_path_element
        for target_images_paths, _, target_images_files in os.walk(self.imgs_dir):
            # TODO: Compare come classe la cartella padre "jpg", trova un modo per risolvere
            target_image_class = target_images_paths[target_images_paths.rindex(os.path.sep) + 1:]
            for target_image_name in target_images_files:
                target_image_root_path, target_image_extension = os.path.splitext(
                    os.path.join(target_images_paths, target_image_name))
                if target_image_extension == ".jpg" and "no-logo" not in target_image_root_path:
                    x = {self.TARGET_IMAGE_PATH: get_class_file_path(target_images_paths, target_image_name),
                         self.MASK_IMAGE_PATH: masks_dict[target_image_name],
                         self.BBOX_PATH: f'{masks_dict[target_image_name][:masks_dict[target_image_name].rindex(".mask")]}{self.bbox_suffix}'}
                    try:
                        image_path_element[target_image_class].append(x)
                    except KeyError:
                        image_path_element[target_image_class] = [x]

        # fill "images_path" variable. it generate every couple (target image, query image) for the same class
        # for now, it only skips couple (target, query) of the same image
        for target_image_class in image_path_element:
            items_class = image_path_element[target_image_class]
            for outer_image in items_class:
                bbox_path = outer_image[self.BBOX_PATH]
                outer_target_image_path = outer_image[self.TARGET_IMAGE_PATH]
                for inner_image in items_class:
                    if not outer_image == inner_image:
                        target_image_path = inner_image[self.TARGET_IMAGE_PATH]
                        mask_image_path = inner_image[self.MASK_IMAGE_PATH]
                        self.images_path.append({self.TARGET_IMAGE_PATH: target_image_path,
                                                 self.MASK_IMAGE_PATH: mask_image_path,
                                                 self.BBOX_PATH: bbox_path,
                                                 self.TARGET_IMAGE_BBOX_PATH: outer_target_image_path})
        print(f"You have {len(self.images_path)} triplets")

    # preprocess the images. then save in file and return a list triplet [query image, target image, mask image]. how?
    # stretch the target image
    # stretch, crop and stretch again the query image
    # stretch the mask image
    # TODO: Check if the values are empty
    @classmethod
    def preprocess(cls, target_img_path: str, bbox_path: str, query_full_img_path: str, skip_bbox_lines: int = 0,
                   img_dim: int = 256, query_img_dim: int = 64, mask_img_path: str = None) -> dict:

        # Target image

        pil_target_img = Image.open(target_img_path)
        # stretch the image
        pil_resized_target_img = pil_target_img.resize((img_dim, img_dim))

        # Query image

        pil_target_img_bbox = Image.open(query_full_img_path)
        pil_resized_target_img_bbox = pil_target_img_bbox.resize((img_dim, img_dim))

        # we will resize, crop and resize again the image but we have the coordinates of the non resized bounding box
        target_img_width, target_img_height = pil_target_img_bbox.size
        resized_target_img_width, resized_target_img_height = pil_resized_target_img_bbox.size
        percent_width = round(100 * int(resized_target_img_width) / (int(target_img_width)), 2) / 100
        percent_height = round(100 * int(resized_target_img_height) / (int(target_img_height)), 2) / 100

        # open the bounding box file
        with open(bbox_path) as bbox_file:
            # read only the first line of the bbox file
            bbox_lines = bbox_file.readlines()
            first_line_bbox = bbox_lines[1 - skip_bbox_lines].split(' ')
            # check if we correctly skipped the first line of the file, the one with no number,
            # and if all the elements are numeric, like every coordinate should be ;)
            if first_line_bbox[0].isnumeric() and first_line_bbox[1].isnumeric() and \
                    first_line_bbox[2].isnumeric() and first_line_bbox[3].rstrip().isnumeric():
                x, y, width, height = first_line_bbox
                # adapt the old coordinates to the new stretched dimension
                left = int(x.strip()) * percent_width
                upper = int(y.strip()) * percent_height
                right = int(int(x.strip()) + int(width)) * percent_width
                lower = int(int(y.strip()) + int(height)) * percent_height
                # crop and resize the query image
                pil_query_img = pil_resized_target_img_bbox.crop((left, upper, right, lower))
                pil_resized_query_img = pil_query_img.resize((query_img_dim, query_img_dim))
            else:
                # TODO: nel traceback compare "error_string" e poi successivamente spiega l'eccezione. Trova un modo per togliere quel "error_string"
                error_string = f"Bounding box file's first line should have 4 groups of integers with whitespace " \
                               f"separator. Check {bbox_path}"
                raise Exception(error_string)

        # Mask

        if mask_img_path:
            pil_mask = Image.open(mask_img_path)
            pil_resized_mask = pil_mask.resize((img_dim, img_dim))
        else:
            pil_resized_mask = None

        # get the size of the images
        # print(f"query image dim: {pil_resized_query_img.size}")
        # print(f"target image dim: {pil_resized_target_img.size}")
        # if pil_resized_mask:
        #     print(f"mask image dim: {pil_resized_mask.size}")

        # just to test if everything works. don't look at these :)
        # if pil_resized_target_img:
        #     pil_resized_target_img.save('target.jpg')
        # if pil_resized_query_img:
        #     pil_resized_query_img.save('query.jpg')
        # if pil_resized_mask:
        #     pil_resized_mask.save('mask.jpg')

        # return the triplet (Dq, Dt, Dm) where Dq is the query image, Dt is the target image and Dm is the mask image
        return create_triplet_with_torch_representation(pil_resized_query_img,
                                                        pil_resized_target_img,
                                                        pil_resized_mask)

    # def h5py_with_pytorch(self, pil_img, index, type):
    #     x = self.h5py_compression(to_pytorch(pil_img), index, type)
    #     return x
    #
    # def h5py_without_pytorch(self, pil_img, index, type):
    #     x = self.h5py_compression(pil_img, index, type)
    #     return x

    # def store_hdf5_file_with_compression(self, image, image_index, image_type):
    #     file_name = f'{self.processed_img_dir}{os.path.sep}{image_index}_{image_type}.hdf5'
    #     f = h5py.File(file_name, "w")
    #     # TODO: Esistono altri algoritmi di compressione come Mafisc. Una roba figa che puoi usare è Bitshuffle
    #     f.create_dataset("init", compression="gzip", compression_opts=9, data=image)
    #     f.close()
    #     return file_name

    def store_hdf5_file_with_compression(self, images, image_index):
        image_type = ["query", "target", "mask"]
        image_type_index = 0
        for image in images:
            file_name = f'{self.processed_img_dir}{os.path.sep}{image_index}_{image}.hdf5'
            f = h5py.File(file_name, "w")
            # TODO: Esistono altri algoritmi di compressione come Mafisc. Una roba figa che puoi usare è Bitshuffle
            f.create_dataset("init", compression="gzip", compression_opts=9, data=images[image])
            f.close()
            image_type_index += 1
        # for image in images:
        #     file_name = f'{self.processed_img_dir}{os.path.sep}{image_index}_{image_type[image_type_index]}.hdf5'
        #     f = h5py.File(file_name, "w")
        #     # TODO: Esistono altri algoritmi di compressione come Mafisc. Una roba figa che puoi usare è Bitshuffle
        #     f.create_dataset("init", compression="gzip", compression_opts=9, data=image)
        #     f.close()
        #     image_type_index += 1
        # return file_name

    # def gzip_compress(self, index, input_file):
    #     # input_file = f'{self.processed_img_dir}{os.path.sep}{index}.npz'
    #     with open(input_file, 'rb') as f_in:
    #         output_file = f'{input_file}.gz'
    #         with gzip.open(output_file, 'wb', compresslevel=9) as f_out:
    #             shutil.copyfileobj(f_in, f_out)
    #     # if os.path.exists(input_file):
    #     #     os.remove(input_file)
    #     return output_file

    # def gzip_compress(self, index, input_file):
    #     # input_file = f'{self.processed_img_dir}{os.path.sep}{index}.npz'
    #     output_file = f'{input_file}.gz'
    #     with gzip.open(output_file, 'wb', compresslevel=1) as f_out:
    #         with open(input_file, 'rb') as f_in:
    #             shutil.copyfileobj(f_in, f_out)
    #     # if os.path.exists(input_file):
    #     #     os.remove(input_file)
    #     return output_file

    # def gzip_uncompress(self, input_file):
    #     with gzip.open(input_file, 'rb') as f:
    #         file_content = f.read()
    #     output_file = input_file[:input_file.rindex('.')]
    #     with open(output_file, mode='wb') as fp:
    #         fp.write(file_content)
    #     return output_file

    def read_hdf5_file(self, hdf5_file):
        with h5py.File(hdf5_file, 'r') as hf:
            data = hf.get('init')
            data = np.array(data)
        return data

    # def np_save_compressed(self, index, triplet_list_in_torch_representation):
    #     file_name = f'{self.processed_img_dir}{os.path.sep}{index}'
    #     # save the file so the next time you don't have to preprocess again
    #     np.savez_compressed(file_name,
    #                         query=triplet_list_in_torch_representation[0],
    #                         target=triplet_list_in_torch_representation[1],
    #                         mask=triplet_list_in_torch_representation[2])
    #     return f'{file_name}.npz'

    def __getitem__(self, item_index):
#         print(f"Getting item {item_index}")
        # get the path of the preprocessed file, if exists
        mask_file_path = f'{self.processed_img_dir}{os.path.sep}{item_index}_mask.hdf5'
        query_file_path = f'{self.processed_img_dir}{os.path.sep}{item_index}_query.hdf5'
        target_file_path = f'{self.processed_img_dir}{os.path.sep}{item_index}_target.hdf5'

        correct_order_triplet = [query_file_path, target_file_path, mask_file_path]
        triplet_element_order = ["query", "target", "mask"]

        return_dict = {}
        for file in correct_order_triplet:
            if not os.path.exists(file):
                # triplet = self.preprocess(item_index, self.images_path[item_index])
                return_dict = self.preprocess(
                    target_img_path=get_full_path(self.imgs_dir, self.images_path[item_index][self.TARGET_IMAGE_PATH]),
                    bbox_path=get_full_path(self.masks_dir, self.images_path[item_index][self.BBOX_PATH]),
                    query_full_img_path=get_full_path(self.imgs_dir,
                                                      self.images_path[item_index][self.TARGET_IMAGE_BBOX_PATH]),
                    mask_img_path=get_full_path(self.masks_dir, self.images_path[item_index][self.MASK_IMAGE_PATH]),
                    skip_bbox_lines=self.skip_bbox_lines)
                if self.save_to_disk:
                    self.store_hdf5_file_with_compression(return_dict, item_index)
                break
        else:
            triplet_index = 0
            for file in correct_order_triplet:
                if os.path.exists(file):
                    hdf5_file = self.read_hdf5_file(file)
                    return_dict[triplet_element_order[triplet_index]] = to_pytorch(hdf5_file)
                triplet_index += 1
#         print(f"Ciao, sono un fantastico {return_dict}")
        return return_dict

        # for file in correct_order_triplet:
        # # check if preprocessed file exists. if not, he will generate it. then return the triplet
        #     if os.path.exists(file):
        #         data = np.load(file, mmap_mode='r')
        #         return_tuple = np.array([data['query'], data['target'], data['mask']])
        #     else:
        #         return_tuple = self.preprocess(i, self.images_path[i])
        # return return_tuple


# something that will be deleted
# class CarvanaBasicDataset(BasicDataset):
#     def __init__(self, imgs_dir, masks_dir, scale=1):
#         super().__init__(imgs_dir, masks_dir, scale, mask_suffix='_mask')


def to_pytorch(image):
    if image:
        image_np = np.array(image)
        # mask image has only one channel, we need to explicit it
        if len(image_np.shape) == 2:
            image_np = np.expand_dims(image_np, axis=2)
        # HWC to CHW for pytorch
        img_trans = image_np.transpose((2, 0, 1))
        if img_trans.max() > 1:
            img_trans = img_trans / 255
        return torch.from_numpy(img_trans).type(torch.FloatTensor)
        # return img_trans
    else:
        return None


# dude, the name says all. just read it :/
# def create_triplet_without_torch_representation(pil_query, pil_target, pil_mask):
#     # return [np.array(pil_query), np.array(pil_target), np.array(pil_mask)]
#     # return np.array([np.array(pil_query), np.array(pil_target), np.array(pil_mask)])
#     return {
#         "query": np.array(pil_query),
#         "target": np.array(pil_target),
#         "mask": np.array(pil_mask)
#     }

def create_triplet_with_torch_representation(pil_query, pil_target, pil_mask):
    # return [to_pytorch(pil_query), to_pytorch(pil_target), to_pytorch(pil_mask)]
    # return np.array([to_pytorch(pil_query), to_pytorch(pil_target), to_pytorch(pil_mask)])
    return {
        "query": to_pytorch(pil_query),
        "target": to_pytorch(pil_target),
        "mask": to_pytorch(pil_mask)
    }


def get_class_file_path(class_name, file_name):
    class_file = f"{class_name[class_name.rindex(os.path.sep):]}{os.path.sep}{file_name}"
    if os.path.sep not in class_file[0:2]:
        class_file = f"{os.path.sep}{class_file}"
    return class_file


def get_full_path(root, file):
    try:
        if root[root.rindex(os.path.sep) + 1:].strip() == file[1:file.index(os.path.sep, 1)].strip():
            path = f"{root[:root.rindex(os.path.sep)]}{file}"
        else:
            path = f"{root}{file}"
        return path
    except AttributeError:
        return None


def fix_input_dir(dir):
    if not dir.strip()[-1:] == os.path.sep:
        return dir.strip()
    return dir.strip()[:-1]


In [11]:
# train

In [12]:
import argparse
import logging
import os
import sys
from tqdm import tqdm
import yaml

import numpy as np
import torch
from torch import optim
from torch.utils.data import DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn
import torch.nn.functional as F

# from model.model import LogoDetection
# from utils.dataset_loader import BasicDataset

# todo: when we add more models, we should move these variables to another location
MODEL_HOME = os.path.abspath("./stored_models/")
ALL_MODEL_NAMES = ["LogoDetection"]
ALL_DATASET_NAMES = ["FlickrLogos-32", "FlickrLogos-32-test", "TopLogos-10"]

with open(os.path.abspath("./config/config.yaml")) as config:
    config_list = yaml.load(config, Loader=yaml.FullLoader)


# # Appl.load_aly Gaussian normalization to the model
# def weights_init(model):
#     if isinstance(model, nn.Module):
#         nn.init.normal_(model.weight.data, mean=0.0, std=0.01)

def train(model,
          device,
          train_loader,
          val_loader,
          max_epochs,
          optimizer,
          vgg_cfg,
          verbose,
          checkpoint_dir,
          model_path,
          save_cp,
          n_train,
          n_val,
          step_eval
          ):
    batch_size = train_loader.batch_size

    # Logging for TensorBoard
    writer = SummaryWriter(
        comment=f'LR__BS_{batch_size}_OPT_{type(optimizer).__name__}')  # does optimizer.lr work? we're gonne find out
    global_step = 0

    ### ERROR: n_train e n_val? ###
    logging.info(f'''Starting training:
        Epochs:             {max_epochs}
        Batch size:         {batch_size}
        Learning rate:      
        Training size:      {n_train}
        Validation size:    {n_val}
        Device:             {device.type}
    ''')

    #    def criterion(pred, true):
    #        return torch.div(nn.BCELoss()(pred, true), 1*1) # L = (1/(H*W)) * BCELoss
    #        return nn.BCELoss()(pred, true)
    # TypeError: unsupported operand type(s) for /: 'BCELoss' and 'int'
    """    
    # now we don't want to train every time to save some minutes. we'll skip the train if there is a model file
    if os.path.exists(checkpoint_dir) and os.path.isdir(checkpoint_dir):
        if os.listdir(checkpoint_dir):
            model.load_state_dict(torch.load(checkpoint_dir + os.path.sep + "model.pt", map_location=device))
            return eval_net(model, val_loader, device, bbox=False, verbose=True)
    """
    criterion = nn.BCEWithLogitsLoss()
    
    last_epoch_val_score = 0
    for epoch in range(max_epochs):
        logging.info(f"Epoch number {epoch}")
        model.train()  # set the model in training flag to True
        epoch_loss = 0  # resets the loss for the current epoch
        # epoch(batch_size, train_samples)

        # TODO
        with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{max_epochs}', unit='img', disable=not verbose) as bar:
            bar.set_description(f'train loss')

            for batch in train_loader:
#                 np.set_printoptions(threshold=sys.maxsize)
#                 print(f"Batch: {batch}")
#                 logging.info(f"Batch number #")
                queries = batch['query']  # Correct dimensions?
                targets = batch['target']
                true_masks = batch['mask']

                queries = queries.to(device=device, dtype=torch.float32)
                targets = targets.to(device=device, dtype=torch.float32)
                true_masks = true_masks.to(device=device, dtype=torch.float32)

#                 logging.info(f"Sending imgs to model")
#                 logging.info(f"queries: {queries}")
#                 logging.info(f"targets: {targets}")
                pred_masks = model(queries, targets)
#                 logging.info(f"model: {model}")
#                 logging.info(f"Mask predicted")
                # print(pred_masks.shape)
                loss = criterion(pred_masks, true_masks)
                epoch_loss += loss.detach().item()  # is the .detach() needed?

                # TensorBoard logging
                writer.add_scalar('Loss/train', loss.item(), global_step)

                bar.set_postfix(loss=f'{loss.item():.5f}')
                logging.info(f"Loss: {loss.item()}")

                optimizer.zero_grad()
                loss.backward()
                # nn.utils.clip_grad_value_(net.parameters(), 0.1) Gradient Clipping
                optimizer.step()

                bar.update(queries.shape[0])
                global_step += 1

                # if n_train % batch_size == 0: 
                #     n_batch = n_train // batch_size
                # else:
                #     n_batch = n_train // batch_size + 1

                ### DOMANDA: Dove lo volevamo usare? ###
                ### ALTRA DOMANDA: Non conviene farlo fuori dai cicli? ###
#                 n_batch = len(train_loader)
                # Deve farlo sia in mezzo ai batch che a fine epoca. Modifica la condizione dell'if
#                 if global_step % (n_train // (10 * batch_size)) == 0 or global_step == n_batch:
#                     for tag, value in model.named_parameters():
#                         tag = tag.replace('.', '/')
#                         writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step)
#                         writer.add_histogram('grads/' + tag, value.grad.cpu().numpy(), global_step)
#                     writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)

#                     writer.add_images('query_images', queries, global_step)
#                     writer.add_images('target_images', targets, global_step)
#                     writer.add_images('masks/true', true_masks, global_step)
#                     writer.add_images('masks/pred', pred_masks, global_step)

            # TODO: Se save_cp è false e non viene cambiato il val_split di default non va in train il 10% del dataset. Si potrebbe fare in modo che non sia così
            if save_cp and (epoch + 1) % step_eval == 0:
                """
                logging.info("Saving model")
                try:
                    os.mkdir(checkpoint_dir)
                    logging.info('Created checkpoint directory')
                except OSError:  # Maybe FileExistsError ?
                    pass
                model_files = [f for f in os.listdir(checkpoint_dir) if
                               os.path.isfile(os.path.join(checkpoint_dir, f))]
                torch.save(model.state_dict(), checkpoint_dir + os.path.sep + f'model.pt')
                for model_file in model_files:
                    os.remove(f'{checkpoint_dir}{os.path.sep}{model_file}')
                """
#                 logging.info(f"Next operation is validation")
#                 print(f"starting validation from train")
                logging.info(f"Validating")
                val_score = eval_net(model, val_loader, device, bbox=False, verbose=True)
#                 print(f"validation completed")
#                 logging.info(f"Validation complete")
                if val_score > last_epoch_val_score:
                    logging.info(f"This model is better the last one, I'm gonna save it")
                    try:
                        os.mkdir(checkpoint_dir)
                        logging.info('Created checkpoint directory')
                    except OSError:  # Maybe FileExistsError ?
                        pass
                    model_files = [f for f in os.listdir(checkpoint_dir) if
                                   os.path.isfile(os.path.join(checkpoint_dir, f))]
                    torch.save(model.state_dict(), checkpoint_dir + os.path.sep + f'CP_epoch{epoch + 1}.pt')
                    logging.info(f'Checkpoint {epoch + 1} saved!')
                    for model_file in model_files:
                        os.remove(f'{checkpoint_dir}{os.path.sep}{model_file}')
                    last_epoch_val_score = val_score

    writer.close()
    torch.save(model.state_dict(), model_path)

    # # WIP
    # # Launches evaluation on the model every evaluate_every steps.
    # # We need to change to appropriate evaluation metrics.
    # if evaluate_every > 0 and valid_samples is not None and (e + 1) % evaluate_every == 0:
    #     self.model.eval()
    #     with torch.no_grad():
    #         mrr, h1 = self.evaluator.eval(samples=valid_samples, write_output= False)

    #     # Metrics printing
    #     print("\tValidation: %f" % h1)

    # if save_path is not None:
    #     print("\tSaving model...")
    #     torch.save(self.model.state_dict(), save_path)
    # print("\tDone.")


# print("\nEvaluating model...")
# model.eval()
# mrr, h1 = Evaluator(model=model).eval(samples=dataset.test_samples, write_output=False)
# print("\tTest Hits@1: %f" % h1)
# print("\tTest Mean Reciprocal Rank: %f" % mrr)


# def get_args():
#     parser = argparse.ArgumentParser()
#     parser.add_argument('--dataset',
#                         choices=ALL_DATASET_NAMES,
#                         default="FlickrLogos-32",
#                         type=str,
#                         help="Dataset in {}".format(ALL_DATASET_NAMES)
#                         )

#     parser.add_argument('--model',
#                         choices=ALL_MODEL_NAMES,
#                         default="LogoDetection",
#                         type=str,
#                         help="Model in {}".format(ALL_MODEL_NAMES)
#                         )

#     optimizers = ['Adam', 'SGD']
#     parser.add_argument('--optimizer',
#                         choices=optimizers,
#                         default='Adam',
#                         help="Optimizer in {}".format(optimizers)
#                         )

#     parser.add_argument('--max_epochs',
#                         default=500,
#                         type=int,
#                         help="Number of epochs"
#                         )

#     parser.add_argument('--valid',
#                         default=-1,
#                         type=float,
#                         help="Number of epochs before valid"
#                         )

#     parser.add_argument('--batch_size',
#                         default=32,
#                         type=int,
#                         help="Number of samples in each mini-batch in SGD and Adam optimization"
#                         )

#     parser.add_argument('--weight_decay',
#                         default=5e-4,
#                         type=float,
#                         help="L2 weight regularization of the optimizer"
#                         )

#     parser.add_argument('--learning_rate',
#                         default=4e-4,
#                         type=float,
#                         help="Learning rate of the optimizer"
#                         )

#     parser.add_argument('--label_smooth',
#                         default=0.1,
#                         type=float,
#                         help="Label smoothing for true labels"
#                         )

#     parser.add_argument('--decay1',
#                         default=0.9,
#                         type=float,
#                         help="Decay rate for the first momentum estimate in Adam"
#                         )

#     parser.add_argument('--decay2',
#                         default=0.999,
#                         type=float,
#                         help="Decay rate for second momentum estimate in Adam"
#                         )

#     parser.add_argument('--verbose',
#                         default=True,
#                         type=bool,
#                         help="Verbose"
#                         )

#     parser.add_argument('--load',
#                         type=str,
#                         required=False,
#                         help="Path to the model to load"
#                         )

#     parser.add_argument('--batch_norm',
#                         default=False,
#                         type=bool,
#                         help="If True, apply batch normalization",
#                         )

#     parser.add_argument('--vgg_cfg',
#                         type=str,
#                         default='A',
#                         help="VGG architecture config",
#                         )

#     parser.add_argument('--step_eval',
#                         type=int,
#                         default=0,
#                         help="Enables automatic evaluation checks every X step",
#                         )

#     parser.add_argument('--val_split',
#                         type=float,
#                         default=0.1,
#                         help="Forces the validation subset to be split according to the set value. Must a value in the [0-1] or the sofware WILL break",
#                         )

#     parser.add_argument('--save_cp',
#                         type=bool,
#                         default=True,
#                         help="If True, saves model checkponts",
#                         )

#     return parser.parse_args()


def train_main(dataset='FlickrLogos-32',
     model='LogoDetection',
     optimizer='Adam',
     vgg_cfg='A',
     max_epochs=1, 
     batch_size=4,
     weight_decay=5e-4,
     learning_rate=4e-4,
     decay1=0.9, 
     decay2=0.999,
     verbose=True,
     batch_norm=True,
     load=None,
     val_split=0.1,
     step_eval=10,
     save_cp=True,
     ):
    # TODO: Add filename
    logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s: %(message)s", filename='oneshot.log', filemode='w')
#     args = get_args()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')

    # Modularized paths with respect to the current Dataset
    imgs_dir = config_list['datasets'][dataset]['paths']['images']
    masks_dir = config_list['datasets'][dataset]['paths']['masks']
    checkpoint_dir = config_list['models'][model]['paths']['train_cp']

    model_path = config_list['models'][model]['paths']['model']+ os.path.sep + "_".join([model, dataset]) + ".pt"
    
    # create checkpoint dir
    try:
        os.makedirs(checkpoint_dir, exist_ok=True)
    except FileExistsError:
        pass
    
    # create the model dir
    try:
        os.makedirs(config_list['models'][model]['paths']['model'], exist_ok=True)
    except FileExistsError:
        pass

    print("Loading %s dataset..." % dataset)
    # you can delete this "save_to_disk" to preserve the ssd :like:
    dataset = BasicDataset(imgs_dir=imgs_dir, masks_dir=masks_dir, dataset_name=dataset)

    # Splitting dataset
    n_val = int(len(dataset) * val_split)
    n_train = len(dataset) - n_val
    # TODO: Il validation set dovrebbe avere il 10% di ogni classe e non il 10% del totale altrimenti verrebbe sbilanciato
    train_set, val_set = random_split(dataset, [n_train, n_val])

    # Loading dataset
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True,
                            drop_last=True)

    # Change here to adapt your data
    print("Initializing model...")
    model = LogoDetection(batch_norm=batch_norm, vgg_cfg=vgg_cfg)

    # Optimizer selection
    # build all the supported optimizers using the passed params (learning rate and decays if Adam)
    supported_optimizers = {
        'Adam': optim.Adam(params=model.parameters(), lr=learning_rate, betas=(decay1, decay2),
                           weight_decay=weight_decay),
        'SGD': optim.SGD(params=model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    }
    # Choose which Torch Optimizer object to use, based on the passed name
    optimizer = supported_optimizers[optimizer]

    # stiamo dando ad "args.load" due compiti, quello di dirci il path e quello di dirci se caricare vecchi checkpoint
    if load is not None:
        model.load_state_dict(
            torch.load(load, map_location=device)
        )
        logging.info(f'Model loaded from {load}')
    model.to(device=device)

    try:
        train(model=model,
              device=device,
              train_loader=train_loader,
              val_loader=val_loader,
              max_epochs=max_epochs,
              optimizer=optimizer,
              vgg_cfg=vgg_cfg,
              verbose=verbose,
              checkpoint_dir=checkpoint_dir,
              model_path=model_path,
              save_cp=save_cp,
              n_train=n_train,
              step_eval=step_eval,
              n_val=n_val
              )
    except KeyboardInterrupt:
        torch.save(model.state_dict(), config_list['models'][model]['paths']['model'] + os.path.sep + 'INTERRUPTED.pt')
        logging.info('Interrupt saved')
        try:
            sys.exit(0)
        except SystemExit:
            os._exit(0)


  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [None]:
train_main(dataset='FlickrLogos-32',
     model='LogoDetection',
     optimizer='Adam',
     vgg_cfg='B',
     max_epochs=500, 
     batch_size=16, 
     weight_decay=5e-4,
     learning_rate=4e-5,
     decay1=0.9, 
     decay2=0.999,
     verbose=True,
     batch_norm=True,
     load=None,
     val_split=0.1,
     step_eval=1,
     save_cp=True,
     )

Loading FlickrLogos-32 dataset...
You have 154560 triplets
Initializing model...


train loss: 100%|██████████| 139104/139104 [1:08:40<00:00, 43.99img/s, loss=0.01473]
train loss: 100%|██████████| 139104/139104 [1:08:51<00:00, 43.99img/s, loss=0.01473]
Validation round:   2%|▏         | 16/966 [00:12<12:44,  1.24samples/s][A
Validation round:   3%|▎         | 32/966 [00:22<11:34,  1.34samples/s][A
Validation round:   5%|▍         | 48/966 [00:32<10:43,  1.43samples/s][A
Validation round:   7%|▋         | 64/966 [00:40<09:52,  1.52samples/s][A
Validation round:   8%|▊         | 80/966 [00:52<09:57,  1.48samples/s][A
Validation round:  10%|▉         | 96/966 [01:03<09:59,  1.45samples/s][A
Validation round:  12%|█▏        | 112/966 [01:14<09:35,  1.48samples/s][A
Validation round:  13%|█▎        | 128/966 [01:23<09:07,  1.53samples/s][A
Validation round:  15%|█▍        | 144/966 [01:33<08:48,  1.55samples/s][A
Validation round:  17%|█▋        | 160/966 [01:43<08:34,  1.57samples/s][A
Validation round:  18%|█▊        | 176/966 [01:54<08:31,  1.55samples/s][A


Validation round: 1936samples [16:03,  3.21samples/s][A
Validation round: 1952samples [16:07,  3.28samples/s][A
Validation round: 1968samples [16:12,  3.27samples/s][A
Validation round: 1984samples [16:17,  3.25samples/s][A
Validation round: 2000samples [16:22,  3.30samples/s][A
Validation round: 2016samples [16:27,  3.33samples/s][A
Validation round: 2032samples [16:32,  3.28samples/s][A
Validation round: 2048samples [16:36,  3.32samples/s][A
Validation round: 2064samples [16:41,  3.34samples/s][A
Validation round: 2080samples [16:46,  3.30samples/s][A
Validation round: 2096samples [16:51,  3.33samples/s][A
Validation round: 2112samples [16:55,  3.37samples/s][A
Validation round: 2128samples [17:01,  3.26samples/s][A
Validation round: 2144samples [17:05,  3.32samples/s][A
Validation round: 2160samples [17:11,  3.21samples/s][A
Validation round: 2176samples [17:16,  3.21samples/s][A
Validation round: 2192samples [17:22,  2.97samples/s][A
Validation round: 2208samples [

Validation round: 4224samples [28:44,  2.77samples/s][A
Validation round: 4240samples [28:49,  2.91samples/s][A
Validation round: 4256samples [28:53,  3.03samples/s][A
Validation round: 4272samples [28:58,  3.16samples/s][A
Validation round: 4288samples [29:03,  3.24samples/s][A
Validation round: 4304samples [29:07,  3.29samples/s][A
Validation round: 4320samples [29:13,  3.23samples/s][A
Validation round: 4336samples [29:17,  3.27samples/s][A
Validation round: 4352samples [29:22,  3.31samples/s][A
Validation round: 4368samples [29:27,  3.32samples/s][A
Validation round: 4384samples [29:32,  3.33samples/s][A
Validation round: 4400samples [29:37,  3.12samples/s][A
Validation round: 4416samples [29:42,  3.15samples/s][A
Validation round: 4432samples [29:47,  3.23samples/s][A
Validation round: 4448samples [29:53,  3.02samples/s][A
Validation round: 4464samples [29:58,  3.14samples/s][A
Validation round: 4480samples [30:02,  3.23samples/s][A
Validation round: 4496samples [

Validation round: 6512samples [43:45,  1.54samples/s][A
Validation round: 6528samples [43:56,  1.54samples/s][A
Validation round: 6544samples [44:06,  1.54samples/s][A
Validation round: 6560samples [44:17,  1.53samples/s][A
Validation round: 6576samples [44:28,  1.51samples/s][A
Validation round: 6592samples [44:37,  1.55samples/s][A
Validation round: 6608samples [44:48,  1.52samples/s][A
Validation round: 6624samples [44:59,  1.53samples/s][A
Validation round: 6640samples [45:09,  1.54samples/s][A
Validation round: 6656samples [45:19,  1.55samples/s][A
Validation round: 6672samples [45:30,  1.54samples/s][A
Validation round: 6688samples [45:39,  1.57samples/s][A
Validation round: 6704samples [45:49,  1.62samples/s][A
Validation round: 6720samples [46:01,  1.52samples/s][A
Validation round: 6736samples [46:10,  1.57samples/s][A
Validation round: 6752samples [46:21,  1.54samples/s][A
Validation round: 6768samples [46:32,  1.51samples/s][A
Validation round: 6784samples [

Validation round: 8784samples [1:03:34,  3.24samples/s][A
Validation round: 8800samples [1:03:40,  3.03samples/s][A
Validation round: 8816samples [1:03:45,  3.15samples/s][A
Validation round: 8832samples [1:03:50,  3.09samples/s][A
Validation round: 8848samples [1:03:55,  3.17samples/s][A
Validation round: 8864samples [1:04:00,  3.24samples/s][A
Validation round: 8880samples [1:04:05,  3.23samples/s][A
Validation round: 8896samples [1:04:09,  3.31samples/s][A
Validation round: 8912samples [1:04:14,  3.33samples/s][A
Validation round: 8928samples [1:04:19,  3.21samples/s][A
Validation round: 8944samples [1:04:24,  3.26samples/s][A
Validation round: 8960samples [1:04:29,  3.30samples/s][A
Validation round: 8976samples [1:04:34,  3.29samples/s][A
Validation round: 8992samples [1:04:38,  3.32samples/s][A
Validation round: 9008samples [1:04:43,  3.33samples/s][A
Validation round: 9024samples [1:04:48,  3.34samples/s][A
Validation round: 9040samples [1:04:54,  3.05samples/s]

Validation round: 10976samples [1:15:25,  3.22samples/s][A
Validation round: 10992samples [1:15:31,  3.19samples/s][A
Validation round: 11008samples [1:15:35,  3.27samples/s][A
Validation round: 11024samples [1:15:40,  3.28samples/s][A
Validation round: 11040samples [1:15:45,  3.30samples/s][A
Validation round: 11056samples [1:15:50,  3.28samples/s][A
Validation round: 11072samples [1:15:55,  3.27samples/s][A
Validation round: 11088samples [1:16:00,  3.26samples/s][A
Validation round: 11104samples [1:16:04,  3.27samples/s][A
Validation round: 11120samples [1:16:09,  3.26samples/s][A
Validation round: 11136samples [1:16:14,  3.30samples/s][A
Validation round: 11152samples [1:16:19,  3.35samples/s][A
Validation round: 11168samples [1:16:24,  3.31samples/s][A
Validation round: 11184samples [1:16:29,  3.23samples/s][A
Validation round: 11200samples [1:16:34,  3.26samples/s][A
Validation round: 11216samples [1:16:38,  3.28samples/s][A
Validation round: 11232samples [1:16:44,

Validation round: 13152samples [1:32:58,  1.64samples/s][A
Validation round: 13168samples [1:33:09,  1.58samples/s][A
Validation round: 13184samples [1:33:18,  1.59samples/s][A
Validation round: 13200samples [1:33:28,  1.60samples/s][A
Validation round: 13216samples [1:33:38,  1.60samples/s][A
Validation round: 13232samples [1:33:48,  1.59samples/s][A
Validation round: 13248samples [1:33:58,  1.60samples/s][A
Validation round: 13264samples [1:34:09,  1.59samples/s][A
Validation round: 13280samples [1:34:18,  1.62samples/s][A
Validation round: 13296samples [1:34:28,  1.64samples/s][A
Validation round: 13312samples [1:34:37,  1.63samples/s][A
Validation round: 13328samples [1:34:48,  1.58samples/s][A
Validation round: 13344samples [1:34:58,  1.62samples/s][A
Validation round: 13360samples [1:35:07,  1.62samples/s][A
Validation round: 13376samples [1:35:17,  1.64samples/s][A
Validation round: 13392samples [1:35:28,  1.56samples/s][A
Validation round: 13408samples [1:35:39,

Validation round: 15328samples [1:54:11,  2.99samples/s][A
Validation round: 15344samples [1:54:16,  3.05samples/s][A
Validation round: 15360samples [1:54:21,  3.18samples/s][A
Validation round: 15376samples [1:54:25,  3.24samples/s][A
Validation round: 15392samples [1:54:30,  3.35samples/s][A
Validation round: 15408samples [1:54:35,  3.24samples/s][A
Validation round: 15424samples [1:54:41,  3.09samples/s][A
Validation round: 15440samples [1:54:46,  3.15samples/s][A
Validation round: 15456samples [1:54:51,  2.24samples/s][A


Precision: 0.9214514978601998    Recall: 0.8886500429922614    Accuracy: 0.8260730557109743


train loss: 100%|██████████| 139104/139104 [3:03:32<00:00, 12.63img/s, loss=0.01473]
train loss: 100%|██████████| 139104/139104 [59:27<00:00, 42.98img/s, loss=0.00982] 
train loss: 100%|██████████| 139104/139104 [59:40<00:00, 42.98img/s, loss=0.00982]
Validation round:   2%|▏         | 16/966 [00:13<13:18,  1.19samples/s][A
Validation round:   3%|▎         | 32/966 [00:23<11:57,  1.30samples/s][A
Validation round:   5%|▍         | 48/966 [00:32<10:52,  1.41samples/s][A
Validation round:   7%|▋         | 64/966 [00:45<11:09,  1.35samples/s][A
Validation round:   8%|▊         | 80/966 [00:55<10:24,  1.42samples/s][A
Validation round:  10%|▉         | 96/966 [01:04<09:46,  1.48samples/s][A
Validation round:  12%|█▏        | 112/966 [01:14<09:24,  1.51samples/s][A
Validation round:  13%|█▎        | 128/966 [01:24<09:02,  1.54samples/s][A
Validation round:  15%|█▍        | 144/966 [01:33<08:32,  1.60samples/s][A
Validation round:  17%|█▋        | 160/966 [01:44<08:31,  1.58samples/

Validation round: 1920samples [13:28,  2.73samples/s][A
Validation round: 1936samples [13:33,  2.83samples/s][A
Validation round: 1952samples [13:41,  2.59samples/s][A
Validation round: 1968samples [13:47,  2.55samples/s][A
Validation round: 1984samples [13:54,  2.50samples/s][A
Validation round: 2000samples [14:02,  2.32samples/s][A
Validation round: 2016samples [14:08,  2.40samples/s][A
Validation round: 2032samples [14:13,  2.58samples/s][A
Validation round: 2048samples [14:18,  2.76samples/s][A
Validation round: 2064samples [14:23,  2.93samples/s][A
Validation round: 2080samples [14:29,  2.77samples/s][A
Validation round: 2096samples [14:39,  2.28samples/s][A
Validation round: 2112samples [14:45,  2.36samples/s][A
Validation round: 2128samples [14:51,  2.43samples/s][A
Validation round: 2144samples [14:57,  2.52samples/s][A
Validation round: 2160samples [15:04,  2.51samples/s][A
Validation round: 2176samples [15:10,  2.53samples/s][A
Validation round: 2192samples [

Validation round: 4208samples [27:40,  2.99samples/s][A
Validation round: 4224samples [27:46,  2.88samples/s][A
Validation round: 4240samples [27:51,  2.97samples/s][A
Validation round: 4256samples [27:56,  3.05samples/s][A
Validation round: 4272samples [28:01,  3.11samples/s][A
Validation round: 4288samples [28:06,  3.17samples/s][A
Validation round: 4304samples [28:12,  2.99samples/s][A
Validation round: 4320samples [28:18,  2.92samples/s][A
Validation round: 4336samples [28:24,  2.78samples/s][A
Validation round: 4352samples [28:30,  2.73samples/s][A
Validation round: 4368samples [28:37,  2.67samples/s][A
Validation round: 4384samples [28:42,  2.72samples/s][A
Validation round: 4400samples [28:47,  2.88samples/s][A
Validation round: 4416samples [28:52,  2.98samples/s][A
Validation round: 4432samples [28:58,  2.89samples/s][A
Validation round: 4448samples [29:03,  2.98samples/s][A
Validation round: 4464samples [29:10,  2.68samples/s][A
Validation round: 4480samples [

Validation round: 6496samples [47:00,  1.55samples/s][A
Validation round: 6512samples [47:10,  1.59samples/s][A
Validation round: 6528samples [47:19,  1.63samples/s][A
Validation round: 6544samples [47:56,  1.13s/samples][A
Validation round: 6560samples [48:06,  1.04samples/s][A
Validation round: 6576samples [48:15,  1.17samples/s][A
Validation round: 6592samples [48:25,  1.27samples/s][A
Validation round: 6608samples [48:56,  1.14s/samples][A
Validation round: 6624samples [49:06,  1.03samples/s][A
Validation round: 6640samples [49:15,  1.16samples/s][A
Validation round: 6656samples [49:26,  1.25samples/s][A
Validation round: 6672samples [49:35,  1.36samples/s][A
Validation round: 6688samples [49:45,  1.45samples/s][A
Validation round: 6704samples [49:54,  1.51samples/s][A
Validation round: 6720samples [50:03,  1.59samples/s][A
Validation round: 6736samples [50:14,  1.56samples/s][A
Validation round: 6752samples [50:23,  1.60samples/s][A
Validation round: 6768samples [

Validation round: 8752samples [1:08:45,  1.85samples/s][A
Validation round: 8768samples [1:08:53,  1.83samples/s][A
Validation round: 8784samples [1:09:03,  1.81samples/s][A
Validation round: 8800samples [1:09:12,  1.80samples/s][A
Validation round: 8816samples [1:09:18,  1.96samples/s][A
Validation round: 8832samples [1:09:25,  2.06samples/s][A
Validation round: 8848samples [1:09:33,  2.01samples/s][A
Validation round: 8864samples [1:09:42,  1.92samples/s][A
Validation round: 8880samples [1:09:52,  1.85samples/s][A
Validation round: 8896samples [1:09:59,  1.96samples/s][A
Validation round: 8912samples [1:10:06,  2.04samples/s][A
Validation round: 8928samples [1:10:14,  1.99samples/s][A
Validation round: 8944samples [1:10:23,  1.96samples/s][A
Validation round: 8960samples [1:10:38,  1.55samples/s][A
Validation round: 8976samples [1:10:49,  1.52samples/s][A
Validation round: 8992samples [1:10:57,  1.66samples/s][A
Validation round: 9008samples [1:11:07,  1.62samples/s]

Validation round: 10944samples [1:28:22,  3.01samples/s][A
Validation round: 10960samples [1:28:27,  3.06samples/s][A
Validation round: 10976samples [1:28:34,  2.87samples/s][A
Validation round: 10992samples [1:28:39,  2.92samples/s][A
Validation round: 11008samples [1:28:44,  2.99samples/s][A
Validation round: 11024samples [1:28:49,  2.98samples/s][A
Validation round: 11040samples [1:28:54,  3.08samples/s][A
Validation round: 11056samples [1:28:59,  3.17samples/s][A
Validation round: 11072samples [1:29:04,  3.18samples/s][A
Validation round: 11088samples [1:29:09,  3.22samples/s][A
Validation round: 11104samples [1:29:13,  3.24samples/s][A
Validation round: 11120samples [1:29:18,  3.25samples/s][A
Validation round: 11136samples [1:29:23,  3.31samples/s][A
Validation round: 11152samples [1:29:28,  3.25samples/s][A
Validation round: 11168samples [1:29:33,  3.23samples/s][A
Validation round: 11184samples [1:29:38,  3.30samples/s][A
Validation round: 11200samples [1:29:43,

Validation round:  68%|██████▊   | 656/966 [03:01<01:26,  3.57samples/s][A
Validation round:  70%|██████▉   | 672/966 [03:05<01:19,  3.71samples/s][A
Validation round:  71%|███████   | 688/966 [03:09<01:12,  3.81samples/s][A
Validation round:  73%|███████▎  | 704/966 [03:13<01:07,  3.87samples/s][A
Validation round:  75%|███████▍  | 720/966 [03:17<01:06,  3.67samples/s][A
Validation round:  76%|███████▌  | 736/966 [03:22<01:02,  3.66samples/s][A
Validation round:  78%|███████▊  | 752/966 [03:26<00:57,  3.71samples/s][A
Validation round:  80%|███████▉  | 768/966 [03:31<00:55,  3.55samples/s][A
Validation round:  81%|████████  | 784/966 [03:35<00:50,  3.62samples/s][A
Validation round:  83%|████████▎ | 800/966 [03:40<00:45,  3.63samples/s][A
Validation round:  84%|████████▍ | 816/966 [03:44<00:40,  3.74samples/s][A
Validation round:  86%|████████▌ | 832/966 [03:48<00:35,  3.82samples/s][A
Validation round:  88%|████████▊ | 848/966 [03:52<00:30,  3.83samples/s][A
Validation r

Validation round: 2832samples [13:00,  3.79samples/s][A
Validation round: 2848samples [13:03,  3.91samples/s][A
Validation round: 2864samples [13:08,  3.79samples/s][A
Validation round: 2880samples [13:12,  3.81samples/s][A
Validation round: 2896samples [13:17,  3.67samples/s][A
Validation round: 2912samples [13:21,  3.80samples/s][A
Validation round: 2928samples [13:25,  3.84samples/s][A
Validation round: 2944samples [13:30,  3.62samples/s][A
Validation round: 2960samples [13:34,  3.71samples/s][A
Validation round: 2976samples [13:38,  3.76samples/s][A
Validation round: 2992samples [13:42,  3.81samples/s][A
Validation round: 3008samples [13:47,  3.63samples/s][A
Validation round: 3024samples [13:51,  3.75samples/s][A
Validation round: 3040samples [13:55,  3.80samples/s][A
Validation round: 3056samples [13:59,  3.79samples/s][A
Validation round: 3072samples [14:03,  3.86samples/s][A
Validation round: 3088samples [14:07,  3.81samples/s][A
Validation round: 3104samples [

Validation round: 5120samples [23:59,  3.72samples/s][A
Validation round: 5136samples [24:03,  3.68samples/s][A
Validation round: 5152samples [24:07,  3.77samples/s][A
Validation round: 5168samples [24:11,  3.85samples/s][A
Validation round: 5184samples [24:15,  3.91samples/s][A
Validation round: 5200samples [24:19,  3.91samples/s][A
Validation round: 5216samples [24:23,  3.89samples/s][A
Validation round: 5232samples [24:27,  3.91samples/s][A
Validation round: 5248samples [24:31,  3.97samples/s][A
Validation round: 5264samples [24:35,  4.02samples/s][A
Validation round: 5280samples [24:39,  4.03samples/s][A
Validation round: 5296samples [24:43,  4.05samples/s][A
Validation round: 5312samples [24:47,  3.97samples/s][A
Validation round: 5328samples [24:51,  3.92samples/s][A
Validation round: 5344samples [24:55,  3.91samples/s][A
Validation round: 5360samples [24:59,  3.89samples/s][A
Validation round: 5376samples [25:03,  3.96samples/s][A
Validation round: 5392samples [

Validation round: 7408samples [34:06,  3.77samples/s][A
Validation round: 7424samples [34:11,  3.77samples/s][A
Validation round: 7440samples [34:16,  3.61samples/s][A
Validation round: 7456samples [34:20,  3.56samples/s][A
Validation round: 7472samples [34:25,  3.59samples/s][A
Validation round: 7488samples [34:29,  3.71samples/s][A
Validation round: 7504samples [34:32,  3.80samples/s][A
Validation round: 7520samples [34:37,  3.61samples/s][A
Validation round: 7536samples [34:42,  3.62samples/s][A
Validation round: 7552samples [34:46,  3.74samples/s][A
Validation round: 7568samples [34:51,  3.45samples/s][A
Validation round: 7584samples [34:56,  3.42samples/s][A
Validation round: 7600samples [35:02,  3.19samples/s][A
Validation round: 7616samples [35:06,  3.37samples/s][A
Validation round: 7632samples [35:10,  3.45samples/s][A
Validation round: 7648samples [35:15,  3.46samples/s][A
Validation round: 7664samples [35:19,  3.51samples/s][A
Validation round: 7680samples [

Validation round: 9696samples [44:48,  3.77samples/s][A
Validation round: 9712samples [44:52,  3.75samples/s][A
Validation round: 9728samples [44:56,  3.82samples/s][A
Validation round: 9744samples [45:01,  3.76samples/s][A
Validation round: 9760samples [45:04,  3.85samples/s][A
Validation round: 9776samples [45:08,  3.93samples/s][A
Validation round: 9792samples [45:12,  3.99samples/s][A
Validation round: 9808samples [45:16,  3.95samples/s][A
Validation round: 9824samples [45:21,  3.87samples/s][A
Validation round: 9840samples [45:25,  3.95samples/s][A
Validation round: 9856samples [45:29,  3.95samples/s][A
Validation round: 9872samples [45:33,  3.96samples/s][A
Validation round: 9888samples [45:36,  4.01samples/s][A
Validation round: 9904samples [45:41,  3.77samples/s][A
Validation round: 9920samples [45:46,  3.75samples/s][A
Validation round: 9936samples [45:50,  3.82samples/s][A
Validation round: 9952samples [45:54,  3.86samples/s][A
Validation round: 9968samples [

Validation round: 7920samples [36:28,  3.93samples/s][A
Validation round: 7936samples [36:33,  3.71samples/s][A
Validation round: 7952samples [36:37,  3.79samples/s][A
Validation round: 7968samples [36:41,  3.71samples/s][A
Validation round: 7984samples [36:45,  3.82samples/s][A
Validation round: 8000samples [36:49,  3.87samples/s][A
Validation round: 8016samples [36:53,  3.93samples/s][A
Validation round: 8032samples [36:57,  3.93samples/s][A
Validation round: 8048samples [37:01,  3.93samples/s][A
Validation round: 8064samples [37:05,  3.94samples/s][A
Validation round: 8080samples [37:11,  3.67samples/s][A
Validation round: 8096samples [37:15,  3.75samples/s][A
Validation round: 8112samples [37:19,  3.70samples/s][A
Validation round: 8128samples [37:23,  3.85samples/s][A
Validation round: 8144samples [37:27,  3.82samples/s][A
Validation round: 8160samples [37:31,  3.90samples/s][A
Validation round: 8176samples [37:35,  4.00samples/s][A
Validation round: 8192samples [

Validation round: 10208samples [47:14,  3.88samples/s][A
Validation round: 10224samples [47:19,  3.86samples/s][A
Validation round: 10240samples [47:23,  3.87samples/s][A
Validation round: 10256samples [47:27,  3.81samples/s][A
Validation round: 10272samples [47:31,  3.84samples/s][A
Validation round: 10288samples [47:36,  3.81samples/s][A
Validation round: 10304samples [47:40,  3.83samples/s][A
Validation round: 10320samples [47:44,  3.86samples/s][A
Validation round: 10336samples [47:48,  3.94samples/s][A
Validation round: 10352samples [47:52,  3.83samples/s][A
Validation round: 10368samples [47:57,  3.59samples/s][A
Validation round: 10384samples [48:01,  3.68samples/s][A
Validation round: 10400samples [48:05,  3.84samples/s][A
Validation round: 10416samples [48:09,  3.81samples/s][A
Validation round: 10432samples [48:13,  3.88samples/s][A
Validation round: 10448samples [48:17,  3.83samples/s][A
Validation round: 10464samples [48:22,  3.80samples/s][A
Validation rou

Validation round: 3184samples [16:17,  3.51samples/s][A
Validation round: 3200samples [16:21,  3.57samples/s][A
Validation round: 3216samples [16:27,  3.37samples/s][A
Validation round: 3232samples [16:32,  3.24samples/s][A
Validation round: 3248samples [16:51,  1.79samples/s][A
Validation round: 3264samples [16:55,  2.08samples/s][A
Validation round: 3280samples [17:00,  2.40samples/s][A
Validation round: 3296samples [17:05,  2.54samples/s][A
Validation round: 3312samples [17:09,  2.83samples/s][A
Validation round: 3328samples [17:13,  3.08samples/s][A
Validation round: 3344samples [17:18,  3.18samples/s][A
Validation round: 3360samples [17:22,  3.32samples/s][A
Validation round: 3376samples [17:27,  3.34samples/s][A
Validation round: 3392samples [17:32,  3.38samples/s][A
Validation round: 3408samples [17:36,  3.36samples/s][A
Validation round: 3424samples [17:42,  3.28samples/s][A
Validation round: 3440samples [17:46,  3.30samples/s][A
Validation round: 3456samples [

Validation round: 5472samples [27:53,  3.68samples/s][A
Validation round: 5488samples [27:58,  3.55samples/s][A
Validation round: 5504samples [28:03,  3.42samples/s][A
Validation round: 5520samples [28:07,  3.46samples/s][A
Validation round: 5536samples [28:12,  3.37samples/s][A
Validation round: 5552samples [28:16,  3.55samples/s][A
Validation round: 5568samples [28:22,  3.29samples/s][A
Validation round: 5584samples [28:27,  3.29samples/s][A
Validation round: 5600samples [28:32,  3.30samples/s][A
Validation round: 5616samples [28:36,  3.36samples/s][A
Validation round: 5632samples [28:41,  3.37samples/s][A
Validation round: 5648samples [28:45,  3.44samples/s][A
Validation round: 5664samples [28:50,  3.50samples/s][A
Validation round: 5680samples [28:54,  3.50samples/s][A
Validation round: 5696samples [28:59,  3.39samples/s][A
Validation round: 5712samples [29:04,  3.40samples/s][A
Validation round: 5728samples [29:09,  3.40samples/s][A
Validation round: 5744samples [

Validation round: 7760samples [39:25,  1.37samples/s][A
Validation round: 7776samples [39:30,  1.66samples/s][A
Validation round: 7792samples [39:35,  1.93samples/s][A
Validation round: 7808samples [39:40,  2.22samples/s][A
Validation round: 7824samples [39:44,  2.47samples/s][A
Validation round: 7840samples [39:49,  2.73samples/s][A
Validation round: 7856samples [39:53,  2.93samples/s][A
Validation round: 7872samples [39:57,  3.17samples/s][A
Validation round: 7888samples [40:04,  2.88samples/s][A
Validation round: 7904samples [40:08,  3.09samples/s][A
Validation round: 7920samples [40:13,  3.16samples/s][A
Validation round: 7936samples [40:18,  3.28samples/s][A
Validation round: 7952samples [40:22,  3.42samples/s][A
Validation round: 7968samples [40:26,  3.46samples/s][A
Validation round: 7984samples [40:31,  3.51samples/s][A
Validation round: 8000samples [40:35,  3.61samples/s][A
Validation round: 8016samples [40:39,  3.56samples/s][A
Validation round: 8032samples [

Validation round:  60%|█████▉    | 576/966 [02:43<01:49,  3.55samples/s][A
Validation round:  61%|██████▏   | 592/966 [02:47<01:43,  3.61samples/s][A
Validation round:  63%|██████▎   | 608/966 [02:51<01:37,  3.68samples/s][A
Validation round:  65%|██████▍   | 624/966 [02:56<01:34,  3.62samples/s][A
Validation round:  66%|██████▋   | 640/966 [03:00<01:27,  3.71samples/s][A
Validation round:  68%|██████▊   | 656/966 [03:04<01:22,  3.74samples/s][A
Validation round:  70%|██████▉   | 672/966 [03:09<01:25,  3.43samples/s][A
Validation round:  71%|███████   | 688/966 [03:13<01:17,  3.57samples/s][A
Validation round:  73%|███████▎  | 704/966 [03:17<01:10,  3.72samples/s][A
Validation round:  75%|███████▍  | 720/966 [03:22<01:08,  3.58samples/s][A
Validation round:  76%|███████▌  | 736/966 [03:27<01:05,  3.53samples/s][A
Validation round:  78%|███████▊  | 752/966 [03:32<01:02,  3.44samples/s][A
Validation round:  80%|███████▉  | 768/966 [03:36<00:57,  3.45samples/s][A
Validation r


Validation round: 10928samples [53:15,  1.93samples/s][A
Validation round: 10944samples [53:22,  1.96samples/s][A
Validation round: 10960samples [53:30,  1.99samples/s][A
Validation round: 10976samples [53:40,  1.86samples/s][A
Validation round: 10992samples [53:49,  1.87samples/s][A
Validation round: 11008samples [53:56,  1.95samples/s][A
Validation round: 11024samples [54:07,  1.78samples/s][A
Validation round: 11040samples [54:16,  1.75samples/s][A
Validation round: 11056samples [54:26,  1.71samples/s][A
Validation round: 11072samples [54:36,  1.70samples/s][A
Validation round: 11088samples [54:46,  1.65samples/s][A
Validation round: 11104samples [54:56,  1.63samples/s][A
Validation round: 11120samples [55:06,  1.62samples/s][A
Validation round: 11136samples [55:14,  1.75samples/s][A
Validation round: 11152samples [55:23,  1.75samples/s][A
Validation round: 11168samples [55:33,  1.70samples/s][A
Validation round: 11184samples [55:43,  1.68samples/s][A
Validation ro

Validation round: 13136samples [1:08:11,  3.66samples/s][A
Validation round: 13152samples [1:08:15,  3.75samples/s][A
Validation round: 13168samples [1:08:19,  3.74samples/s][A
Validation round: 13184samples [1:08:23,  3.83samples/s][A
Validation round: 13200samples [1:08:28,  3.81samples/s][A
Validation round: 13216samples [1:08:34,  3.35samples/s][A
Validation round: 13232samples [1:08:40,  3.15samples/s][A
Validation round: 13248samples [1:08:46,  2.93samples/s][A
Validation round: 13264samples [1:08:50,  3.12samples/s][A
Validation round: 13280samples [1:08:55,  3.26samples/s][A
Validation round: 13296samples [1:09:01,  2.94samples/s][A
Validation round: 13312samples [1:09:08,  2.70samples/s][A
Validation round: 13328samples [1:09:15,  2.56samples/s][A
Validation round: 13344samples [1:09:22,  2.56samples/s][A
Validation round: 13360samples [1:09:27,  2.64samples/s][A
Validation round: 13376samples [1:09:34,  2.57samples/s][A
Validation round: 13392samples [1:09:40,

Validation round: 15312samples [1:23:44,  2.19samples/s][A
Validation round: 15328samples [1:23:51,  2.23samples/s][A
Validation round: 15344samples [1:23:57,  2.33samples/s][A
Validation round: 15360samples [1:24:03,  2.51samples/s][A
Validation round: 15376samples [1:24:08,  2.61samples/s][A
Validation round: 15392samples [1:24:14,  2.63samples/s][A
Validation round: 15408samples [1:24:20,  2.68samples/s][A
Validation round: 15424samples [1:24:25,  2.76samples/s][A
Validation round: 15440samples [1:24:31,  2.71samples/s][A
Validation round: 15456samples [1:24:37,  3.04samples/s][A


Precision: 0.9864517534306252    Recall: 0.9735167669819432    Accuracy: 0.9606720122184039


train loss: 100%|██████████| 139104/139104 [2:14:58<00:00, 17.18img/s, loss=0.00553]
train loss: 100%|██████████| 139104/139104 [1:02:29<00:00, 47.40img/s, loss=0.00544]
Validation round:   0%|          | 0/966 [00:00<?, ?samples/s][A
Validation round:   2%|▏         | 16/966 [00:06<06:16,  2.52samples/s][A
train loss: 100%|██████████| 139104/139104 [1:02:41<00:00, 47.40img/s, loss=0.00544]
Validation round:   5%|▍         | 48/966 [00:14<05:04,  3.02samples/s][A
Validation round:   7%|▋         | 64/966 [00:19<04:50,  3.11samples/s][A
Validation round:   8%|▊         | 80/966 [00:23<04:29,  3.29samples/s][A
Validation round:  10%|▉         | 96/966 [00:34<05:56,  2.44samples/s][A
Validation round:  12%|█▏        | 112/966 [00:38<05:10,  2.75samples/s][A
Validation round:  13%|█▎        | 128/966 [00:43<04:46,  2.92samples/s][A
Validation round:  15%|█▍        | 144/966 [00:47<04:16,  3.21samples/s][A
Validation round:  17%|█▋        | 160/966 [00:51<04:00,  3.35samples/s][A


Validation round: 1920samples [13:54,  2.47samples/s][A
Validation round: 1936samples [14:01,  2.46samples/s][A
Validation round: 1952samples [14:07,  2.42samples/s][A
Validation round: 1968samples [14:14,  2.37samples/s][A
Validation round: 1984samples [14:21,  2.38samples/s][A
Validation round: 2000samples [14:28,  2.39samples/s][A
Validation round: 2016samples [14:35,  2.38samples/s][A
Validation round: 2032samples [14:43,  2.21samples/s][A
Validation round: 2048samples [14:49,  2.28samples/s][A
Validation round: 2064samples [14:55,  2.42samples/s][A
Validation round: 2080samples [15:01,  2.52samples/s][A
Validation round: 2096samples [15:07,  2.54samples/s][A
Validation round: 2112samples [15:14,  2.51samples/s][A
Validation round: 2128samples [15:20,  2.52samples/s][A
Validation round: 2144samples [15:26,  2.51samples/s][A
Validation round: 2160samples [15:34,  2.32samples/s][A
Validation round: 2176samples [15:41,  2.34samples/s][A
Validation round: 2192samples [

Validation round: 4208samples [29:12,  2.33samples/s][A
Validation round: 4224samples [29:19,  2.36samples/s][A
Validation round: 4240samples [29:25,  2.42samples/s][A
Validation round: 4256samples [29:31,  2.43samples/s][A
Validation round: 4272samples [29:38,  2.37samples/s][A
Validation round: 4288samples [29:45,  2.42samples/s][A
Validation round: 4304samples [29:52,  2.39samples/s][A
Validation round: 4320samples [29:58,  2.42samples/s][A
Validation round: 4336samples [30:04,  2.44samples/s][A
Validation round: 4352samples [30:12,  2.36samples/s][A
Validation round: 4368samples [30:18,  2.38samples/s][A
Validation round: 4384samples [30:24,  2.45samples/s][A
Validation round: 4400samples [30:30,  2.55samples/s][A
Validation round: 4416samples [30:37,  2.49samples/s][A
Validation round: 4432samples [30:44,  2.40samples/s][A
Validation round: 4448samples [30:51,  2.40samples/s][A
Validation round: 4464samples [30:58,  2.31samples/s][A
Validation round: 4480samples [

Validation round: 6496samples [44:21,  2.42samples/s][A
Validation round: 6512samples [44:27,  2.47samples/s][A
Validation round: 6528samples [44:33,  2.51samples/s][A
Validation round: 6544samples [44:40,  2.44samples/s][A
Validation round: 6560samples [44:49,  2.21samples/s][A
Validation round: 6576samples [44:54,  2.35samples/s][A
Validation round: 6592samples [45:01,  2.38samples/s][A
Validation round: 6608samples [45:09,  2.29samples/s][A
Validation round: 6624samples [45:16,  2.27samples/s][A
Validation round: 6640samples [45:23,  2.29samples/s][A
Validation round: 6656samples [45:29,  2.32samples/s][A
Validation round: 6672samples [45:35,  2.46samples/s][A
Validation round: 6688samples [45:42,  2.38samples/s][A
Validation round: 6704samples [45:50,  2.25samples/s][A
Validation round: 6720samples [45:57,  2.31samples/s][A
Validation round: 6736samples [46:03,  2.41samples/s][A
Validation round: 6752samples [46:16,  1.86samples/s][A
Validation round: 6768samples [

Validation round: 8784samples [1:01:23,  2.42samples/s][A
Validation round: 8800samples [1:01:28,  2.69samples/s][A
Validation round: 8816samples [1:01:32,  3.02samples/s][A
Validation round: 8832samples [1:01:36,  3.12samples/s][A
Validation round: 8848samples [1:01:41,  3.18samples/s][A
Validation round: 8864samples [1:01:46,  3.25samples/s][A
Validation round: 8880samples [1:01:52,  3.01samples/s][A
Validation round: 8896samples [1:01:57,  3.13samples/s][A
Validation round: 8912samples [1:02:01,  3.21samples/s][A
Validation round: 8928samples [1:02:07,  3.07samples/s][A
Validation round: 8944samples [1:02:15,  2.70samples/s][A
Validation round: 8960samples [1:02:20,  2.83samples/s][A
Validation round: 8976samples [1:02:24,  2.97samples/s][A
Validation round: 8992samples [1:02:31,  2.83samples/s][A
Validation round: 9008samples [1:02:35,  3.00samples/s][A
Validation round: 9024samples [1:02:40,  3.11samples/s][A
Validation round: 9040samples [1:02:46,  3.01samples/s]

Validation round: 10976samples [1:17:25,  2.41samples/s][A
Validation round: 10992samples [1:17:33,  2.32samples/s][A
Validation round: 11008samples [1:17:38,  2.46samples/s][A
Validation round: 11024samples [1:17:45,  2.42samples/s][A
Validation round: 11040samples [1:17:52,  2.44samples/s][A
Validation round: 11056samples [1:17:59,  2.36samples/s][A
Validation round: 11072samples [1:18:06,  2.38samples/s][A
Validation round: 11088samples [1:18:12,  2.37samples/s][A
Validation round: 11104samples [1:18:19,  2.38samples/s][A
Validation round: 11120samples [1:18:26,  2.39samples/s][A
Validation round: 11136samples [1:18:33,  2.35samples/s][A
Validation round: 11152samples [1:18:39,  2.47samples/s][A
Validation round: 11168samples [1:18:44,  2.56samples/s][A
Validation round: 11184samples [1:18:50,  2.57samples/s][A
Validation round: 11200samples [1:18:57,  2.59samples/s][A
Validation round: 11216samples [1:19:03,  2.56samples/s][A
Validation round: 11232samples [1:19:10,

Validation round: 13152samples [1:32:04,  2.25samples/s][A
Validation round: 13168samples [1:32:11,  2.27samples/s][A
Validation round: 13184samples [1:32:17,  2.35samples/s][A
Validation round: 13200samples [1:32:24,  2.39samples/s][A
Validation round: 13216samples [1:32:30,  2.37samples/s][A
Validation round: 13232samples [1:32:38,  2.25samples/s][A
Validation round: 13248samples [1:32:47,  2.12samples/s][A
Validation round: 13264samples [1:32:53,  2.22samples/s][A
Validation round: 13280samples [1:33:01,  2.19samples/s][A
Validation round: 13296samples [1:33:08,  2.24samples/s][A
Validation round: 13312samples [1:33:14,  2.33samples/s][A
Validation round: 13328samples [1:33:22,  2.23samples/s][A
Validation round: 13344samples [1:33:28,  2.37samples/s][A
Validation round: 13360samples [1:33:34,  2.43samples/s][A
Validation round: 13376samples [1:33:40,  2.50samples/s][A
Validation round: 13392samples [1:33:46,  2.50samples/s][A
Validation round: 13408samples [1:33:53,

Validation round: 15328samples [1:46:26,  3.41samples/s][A
Validation round: 15344samples [1:46:31,  3.36samples/s][A
Validation round: 15360samples [1:46:36,  3.19samples/s][A
Validation round: 15376samples [1:46:42,  3.00samples/s][A
Validation round: 15392samples [1:46:48,  2.98samples/s][A
Validation round: 15408samples [1:46:54,  2.89samples/s][A
Validation round: 15424samples [1:46:59,  2.89samples/s][A
Validation round: 15440samples [1:47:05,  2.93samples/s][A
Validation round: 15456samples [1:47:11,  2.40samples/s][A


Precision: 0.9860585197934596    Recall: 0.9852106620808254    Accuracy: 0.9716757123473542


train loss: 100%|██████████| 139104/139104 [2:49:40<00:00, 13.66img/s, loss=0.00544]
train loss: 100%|██████████| 139104/139104 [1:06:08<00:00, 44.82img/s, loss=0.00456]
Validation round:   0%|          | 0/966 [00:00<?, ?samples/s][A
train loss: 100%|██████████| 139104/139104 [1:06:22<00:00, 44.82img/s, loss=0.00456]
Validation round:   3%|▎         | 32/966 [00:16<08:15,  1.89samples/s][A
Validation round:   5%|▍         | 48/966 [00:22<07:35,  2.01samples/s][A
Validation round:   7%|▋         | 64/966 [00:27<06:29,  2.32samples/s][A
Validation round:   8%|▊         | 80/966 [00:33<06:13,  2.37samples/s][A
Validation round:  10%|▉         | 96/966 [00:39<05:58,  2.42samples/s][A
Validation round:  12%|█▏        | 112/966 [00:47<06:15,  2.28samples/s][A
Validation round:  13%|█▎        | 128/966 [00:56<06:35,  2.12samples/s][A
Validation round:  15%|█▍        | 144/966 [01:03<06:15,  2.19samples/s][A
Validation round:  17%|█▋        | 160/966 [01:10<06:09,  2.18samples/s][A


Validation round: 1920samples [14:03,  1.72samples/s][A
Validation round: 1936samples [14:12,  1.74samples/s][A
Validation round: 1952samples [14:20,  1.77samples/s][A
Validation round: 1968samples [14:29,  1.79samples/s][A
Validation round: 1984samples [14:36,  1.94samples/s][A
Validation round: 2000samples [14:45,  1.85samples/s][A
Validation round: 2016samples [14:55,  1.80samples/s][A
Validation round: 2032samples [15:04,  1.76samples/s][A
Validation round: 2048samples [15:15,  1.69samples/s][A
Validation round: 2064samples [15:24,  1.70samples/s][A
Validation round: 2080samples [15:33,  1.70samples/s][A
Validation round: 2096samples [15:43,  1.66samples/s][A
Validation round: 2112samples [15:53,  1.65samples/s][A
Validation round: 2128samples [16:03,  1.66samples/s][A
Validation round: 2144samples [16:15,  1.51samples/s][A
Validation round: 2160samples [16:21,  1.76samples/s][A
Validation round: 2176samples [16:31,  1.73samples/s][A
Validation round: 2192samples [

Validation round: 4208samples [33:31,  1.59samples/s][A
Validation round: 4224samples [34:02,  1.03s/samples][A
Validation round: 4240samples [34:13,  1.08samples/s][A
Validation round: 4256samples [34:19,  1.31samples/s][A
Validation round: 4272samples [34:27,  1.46samples/s][A
Validation round: 4288samples [34:37,  1.51samples/s][A
Validation round: 4304samples [34:49,  1.43samples/s][A
Validation round: 4320samples [34:59,  1.49samples/s][A
Validation round: 4336samples [35:09,  1.51samples/s][A
Validation round: 4352samples [35:22,  1.43samples/s][A
Validation round: 4368samples [35:32,  1.48samples/s][A
Validation round: 4384samples [35:39,  1.65samples/s][A
Validation round: 4400samples [35:45,  1.89samples/s][A
Validation round: 4416samples [35:54,  1.82samples/s][A
Validation round: 4432samples [36:08,  1.54samples/s][A
Validation round: 4448samples [36:18,  1.58samples/s][A
Validation round: 4464samples [36:28,  1.60samples/s][A
Validation round: 4480samples [

Validation round: 6496samples [54:10,  1.73samples/s][A
Validation round: 6512samples [54:18,  1.84samples/s][A
Validation round: 6528samples [54:28,  1.77samples/s][A
Validation round: 6544samples [54:39,  1.65samples/s][A
Validation round: 6560samples [54:56,  1.35samples/s][A
Validation round: 6576samples [55:05,  1.46samples/s][A
Validation round: 6592samples [55:15,  1.48samples/s][A
Validation round: 6608samples [55:25,  1.51samples/s][A
Validation round: 6624samples [55:37,  1.44samples/s][A
Validation round: 6640samples [55:46,  1.54samples/s][A
Validation round: 6656samples [55:52,  1.76samples/s][A
Validation round: 6672samples [56:05,  1.57samples/s][A
Validation round: 6688samples [56:15,  1.58samples/s][A
Validation round: 6704samples [56:27,  1.48samples/s][A
Validation round: 6720samples [56:37,  1.52samples/s][A
Validation round: 6736samples [56:47,  1.57samples/s][A
Validation round: 6752samples [56:56,  1.59samples/s][A
Validation round: 6768samples [

Validation round: 8736samples [1:14:35,  1.94samples/s][A
Validation round: 8752samples [1:15:17,  1.16s/samples][A
Validation round: 8768samples [1:15:31,  1.07s/samples][A
Validation round: 8784samples [1:15:38,  1.13samples/s][A
Validation round: 8800samples [1:15:44,  1.38samples/s][A
Validation round: 8816samples [1:15:54,  1.43samples/s][A
Validation round: 8832samples [1:16:04,  1.50samples/s][A
Validation round: 8848samples [1:16:14,  1.53samples/s][A
Validation round: 8864samples [1:16:23,  1.57samples/s][A
Validation round: 8880samples [1:16:33,  1.59samples/s][A
Validation round: 8896samples [1:16:42,  1.63samples/s][A
Validation round: 8912samples [1:16:52,  1.60samples/s][A
Validation round: 8928samples [1:16:59,  1.80samples/s][A
Validation round: 8944samples [1:17:07,  1.86samples/s][A

In [None]:
# test

In [None]:
import argparse
import logging
import os
import sys
from tqdm import tqdm
import yaml

import numpy as np
import torch
from torch import optim
from torch.utils.data import DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn
import torch.nn.functional as F


# todo: when we add more models, we should move these variables to another location
ALL_MODEL_NAMES = ["LogoDetection"]
ALL_DATASET_NAMES = ["FlickrLogos-32, TopLogos-10"]

with open(os.path.abspath("./config/config.yaml")) as config:
    config_list = yaml.load(config, Loader=yaml.FullLoader)

# def pred(model,
#          sample,
#          device,
#          threshold=0.5):
#     if(model.eval==False):
#         model.eval()

#     queries, targets = torch.from_numpy(BasicDataset.preprocess(index=index, file:files_path))
#     queries = queries.unsqueeze(0)
#     queries = queries.to(device=device, dtype=torch.float32)
#     targets = targets.unsqueeze(0)
#     targets = targets.to(device=device, dtype=torch.float32)


def test(model,
         device,
         dataset,
         batch_size,
        #   save_path,
         verbose: bool,
         threshold=0.5):
    
    model.eval()

    # #TODO dataset preprocessing
    # with torch.no_grad():
    #     output = model(queries, targets)

    #     probs = output.squeeze(0)
    
    logging.info("\nPredicting image{} ... ")

    with tqdm(total=len(dataset), desc=f'Testing dataset', unit='test-img', disable=not verbose) as bar:
            bar.set_description(f'model testing')

            for batch in test_loader:
                queries = batch['query']  # Correct dimensions?
                targets = batch['target']
                bboxes = batch['bbox']

                queries = queries.to(device=device, dtype=torch.float32)
                targets = targets.to(device=device, dtype=torch.float32)
                bboxes = bboxes.to(device=device, dtype=torch.float32)

                with torch.no_grad():
                    pred_masks = model(queries, targets)
                    # print(pred_masks.shape)

                bar.update(queries.shape[0])
                global_step += 1
                val_score += eval_net(model, batch, device)
                # writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)

                writer.add_images('query_images', queries, global_step)
                writer.add_images('target_images', targets, global_step)
                # writer.add_images('bboxes/true', bboxes, global_step)
                writer.add_images('masks/pred', pred_masks, global_step)
                writer.close()
    return val_score / global_step


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset',
                        choices=ALL_DATASET_NAMES,
                        help="Dataset in {}".format(ALL_DATASET_NAMES),
                        required=True
                        )

    parser.add_argument('--model',
                        choices=ALL_MODEL_NAMES,
                        help="Model in {}".format(ALL_MODEL_NAMES)
                        )

    parser.add_argument('--batch_size',
                        default=32,
                        type=int,
                        help="Number of samples in each mini-batch in SGD and Adam optimization"
                        )

    parser.add_argument('--verbose',
                        default=True,
                        type=bool,
                        help="Verbose"
                        )

    parser.add_argument('--load',
                        type=str,
                        required=True,
                        help="Path to the model to load"
                        )

    return parser.parse_args()


if __name__ == '__main__':
    # Logging
    logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
    args = get_args()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')

    # Modularized paths with respect to the current Dataset
    imgs_dir = config_list['datasets'][args.dataset]['images']
    masks_dir = config_list['datasets'][args.dataset]['masks']
    # TODO: Controlla che ste due liste hanno le stesse sottocartelle
    imgs_classes = [f.name for f in os.scandir(imgs_dir) if f.is_dir()]
    mask_classes = [f.name for f in os.scandir(masks_dir) if f.is_dir()]
    # checkpoint_dir = config_list['models'][args.model]['train_cp']

    model_path = config_list['models'][args.model]['paths']['model']+ "_".join([args.model, args.dataset]) + ".pt"

    # print("Loading %s dataset..." % args.dataset)
    # dataset = BasicDataset(imgs_dir=imgs_dir, masks_dir=masks_dir)

    # Change here to adapt your data
    print("Initializing model...")
    model = LogoDetection(batch_norm=args.batch_norm,
                        vgg_cfg=args.vgg_cfg)

    model.load_state_dict(
        torch.load(model_path, map_location=device)
    )
    logging.info(f'Model loaded from {model_path}')
    model.to(device=device)

    # Neo, enter in Metrics
    metrics = []

    for img_class_idx, img_class_path in enumerate(imgs_classes):
        dataset = BasicDataset(imgs_dir=f"{imgs_dir}{os.path.sep}{img_class_path}", masks_dir=f"{masks_dir}{os.path.sep}{masks_dir[img_class_idx]}", dataset_name=args.dataset, skip_bbox_lines=1)
        test_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)

        try:
            metrics.append(test(model=model,
                                device=device,
                                dataset=test_loader,
                                batch_size=args.batch_size,
                                verbose=args.verbose
                                ))
        except KeyboardInterrupt:
            # torch.save(model.state_dict(), 'INTERRUPTED.ph')
            # logging.info('Interrupt saved')
            logging.info("Test interrupted")
            try:
                sys.exit(0)
            except SystemExit:
                os._exit(0)


In [None]:
!pip install nvgpu

In [None]:
import nvgpu

avail_gpus = nvgpu.available_gpus()
print(avail_gpus)