### Imports

In [12]:
import os
import sys
import time
import math
import random
import datetime
import subprocess
from collections import defaultdict, deque

import numpy as np
import torch
from torch import nn
from PIL import ImageFilter, ImageOps

### Utils

In [8]:
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                      "The distribution of values may be incorrect.",
                      stacklevel=2)

    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)

        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1)

        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()

        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)

        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor


def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    # type: (Tensor, float, float, float, float) -> Tensor
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)

### Arguments

In [9]:
# see https://github.com/facebookresearch/dino/blob/main/main_dino.py get_args_parser() for explanations

class args_class:
  def __init__(self):
    # Model Parameters
    self.MOMENTUM_TEACHER = 0.996     #default 0.996
    self.OUT_DIM = 65536 #default 65536
    self.NORM_LAST_LAYER = True
    self.USE_BN_IN_HEAD = False     #default False


    # Temperature Teacher Parameters
    self.WARMUP_TEACHER_TEMP = 0.04   #default 0.04
    self.TEACHER_TEMP = 0.04   #default 0.04
    self.WARMUP_TEACHER_TEMP_EPOCHS = 30 #default 30, but erroneously default 0 in the DINO paper?

    # Training / Optimizations Parameters
    self.USE_FP16 = True #default True
    self.WEIGHT_DECAY = 0.04 #default 0.04
    self.WEIGHT_DECAY_END = 0.4 # default 0.4
    self.CLIP_GRAD = 3.0 #default 3.0
    self.BATCH_SIZE_PER_GPU = 64 #default 64
    self.EPOCHS = 100 #default 100
    self.FREEZE_LAST_LAYER = 1 #default 1
    self.LR = 0.0005 #default 0.0005
    self.WARMUP_EPOCHS = 10 #default 10
    self.MIN_LR = 1e-6 #default 1e-6
    self.optimizer = 'adamw' # default adamw    TODO:Could be constructed here? not sure.
    self.DROP_PATH_RATE = 0.1 #default 0.1

    # Multi-Crop Parameters
    self.GLOBAL_CROPS_SCALE = (0.4, 0.1) # default (0.4, 0.1)
    self.LOCAL_CROPS_NUMBER = 8 #default 8
    self.LOCAL_CROPS_SCALE = (0.05, 0.4) #default (0.05, 0.4)

    #Misc
    self.num_works = 10 # default 10
    # TODO: imgnet directory, how to store weights?

args = args_class()

### Architecture, DINOHead

In [16]:
class DINOHead(nn.Module):
  def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
    super().__init__()
    nlayers = max(nlayers, 1)
    if nlayers == 1:
      self.mlp = nn.Linear(in_dim, bottleneck_dim)
    else:
      layers = [nn.Linear(in_dim, hidden_dim)]
      if use_bn:
        layers.append(nn.BatchNorm1d(hidden_dim))
      layers.append(nn.GELU())
      for _ in range(nlayers - 2):
        layers.append(nn.Linear(hidden_dim, hidden_dim))
        if use_bn:
          layers.append(nn.BatchNorm1d(hidden_dim))
        layers.append(nn.GELU())
      layers.append(nn.Linear(hidden_dim, bottleneck_dim))
      self.mlp = nn.Sequential(*layers)
    self.apply(self._init_weights)
    self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
    self.last_layer.weight_g.data.fill_(1)
    if norm_last_layer:
      self.last_layer.weight_g.requires_grad = False

    def _init_weights(self, m):
      if isinstance(m, nn.Linear):
        trunc_normal_(m.weight, std=.02)
        if isinstance(m, nn.Linear) and m.bias is not None:
          nn.init.constant_(m.bias, 0)

    def forward(self, x):
      x = self.mlp(x)
      x = nn.functional.normalize(x, dim=-1, p=2)
      x = self.last_layer(x)
      return x

### Loss, Train one Epoch, Augmentation


In [None]:
class DINOLoss(nn.module):
  # init vars
  # forward
    # scale student
    # center teacher
    # cross entropy
    # update center
    # return loss

In [None]:
def train_one_epoch():
  # for it
    # update weights + lr
    # imgs to GPU
    # Forward Pass + loss

    # Update Student

    # EMA teacher

    # logging

In [None]:
class DataAugmentationDINO(object):
  # define crops
  # make it callable

### Train

In [None]:
def train_dino(args):
  # init rand seed
  # prep data
  # build student, teacher ???
  # add multicrop wrapper
  # init loss
  # init optimizer
  # init schedulers
  # train
    # train_one_epoch
    # write logs