In [1]:
import os
os.getcwd()

'/content'

In [2]:
os.chdir('/content/drive/MyDrive/DINO')

In [3]:
import argparse
import os
import sys
import datetime
import time
import math
import json
from pathlib import Path

import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torchvision import models as torchvision_models

import utils
import vision_transformer as vits
from vision_transformer import DINOHead


torchvision_archs = sorted(name for name in torchvision_models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(torchvision_models.__dict__[name]))

In [4]:
args={"arch":'vit_tiny',
      "patch_size":4,
      "out_dim":30,
      "norm_last_layer":True,
      "momentum_teacher":0.996,
      "use_bn_in_head":False,
    
      # Temperature teacher parameters
      "warmup_teacher_temp":0.04,
      "teacher_temp":0.04,
      "warmup_teacher_temp_epochs":0,
    
      # Training/Optimization parameters
      "use_fp16":True,
      "weight_decay":0.04,
      "weight_decay_end":0.4,
      "clip_grad":3.0,
      "batch_size_per_gpu":64,
      "epochs":100,
      "freeze_last_layer":1,
      "lr":0.0005,
      "warmup_epochs":10,
      "min_lr":1e-6,
      "optimizer":'adamw',
    
      # Multi-crop parameters
      "global_crops_scale":400,
      "local_crops_number":8,
      "local_crops_scale":80,

      # Misc
      "data_path":'./data/FashionMNIST',
      "output_dir":".",
      "saveckp_freq":20,
      "seed":0,
      "num_workers":10,
      "dist_url":"env://",
      "local_rank":0}

class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self

args=AttrDict(args)

In [5]:
def train_dino(args):
    utils.init_distributed_mode(args)
    utils.fix_random_seeds(args.seed)
    print("git:\n  {}\n".format(utils.get_sha()))
    print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
    cudnn.benchmark = True

    # custom dataset
    test_dataset = datasets.FashionMNIST(args.data_path, download=True, train=False)
    X_test = torch.flatten(test_dataset.data, start_dim=1).numpy()/255
    y_test = test_dataset.targets.numpy()
    dataset = CustomDataset(X_test, y_test)
    dataset_trans = DataAugmentationDINO(dataset,args.global_crops_scale,args.local_crops_scale,args.local_crops_number)
    sampler = torch.utils.data.DistributedSampler(dataset_trans, shuffle=True)
    data_loader = torch.utils.data.DataLoader(
            dataset_trans,
            sampler=sampler,
            batch_size=args.batch_size_per_gpu,
            num_workers=args.num_workers,
            pin_memory=True,
            drop_last=True,
            )
    return(dataset_trans)

    print(f"Data loaded: there are {len(dataset_trans)} cells.")

    # ============ building student and teacher networks ... ============
    # we changed the name DeiT-S for ViT-S to avoid confusions
    args.arch = args.arch.replace("deit", "vit")
    # if the network is a vision transformer (i.e. vit_tiny, vit_small, vit_base)
    if args.arch in vits.__dict__.keys():
        student = vits.__dict__[args.arch](
            patch_size=args.patch_size,
            drop_path_rate=0.1,  # stochastic depth
        )
        teacher = vits.__dict__[args.arch](patch_size=args.patch_size) 
        embed_dim = student.embed_dim
    # otherwise, we check if the architecture is in torchvision models
    elif args.arch in torchvision_models.__dict__.keys():
        student = torchvision_models.__dict__[args.arch]()
        teacher = torchvision_models.__dict__[args.arch]()
        embed_dim = student.fc.weight.shape[1]
    else:
        print(f"Unknow architecture: {args.arch}")

    # multi-crop wrapper handles forward with inputs of different resolutions
    student = utils.MultiCropWrapper(student, DINOHead(
        embed_dim,
        args.out_dim,
        use_bn=args.use_bn_in_head,
        norm_last_layer=args.norm_last_layer,
    ))
    teacher = utils.MultiCropWrapper(
        teacher,
        DINOHead(embed_dim, args.out_dim, args.use_bn_in_head),
    )
    # move networks to gpu
    student, teacher = student.cuda(), teacher.cuda()
    # synchronize batch norms (if any)
    if utils.has_batchnorms(student):
        student = nn.SyncBatchNorm.convert_sync_batchnorm(student)
        teacher = nn.SyncBatchNorm.convert_sync_batchnorm(teacher)

        # we need DDP wrapper to have synchro batch norms working...
        teacher = nn.parallel.DistributedDataParallel(teacher, device_ids=[args.gpu])
        teacher_without_ddp = teacher.module
    else:
        # teacher_without_ddp and teacher are the same thing
        teacher_without_ddp = teacher
    student = nn.parallel.DistributedDataParallel(student, device_ids=[args.gpu])
    # teacher and student start with the same weights
    teacher_without_ddp.load_state_dict(student.module.state_dict())
    # there is no backpropagation through the teacher, so no need for gradients
    for p in teacher.parameters():
        p.requires_grad = False
    print(f"Student and Teacher are built: they are both {args.arch} network.")

    # ============ preparing loss ... ============
    dino_loss = DINOLoss(
        args.out_dim,
        args.local_crops_number + 2,  # total number of crops = 2 global crops + local_crops_number
        args.warmup_teacher_temp,
        args.teacher_temp,
        args.warmup_teacher_temp_epochs,
        args.epochs,
    ).cuda()

    # ============ preparing optimizer ... ============
    params_groups = utils.get_params_groups(student)
    if args.optimizer == "adamw":
        optimizer = torch.optim.AdamW(params_groups)  # to use with ViTs
    elif args.optimizer == "sgd":
        optimizer = torch.optim.SGD(params_groups, lr=0, momentum=0.9)  # lr is set by scheduler
    elif args.optimizer == "lars":
        optimizer = utils.LARS(params_groups)  # to use with convnet and large batches
    # for mixed precision training
    fp16_scaler = None
    if args.use_fp16:
        fp16_scaler = torch.cuda.amp.GradScaler()

    # ============ init schedulers ... ============
    lr_schedule = utils.cosine_scheduler(
        args.lr * (args.batch_size_per_gpu * utils.get_world_size()) / 256.,  # linear scaling rule
        args.min_lr,
        args.epochs, len(data_loader),
        warmup_epochs=args.warmup_epochs,
    )
    wd_schedule = utils.cosine_scheduler(
        args.weight_decay,
        args.weight_decay_end,
        args.epochs, len(data_loader),
    )
    # momentum parameter is increased to 1. during training with a cosine schedule
    momentum_schedule = utils.cosine_scheduler(args.momentum_teacher, 1,
                                               args.epochs, len(data_loader))
    print(f"Loss, optimizer and schedulers ready.")

    # ============ optionally resume training ... ============
    to_restore = {"epoch": 0}
    utils.restart_from_checkpoint(
        os.path.join(args.output_dir, "checkpoint.pth"),
        run_variables=to_restore,
        student=student,
        teacher=teacher,
        optimizer=optimizer,
        fp16_scaler=fp16_scaler,
        dino_loss=dino_loss,
    )
    start_epoch = to_restore["epoch"]

    start_time = time.time()
    print("Starting DINO training !")
    for epoch in range(start_epoch, args.epochs):
        data_loader.sampler.set_epoch(epoch)

        # ============ training one epoch of DINO ... ============
        train_stats = train_one_epoch(student, teacher, teacher_without_ddp, dino_loss,
            data_loader, optimizer, lr_schedule, wd_schedule, momentum_schedule,
            epoch, fp16_scaler, args)

        # ============ writing logs ... ============
        save_dict = {
            'student': student.state_dict(),
            'teacher': teacher.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch + 1,
            'args': args,
            'dino_loss': dino_loss.state_dict(),
        }
        if fp16_scaler is not None:
            save_dict['fp16_scaler'] = fp16_scaler.state_dict()
        utils.save_on_master(save_dict, os.path.join(args.output_dir, 'checkpoint.pth'))
        if args.saveckp_freq and epoch % args.saveckp_freq == 0:
            utils.save_on_master(save_dict, os.path.join(args.output_dir, f'checkpoint{epoch:04}.pth'))
        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                     'epoch': epoch}
        if utils.is_main_process():
            with (Path(args.output_dir) / "log.txt").open("a") as f:
                f.write(json.dumps(log_stats) + "\n")
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))

