### Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

### Arguments

In [2]:
# see https://github.com/facebookresearch/dino/blob/main/main_dino.py get_args_parser() for explanations

class args_class:
  def __init__(self):
    # Model Parameters
    self.MOMENTUM_TEACHER = 0.996     #default 0.996
    self.OUT_DIM = 65536 #default 65536
    self.NORM_LAST_LAYER = True
    self.USE_BN_IN_HEAD = False     #default False


    # Temperature Teacher Parameters
    self.WARMUP_TEACHER_TEMP = 0.04   #default 0.04
    self.TEACHER_TEMP = 0.04   #default 0.04
    self.WARMUP_TEACHER_TEMP_EPOCHS = 30 #default 30, but erroneously default 0 in the DINO paper?

    # Training / Optimizations Parameters
    self.USE_FP16 = True #default True
    self.WEIGHT_DECAY = 0.04 #default 0.04
    self.WEIGHT_DECAY_END = 0.4 # default 0.4
    self.CLIP_GRAD = 3.0 #default 3.0
    self.BATCH_SIZE_PER_GPU = 64 #default 64
    self.EPOCHS = 100 #default 100
    self.FREEZE_LAST_LAYER = 1 #default 1
    self.LR = 0.0005 #default 0.0005
    self.WARMUP_EPOCHS = 10 #default 10
    self.MIN_LR = 1e-6 #default 1e-6
    self.optimizer = 'adamw' # default adamw    TODO:Could be constructed here? not sure.
    self.DROP_PATH_RATE = 0.1 #default 0.1

    # Multi-Crop Parameters
    self.GLOBAL_CROPS_SCALE = (0.4, 0.1) # default (0.4, 0.1)
    self.LOCAL_CROPS_NUMBER = 8 #default 8
    self.LOCAL_CROPS_SCALE = (0.05, 0.4) #default (0.05, 0.4)

    #Misc
    self.num_works = 10 # default 10
    # TODO: imgnet directory, how to store weights?

args = args_class()

### Architecture, DINOHead

### Utils

### Loss, Train one Epoch, Augmentation


In [3]:
class DINOLoss(nn.Module):
    """
    The loss function encourages a student network to match the output of a momentum teacher network.
    This is a form of self-distillation without labels.

    Args:
        args (args_class): An object containing the necessary hyperparameters.
    """
    def __init__(self, args):
        super().__init__()
        # The student temperature is constant
        self.student_temp = 0.1
        # The number of crops is the two global crops plus all local crops
        self.n_crops = 2 + args.LOCAL_CROPS_NUMBER
        self.center_momentum = 0.9
        self.register_buffer("center", torch.zeros(1, args.OUT_DIM)) # Initialize a non-trainable tensor and add to the module

        # The teacher temperature is scheduled to warm up from an initial value to a final value.
        # The DINO paper mentions a warmup from 0.04 to 0.07 over 30 epochs.
        self.teacher_temp_schedule = torch.linspace(
            args.WARMUP_TEACHER_TEMP,
            args.TEACHER_TEMP,
            args.EPOCHS,
        )

    def forward(self, student_output, teacher_output, epoch):
        """
        Calculates the DINO loss.

        The student is trained to match the teacher's output. The teacher's output is centered
        and sharpened. The loss is computed as the cross-entropy between the teacher's and
        student's probability distributions over different views.

        Args:
            student_output (torch.Tensor): The output of the student network for all crops.
            teacher_output (torch.Tensor): The output of the teacher network for global crops.
            epoch (int): The current training epoch, used for the temperature schedule.

        Returns:
            torch.Tensor: The calculated DINO loss.
        """
        # Scale student output by the student temperature
        student_out = student_output / self.student_temp
        student_out = student_out.chunk(self.n_crops)

        # Center and sharpen teacher output
        temp = self.teacher_temp_schedule[epoch]
        teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
        teacher_out = teacher_out.detach().chunk(2) # Teacher only processes the 2 global views

        total_loss = 0
        n_loss_terms = 0
        for iq, q in enumerate(teacher_out):
            for v in range(len(student_out)):
                if v == iq:
                    # Skip comparing a global view to itself
                    continue
                # Calculate cross-entropy between a teacher view and a student view
                loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
                total_loss += loss.mean()
                n_loss_terms += 1

        total_loss /= n_loss_terms
        self.update_center(teacher_output)
        return total_loss

    @torch.no_grad()
    def update_center(self, teacher_output):
        """
        Updates the center for the teacher's output using an exponential moving average.
        This operation helps to prevent model collapse.

        Args:
            teacher_output (torch.Tensor): The output of the teacher network.
        """
        batch_center = torch.mean(teacher_output, dim=0, keepdim=True)
        self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)

In [None]:
def train_one_epoch():
  # for it
    # update weights + lr
    # imgs to GPU
    # Forward Pass + loss

    # Update Student

    # EMA teacher

    # logging

In [None]:
class DataAugmentationDINO(object):
  # define crops
  # make it callable

### Train

In [None]:
def train_dino(args):
  # init rand seed
  # prep data
  # build student, teacher ???
  # add multicrop wrapper
  # init loss
  # init optimizer
  # init schedulers
  # train
    # train_one_epoch
    # write logs