### Imports

In [5]:
import torch
import torch.nn as nn
from torch import Tensor
from typing import List

### Arguments

In [None]:
# see https://github.com/facebookresearch/dino/blob/main/main_dino.py get_args_parser() for explanations

class args_class:
  def __init__(self):
    # Model Parameters
    self.MOMENTUM_TEACHER = 0.996     #default 0.996
    self.OUT_DIM = 65536 #default 65536
    self.NORM_LAST_LAYER = True
    self.USE_BN_IN_HEAD = False     #default False


    # Temperature Teacher Parameters
    self.WARMUP_TEACHER_TEMP = 0.04   #default 0.04
    self.TEACHER_TEMP = 0.04   #default 0.04
    self.WARMUP_TEACHER_TEMP_EPOCHS = 30 #default 30, but erroneously default 0 in the DINO paper?

    # Training / Optimizations Parameters
    self.USE_FP16 = True #default True
    self.WEIGHT_DECAY = 0.04 #default 0.04
    self.WEIGHT_DECAY_END = 0.4 # default 0.4
    self.CLIP_GRAD = 3.0 #default 3.0
    self.BATCH_SIZE_PER_GPU = 64 #default 64
    self.EPOCHS = 100 #default 100
    self.FREEZE_LAST_LAYER = 1 #default 1
    self.LR = 0.0005 #default 0.0005
    self.WARMUP_EPOCHS = 10 #default 10
    self.MIN_LR = 1e-6 #default 1e-6
    self.optimizer = 'adamw' # default adamw    TODO:Could be constructed here? not sure.
    self.DROP_PATH_RATE = 0.1 #default 0.1

    # Multi-Crop Parameters
    self.GLOBAL_CROPS_SCALE = (0.4, 0.1) # default (0.4, 0.1)
    self.LOCAL_CROPS_NUMBER = 8 #default 8
    self.LOCAL_CROPS_SCALE = (0.05, 0.4) #default (0.05, 0.4)

    #Misc
    self.num_works = 10 # default 10
    # TODO: imgnet directory, how to store weights?

args = args_class()

### Architecture, DINOHead

### Utils

### Loss, Train one Epoch, Augmentation


In [None]:
class DINOLoss(nn.module):
  # init vars
  # forward
    # scale student
    # center teacher
    # cross entropy
    # update center
    # return loss

In [None]:
def train_one_epoch():
  # for it
    # update weights + lr
    # imgs to GPU
    # Forward Pass + loss

    # Update Student

    # EMA teacher

    # logging

In [None]:
class DataAugmentationDINO(object):
  # define crops
  # make it callable

##MultiCropWrapper

In [6]:
class MultiCropWrapper(nn.Module):
    """
    This is essential in DINO-style self-supervised learning, where each image
    is transformed into several "views" (global and local crops) and all
    views must be processed efficiently.

    Args:
        backbone (nn.Module):  Feature extractor ( ResNet --> from pietro).
        head (nn.Module):      head (MLP) mapping extracted features to the output space.
    """
    def __init__(self, backbone: nn.Module = None, head:nn.Module=None):# : hint for the expected type
        super().__init__()
        # Store the backbone and head modules
        self.backbone = backbone
        self.head = head

    def forward(self, crops: List[Tensor]) -> List[Tensor]: #list of image crops (each crop is a tensor)
        """
        Performs a single forward pass for multiple crops of each image.

        -- Concatenate all crop tensors along the batch dimension.
          num of crops per image multiple the batch size to get the total batch size.

        -- Run the entire batch through the Mpdel to extract features.

        -- Flatten spatial dimensions if needed, resulting in (N*B, D_flat).

        -- Split the heading back into a list of N tensors, each of shape (B, out_dim).

        Args:
            crops (List[Tensor]):
                A list of image batches.
                The list length equals the number of crops per image.

        Returns:
            List[Tensor]:
                A list of head outputs (headings), each of shape (B, out_dim),
                in the same order as the input crops.
        """
        # Concatenate all crops into one large batch
        all_crops = torch.cat(crops, dim=0)

        # Feature extraction through thr model
        features = self.backbone(all_crops)

        # Flatten--->if needed
        if features.dim() > 2:
            features = torch.flatten(features, start_dim=1)

        # Project features to the dino head which is MLP
        heading = self.head(features)


        #  Split heading_inp back again
        #  divides the first dim into `len(crops)` equal parts of size B
        outputs = list(heading.chunk(len(crops), dim=0))

        return outputs




### Train

In [None]:
def train_dino(args):
  # init rand seed
  # prep data
  # build student, teacher ???
  # add multicrop wrapper
  # init loss
  # init optimizer
  # init schedulers
  # train
    # train_one_epoch
    # write logs