In [5]:
import torch
from vit_pytorch import ViT, Dino
from tqdm import tqdm

model = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048
)

learner = Dino(
    model,
    image_size = 256,
    hidden_layer = 'to_latent',        # hidden layer name or index, from which to extract the embedding
    projection_hidden_size = 256,      # projector network hidden dimension
    projection_layers = 4,             # number of layers in projection network
    num_classes_K = 65336,             # output logits dimensions (referenced as K in paper)
    student_temp = 0.9,                # student temperature
    teacher_temp = 0.04,               # teacher temperature, needs to be annealed from 0.04 to 0.07 over 30 epochs
    local_upper_crop_scale = 0.4,      # upper bound for local crop - 0.4 was recommended in the paper 
    global_lower_crop_scale = 0.5,     # lower bound for global crop - 0.5 was recommended in the paper
    moving_average_decay = 0.9,        # moving average of encoder - paper showed anywhere from 0.9 to 0.999 was ok
    center_moving_average_decay = 0.9, # moving average of teacher centers - paper showed anywhere from 0.9 to 0.999 was ok
)

opt = torch.optim.Adam(learner.parameters(), lr = 3e-4)

def sample_unlabelled_images():
    return torch.randn(20, 3, 256, 256)



In [7]:
for _ in tqdm(range(100)):
    images = sample_unlabelled_images()
    loss = learner(images)
    opt.zero_grad()
    loss.backward()
    opt.step()
    learner.update_moving_average() # update moving average of teacher encoder and teacher centers

# save your improved network
torch.save(model.state_dict(), './pretrained-net.pt')

  0%|                                                                                                                                                                   | 0/100 [00:00<?, ?it/s]

tensor(10.5418, grad_fn=<DivBackward0>)


  1%|█▌                                                                                                                                                         | 1/100 [00:02<04:27,  2.70s/it]

tensor(10.5391, grad_fn=<DivBackward0>)


  2%|███                                                                                                                                                        | 2/100 [00:05<04:12,  2.58s/it]

tensor(10.5365, grad_fn=<DivBackward0>)


  3%|████▋                                                                                                                                                      | 3/100 [00:07<04:05,  2.53s/it]

tensor(10.5340, grad_fn=<DivBackward0>)


  4%|██████▏                                                                                                                                                    | 4/100 [00:10<03:56,  2.47s/it]

tensor(10.5316, grad_fn=<DivBackward0>)


  5%|███████▊                                                                                                                                                   | 5/100 [00:12<03:50,  2.42s/it]

tensor(10.5294, grad_fn=<DivBackward0>)


  6%|█████████▎                                                                                                                                                 | 6/100 [00:15<03:54,  2.49s/it]

tensor(10.5273, grad_fn=<DivBackward0>)


  7%|██████████▊                                                                                                                                                | 7/100 [00:17<03:47,  2.45s/it]

tensor(10.5253, grad_fn=<DivBackward0>)


  8%|████████████▍                                                                                                                                              | 8/100 [00:20<03:50,  2.51s/it]

tensor(10.5234, grad_fn=<DivBackward0>)


  9%|█████████████▉                                                                                                                                             | 9/100 [00:22<03:51,  2.54s/it]

tensor(10.5217, grad_fn=<DivBackward0>)


 10%|███████████████▍                                                                                                                                          | 10/100 [00:25<03:48,  2.54s/it]

tensor(10.5202, grad_fn=<DivBackward0>)


 11%|████████████████▉                                                                                                                                         | 11/100 [00:27<03:49,  2.58s/it]

tensor(10.5187, grad_fn=<DivBackward0>)


 12%|██████████████████▍                                                                                                                                       | 12/100 [00:30<03:40,  2.51s/it]

tensor(10.5174, grad_fn=<DivBackward0>)


 13%|████████████████████                                                                                                                                      | 13/100 [00:32<03:37,  2.50s/it]

tensor(10.5162, grad_fn=<DivBackward0>)


 14%|█████████████████████▌                                                                                                                                    | 14/100 [00:35<03:35,  2.50s/it]

tensor(10.5152, grad_fn=<DivBackward0>)


 15%|███████████████████████                                                                                                                                   | 15/100 [00:37<03:31,  2.49s/it]

tensor(10.5143, grad_fn=<DivBackward0>)


 16%|████████████████████████▋                                                                                                                                 | 16/100 [00:40<03:29,  2.49s/it]

tensor(10.5137, grad_fn=<DivBackward0>)


 17%|██████████████████████████▏                                                                                                                               | 17/100 [00:42<03:26,  2.49s/it]

tensor(10.5131, grad_fn=<DivBackward0>)


 18%|███████████████████████████▋                                                                                                                              | 18/100 [00:44<03:20,  2.44s/it]

tensor(10.5127, grad_fn=<DivBackward0>)


 19%|█████████████████████████████▎                                                                                                                            | 19/100 [00:47<03:20,  2.48s/it]

tensor(10.5125, grad_fn=<DivBackward0>)


 20%|██████████████████████████████▊                                                                                                                           | 20/100 [00:50<03:19,  2.49s/it]

tensor(10.5124, grad_fn=<DivBackward0>)


 21%|████████████████████████████████▎                                                                                                                         | 21/100 [00:52<03:21,  2.55s/it]

tensor(10.5125, grad_fn=<DivBackward0>)


 22%|█████████████████████████████████▉                                                                                                                        | 22/100 [00:55<03:22,  2.59s/it]

tensor(10.5128, grad_fn=<DivBackward0>)


 23%|███████████████████████████████████▍                                                                                                                      | 23/100 [00:58<03:23,  2.64s/it]

tensor(10.5133, grad_fn=<DivBackward0>)


 24%|████████████████████████████████████▉                                                                                                                     | 24/100 [01:00<03:21,  2.65s/it]

tensor(10.5139, grad_fn=<DivBackward0>)


 25%|██████████████████████████████████████▌                                                                                                                   | 25/100 [01:03<03:16,  2.62s/it]

tensor(10.5147, grad_fn=<DivBackward0>)


 26%|████████████████████████████████████████                                                                                                                  | 26/100 [01:05<03:10,  2.58s/it]

tensor(nan, grad_fn=<DivBackward0>)


 27%|█████████████████████████████████████████▌                                                                                                                | 27/100 [01:08<03:09,  2.59s/it]

tensor(nan, grad_fn=<DivBackward0>)


 28%|███████████████████████████████████████████                                                                                                               | 28/100 [01:10<03:02,  2.54s/it]

tensor(nan, grad_fn=<DivBackward0>)


 29%|████████████████████████████████████████████▋                                                                                                             | 29/100 [01:13<02:58,  2.51s/it]

tensor(nan, grad_fn=<DivBackward0>)


 30%|██████████████████████████████████████████████▏                                                                                                           | 30/100 [01:15<02:56,  2.52s/it]

tensor(nan, grad_fn=<DivBackward0>)


 31%|███████████████████████████████████████████████▋                                                                                                          | 31/100 [01:18<02:57,  2.57s/it]

tensor(nan, grad_fn=<DivBackward0>)


 32%|█████████████████████████████████████████████████▎                                                                                                        | 32/100 [01:21<02:53,  2.56s/it]


KeyboardInterrupt: 