In [2]:
cd ..

/export/kaspar/SimGCD


In [3]:
import argparse
import torch
from data.get_datasets import get_datasets, get_class_splits




def get_args(dataset = 'cifar100'):
    parser = argparse.ArgumentParser(description='cluster', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--batch_size', default=128, type=int)
    parser.add_argument('--num_workers', default=8, type=int)
    parser.add_argument('--eval_funcs', nargs='+', help='Which eval functions to use', default=['v2', 'v2p'])

    parser.add_argument('--warmup_model_dir', type=str, default=None)
    parser.add_argument('--dataset_name', type=str, default=dataset, help='options: cifar10, cifar100, imagenet_100, cub, scars, fgvc_aricraft, herbarium_19')
    parser.add_argument('--prop_train_labels', type=float, default=0.5)
    parser.add_argument('--use_ssb_splits', action='store_true', default=True)

    parser.add_argument('--grad_from_block', type=int, default=11)
    parser.add_argument('--lr', type=float, default=0.1)
    parser.add_argument('--gamma', type=float, default=0.1)
    parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--weight_decay', type=float, default=1e-4)
    parser.add_argument('--epochs', default=200, type=int)
    parser.add_argument('--exp_root', type=str, default=None)
    parser.add_argument('--transform', type=str, default='imagenet')
    parser.add_argument('--sup_weight', type=float, default=0.35)
    parser.add_argument('--n_views', default=2, type=int)
    
    parser.add_argument('--memax_weight', type=float, default=2)
    parser.add_argument('--warmup_teacher_temp', default=0.07, type=float, help='Initial value for the teacher temperature.')
    parser.add_argument('--teacher_temp', default=0.04, type=float, help='Final value (after linear warmup)of the teacher temperature.')
    parser.add_argument('--warmup_teacher_temp_epochs', default=30, type=int, help='Number of warmup epochs for the teacher temperature.')

    parser.add_argument('--fp16', action='store_true', default=False)
    parser.add_argument('--print_freq', default=10, type=int)
    parser.add_argument('--exp_name', default=None, type=str)

    """
    Size of Memory Queue
    """
    parser.add_argument('--mem_queue', type=bool, default=False, help='whether to use the memory queue')
    parser.add_argument('--mem_q_size', type=int, default= 32768, help='How many images to store in the queue size')
    parser.add_argument('--mem_p', type=float, default = 0., help='Probability of selecting the memory queue element instead of the other view')
    parser.add_argument('--mem_direct_knn', type=bool, default = False, help='If true, the KNN queue will be the first view, if false the second')

    """
    Static KNN
    """
    parser.add_argument('--static_knn', type=bool, default=False)

    # ----------------------
    # INIT
    # ----------------------
    args = parser.parse_args('')
    args = get_class_splits(args)

    args.num_labeled_classes = len(args.train_classes)
    args.num_unlabeled_classes = len(args.unlabeled_classes)


    # ----------------------
    # BASE MODEL
    # ----------------------
    args.interpolation = 3
    args.crop_pct = 0.875
    
    # NOTE: Hardcoded image size as we do not finetune the entire ViT model
    args.image_size = 224
    args.feat_dim = 768
    args.num_mlp_layers = 3
    args.mlp_out_dim = args.num_labeled_classes + args.num_unlabeled_classes


    args.eval_funcs = ['v2']
    
    return args

a = get_args()

In [5]:
from model import DINOHead
from torch import nn

def get_model(args):
    backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
    projector = DINOHead(in_dim=args.feat_dim, out_dim=args.mlp_out_dim, nlayers=args.num_mlp_layers)

    return nn.Sequential(backbone, projector)

args = get_args()
base_model = get_model(args)
ckpt = torch.load('/export/kaspar/SimGCD/experiments/dev_outputs_block10/simgcd/log/exp/checkpoints/model.pt')['model']

base_model.load_state_dict(ckpt)

Using cache found in /export/kaspar/.cache/torch/hub/facebookresearch_dino_main


<All keys matched successfully>

In [7]:
from data.augmentations import get_transform
from data.get_datasets import get_datasets, get_class_splits
from model import ContrastiveLearningViewGenerator
from torch.utils.data import DataLoader

train_transform, test_transform = get_transform(args.transform, image_size=args.image_size, args=args)
train_transform = ContrastiveLearningViewGenerator(base_transform=train_transform, n_views=args.n_views)
# --------------------
# DATASETS
# --------------------
train_dataset, test_dataset, unlabelled_train_examples_test, datasets = get_datasets(args.dataset_name,
                                                                                        train_transform,
                                                                                        test_transform,
                                                                                        args)
test_loader_unlabelled = DataLoader(unlabelled_train_examples_test, num_workers=args.num_workers,
                                        batch_size=256, shuffle=False, pin_memory=False)

  "Argument interpolation should be of type InterpolationMode instead of int. "


In [8]:
from train import test

base_model = base_model.cuda()

test(base_model, test_loader_unlabelled, 0, 'block10', args)

  0%|          | 0/118 [00:00<?, ?it/s]

100%|██████████| 118/118 [06:21<00:00,  3.23s/it]


(0.2156, 0.2155, 0.2158)