# Importing Libraries

In [None]:
import torch
import torchvision
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F

import math
import time
import torch.utils.model_zoo as model_zoo

# Setting Device

In [None]:
# setting gpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Anchors

In [None]:
class Anchors(nn.Module):
  def __init__(self, pyramid_levels=None, strides=None, sizes=None, ratios=None, scales=None):
    super(Anchors, self).__init__()

    if pyramid_levels is None:
      self.pyramid_levels = [3, 4, 5, 6, 7]
    else:
      self.pyramid_levels = pyramid_levels

    if strides is None:
      self.strides = [2 ** x for x in self.pyramid_levels]
    if sizes is None:
      self.sizes = [2 ** (x + 2) for x in self.pyramid_levels]
    if ratios is None:
      self.ratios = np.array([0.5, 1, 2])
    if scales is None:
      self.scales = np.array([2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)])

  def forward(self, image):
        
    image_shape = image.shape[2:]
    image_shape = np.array(image_shape)
    image_shapes = [(image_shape + 2 ** x - 1) // (2 ** x) for x in self.pyramid_levels]

    # compute anchors over all pyramid levels
    all_anchors = np.zeros((0, 4)).astype(np.float32)

    for idx, p in enumerate(self.pyramid_levels):
      anchors         = generate_anchors(base_size=self.sizes[idx], ratios=self.ratios, scales=self.scales)
      shifted_anchors = shift(image_shapes[idx], self.strides[idx], anchors)
      all_anchors     = np.append(all_anchors, shifted_anchors, axis=0)

    all_anchors = np.expand_dims(all_anchors, axis=0)

    return torch.from_numpy(all_anchors.astype(np.float32)).cuda()
  
  def generate_anchors(base_size=16, ratios=None, scales=None):
    """
    Generate anchor (reference) windows by enumerating aspect ratios X
    scales w.r.t. a reference window.
    """

    if ratios is None:
      ratios = np.array([0.5, 1, 2])

    if scales is None:
      scales = np.array([2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)])

    num_anchors = len(ratios) * len(scales)

    # initialize output anchors
    anchors = np.zeros((num_anchors, 4))

    # scale base_size
    anchors[:, 2:] = base_size * np.tile(scales, (2, len(ratios))).T

    # compute areas of anchors
    areas = anchors[:, 2] * anchors[:, 3]

    # correct for ratios
    anchors[:, 2] = np.sqrt(areas / np.repeat(ratios, len(scales)))
    anchors[:, 3] = anchors[:, 2] * np.repeat(ratios, len(scales))

    # transform from (x_ctr, y_ctr, w, h) -> (x1, y1, x2, y2)
    anchors[:, 0::2] -= np.tile(anchors[:, 2] * 0.5, (2, 1)).T
    anchors[:, 1::2] -= np.tile(anchors[:, 3] * 0.5, (2, 1)).T

    return anchors
  
  def compute_shape(image_shape, pyramid_levels):
    """Compute shapes based on pyramid levels.
    :param image_shape:
    :param pyramid_levels:
    :return:
    """
    image_shape = np.array(image_shape[:2])
    image_shapes = [(image_shape + 2 ** x - 1) // (2 ** x) for x in pyramid_levels]
    return image_shape

  def anchors_for_shape(
    image_shape,
    pyramid_levels=None,
    ratios=None,
    scales=None,
    strides=None,
    sizes=None,
    shapes_callback=None,
    ):

    image_shapes = compute_shape(image_shape, pyramid_levels)

    # compute anchors over all pyramid levels
    all_anchors = np.zeros((0, 4))
    for idx, p in enumerate(pyramid_levels):
      anchors         = generate_anchors(base_size=sizes[idx], ratios=ratios, scales=scales)
      shifted_anchors = shift(image_shapes[idx], strides[idx], anchors)
      all_anchors     = np.append(all_anchors, shifted_anchors, axis=0)

    return all_anchors

  def shift(shape, stride, anchors):
    shift_x = (np.arange(0, shape[1]) + 0.5) * stride
    shift_y = (np.arange(0, shape[0]) + 0.5) * stride

    shift_x, shift_y = np.meshgrid(shift_x, shift_y)

    shifts = np.vstack((
        shift_x.ravel(), shift_y.ravel(),
        shift_x.ravel(), shift_y.ravel()
    )).transpose()

    # add A anchors (1, A, 4) to
    # cell K shifts (K, 1, 4) to get
    # shift anchors (K, A, 4)
    # reshape to (K*A, 4) shifted anchors
    A = anchors.shape[0]
    K = shifts.shape[0]
    all_anchors = (anchors.reshape((1, A, 4)) + shifts.reshape((1, K, 4)).transpose((1, 0, 2)))
    all_anchors = all_anchors.reshape((K * A, 4))

    return all_anchors

# Boxes

In [None]:
class BBoxTransform(nn.Module):

  def __init__(self, mean=None, std=None):
    super(BBoxTransform, self).__init__()
    if mean is None:
      self.mean = torch.from_numpy(np.array([0, 0, 0, 0]).astype(np.float32)).cuda()
    else:
      self.mean = mean
    if std is None:
      self.std = torch.from_numpy(np.array([0.1, 0.1, 0.2, 0.2]).astype(np.float32)).cuda()
    else:
      self.std = std
  
  def forward(self, boxes, deltas):

    widths  = boxes[:, :, 2] - boxes[:, :, 0]
    heights = boxes[:, :, 3] - boxes[:, :, 1]
    ctr_x   = boxes[:, :, 0] + 0.5 * widths
    ctr_y   = boxes[:, :, 1] + 0.5 * heights

    dx = deltas[:, :, 0] * self.std[0] + self.mean[0]
    dy = deltas[:, :, 1] * self.std[1] + self.mean[1]
    dw = deltas[:, :, 2] * self.std[2] + self.mean[2]
    dh = deltas[:, :, 3] * self.std[3] + self.mean[3]

    pred_ctr_x = ctr_x + dx * widths
    pred_ctr_y = ctr_y + dy * heights
    pred_w     = torch.exp(dw) * widths
    pred_h     = torch.exp(dh) * heights

    pred_boxes_x1 = pred_ctr_x - 0.5 * pred_w
    pred_boxes_y1 = pred_ctr_y - 0.5 * pred_h
    pred_boxes_x2 = pred_ctr_x + 0.5 * pred_w
    pred_boxes_y2 = pred_ctr_y + 0.5 * pred_h

    pred_boxes = torch.stack([pred_boxes_x1, pred_boxes_y1, pred_boxes_x2, pred_boxes_y2], dim=2)

    return pred_boxes

In [None]:
class ClipBoxes(nn.Module):

  def __init__(self, width=None, height=None):
    super(ClipBoxes, self).__init__()

  def forward(self, boxes, img):

    batch_size, num_channels, height, width = img.shape

    boxes[:, :, 0] = torch.clamp(boxes[:, :, 0], min=0)
    boxes[:, :, 1] = torch.clamp(boxes[:, :, 1], min=0)

    boxes[:, :, 2] = torch.clamp(boxes[:, :, 2], max=width)
    boxes[:, :, 3] = torch.clamp(boxes[:, :, 3], max=height)
      
    return boxes

# Bottleneck

In [None]:
class Bottleneck(nn.Module):
  expansion = 4

  def __init__(self, inplanes, planes, stride=1, downsample=None):
    super(Bottleneck, self).__init__()
    self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
    self.bn1 = nn.BatchNorm2d(planes)
    self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
    self.bn2 = nn.BatchNorm2d(planes)
    self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
    self.bn3 = nn.BatchNorm2d(planes * 4)
    self.relu = nn.ReLU(inplace=True)
    self.downsample = downsample
    self.stride = stride

  def forward(self, x):
    residual = x

    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)

    out = self.conv2(out)
    out = self.bn2(out)
    out = self.relu(out)

    out = self.conv3(out)
    out = self.bn3(out)

    if self.downsample is not None:
      residual = self.downsample(x)

    out += residual
    out = self.relu(out)

    return out

