In [17]:
# standard python imports
import sys
sys.path.append("/home/conradb/git/ifg-ssl")

# pytorch + fastai imports
from fastai.vision.all import *
import torch
from torchvision import datasets, transforms
# from torch.utils.data import DataLoader
# import torchvision.transforms.functional as F
# from torchvision.utils import make_grid

# local module imports
# from dino.augment import ImageAugmentationDINO
# from dino.loss import DINOLoss
# import dino.utils as utils
import dino.vision_transformer as vit
from dino.linear_classifier import LinearClassifier, train
# from dino.vision_transformer import DINOHead

In [18]:
# model
model = 'vit_tiny' # embed_dim=192, depth=12, num_heads=3
n_last_blocks = 4
avgpool_patchtokens = False

num_labels = 10

batch_size= 256
num_workers= 8 
pin_memory= False
epochs = 10
lr = 0.001

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [19]:
model = vit.__dict__[model](patch_size=16, num_classes=0)
embed_dim = model.embed_dim * (n_last_blocks + int(avgpool_patchtokens))

In [20]:
state_dict = torch.load('/home/conradb/git/ifg-ssl/dino/dino_checkpoints/dino_imagenette_160_ckpt.pth', map_location='cpu')
state_dict = state_dict["teacher"]

In [59]:
print(state_dict.keys())

odict_keys(['backbone.cls_token', 'backbone.pos_embed', 'backbone.patch_embed.proj.weight', 'backbone.patch_embed.proj.bias', 'backbone.blocks.0.norm1.weight', 'backbone.blocks.0.norm1.bias', 'backbone.blocks.0.attn.qkv.weight', 'backbone.blocks.0.attn.qkv.bias', 'backbone.blocks.0.attn.proj.weight', 'backbone.blocks.0.attn.proj.bias', 'backbone.blocks.0.norm2.weight', 'backbone.blocks.0.norm2.bias', 'backbone.blocks.0.mlp.fc1.weight', 'backbone.blocks.0.mlp.fc1.bias', 'backbone.blocks.0.mlp.fc2.weight', 'backbone.blocks.0.mlp.fc2.bias', 'backbone.blocks.1.norm1.weight', 'backbone.blocks.1.norm1.bias', 'backbone.blocks.1.attn.qkv.weight', 'backbone.blocks.1.attn.qkv.bias', 'backbone.blocks.1.attn.proj.weight', 'backbone.blocks.1.attn.proj.bias', 'backbone.blocks.1.norm2.weight', 'backbone.blocks.1.norm2.bias', 'backbone.blocks.1.mlp.fc1.weight', 'backbone.blocks.1.mlp.fc1.bias', 'backbone.blocks.1.mlp.fc2.weight', 'backbone.blocks.1.mlp.fc2.bias', 'backbone.blocks.2.norm1.weight', 'bac

In [21]:
# remove `module.` prefix
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}

In [22]:
print(state_dict.keys())

dict_keys(['backbone.cls_token', 'backbone.pos_embed', 'backbone.patch_embed.proj.weight', 'backbone.patch_embed.proj.bias', 'backbone.blocks.0.norm1.weight', 'backbone.blocks.0.norm1.bias', 'backbone.blocks.0.attn.qkv.weight', 'backbone.blocks.0.attn.qkv.bias', 'backbone.blocks.0.attn.proj.weight', 'backbone.blocks.0.attn.proj.bias', 'backbone.blocks.0.norm2.weight', 'backbone.blocks.0.norm2.bias', 'backbone.blocks.0.mlp.fc1.weight', 'backbone.blocks.0.mlp.fc1.bias', 'backbone.blocks.0.mlp.fc2.weight', 'backbone.blocks.0.mlp.fc2.bias', 'backbone.blocks.1.norm1.weight', 'backbone.blocks.1.norm1.bias', 'backbone.blocks.1.attn.qkv.weight', 'backbone.blocks.1.attn.qkv.bias', 'backbone.blocks.1.attn.proj.weight', 'backbone.blocks.1.attn.proj.bias', 'backbone.blocks.1.norm2.weight', 'backbone.blocks.1.norm2.bias', 'backbone.blocks.1.mlp.fc1.weight', 'backbone.blocks.1.mlp.fc1.bias', 'backbone.blocks.1.mlp.fc2.weight', 'backbone.blocks.1.mlp.fc2.bias', 'backbone.blocks.2.norm1.weight', 'back

In [23]:

# remove `backbone.` prefix induced by multicrop wrapper
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}

In [24]:
print(state_dict.keys())


dict_keys(['cls_token', 'pos_embed', 'patch_embed.proj.weight', 'patch_embed.proj.bias', 'blocks.0.norm1.weight', 'blocks.0.norm1.bias', 'blocks.0.attn.qkv.weight', 'blocks.0.attn.qkv.bias', 'blocks.0.attn.proj.weight', 'blocks.0.attn.proj.bias', 'blocks.0.norm2.weight', 'blocks.0.norm2.bias', 'blocks.0.mlp.fc1.weight', 'blocks.0.mlp.fc1.bias', 'blocks.0.mlp.fc2.weight', 'blocks.0.mlp.fc2.bias', 'blocks.1.norm1.weight', 'blocks.1.norm1.bias', 'blocks.1.attn.qkv.weight', 'blocks.1.attn.qkv.bias', 'blocks.1.attn.proj.weight', 'blocks.1.attn.proj.bias', 'blocks.1.norm2.weight', 'blocks.1.norm2.bias', 'blocks.1.mlp.fc1.weight', 'blocks.1.mlp.fc1.bias', 'blocks.1.mlp.fc2.weight', 'blocks.1.mlp.fc2.bias', 'blocks.2.norm1.weight', 'blocks.2.norm1.bias', 'blocks.2.attn.qkv.weight', 'blocks.2.attn.qkv.bias', 'blocks.2.attn.proj.weight', 'blocks.2.attn.proj.bias', 'blocks.2.norm2.weight', 'blocks.2.norm2.bias', 'blocks.2.mlp.fc1.weight', 'blocks.2.mlp.fc1.bias', 'blocks.2.mlp.fc2.weight', 'block

In [25]:
model.load_state_dict(state_dict, strict=False)

_IncompatibleKeys(missing_keys=[], unexpected_keys=['head.mlp.0.weight', 'head.mlp.0.bias', 'head.mlp.2.weight', 'head.mlp.2.bias', 'head.mlp.4.weight', 'head.mlp.4.bias', 'head.last_layer.weight_g', 'head.last_layer.weight_v'])

In [26]:
linear_classifier = LinearClassifier(embed_dim, num_labels=num_labels)

In [27]:
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

In [28]:
path = untar_data(URLs.IMAGENETTE_160)
train_dataset = datasets.ImageFolder(root=path/'train', transform=train_transform)
# dataset_train = datasets.ImageFolder(os.path.join(args.data_path, "train"), transform=train_transform)

In [29]:
train_loader = torch.utils.data.DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory)

In [30]:
 # set optimizer
optimizer = torch.optim.SGD(linear_classifier.parameters(),
                            lr,
                            momentum=0.9,
                            weight_decay=0, # we do not apply weight decay
                            )
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, eta_min=0)


TO DO - resume from checkpoint

In [31]:
for epoch in range(0, epochs+1):
        train(model, linear_classifier, optimizer, train_loader, epoch, n_last_blocks, avgpool_patchtokens, device)
        scheduler.step()

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn