In [1]:
!nvidia-smi

Fri Dec 18 12:31:20 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 410.48                 Driver Version: 410.48                    |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla P100-SXM2...  Off  | 00000000:05:00.0 Off |                    0 |
| N/A   33C    P0    41W / 300W |   8035MiB / 16280MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  Tesla P100-SXM2...  Off  | 00000000:06:00.0 Off |                    0 |
| N/A   34C    P0    33W / 300W |     10MiB / 16280MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   2  Tesla P100-SXM2...  Off  | 00000000:84:00.0 Off |                    0 |
| N/A   

In [22]:
!pip install nvgpu

Collecting nvgpu
  Downloading https://files.pythonhosted.org/packages/9f/1f/1543808105e377bb154ff61c2623cdf76a330b983b22c285a14d03e7431b/nvgpu-0.9.0-py2.py3-none-any.whl
Collecting arrow (from nvgpu)
[?25l  Downloading https://files.pythonhosted.org/packages/ca/bc/ebc1afb3c54377e128a01024c006f983d03ee124bc52392b78ba98c421b8/arrow-0.17.0-py2.py3-none-any.whl (50kB)
[K    100% |████████████████████████████████| 51kB 4.1MB/s ta 0:00:011
Collecting pynvml; python_version >= "3.6" (from nvgpu)
  Downloading https://files.pythonhosted.org/packages/1b/1a/a25c143e1d2f873d67edf534b269d028dd3c20be69737cca56bf28911d02/pynvml-8.0.4-py3-none-any.whl
Collecting flask-restful (from nvgpu)
  Downloading https://files.pythonhosted.org/packages/e9/83/d0d33c971de2d38e54b0037136c8b8d20b9c83d308bc6c220a25162755fd/Flask_RESTful-0.3.8-py2.py3-none-any.whl
Collecting ansi2html (from nvgpu)
  Downloading https://files.pythonhosted.org/packages/c6/85/3a46be84afbb16b392a138cd396117f438c7b2e91d8dc327621d1ae1b5

In [23]:
import nvgpu

avail_gpus = nvgpu.available_gpus()
print(avail_gpus)

IndexError: list index out of range

In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [2]:
# Model parts

In [3]:

""" Parts of the U-Net model """

import torch
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Downscaler(nn.Module):
    """Double conv 3x3, then max pool 2x2 stride 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            DoubleConv(in_channels, out_channels),
            nn.MaxPool2d(2, stride=2)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Upscaler(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = DoubleConv(in_channels, in_channels)
        self.upscale = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

    def forward(self, x):
        x = self.double_conv(x)
        return self.upscale(x)        


class OutSoftmax(nn.Module):
    def __init__(self):
        super(OutSoftmax, self).__init__()
        self.softmax = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Softmax()
        )

    def forward(self, x):
        return self.softmax(x)


# Conditioning Branch
class OneOneConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OneOneConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1),
            # nn.Conv2d(out_channels, out_channels, kernel_size=1), # kernel_size=3, padding=1
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


# class ConditioningConcat(nn.Module):
#     def __init__(self, tile_concat):
#         super(ConditioningConcat, self).__init__()
#         self.filtered_tile = OneOneConv(self.tile_concat)       # Here or in forward method?

#     def forward(self, x):
#         return torch.cat(x, self.filtered_tile, dim=-1)


class VGG16(nn.Module):
    def __init__(self, batch_norm=False, cfg='A', in_channels=3):
        super(VGG16, self).__init__()
        self.batch_norm = batch_norm

        self.cfgs = {
            'A': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M', 512, 512, 512, 'M'],     # 6 "evenly" distributed maxpools to reduce dims to 1x1x512 
            'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M', 'M'],     # 6 maxpools to reduce dims to 1x1x512
        }
        self.cfg = self.cfgs[cfg]

        self.layers = []
        for v in self.cfg:
            if v == 'M':
                self.layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
                if self.batch_norm:
                    self.layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
                else:
                    self.layers += [conv2d, nn.ReLU(inplace=True)]
                in_channels = v
        self.vgg16 = nn.Sequential(*self.layers)

    def forward(self, x):
        return self.vgg16(x)

In [4]:
# Model

In [5]:
import torch
from torch import nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
import numpy as np


class LogoDetection(nn.Module):
    def __init__(self,
                 n_channels: int = 3,
                 batch_norm=False,
                 vgg_cfg: str = 'A'):
        super(LogoDetection, self).__init__()
        self.n_channels = n_channels

        # Encoder steps
        self.input_layer = Downscaler(self.n_channels, 64)
        # self.input_layer.wei
        self.down_layer1 = Downscaler(64, 128)
        self.down_layer2 = Downscaler(128, 256)
        self.down_layer3 = Downscaler(256, 512)
        self.down_layer4 = Downscaler(512, 512)  # 1/(2^5)*(width x height) x 512

        # Conditioning Module
        self.latent_repr = VGG16(batch_norm, vgg_cfg)
        self.one_conv1 = OneOneConv(576, 64)  # 64+512
        self.one_conv2 = OneOneConv(640, 128)  # 128+512
        self.one_conv3 = OneOneConv(768, 256)  # 256+512
        self.one_conv4 = OneOneConv(1024, 512)  # 512+512

        # Decoder steps
        self.up1 = Upscaler(1536, 512)  # 1024+512
        self.up2 = Upscaler(1024, 256)  # 512*2
        self.up3 = Upscaler(512, 128)  # 256*2
        self.up4 = Upscaler(256, 64)  # 128*2
        self.up5 = Upscaler(128, 1)  # 64*2
        self.output_layer = OutSoftmax()

        # with torch.no_grad():
        #     self.input_layer.weight = torch.nn.Parameter()

    def forward(self, query, target):
        # query = samples[:, 0]
        # target = samples[:, 1]
        z = self.latent_repr(query)
        # print(z.shape)

        # Encoder + Conditioning
        x = self.input_layer(target)

        tile = z.expand(z.shape[0], z.shape[1], 128, 128)
        # print(tile.shape)
        x1 = torch.cat((x, tile), dim=1)
        x = self.down_layer1(x)

        tile = z.expand(z.shape[0], z.shape[1], 64, 64)
        x2 = torch.cat((x, tile), dim=1)
        x = self.down_layer2(x)

        tile = z.expand(z.shape[0], z.shape[1], 32, 32)
        x3 = torch.cat((x, tile), dim=1)
        x = self.down_layer3(x)

        tile = z.expand(z.shape[0], z.shape[1], 16, 16)
        x4 = torch.cat((x, tile), dim=1)
        x = self.down_layer4(x)

        tile = z.expand(z.shape[0], z.shape[1], 8, 8)
        x5 = torch.cat((x, tile), dim=1)
        # print(x.shape)

        # Decoder + Conditioning
        x = torch.cat((x, x5), dim=1)
        # print(x.shape)
        x = self.up1(x)
        # print(x.shape)

        x4 = self.one_conv4(x4)
        x = torch.cat((x, x4), dim=1)
        # del x4
        x = self.up2(x)

        x3 = self.one_conv3(x3)
        x = torch.cat((x, x3), dim=1)
        # del x3
        x = self.up3(x)

        x2 = self.one_conv2(x2)
        x = torch.cat((x, x2), dim=1)
        # del x2
        x = self.up4(x)

        x1 = self.one_conv1(x1)
        x = torch.cat((x, x1), dim=1)
        # del x1
        x = self.up5(x)

        output = self.output_layer(x)
        return output

    def predict_mask(self, query, target):
        return self.forward(query, target)


In [6]:
# eval

In [7]:
import functools, operator, collections
import torch
import torch.nn.functional as F
import numpy as np

from sklearn.metrics import jaccard_score as jsc
from sklearn.metrics import average_precision_score as avg_pr
from sklearn.cluster import DBSCAN
from sklearn import metrics

import matplotlib.pyplot as plt

from tqdm import tqdm

from skimage.measure import label, regionprops


def eval_net(model,
         loader, 
         device, 
         bbox: bool, 
         verbose: bool,
         iou_thr: int = 0.5
         ):
#     logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s: %(message)s", filename='oneshot_eval.log')
    model.eval()

    # Number of batches
    n_val = len(loader)

    logging.info(f"n_val: {n_val}")
    
    matches = 0

    # Number of bboxes
    max_matches = 0

    if bbox:
        truth_type = "bbox"
    else:
        truth_type = "mask"
    
    logging.info(f"type of ground truth: {truth_type}")

    precisions, recalls, accuracies = [], [], []

    batch_results = []

    with tqdm(total=n_val, desc='Validation round', unit='img', leave=False, disable=not verbose) as bar:
        for batch in loader:
            logging.info(f"New batch!")
            queries, targets, truth = batch['query'], batch['target'], batch[truth_type]
            queries = queries.to(device=device, dtype=torch.float32)
            targets = targets.to(device=device, dtype=torch.float32)
            # truth = truth.to(device=device, dtype=torch.float32)

            with torch.no_grad():
                pred = model(queries, targets)
                logging.info(f"Masks predicted")
                pred_masks = pred.cpu().numpy()
                logging.info(f"Masks to CPU")
                print(f"pred_masks_shape_0: {pred_masks.shape[0]}")
                print(f"pred_masks: {pred_masks}")
                # assunzione: gli indici della true e pred masks sono gli stessi
                for mask_index in range(pred_masks.shape[0]):
                    pred_mask = np.asarray(pred_masks[mask_index])
                    pred_mask = masks_as_image(rle_encode(pred_mask))
                    
                    # mask labeling and conversion to bboxes
                    pred_labels = label(pred_mask)
                    pred_bboxes_coords = np.array(list(map(lambda x: x.bbox, regionprops(pred_labels))))
                    pred_bboxes = calc_bboxes_from_coords(pred_bboxes_coords)

                    # computes truth bboxes in the same way as the pred
                    if bbox:
                        truth_bboxes = np.array(truth[mask_index])
                    else:
                        truth_mask = np.asarray(truth[mask_index])
                        truth_mask = masks_as_image(rle_encode(truth_mask))
                        true_mask_labels = label(truth_mask)
                        truth_bboxes_coords = np.array(list(map(lambda x: x.bbox, regionprops(true_mask_labels))))
                        truth_bboxes = calc_bboxes_from_coords(truth_bboxes_coords)

                    max_matches += len(truth_bboxes)
                    
                    b_result = get_pred_results(truth_bboxes, pred_bboxes, iou_thr)
                    print(f"b_result: {b_result}")
                    batch_results.append(b_result)
                
                logging.info(f"Batch finished. batch_results: {batch_results}")

    print(f"Validation from eval completed")
    # Should not be here, since the eval method is used in both validation and test -> TODO: better handling of the flag.
    model.train()
    
    print(f"Batch results: {batch_results}")
    
    result = dict(functools.reduce(operator.add, map(collections.Counter, batch_results)))
    print("Validation output: ", str(result))
    true_pos = result['true_pos']
    false_pos = result['false_pos']
    false_neg = result['false_neg']

    precision = calc_precision(true_pos, false_pos)
    recall = calc_recall(true_pos, false_neg)
    accuracy = calc_accuracy(true_pos, false_pos, false_neg)
    print(f"Precision: {precision}    Recall: {recall}    Accuracy: {accuracy}")

    return accuracy


def calc_bboxes_from_coords(bboxes_coords):
    """Calculate all bounding boxes from a set of bounding boxes coordinates"""
    bboxes = []
    for coord_idx in range(len(bboxes_coords)):
        coord = bboxes_coords[coord_idx]
        bbox = (coord[1], coord[0], int(coord[4])-int(coord[1]), int(coord[3])-int(coord[0]))
        bboxes.append(bbox)
    return bboxes


def get_pred_results(truth_bboxes, pred_bboxes, iou_thr = 0.5):
    """Calculates true_pos, false_pos and false_neg from the input bounding boxes. """
    n_pred_idxs = range(len(pred_bboxes))
    n_truth_idxs = range(len(truth_bboxes))
    if len(n_pred_idxs) == 0:
        tp = 0
        fp = 0
        fn = len(truth_bboxes)
        return {'true_pos': tp, 'false_pos': fp, 'false_neg': fn}
    if len(n_truth_idxs) == 0:
        tp = 0
        fp = len(pred_bboxes)
        fn = 0
        return {'true_pos': tp, 'false_pos': fp, 'false_neg': fn}

    truth_idx_thr = []
    pred_idx_thr = []
    ious = []
    for pred_idx, pred_bbox in enumerate(pred_bboxes):
        for truth_idx, truth_bbox in enumerate(truth_bboxes):
            iou = get_jaccard(pred_bbox, truth_bbox)
            if iou > iou_thr:
                truth_idx_thr.append(truth_idx)
                pred_idx_thr.append(pred_idx)
                ious.append(iou)
    # ::-1 reverses the list
    ious_desc = np.argsort(ious)[::-1]
    if len(ious_desc) == 0:
        # No matches
        tp = 0
        fp = len(pred_bboxes)
        fn = len(truth_bboxes)
    else:
        truth_match_idxes = []
        pred_match_idxes = []
        for idx in ious_desc:
            truth_idx = truth_idx_thr[idx]
            pred_idx = pred_idx_thr[idx]
            # If the bboxes are unmatched, add them to matches
            if (truth_idx not in truth_match_idxes) and (pred_idx not in pred_match_idxes):
                truth_match_idxes.append(truth_idx)
                pred_match_idxes.append(pred_match_idxes)
        tp = len(truth_match_idxes)
        fp = len(pred_bboxes) - len(pred_match_idxes)
        fn = len(truth_bboxes) - len(truth_match_idxes)
    return {'true_pos': tp, 'false_pos': fp, 'false_neg': fn}


def calc_precision(true_pos, false_neg):
    try:
        precision = true_pos / (true_pos + false_neg)
    except ZeroDivisionError:
        precision = 0.0
    return precision


def calc_recall(true_pos, false_pos):
    try:
        recall = true_pos / (true_pos + false_pos)
    except ZeroDivisionError:
        recall = 0.0
    return recall


def calc_accuracy(true_pos, false_pos, false_neg):
    try:
        accuracy = true_pos / (true_pos + false_pos + false_neg)
    except ZeroDivisionError:
        accuracy = 0.0
    return accuracy 


def calc_mavg_precision(precision_array):
    return 


def get_jaccard(pred_bbox, truth_bbox):
    pred_mask = get_mask_from_bbox(pred_bbox)
    truth_mask = get_mask_from_bbox(truth_bbox)
    return get_jaccard_from_mask(pred_mask, truth_mask)


def get_jaccard_from_mask(pred_masks, truth):
    return jsc(pred_masks, truth)


def get_mask_from_bbox(bbox):
    mask = np.zeros((256, 256))
    x = bbox[0]
    y = bbox[1]
    for width in range(bbox[2] + 1):
        for height in range(bbox[3] + 1):
            mask[x + width, y + height] = 1
    return mask


def get_bbox_batch(img):
    bbox = np.empty((img.shape[0], 4))
    for i in range(img.shape[0]):
        bbox[i] = get_bbox(img[i])
    return bbox


def get_bbox(img):
    # img.shape = [batch_size, 1, 256, 256]
    rows = np.any(img, axis=-1)
    cols = np.any(img, axis=-2)
    rmin, rmax = np.where(rows)[0][[0, -1]]
    cmin, cmax = np.where(cols)[0][[0, -1]]
    # X, Y, Width, Height
    return [cmin, rmin, cmax-cmin, rmax-rmin]


def rle_encode(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels_old = img.T.flatten()
    pixels = img.T.flatten()
    for x in range(len(pixels_old)):
        if pixels_old[x] > 0.5:
            pixels[x] = 1
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)


def rle_decode(mask_rle, shape=(256, 256)):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background
    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape).T  # Needed to align to RLE direction


def masks_as_image(in_mask_list, all_masks=None):
    """
    Take the complete rle_encoded mask and create a mask array of the single masks
    """
    if all_masks is None:
        all_masks = np.zeros((256, 256), dtype=np.int16)
    # if isinstance(in_mask_list, list):
    for mask in in_mask_list:
        if isinstance(mask, str):
            all_masks += rle_decode(mask)
    return np.expand_dims(all_masks, -1)

  data = yaml.load(f.read()) or {}


In [8]:
# dataset_loader

In [1]:
import os
import numpy as np
from torch.utils.data import Dataset
import logging
from PIL import Image

import torch

import gzip
import shutil
import h5py
import tables


# TODO: Deve preprocessare anche le immagini di test
# TODO: Ha da funzionà co TorchVision, se hai tempo
# TODO: Visto che le maschere ci servono solo per estrapolare le bbox e confrontarle con quelle stimate, ha senso portarsi tutta la maschera e non le singole bbox?
class BasicDataset(Dataset):
    TARGET_IMAGE_PATH = "target_image_path"
    MASK_IMAGE_PATH = "mask_image_path"
    BBOX_PATH = "bbox_path"
    TARGET_IMAGE_BBOX_PATH = "target_image_bbox_path"

    # TODO: Check if the values are empty
    def __init__(self, imgs_dir: str, masks_dir: str, dataset_name: str, mask_image_dim: int = 256, query_dim: int = 64,
                 bbox_suffix: str = '.bboxes.txt', save_to_disk: bool = False, skip_bbox_lines: int = 0):
        self.imgs_dir = fix_input_dir(imgs_dir)
        self.masks_dir = fix_input_dir(masks_dir)
        self.processed_img_dir = str(self.imgs_dir[:self.imgs_dir.rindex(os.path.sep) + 1]) + "preprocessed"
        self.mask_img_dim = mask_image_dim
        self.query_dim = query_dim
        self.bbox_suffix = bbox_suffix
        self.save_to_disk = save_to_disk
        self.skip_bbox_lines = skip_bbox_lines
        assert mask_image_dim > 1, 'The dimension of mask and image must be higher than 1'
        assert query_dim > 1, 'The dimension of query image must be higher than 1'

        assert os.path.isdir(imgs_dir), f"Bad path for images directory: {imgs_dir}"

        assert os.path.isdir(masks_dir), f"Bad path for masks directory: {masks_dir}"

        if save_to_disk:
            # create processed image's directory, if not exists yet
            try:
                os.mkdir(self.processed_img_dir)
            except FileExistsError:
                # some previous instance generate this directory, no need to raise an exception
                pass

        # list of dict with the path of the images, which contains the paths for the following images:
        #       target, mask, bbox, target's bbox
        # every dict is defined by 4 str keys which have a str value
        # key = type of image
        # value = path of the image

        # List of dict. Every dict refers to an image with 4 keys:
        #       target, mask, bbox, query_target_bbox
        self.images_path = []

        # TODO: Fai in modo che preprocess calcoli sia la maschera che il bbox e poi, in base al dataset, togline uno
        if dataset_name == "FlickrLogos-32":
            self.flickrlogos32_load()
        elif dataset_name == "TopLogos-10":
            self.toplogos10_load()

    def __len__(self):
        return len(self.images_path)

    def toplogos10_load(self):

        # get bbox path
        bbox_path = None
        for bbox_paths, _, bbox_list in os.walk(self.masks_dir):
            for bbox_file in bbox_list:
                if self.bbox_suffix in bbox_file:
                    bbox_path = get_class_file_path(bbox_paths, bbox_file)
                    break
            if bbox_path:
                break

        # get query image path
        query_full_image_path = f"{bbox_path[:bbox_path.index(self.bbox_suffix)]}"

        # get target images path and fill "self.images_path"
        for target_images_paths, _, target_images_list in os.walk(self.imgs_dir):
            for target_image_name in target_images_list:
                target_image_root_path, target_image_extension = os.path.splitext(
                    os.path.join(target_images_paths, target_image_name))
                if target_image_extension == ".jpg":
                    self.images_path.append(
                        {self.TARGET_IMAGE_PATH: get_class_file_path(target_images_paths, target_image_name),
                         self.MASK_IMAGE_PATH: None,
                         self.BBOX_PATH: bbox_path,
                         self.TARGET_IMAGE_BBOX_PATH: query_full_image_path})

    def flickrlogos32_load(self):

        # dict with merged masks path
        # key = target image file name
        # value = merged mask's file path
        masks_dict = {}

        # put stuff into masks_dict
        for masks_paths, _, masks_files in os.walk(self.masks_dir):
            for mask_file_name in masks_files:
                _, mask_extension = os.path.splitext(os.path.join(masks_paths, mask_file_name))
                if mask_extension == ".png" and "merged" in mask_file_name:
                    masks_dict[mask_file_name[:mask_file_name.rindex(".mask")]] = get_class_file_path(masks_paths,
                                                                                                      mask_file_name)

        # dict with every image of every class
        # key = class name
        # value = dict with images type and path
        #       key = type of image (target, query, mask)
        #       value = path of the file
        image_path_element = {}

        # put stuff into image_path_element
        for target_images_paths, _, target_images_files in os.walk(self.imgs_dir):
            # TODO: Compare come classe la cartella padre "jpg", trova un modo per risolvere
            target_image_class = target_images_paths[target_images_paths.rindex(os.path.sep) + 1:]
            for target_image_name in target_images_files:
                target_image_root_path, target_image_extension = os.path.splitext(
                    os.path.join(target_images_paths, target_image_name))
                if target_image_extension == ".jpg" and "no-logo" not in target_image_root_path:
                    x = {self.TARGET_IMAGE_PATH: get_class_file_path(target_images_paths, target_image_name),
                         self.MASK_IMAGE_PATH: masks_dict[target_image_name],
                         self.BBOX_PATH: f'{masks_dict[target_image_name][:masks_dict[target_image_name].rindex(".mask")]}{self.bbox_suffix}'}
                    try:
                        image_path_element[target_image_class].append(x)
                    except KeyError:
                        image_path_element[target_image_class] = [x]

        # fill "images_path" variable. it generate every couple (target image, query image) for the same class
        # for now, it only skips couple (target, query) of the same image
        for target_image_class in image_path_element:
            items_class = image_path_element[target_image_class]
            for outer_image in items_class:
                bbox_path = outer_image[self.BBOX_PATH]
                outer_target_image_path = outer_image[self.TARGET_IMAGE_PATH]
                for inner_image in items_class:
                    if not outer_image == inner_image:
                        target_image_path = inner_image[self.TARGET_IMAGE_PATH]
                        mask_image_path = inner_image[self.MASK_IMAGE_PATH]
                        self.images_path.append({self.TARGET_IMAGE_PATH: target_image_path,
                                                 self.MASK_IMAGE_PATH: mask_image_path,
                                                 self.BBOX_PATH: bbox_path,
                                                 self.TARGET_IMAGE_BBOX_PATH: outer_target_image_path})
        print(f"You have {len(self.images_path)} triplets")

    # preprocess the images. then save in file and return a list triplet [query image, target image, mask image]. how?
    # stretch the target image
    # stretch, crop and stretch again the query image
    # stretch the mask image
    # TODO: Check if the values are empty
    @classmethod
    def preprocess(cls, target_img_path: str, bbox_path: str, query_full_img_path: str, skip_bbox_lines: int = 0,
                   img_dim: int = 256, query_img_dim: int = 64, mask_img_path: str = None) -> dict:

        # Target image

        pil_target_img = Image.open(target_img_path)
        # stretch the image
        pil_resized_target_img = pil_target_img.resize((img_dim, img_dim))

        # Query image

        pil_target_img_bbox = Image.open(query_full_img_path)
        pil_resized_target_img_bbox = pil_target_img_bbox.resize((img_dim, img_dim))

        # we will resize, crop and resize again the image but we have the coordinates of the non resized bounding box
        target_img_width, target_img_height = pil_target_img_bbox.size
        resized_target_img_width, resized_target_img_height = pil_resized_target_img_bbox.size
        percent_width = round(100 * int(resized_target_img_width) / (int(target_img_width)), 2) / 100
        percent_height = round(100 * int(resized_target_img_height) / (int(target_img_height)), 2) / 100

        # open the bounding box file
        with open(bbox_path) as bbox_file:
            # read only the first line of the bbox file
            bbox_lines = bbox_file.readlines()
            first_line_bbox = bbox_lines[1 - skip_bbox_lines].split(' ')
            # check if we correctly skipped the first line of the file, the one with no number,
            # and if all the elements are numeric, like every coordinate should be ;)
            if first_line_bbox[0].isnumeric() and first_line_bbox[1].isnumeric() and \
                    first_line_bbox[2].isnumeric() and first_line_bbox[3].rstrip().isnumeric():
                x, y, width, height = first_line_bbox
                # adapt the old coordinates to the new stretched dimension
                left = int(x.strip()) * percent_width
                upper = int(y.strip()) * percent_height
                right = int(int(x.strip()) + int(width)) * percent_width
                lower = int(int(y.strip()) + int(height)) * percent_height
                # crop and resize the query image
                pil_query_img = pil_resized_target_img_bbox.crop((left, upper, right, lower))
                pil_resized_query_img = pil_query_img.resize((query_img_dim, query_img_dim))
            else:
                # TODO: nel traceback compare "error_string" e poi successivamente spiega l'eccezione. Trova un modo per togliere quel "error_string"
                error_string = f"Bounding box file's first line should have 4 groups of integers with whitespace " \
                               f"separator. Check {bbox_path}"
                raise Exception(error_string)

        # Mask

        if mask_img_path:
            pil_mask = Image.open(mask_img_path)
            pil_resized_mask = pil_mask.resize((img_dim, img_dim))
        else:
            pil_resized_mask = None

        # get the size of the images
        # print(f"query image dim: {pil_resized_query_img.size}")
        # print(f"target image dim: {pil_resized_target_img.size}")
        # if pil_resized_mask:
        #     print(f"mask image dim: {pil_resized_mask.size}")

        # just to test if everything works. don't look at these :)
        # if pil_resized_target_img:
        #     pil_resized_target_img.save('target.jpg')
        # if pil_resized_query_img:
        #     pil_resized_query_img.save('query.jpg')
        # if pil_resized_mask:
        #     pil_resized_mask.save('mask.jpg')

        # return the triplet (Dq, Dt, Dm) where Dq is the query image, Dt is the target image and Dm is the mask image
        return create_triplet_with_torch_representation(pil_resized_query_img,
                                                        pil_resized_target_img,
                                                        pil_resized_mask)

    # def h5py_with_pytorch(self, pil_img, index, type):
    #     x = self.h5py_compression(to_pytorch(pil_img), index, type)
    #     return x
    #
    # def h5py_without_pytorch(self, pil_img, index, type):
    #     x = self.h5py_compression(pil_img, index, type)
    #     return x

    # def store_hdf5_file_with_compression(self, image, image_index, image_type):
    #     file_name = f'{self.processed_img_dir}{os.path.sep}{image_index}_{image_type}.hdf5'
    #     f = h5py.File(file_name, "w")
    #     # TODO: Esistono altri algoritmi di compressione come Mafisc. Una roba figa che puoi usare è Bitshuffle
    #     f.create_dataset("init", compression="gzip", compression_opts=9, data=image)
    #     f.close()
    #     return file_name

    def store_hdf5_file_with_compression(self, images, image_index):
        image_type = ["query", "target", "mask"]
        image_type_index = 0
        for image in images:
            file_name = f'{self.processed_img_dir}{os.path.sep}{image_index}_{image}.hdf5'
            f = h5py.File(file_name, "w")
            # TODO: Esistono altri algoritmi di compressione come Mafisc. Una roba figa che puoi usare è Bitshuffle
            f.create_dataset("init", compression="gzip", compression_opts=9, data=images[image])
            f.close()
            image_type_index += 1
        # for image in images:
        #     file_name = f'{self.processed_img_dir}{os.path.sep}{image_index}_{image_type[image_type_index]}.hdf5'
        #     f = h5py.File(file_name, "w")
        #     # TODO: Esistono altri algoritmi di compressione come Mafisc. Una roba figa che puoi usare è Bitshuffle
        #     f.create_dataset("init", compression="gzip", compression_opts=9, data=image)
        #     f.close()
        #     image_type_index += 1
        # return file_name

    # def gzip_compress(self, index, input_file):
    #     # input_file = f'{self.processed_img_dir}{os.path.sep}{index}.npz'
    #     with open(input_file, 'rb') as f_in:
    #         output_file = f'{input_file}.gz'
    #         with gzip.open(output_file, 'wb', compresslevel=9) as f_out:
    #             shutil.copyfileobj(f_in, f_out)
    #     # if os.path.exists(input_file):
    #     #     os.remove(input_file)
    #     return output_file

    # def gzip_compress(self, index, input_file):
    #     # input_file = f'{self.processed_img_dir}{os.path.sep}{index}.npz'
    #     output_file = f'{input_file}.gz'
    #     with gzip.open(output_file, 'wb', compresslevel=1) as f_out:
    #         with open(input_file, 'rb') as f_in:
    #             shutil.copyfileobj(f_in, f_out)
    #     # if os.path.exists(input_file):
    #     #     os.remove(input_file)
    #     return output_file

    # def gzip_uncompress(self, input_file):
    #     with gzip.open(input_file, 'rb') as f:
    #         file_content = f.read()
    #     output_file = input_file[:input_file.rindex('.')]
    #     with open(output_file, mode='wb') as fp:
    #         fp.write(file_content)
    #     return output_file

    def read_hdf5_file(self, hdf5_file):
        with h5py.File(hdf5_file, 'r') as hf:
            data = hf.get('init')
            data = np.array(data)
        return data

    # def np_save_compressed(self, index, triplet_list_in_torch_representation):
    #     file_name = f'{self.processed_img_dir}{os.path.sep}{index}'
    #     # save the file so the next time you don't have to preprocess again
    #     np.savez_compressed(file_name,
    #                         query=triplet_list_in_torch_representation[0],
    #                         target=triplet_list_in_torch_representation[1],
    #                         mask=triplet_list_in_torch_representation[2])
    #     return f'{file_name}.npz'

    def __getitem__(self, item_index):
#         print(f"Getting item {item_index}")
        # get the path of the preprocessed file, if exists
        mask_file_path = f'{self.processed_img_dir}{os.path.sep}{item_index}_mask.hdf5'
        query_file_path = f'{self.processed_img_dir}{os.path.sep}{item_index}_query.hdf5'
        target_file_path = f'{self.processed_img_dir}{os.path.sep}{item_index}_target.hdf5'

        correct_order_triplet = [query_file_path, target_file_path, mask_file_path]
        triplet_element_order = ["query", "target", "mask"]

        return_dict = {}
        for file in correct_order_triplet:
            if not os.path.exists(file):
                # triplet = self.preprocess(item_index, self.images_path[item_index])
                return_dict = self.preprocess(
                    target_img_path=get_full_path(self.imgs_dir, self.images_path[item_index][self.TARGET_IMAGE_PATH]),
                    bbox_path=get_full_path(self.masks_dir, self.images_path[item_index][self.BBOX_PATH]),
                    query_full_img_path=get_full_path(self.imgs_dir,
                                                      self.images_path[item_index][self.TARGET_IMAGE_BBOX_PATH]),
                    mask_img_path=get_full_path(self.masks_dir, self.images_path[item_index][self.MASK_IMAGE_PATH]),
                    skip_bbox_lines=self.skip_bbox_lines)
                if self.save_to_disk:
                    self.store_hdf5_file_with_compression(return_dict, item_index)
                break
        else:
            triplet_index = 0
            for file in correct_order_triplet:
                if os.path.exists(file):
                    hdf5_file = self.read_hdf5_file(file)
                    return_dict[triplet_element_order[triplet_index]] = to_pytorch(hdf5_file)
                triplet_index += 1
#         print(f"Ciao, sono un fantastico {return_dict}")
        return return_dict

        # for file in correct_order_triplet:
        # # check if preprocessed file exists. if not, he will generate it. then return the triplet
        #     if os.path.exists(file):
        #         data = np.load(file, mmap_mode='r')
        #         return_tuple = np.array([data['query'], data['target'], data['mask']])
        #     else:
        #         return_tuple = self.preprocess(i, self.images_path[i])
        # return return_tuple


# something that will be deleted
# class CarvanaBasicDataset(BasicDataset):
#     def __init__(self, imgs_dir, masks_dir, scale=1):
#         super().__init__(imgs_dir, masks_dir, scale, mask_suffix='_mask')


def to_pytorch(image):
    if image:
        image_np = np.array(image)
        # mask image has only one channel, we need to explicit it
        if len(image_np.shape) == 2:
            image_np = np.expand_dims(image_np, axis=2)
        # HWC to CHW for pytorch
        img_trans = image_np.transpose((2, 0, 1))
        if img_trans.max() > 1:
            img_trans = img_trans / 255
        return torch.from_numpy(img_trans).type(torch.FloatTensor)
        # return img_trans
    else:
        return None


# dude, the name says all. just read it :/
# def create_triplet_without_torch_representation(pil_query, pil_target, pil_mask):
#     # return [np.array(pil_query), np.array(pil_target), np.array(pil_mask)]
#     # return np.array([np.array(pil_query), np.array(pil_target), np.array(pil_mask)])
#     return {
#         "query": np.array(pil_query),
#         "target": np.array(pil_target),
#         "mask": np.array(pil_mask)
#     }

def create_triplet_with_torch_representation(pil_query, pil_target, pil_mask):
    # return [to_pytorch(pil_query), to_pytorch(pil_target), to_pytorch(pil_mask)]
    # return np.array([to_pytorch(pil_query), to_pytorch(pil_target), to_pytorch(pil_mask)])
    return {
        "query": to_pytorch(pil_query),
        "target": to_pytorch(pil_target),
        "mask": to_pytorch(pil_mask)
    }


def get_class_file_path(class_name, file_name):
    class_file = f"{class_name[class_name.rindex(os.path.sep):]}{os.path.sep}{file_name}"
    if os.path.sep not in class_file[0:2]:
        class_file = f"{os.path.sep}{class_file}"
    return class_file


def get_full_path(root, file):
    try:
        if root[root.rindex(os.path.sep) + 1:].strip() == file[1:file.index(os.path.sep, 1)].strip():
            path = f"{root[:root.rindex(os.path.sep)]}{file}"
        else:
            path = f"{root}{file}"
        return path
    except AttributeError:
        return None


def fix_input_dir(dir):
    if not dir.strip()[-1:] == os.path.sep:
        return dir.strip()
    return dir.strip()[:-1]


In [2]:
# train

In [3]:
import argparse
import logging
import os
import sys
from tqdm import tqdm
import yaml

import numpy as np
import torch
from torch import optim
from torch.utils.data import DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn
import torch.nn.functional as F

from model.model import LogoDetection
from utils.dataset_loader import BasicDataset

# todo: when we add more models, we should move these variables to another location
MODEL_HOME = os.path.abspath("./stored_models/")
ALL_MODEL_NAMES = ["LogoDetection"]
ALL_DATASET_NAMES = ["FlickrLogos-32", "TopLogos-10"]

with open(os.path.abspath("./config/config.yaml")) as config:
    config_list = yaml.load(config, Loader=yaml.FullLoader)


# # Appl.load_aly Gaussian normalization to the model
# def weights_init(model):
#     if isinstance(model, nn.Module):
#         nn.init.normal_(model.weight.data, mean=0.0, std=0.01)

def train(model,
          device,
          train_loader,
          val_loader,
          max_epochs,
          optimizer,
          verbose,
          checkpoint_dir,
          model_path,
          save_cp,
          n_train,
          n_val
          ):
    batch_size = train_loader.batch_size

    # Logging for TensorBoard
    writer = SummaryWriter(
        comment=f'LR__BS_{batch_size}_OPT_{type(optimizer).__name__}')  # does optimizer.lr work? we're gonne find out
    global_step = 0

    ### ERROR: n_train e n_val? ###
    logging.info(f'''Starting training:
        Epochs:             {max_epochs}
        Batch size:         {batch_size}
        Learning rate:      
        Training size:      {n_train}
        Validation size:    {n_val}
        Device:             {device.type}
    ''')

    def criterion(pred, true):
        return torch.div(nn.BCELoss()(pred, true), 256 * 256)  # L = (1/(H*W)) * BCELoss
        # TypeError: unsupported operand type(s) for /: 'BCELoss' and 'int'

    last_epoch_val_score = 0
    for epoch in range(max_epochs):
        logging.info(f"Epoch number #insert_epoch_number")
        model.train()  # set the model in training flag to True
        epoch_loss = 0  # resets the loss for the current epoch
        # epoch(batch_size, train_samples)

        # TODO
        with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{max_epochs}', unit='img', disable=not verbose) as bar:
            bar.set_description(f'train loss')

            for batch in train_loader:
#                 np.set_printoptions(threshold=sys.maxsize)
#                 print(f"Batch: {batch}")
                logging.info(f"Batch number #")
                queries = batch['query']  # Correct dimensions?
                targets = batch['target']
                true_masks = batch['mask']

                queries = queries.to(device=device, dtype=torch.float32)
                targets = targets.to(device=device, dtype=torch.float32)
                true_masks = true_masks.to(device=device, dtype=torch.float32)

                logging.info(f"Sending imgs to model")
                pred_masks = model(queries, targets)
                logging.info(f"Mask correctly predicted")
                # print(pred_masks.shape)
                loss = criterion(pred_masks, true_masks)
                epoch_loss += loss.detach().item()  # is the .detach() needed?

                # TensorBoard logging
                writer.add_scalar('Loss/train', loss.item(), global_step)

                bar.set_postfix(loss=f'{loss.item():.5f}')
                logging.info(f"Loss: {loss.item()}")

                optimizer.zero_grad()
                loss.backward()
                # nn.utils.clip_grad_value_(net.parameters(), 0.1) Gradient Clipping
                optimizer.step()

                bar.update(queries.shape[0])
                global_step += 1

                # if n_train % batch_size == 0: 
                #     n_batch = n_train // batch_size
                # else:
                #     n_batch = n_train // batch_size + 1

                ### DOMANDA: Dove lo volevamo usare? ###
                ### ALTRA DOMANDA: Non conviene farlo fuori dai cicli? ###
#                 n_batch = len(train_loader)
                # Deve farlo sia in mezzo ai batch che a fine epoca. Modifica la condizione dell'if
#                 if global_step % (n_train // (10 * batch_size)) == 0 or global_step == n_batch:
#                     for tag, value in model.named_parameters():
#                         tag = tag.replace('.', '/')
#                         writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step)
#                         writer.add_histogram('grads/' + tag, value.grad.cpu().numpy(), global_step)
#                     writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)

#                     writer.add_images('query_images', queries, global_step)
#                     writer.add_images('target_images', targets, global_step)
#                     writer.add_images('masks/true', true_masks, global_step)
#                     writer.add_images('masks/pred', pred_masks, global_step)

            # TODO: Se save_cp è false e non viene cambiato il val_split di default non va in train il 10% del dataset. Si potrebbe fare in modo che non sia così
            if save_cp:
                logging.info(f"Next operation is validation")
                print(f"starting validation from train")
                val_score = eval_net(model, val_loader, device, bbox=False, verbose=True)
                print(f"validation completed")
                logging.info(f"Validation complete")
                if val_score > last_epoch_val_score:
                    try:
                        os.mkdir(checkpoint_dir)
                        logging.info('Created checkpoint directory')
                    except OSError:  # Maybe FileExistsError ?
                        pass
                    model_files = [f for f in os.listdir(checkpoint_dir) if
                                   os.path.isfile(os.path.join(checkpoint_dir, f))]
                    torch.save(model.state_dict(), checkpoint_dir + f'CP_epoch{epoch + 1}.pt')
                    for model_file in model_files:
                        os.remove(f'{checkpoint_dir}{model_file}')
                    logging.info(f'Checkpoint {epoch + 1} saved!')
                    last_epoch_val_score = val_score

    writer.close()
    torch.save(model.state_dict(), model_path)

    # # WIP
    # # Launches evaluation on the model every evaluate_every steps.
    # # We need to change to appropriate evaluation metrics.
    # if evaluate_every > 0 and valid_samples is not None and (e + 1) % evaluate_every == 0:
    #     self.model.eval()
    #     with torch.no_grad():
    #         mrr, h1 = self.evaluator.eval(samples=valid_samples, write_output= False)

    #     # Metrics printing
    #     print("\tValidation: %f" % h1)

    # if save_path is not None:
    #     print("\tSaving model...")
    #     torch.save(self.model.state_dict(), save_path)
    # print("\tDone.")


# print("\nEvaluating model...")
# model.eval()
# mrr, h1 = Evaluator(model=model).eval(samples=dataset.test_samples, write_output=False)
# print("\tTest Hits@1: %f" % h1)
# print("\tTest Mean Reciprocal Rank: %f" % mrr)


# def get_args():
#     parser = argparse.ArgumentParser()
#     parser.add_argument('--dataset',
#                         choices=ALL_DATASET_NAMES,
#                         default="FlickrLogos-32",
#                         type=str,
#                         help="Dataset in {}".format(ALL_DATASET_NAMES)
#                         )

#     parser.add_argument('--model',
#                         choices=ALL_MODEL_NAMES,
#                         default="LogoDetection",
#                         type=str,
#                         help="Model in {}".format(ALL_MODEL_NAMES)
#                         )

#     optimizers = ['Adam', 'SGD']
#     parser.add_argument('--optimizer',
#                         choices=optimizers,
#                         default='Adam',
#                         help="Optimizer in {}".format(optimizers)
#                         )

#     parser.add_argument('--max_epochs',
#                         default=500,
#                         type=int,
#                         help="Number of epochs"
#                         )

#     parser.add_argument('--valid',
#                         default=-1,
#                         type=float,
#                         help="Number of epochs before valid"
#                         )

#     parser.add_argument('--batch_size',
#                         default=32,
#                         type=int,
#                         help="Number of samples in each mini-batch in SGD and Adam optimization"
#                         )

#     parser.add_argument('--weight_decay',
#                         default=5e-4,
#                         type=float,
#                         help="L2 weight regularization of the optimizer"
#                         )

#     parser.add_argument('--learning_rate',
#                         default=4e-4,
#                         type=float,
#                         help="Learning rate of the optimizer"
#                         )

#     parser.add_argument('--label_smooth',
#                         default=0.1,
#                         type=float,
#                         help="Label smoothing for true labels"
#                         )

#     parser.add_argument('--decay1',
#                         default=0.9,
#                         type=float,
#                         help="Decay rate for the first momentum estimate in Adam"
#                         )

#     parser.add_argument('--decay2',
#                         default=0.999,
#                         type=float,
#                         help="Decay rate for second momentum estimate in Adam"
#                         )

#     parser.add_argument('--verbose',
#                         default=True,
#                         type=bool,
#                         help="Verbose"
#                         )

#     parser.add_argument('--load',
#                         type=str,
#                         required=False,
#                         help="Path to the model to load"
#                         )

#     parser.add_argument('--batch_norm',
#                         default=False,
#                         type=bool,
#                         help="If True, apply batch normalization",
#                         )

#     parser.add_argument('--vgg_cfg',
#                         type=str,
#                         default='A',
#                         help="VGG architecture config",
#                         )

#     parser.add_argument('--step_eval',
#                         type=int,
#                         default=0,
#                         help="Enables automatic evaluation checks every X step",
#                         )

#     parser.add_argument('--val_split',
#                         type=float,
#                         default=0.1,
#                         help="Forces the validation subset to be split according to the set value. Must a value in the [0-1] or the sofware WILL break",
#                         )

#     parser.add_argument('--save_cp',
#                         type=bool,
#                         default=True,
#                         help="If True, saves model checkponts",
#                         )

#     return parser.parse_args()


def train_main(dataset="FlickrLogos-32",
     model="LogoDetection",
     optimizer="Adam",
     max_epochs=1, 
     batch_size=32, 
     weight_decay=5e-4,
     learning_rate=4e-4,
     decay1=0.9, 
     decay2=0.999,
     verbose=True,
     batch_norm=False,
     load=None,
     val_split=0.1,
     save_cp=True,
     ):
    # TODO: Add filename
    logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s: %(message)s", filename='oneshot.log')
#     args = get_args()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')

    # Modularized paths with respect to the current Dataset
    imgs_dir = config_list['datasets'][dataset]['paths']['images']
    masks_dir = config_list['datasets'][dataset]['paths']['masks']
    checkpoint_dir = config_list['models'][model]['paths']['train_cp']

    model_path = config_list['models'][model]['paths']['model'] + "_".join([model, dataset]) + ".pt"

    print("Loading %s dataset..." % dataset)
    # you can delete this "save_to_disk" to preserve the ssd :like:
    dataset = BasicDataset(imgs_dir=imgs_dir, masks_dir=masks_dir, dataset_name=dataset)

    # Splitting dataset
    n_val = int(len(dataset) * val_split)
    n_train = len(dataset) - n_val
    # TODO: Il validation set dovrebbe avere il 10% di ogni classe e non il 10% del totale altrimenti verrebbe sbilanciato
    train_set, val_set = random_split(dataset, [n_train, n_val])

    # Loading dataset
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True,
                            drop_last=True)

    # Change here to adapt your data
    print("Initializing model...")
    model = LogoDetection(batch_norm=batch_norm)

    # Optimizer selection
    # build all the supported optimizers using the passed params (learning rate and decays if Adam)
    supported_optimizers = {
        'Adam': optim.Adam(params=model.parameters(), lr=learning_rate, betas=(decay1, decay2),
                           weight_decay=weight_decay),
        'SGD': optim.SGD(params=model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    }
    # Choose which Torch Optimizer object to use, based on the passed name
    optimizer = supported_optimizers[optimizer]

    # stiamo dando ad "args.load" due compiti, quello di dirci il path e quello di dirci se caricare vecchi checkpoint
    if load is not None:
        model.load_state_dict(
            torch.load(load, map_location=device)
        )
        logging.info(f'Model loaded from {load}')
    model.to(device=device)

    try:
        train(model=model,
              device=device,
              train_loader=train_loader,
              val_loader=val_loader,
              max_epochs=max_epochs,
              optimizer=optimizer,
              verbose=verbose,
              checkpoint_dir=checkpoint_dir,
              model_path=model_path,
              save_cp=save_cp,
              n_train=n_train,
              n_val=n_val
              )
    except KeyboardInterrupt:
        torch.save(model.state_dict(), 'INTERRUPTED.ph')
        logging.info('Interrupt saved')
        try:
            sys.exit(0)
        except SystemExit:
            os._exit(0)


  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [4]:
train_main(dataset="FlickrLogos-32",
     model="LogoDetection",
     optimizer="Adam",
     max_epochs=1, 
     batch_size=32, 
     weight_decay=5e-4,
     learning_rate=4e-4,
     decay1=0.9, 
     decay2=0.999,
     verbose=True,
     batch_norm=False,
     load=None,
     val_split=0.1,
     save_cp=True,
     )

Loading FlickrLogos-32 dataset...
You have 9660 triplets
Initializing model...


train loss:   0%|          | 0/8694 [00:02<?, ?img/s]


RuntimeError: CUDA out of memory. Tried to allocate 64.00 MiB (GPU 0; 15.90 GiB total capacity; 7.24 GiB already allocated; 26.88 MiB free; 141.45 MiB cached)

In [None]:
# test

In [None]:
import argparse
import logging
import os
import sys
from tqdm import tqdm
import yaml

import numpy as np
import torch
from torch import optim
from torch.utils.data import DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn
import torch.nn.functional as F


# todo: when we add more models, we should move these variables to another location
ALL_MODEL_NAMES = ["LogoDetection"]
ALL_DATASET_NAMES = ["FlickrLogos-32, TopLogos-10"]

with open(os.path.abspath("./config/config.yaml")) as config:
    config_list = yaml.load(config, Loader=yaml.FullLoader)

# def pred(model,
#          sample,
#          device,
#          threshold=0.5):
#     if(model.eval==False):
#         model.eval()

#     queries, targets = torch.from_numpy(BasicDataset.preprocess(index=index, file:files_path))
#     queries = queries.unsqueeze(0)
#     queries = queries.to(device=device, dtype=torch.float32)
#     targets = targets.unsqueeze(0)
#     targets = targets.to(device=device, dtype=torch.float32)


def test(model,
         device,
         dataset,
         batch_size,
        #   save_path,
         verbose: bool,
         threshold=0.5):
    
    model.eval()

    # #TODO dataset preprocessing
    # with torch.no_grad():
    #     output = model(queries, targets)

    #     probs = output.squeeze(0)
    
    logging.info("\nPredicting image{} ... ")

    with tqdm(total=len(dataset), desc=f'Testing dataset', unit='test-img', disable=not verbose) as bar:
            bar.set_description(f'model testing')

            for batch in test_loader:
                queries = batch['query']  # Correct dimensions?
                targets = batch['target']
                bboxes = batch['bbox']

                queries = queries.to(device=device, dtype=torch.float32)
                targets = targets.to(device=device, dtype=torch.float32)
                bboxes = bboxes.to(device=device, dtype=torch.float32)

                with torch.no_grad():
                    pred_masks = model(queries, targets)
                    # print(pred_masks.shape)

                bar.update(queries.shape[0])
                global_step += 1
                val_score += eval_net(model, batch, device)
                # writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)

                writer.add_images('query_images', queries, global_step)
                writer.add_images('target_images', targets, global_step)
                # writer.add_images('bboxes/true', bboxes, global_step)
                writer.add_images('masks/pred', pred_masks, global_step)
                writer.close()
    return val_score / global_step


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset',
                        choices=ALL_DATASET_NAMES,
                        help="Dataset in {}".format(ALL_DATASET_NAMES),
                        required=True
                        )

    parser.add_argument('--model',
                        choices=ALL_MODEL_NAMES,
                        help="Model in {}".format(ALL_MODEL_NAMES)
                        )

    parser.add_argument('--batch_size',
                        default=32,
                        type=int,
                        help="Number of samples in each mini-batch in SGD and Adam optimization"
                        )

    parser.add_argument('--verbose',
                        default=True,
                        type=bool,
                        help="Verbose"
                        )

    parser.add_argument('--load',
                        type=str,
                        required=True,
                        help="Path to the model to load"
                        )

    return parser.parse_args()


if __name__ == '__main__':
    # Logging
    logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
    args = get_args()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')

    # Modularized paths with respect to the current Dataset
    imgs_dir = config_list['datasets'][args.dataset]['images']
    masks_dir = config_list['datasets'][args.dataset]['masks']
    # TODO: Controlla che ste due liste hanno le stesse sottocartelle
    imgs_classes = [f.name for f in os.scandir(imgs_dir) if f.is_dir()]
    mask_classes = [f.name for f in os.scandir(masks_dir) if f.is_dir()]
    # checkpoint_dir = config_list['models'][args.model]['train_cp']

    model_path = config_list['models'][args.model]['paths']['model']+ "_".join([args.model, args.dataset]) + ".pt"

    # print("Loading %s dataset..." % args.dataset)
    # dataset = BasicDataset(imgs_dir=imgs_dir, masks_dir=masks_dir)

    # Change here to adapt your data
    print("Initializing model...")
    model = LogoDetection(batch_norm=args.batch_norm,
                        vgg_cfg=args.vgg_cfg)

    model.load_state_dict(
        torch.load(model_path, map_location=device)
    )
    logging.info(f'Model loaded from {model_path}')
    model.to(device=device)

    # Neo, enter in Metrics
    metrics = []

    for img_class_idx, img_class_path in enumerate(imgs_classes):
        dataset = BasicDataset(imgs_dir=f"{imgs_dir}{os.path.sep}{img_class_path}", masks_dir=f"{masks_dir}{os.path.sep}{masks_dir[img_class_idx]}", dataset_name=args.dataset, skip_bbox_lines=1)
        test_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)

        try:
            metrics.append(test(model=model,
                                device=device,
                                dataset=test_loader,
                                batch_size=args.batch_size,
                                verbose=args.verbose
                                ))
        except KeyboardInterrupt:
            # torch.save(model.state_dict(), 'INTERRUPTED.ph')
            # logging.info('Interrupt saved')
            logging.info("Test interrupted")
            try:
                sys.exit(0)
            except SystemExit:
                os._exit(0)