# Feature Pyramid Network

In [None]:
class PyramidFeatures(nn.Module):
  def __init__(self, C2_size, C3_size, C4_size, C5_size, feature_size=256, use_l2_features=True):
    super(PyramidFeatures, self).__init__()
    self.use_l2_features = use_l2_features
        
    # upsample C5 to get P5 from the FPN paper
    self.P5_1           = nn.Conv2d(C5_size, feature_size, kernel_size=1, stride=1, padding=0)
    self.P5_upsampled   = nn.Upsample(scale_factor=2, mode='nearest')
    self.P5_2           = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1)

    # add P5 elementwise to C4
    self.P4_1           = nn.Conv2d(C4_size, feature_size, kernel_size=1, stride=1, padding=0)
    self.P4_upsampled   = nn.Upsample(scale_factor=2, mode='nearest')
    self.P4_2           = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1)

    # add P4 elementwise to C3
    self.P3_1 = nn.Conv2d(C3_size, feature_size, kernel_size=1, stride=1, padding=0)
    self.P3_upsampled = nn.Upsample(scale_factor=2, mode='nearest')
    self.P3_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1)

    # add P3 elementwise to C2
    self.P2_1 = nn.Conv2d(C2_size, feature_size, kernel_size=1, stride=1, padding=0)
    self.P2_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1)

    # "P6 is obtained via a 3x3 stride-2 conv on C5"
    self.P6 = nn.Conv2d(C5_size, feature_size, kernel_size=3, stride=2, padding=1)

    # "P7 is computed by applying ReLU followed by a 3x3 stride-2 conv on P6"
    self.P7_1 = nn.ReLU()
    self.P7_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=2, padding=1)

  def forward(self, inputs):

    C2, C3, C4, C5 = inputs

    P5_x = self.P5_1(C5)
    P5_upsampled_x = self.P5_upsampled(P5_x)
    P5_x = self.P5_2(P5_x)
    
    P4_x = self.P4_1(C4)
    P4_x = P5_upsampled_x + P4_x
    P4_upsampled_x = self.P4_upsampled(P4_x)
    P4_x = self.P4_2(P4_x)

    P3_x = self.P3_1(C3)
    P3_x = P3_x + P4_upsampled_x
    P3_upsampled_x = self.P3_upsampled(P3_x)
    P3_x = self.P3_2(P3_x)

    if self.use_l2_features:
      P2_x = self.P2_1(C2)
      P2_x = P2_x + P3_upsampled_x
      P2_x = self.P2_2(P2_x)

      P6_x = self.P6(C5)

      P7_x = self.P7_1(P6_x)
      P7_x = self.P7_2(P7_x)

      if self.use_l2_features:
        return [P2_x, P3_x, P4_x, P5_x, P6_x, P7_x]
      else:
        return [P3_x, P4_x, P5_x, P6_x, P7_x]

# Box Regression Subnet

In [None]:
class RegressionModel(nn.Module):
  
  # creates the default regression submodel,
  # it predicts regression values for each anchor.
  
  def __init__(self, num_features_in, num_anchors=9, feature_size=256):
    super(RegressionModel, self).__init__()
        
    self.conv1 = nn.Conv2d(num_features_in, feature_size, kernel_size=3, padding=1)
    self.act1 = nn.ReLU()

    self.conv2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
    self.act2 = nn.ReLU()

    self.conv3 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
    self.act3 = nn.ReLU()

    self.conv4 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
    self.act4 = nn.ReLU()

    self.output = nn.Conv2d(feature_size, num_anchors*4, kernel_size=3, padding=1)

  def forward(self, x):

    out = self.conv1(x)
    out = self.act1(out)

    out = self.conv2(out)
    out = self.act2(out)

    out = self.conv3(out)
    out = self.act3(out)

    out = self.conv4(out)
    out = self.act4(out)

    out = self.output(out)

    # out is B x C x W x H, with C = 4*num_anchors
    out = out.permute(0, 2, 3, 1)

    return out.contiguous().view(out.shape[0], -1, 4)

# Classification Subnet

In [None]:
class ClassificationModel(nn.Module):

  # creates the classification submodel, 
  # it predicts classes for each anchor.

  def __init__(self, num_features_in, num_anchors=9, num_classes=80, prior=0.01, feature_size=256, dropout=0.5):
    super(ClassificationModel, self).__init__()

    self.num_classes = num_classes
    self.num_anchors = num_anchors
        
    self.conv1 = nn.Conv2d(num_features_in, feature_size, kernel_size=3, padding=1)
    self.act1 = nn.ReLU()

    self.conv2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
    self.act2 = nn.ReLU()

    self.conv3 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
    self.act3 = nn.ReLU()

    self.conv4 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
    self.act4 = nn.ReLU()

    self.output = nn.Conv2d(feature_size, num_anchors*num_classes, kernel_size=3, padding=1)
    self.output_act = nn.Sigmoid()

    self.dropout = dropout

  def forward(self, x):

    out = self.conv1(x)
    out = self.act1(out)

    out = self.conv2(out)
    out = self.act2(out)

    out = self.conv3(out)
    out = self.act3(out)

    out = self.conv4(out)
    out = self.act4(out)

    if self.dropout > 0:
      out = F.dropout(out, self.dropout, self.training)

    out = self.output(out)
    out = self.output_act(out)

    # out is B x C x W x H, with C = n_classes + n_anchors
    out1 = out.permute(0, 2, 3, 1)

    batch_size, width, height, channels = out1.shape

    out2 = out1.view(batch_size, width, height, self.num_anchors, self.num_classes)

    return out2.contiguous().view(x.shape[0], -1, self.num_classes)

# Global Classification

In [None]:
class GlobalClassificationModel(nn.Module):
  def __init__(self, num_features_in, num_classes=80, feature_size=256, dropout=0.5):
    super().__init__()

    self.num_classes = num_classes
    self.conv1 = nn.Conv2d(num_features_in, feature_size, kernel_size=3, dilation=1, padding=0)
    self.fc = nn.Linear(feature_size*2, num_classes)
    self.output_act = nn.LogSoftmax(dim=-1)

    self.dropout = dropout

  def forward(self, x):
    out = F.max_pool2d(x, 2)
    out = self.conv1(out)
    out = F.relu(out)

    #if self.dropout > 0:
    #  out = F.dropout(out, self.dropout, self.training)

    avg_pool = F.avg_pool2d(out, out.shape[2:])
    max_pool = F.max_pool2d(out, out.shape[2:])
    avg_max_pool = torch.cat((avg_pool, max_pool), 1)
    out = avg_max_pool.view(avg_max_pool.size(0), -1)

    if self.dropout > 0:
      out = F.dropout(out, self.dropout, self.training)

    out = self.fc(out)
    out = self.output_act(out)

    return out

# RetinaNet

In [None]:
class RetinaNetEncoder(nn.Module):
  def __init__(self):
    super().__init__()
    self.fpn_sizes = []

  def forward(self, x):
    """
    :param x: input tensor
    :return: x1, x2, x3, x4 layer outputs
    """
    raise NotImplementedError()