In [6]:
def train_one_epoch(student, teacher, teacher_without_ddp, dino_loss, data_loader,
                    optimizer, lr_schedule, wd_schedule, momentum_schedule,epoch,
                    fp16_scaler, args):
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Epoch: [{}/{}]'.format(epoch, args.epochs)
    for it, (images, _) in enumerate(metric_logger.log_every(data_loader, 10, header)):
        # update weight decay and learning rate according to their schedule
        it = len(data_loader) * epoch + it  # global training iteration
        for i, param_group in enumerate(optimizer.param_groups):
            param_group["lr"] = lr_schedule[it]
            if i == 0:  # only the first group is regularized
                param_group["weight_decay"] = wd_schedule[it]

        # move images to gpu
        images = [im.cuda(non_blocking=True) for im in images]
        # teacher and student forward passes + compute dino loss
        with torch.cuda.amp.autocast(fp16_scaler is not None):
            teacher_output = teacher(images[:2])  # only the 2 global views pass through the teacher
            student_output = student(images)
            loss = dino_loss(student_output, teacher_output, epoch)

        if not math.isfinite(loss.item()):
            print("Loss is {}, stopping training".format(loss.item()), force=True)
            sys.exit(1)

        # student update
        optimizer.zero_grad()
        param_norms = None
        if fp16_scaler is None:
            loss.backward()
            if args.clip_grad:
                param_norms = utils.clip_gradients(student, args.clip_grad)
            utils.cancel_gradients_last_layer(epoch, student,
                                              args.freeze_last_layer)
            optimizer.step()
        else:
            fp16_scaler.scale(loss).backward()
            if args.clip_grad:
                fp16_scaler.unscale_(optimizer)  # unscale the gradients of optimizer's assigned params in-place
                param_norms = utils.clip_gradients(student, args.clip_grad)
            utils.cancel_gradients_last_layer(epoch, student,
                                              args.freeze_last_layer)
            fp16_scaler.step(optimizer)
            fp16_scaler.update()

        # EMA update for the teacher
        with torch.no_grad():
            m = momentum_schedule[it]  # momentum parameter
            for param_q, param_k in zip(student.module.parameters(), teacher_without_ddp.parameters()):
                param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)

        # logging
        torch.cuda.synchronize()
        metric_logger.update(loss=loss.item())
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
        metric_logger.update(wd=optimizer.param_groups[0]["weight_decay"])
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

In [7]:
class DINOLoss(nn.Module):
    def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp,
                 warmup_teacher_temp_epochs, nepochs, student_temp=0.1,
                 center_momentum=0.9):
        super().__init__()
        self.student_temp = student_temp
        self.center_momentum = center_momentum
        self.ncrops = ncrops
        self.register_buffer("center", torch.zeros(1, out_dim))
        # we apply a warm up for the teacher temperature because
        # a too high temperature makes the training instable at the beginning
        self.teacher_temp_schedule = np.concatenate((
            np.linspace(warmup_teacher_temp,
                        teacher_temp, warmup_teacher_temp_epochs),
            np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
        ))

    def forward(self, student_output, teacher_output, epoch):
        """
        Cross-entropy between softmax outputs of the teacher and student networks.
        """
        student_out = student_output / self.student_temp
        student_out = student_out.chunk(self.ncrops)

        # teacher centering and sharpening
        temp = self.teacher_temp_schedule[epoch]
        teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
        teacher_out = teacher_out.detach().chunk(2)

        total_loss = 0
        n_loss_terms = 0
        for iq, q in enumerate(teacher_out):
            for v in range(len(student_out)):
                if v == iq:
                    # we skip cases where student and teacher operate on the same view
                    continue
                loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
                total_loss += loss.mean()
                n_loss_terms += 1
        total_loss /= n_loss_terms
        self.update_center(teacher_output)
        return total_loss

    @torch.no_grad()
    def update_center(self, teacher_output):
        """
        Update center used for teacher output.
        """
        batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
        dist.all_reduce(batch_center)
        batch_center = batch_center / (len(teacher_output) * dist.get_world_size())

        # ema update
        self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)


In [8]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, images, labels=None, transforms=None):
        self.X = images
        self.y = labels
        self.transforms = transforms
         
    def __len__(self):
        return (len(self.X))
    
    def __getitem__(self, i):
        data = self.X[i, :]
        #data = np.asarray(data).astype(np.uint8).reshape(28, 28, 1)
        
        if self.transforms:
            data = self.transforms(data)
            
        if self.y is not None:
            return (data, self.y[i])
        else:
            return data