In [None]:
class RetinaNet(nn.Module):
  
  def __init__(self, encoder: RetinaNetEncoder, num_classes, dropout_cls=0.5,
              dropout_global_cls=0.5, use_l2_features=True):
    super(RetinaNet, self).__init__()

    self.encoder = encoder
    fpn_sizes = encoder.fpn_sizes
    self.use_l2_features = use_l2_features

    self.fpn = PyramidFeatures(fpn_sizes[0], fpn_sizes[1], fpn_sizes[2], fpn_sizes[3], use_l2_features=use_l2_features)

    self.regressionModel = RegressionModel(256)
    self.classificationModel = ClassificationModel(256, num_classes=num_classes, dropout=dropout_cls)
    self.globalClassificationModel = GlobalClassificationModel(fpn_sizes[-1], num_classes=3, feature_size=256, dropout=dropout_global_cls)
    self.globalClassificationLoss = nn.NLLLoss()

    if use_l2_features:
      pyramid_levels = [2, 3, 4, 5, 6, 7]
    else:
      pyramid_levels = [3, 4, 5, 6, 7]

    self.anchors = Anchors(pyramid_levels=pyramid_levels)

    self.regressBoxes = BBoxTransform()

    self.clipBoxes = ClipBoxes()

    self.focalLoss = losses.FocalLoss()

    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(0, math.sqrt(2. / n))
      elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()

    self.encoder = encoder

    prior = 0.01

    self.classificationModel.output.weight.data.fill_(0)
    self.classificationModel.output.bias.data.fill_(-math.log((1.0 - prior) / prior))

    self.regressionModel.output.weight.data.fill_(0)
    self.regressionModel.output.bias.data.fill_(0)

    self.freeze_bn()

    def freeze_bn(self):
      """Freeze BatchNorm layers."""
      for layer in self.modules():
        if isinstance(layer, nn.BatchNorm2d):
          layer.eval()
    
    def freeze_encoder(self):
      self.encoder.eval()
      # correct version, but keep original as model has been trained this way
      # for param in self.encoder.parameters():
      #     param.requires_grad = False

    def unfreeze_encoder(self):
      for param in self.encoder.parameters():
        param.requires_grad = True
    
    def boxes(self, img_batch, regression, classification, global_classification, anchors):
      transformed_anchors = self.regressBoxes(anchors, regression)
      transformed_anchors = self.clipBoxes(transformed_anchors, img_batch)

      scores = torch.max(classification, dim=2, keepdim=True)[0]

      scores_over_thresh = (scores > 0.025)[0, :, 0]

      if scores_over_thresh.sum() == 0:
        # no boxes to NMS, just return
        return [torch.zeros(0), global_classification, torch.zeros(0, 4)]
      else:
        classification = classification[:, scores_over_thresh, :]
        transformed_anchors = transformed_anchors[:, scores_over_thresh, :]
        scores = scores[:, scores_over_thresh, :]

        # use very low threshold of 0.05 as boxes should not overlap
        anchors_nms_idx = nms(torch.cat([transformed_anchors, scores], dim=2)[0, :, :], 0.05)

        nms_scores, nms_class = classification[0, anchors_nms_idx, :].max(dim=1)
        return [nms_scores, global_classification, transformed_anchors[0, anchors_nms_idx, :]]

    def forward(self, inputs, return_loss, return_boxes, return_raw=False):
      
      if return_loss:
        img_batch, annotations, global_annotations = inputs
      else:
        img_batch = inputs

      x1, x2, x3, x4 = self.encoder.forward(img_batch)

      features = self.fpn([x1, x2, x3, x4])

      regression = torch.cat([self.regressionModel(feature) for feature in features], dim=1)

      classification = torch.cat([self.classificationModel(feature) for feature in features], dim=1)

      global_classification = self.globalClassificationModel(x4)

      anchors = self.anchors(img_batch)

      if return_raw:
        return [regression, classification, torch.exp(global_classification), anchors]

      res = []

      if return_loss:
        res += list(self.focalLoss(classification, regression, anchors, annotations))
        res += [self.globalClassificationLoss(global_classification, global_annotations)]

      if return_boxes:
        res += self.boxes(img_batch=img_batch,
                          regression=regression,
                          classification=classification,
                          global_classification=global_classification,
                          anchors=anchors)

      return res

# Defining ResNet50

In [None]:
class ResNetEncoder(RetinaNetEncoder):
  def __init__(self, block, layers):
    self.inplanes = 64
    super().__init__()
    self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
    self.bn1 = nn.BatchNorm2d(64)
    self.relu = nn.ReLU(inplace=True)
    self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    self.layer1 = self._make_layer(block, 64, layers[0])
    self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
    self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
    self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

    if block == BasicBlock:
      self.fpn_sizes = [
        self.layer1[layers[0]-1].conv2.out_channels,
        self.layer2[layers[1]-1].conv2.out_channels,
        self.layer3[layers[2]-1].conv2.out_channels,
        self.layer4[layers[3]-1].conv2.out_channels
      ]
    elif block == Bottleneck:
      self.fpn_sizes = [
        self.layer1[layers[0]-1].conv3.out_channels,
        self.layer2[layers[1]-1].conv3.out_channels,
        self.layer3[layers[2]-1].conv3.out_channels,
        self.layer4[layers[3]-1].conv3.out_channels
      ]

  def _make_layer(self, block, planes, blocks, stride=1):
    downsample = None
    if stride != 1 or self.inplanes != planes * block.expansion:
        downsample = nn.Sequential(
          nn.Conv2d(self.inplanes, planes * block.expansion,
            kernel_size=1, stride=stride, bias=False),
          nn.BatchNorm2d(planes * block.expansion),
        )

    layers = []
    layers.append(block(self.inplanes, planes, stride, downsample))
    self.inplanes = planes * block.expansion
    for i in range(1, blocks):
      layers.append(block(self.inplanes, planes))

    return nn.Sequential(*layers)

  def forward(self, inputs):
    img_batch = inputs

    x = torch.cat([img_batch, img_batch, img_batch], dim=1)

    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.maxpool(x)

    x1 = self.layer1(x)
    x2 = self.layer2(x1)
    x3 = self.layer3(x2)
    x4 = self.layer4(x3)

    return x1, x2, x3, x4

In [None]:
def resnet50(num_classes, pretrained=True, **kwargs):
  # defining a resnet50 model
  encoder = ResNetEncoder(Bottleneck, [3,4,6,3])
  
  if pretrained:
    encoder.load_state_dict(model_zoo.load_url(model_urls['https://download.pytorch.org/models/resnet50-19c8e357.pth'], model_dir='models'), strict=False)
  
  model = RetinaNet(encoder=encoder, num_classes=num_classes, **kwargs)
  return model

# Training

In [None]:
def training_f(
    model_name: str,
    # fold: int,
    # debug: bool,
    epochs: int,
    # run: str=None,
    resume_weights: str="",
    resume_epoch: int=0,)

  if model_name == 'resnet50':
    retinanet = resnet50(2, pretrained)
  
  # TODO metti altre opzioni encoders

  # TODO crea cartelle checkpoints
  
  # load weights to continue training
  if resume_weights != "":
    print("load model from: ", resume_weights)
    retinanet = torch.load(resume_weights).cuda()
  else:
    retinanet = retinanet.to(device)

  pretrained = True

  optimizier = torch.optimizier.Adam(retinanet.parameters(), lr=1e-5)
  scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, patience=4, verbose=True, factor=0.2
  )

  for epoch_num in range(resume_epoch+1, epochs):
    
    retinanet.train()
    
    if epoch_num < 1:
      # train FC layers with freezed encoder for the first epoch
      retinanet.module.freeze_encoder()  
    else:
      retinanet.module.unfreeze_encoder()
    
    retinanet.module.freeze_bn()

    # losses
    epoch_loss, loss_cls_hist, loss_cls_global_hist, loss_reg_hist = [], [], [], []

    with torch.set_grad_enabled(True):
      data_iter = tqdm(enumerate(dataloader_train), total = len(dataloader_train))
      for iter_num, data in data_iter:
        optimizer.zero_grad()

        inputs = [
                  data['img'].cuda().float(),
                  data['annot'].cuda().float(),
                  data['category'].cuda(),
        ]
        (classification_loss, regression_loss, global_classification_loss,) = retinanet(
            inputs, return_loss=True, return_boxes=False
            )
        
        classification_loss = classification_loss.mean() 
        regression_loss = regression_loss.mean()
        global_classification_loss = global_classification_loss.mean()
        loss = classification_loss + regression_loss + global_classification_loss*0.1

        loss.backward()
        torch.nn.utils.clip_grad_norm_(retinanet.parameters(), 0.05)
        optimizer.step()
        # loss history
        loss_cls_hist.append(float(classification_loss))
        loss_cls_global_hist.append(float(global_classification_loss))
        loss_reg_hist.append(float(regression_loss))
        epoch_loss.append(float(loss))
        # print losses with tqdm interator
        data_iter.set_description(
          f"{epoch_num} cls: {np.mean(loss_cls_hist):1.4f} cls g: {np.mean(loss_cls_global_hist):1.4f} Reg: {np.mean(loss_reg_hist):1.4f} Loss: {np.mean(epoch_loss):1.4f}"
        )
        del classification_loss
        del regression_loss

  # TODO save model and log loss history
  

  # validation
  (
    loss_hist_valid,
    loss_cls_hist_valid,
    loss_cls_global_hist_valid,
    loss_reg_hist_valid,
  ) = validation(retinanet,
        dataloader_valid,
        epoch_num,
        predictions_dir,
        save_oof=True,
  )
  
  # log validation loss history
  logger.scalar_summary("loss_valid", np.mean(loss_hist_valid), epoch_num)
  logger.scalar_summary("loss_valid_classification", np.mean(loss_cls_hist_valid), epoch_num)
  logger.scalar_summary(
    "loss_valid_global_classification", np.mean(loss_cls_global_hist_valid), epoch_num,
  )
  logger.scalar_summary("loss_valid_regression", np.mean(loss_reg_hist_valid), epoch_num)
  
  scheduler.step(np.mean(loss_reg_hist_valid))
  retinanet.eval()

  # TODO riscrivi percorso cartella
  torch.save(retinanet, f"{checkpoints_dir}/{model_name}_final.pt")