In [9]:
def DataAugmentationDINO(dataset, global_crops_scale, local_crops_scale, local_crops_number):
#    def __init__(self, dataset, global_crops_scale, local_crops_scale, local_crops_number):
    #    flip_and_color_jitter = transforms.Compose([
    #        transforms.RandomHorizontalFlip(p=0.5),
    #        transforms.RandomApply(
    #            [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
    #            p=0.8
    #        ),
    #        transforms.RandomGrayscale(p=0.2),
    #    ])
        #normalize = transforms.Compose([
        #    transforms.ToTensor(),
        #    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        #])
#        normalize = transforms.Compose([
#                                        transforms.ToTensor(),
#                                        transforms.Normalize((0.28604059698879553,), (0.35302424451492237,))
#                                        ])

        # first global crop
#        self.global_transfo1 = transforms.Compose([
#            transforms.RandomResizedCrop(28, scale=global_crops_scale, interpolation=Image.BICUBIC),
#            flip_and_color_jitter,
#            utils.GaussianBlur(1.0),
#            normalize,
#        ])
#        global_transfo1 = np.random.choice(dataset.X,size=global_crops_scale)
        # second global crop
#        self.global_transfo2 = transforms.Compose([
#            transforms.RandomResizedCrop(28, scale=global_crops_scale, interpolation=Image.BICUBIC),
#            flip_and_color_jitter,
#            utils.GaussianBlur(0.1),
#            utils.Solarization(0.2),
#            normalize,
#        ])
#        global_transfo2 = 
        # transformation for the local small crops
        #self.local_crops_number = local_crops_number
        #self.local_transfo = np.random.choice()
     length = dataset.X.shape[1]
     num = dataset.X.shape[0]
     data = np.empty(num).tolist()
     for n in range(num):
#    def __call__(self, X):
        crops = []
        crops.append(np.random.choice(dataset.X[n],size=global_crops_scale))
        crops.append(np.random.choice(dataset.X[n],size=global_crops_scale))
        for i in range(local_crops_number):
          crops.append(np.random.choice(dataset.X[n],size=local_crops_scale))
        data[n] = crops
    
     res = CustomDataset(data, dataset.y)
     return(res)

In [10]:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
pp=train_dino(args)

Will run the code on one GPU.
| distributed init (rank 0): env://
git:
  sha: N/A, status: clean, branch: N/A

arch: vit_tiny
batch_size_per_gpu: 64
clip_grad: 3.0
data_path: ./data/FashionMNIST
dist_url: env://
epochs: 100
freeze_last_layer: 1
global_crops_scale: 400
gpu: 0
local_crops_number: 8
local_crops_scale: 80
local_rank: 0
lr: 0.0005
min_lr: 1e-06
momentum_teacher: 0.996
norm_last_layer: True
num_workers: 10
optimizer: adamw
out_dim: 30
output_dir: .
patch_size: 4
rank: 0
saveckp_freq: 20
seed: 0
teacher_temp: 0.04
use_bn_in_head: False
use_fp16: True
warmup_epochs: 10
warmup_teacher_temp: 0.04
warmup_teacher_temp_epochs: 0
weight_decay: 0.04
weight_decay_end: 0.4
world_size: 1


  cpuset_checked))


Data loaded: there are 10000 cells.
Student and Teacher are built: they are both vit_tiny network.
Loss, optimizer and schedulers ready.
Found checkpoint at ./checkpoint.pth
=> loaded student from checkpoint './checkpoint.pth' with msg <All keys matched successfully>
=> loaded teacher from checkpoint './checkpoint.pth' with msg <All keys matched successfully>
=> loaded optimizer from checkpoint './checkpoint.pth'
=> loaded fp16_scaler from checkpoint './checkpoint.pth'
=> loaded dino_loss from checkpoint './checkpoint.pth' with msg <All keys matched successfully>
Starting DINO training !
Epoch: [1/100]  [  0/156]  eta: 0:02:05  loss: 3.340681 (3.340681)  lr: 0.000013 (0.000013)  wd: 0.040089 (0.040089)  time: 0.806340  data: 0.391736  max mem: 2971
Epoch: [1/100]  [ 10/156]  eta: 0:00:33  loss: 3.336869 (3.334374)  lr: 0.000013 (0.000013)  wd: 0.040095 (0.040095)  time: 0.229806  data: 0.035760  max mem: 3007
Epoch: [1/100]  [ 20/156]  eta: 0:00:27  loss: 3.330628 (3.331539)  lr: 0.000

In [15]:
pp.X[1]