In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm

In [7]:
# test using a knn monitor
def knn_monitor(net, memory_data_loader, test_data_loader, device='cuda', k=200, t=0.1, targets=None):
    if not targets:
        targets = memory_data_loader.dataset.targets
    net.eval()
    classes = len(memory_data_loader.dataset.classes)
    total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, []
    with torch.no_grad():
        # generate feature bank
        for data, target in tqdm(memory_data_loader):
            _, feature = net(data.to(device=device, non_blocking=True))
            feature = F.normalize(feature, dim=1)
            feature_bank.append(feature)
        # [D, N]
        feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()
        # [N]
        feature_labels = torch.tensor(targets, device=feature_bank.device)
        # loop test data to predict the label by weighted knn search
        for data, target in tqdm(test_data_loader):
            data, target = data.to(device=device, non_blocking=True), target.to(device=device, non_blocking=True)
            _, feature = net(data)
            feature = F.normalize(feature, dim=1)

            pred_labels = knn_predict(feature, feature_bank, feature_labels, classes, k, t)
            
            total_num += data.size(0)

            total_top1 += (pred_labels[:, 0] == target).float().sum().item()
    return total_top1 / total_num * 100


# knn monitor as in InstDisc https://arxiv.org/abs/1805.01978
# implementation follows http://github.com/zhirongw/lemniscate.pytorch and https://github.com/leftthomas/SimCLR
def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t):
    """
    feature and feature_bank are normalized
    """
    # compute cos similarity between each feature vector and feature bank ---> [B, N]
    sim_matrix = torch.mm(feature, feature_bank)
    
    # [B, K]
    sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1)
    # [B, K]
    sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), dim=-1, index=sim_indices)
    sim_weight = (sim_weight / knn_t).exp()
    
    # counts for each class
    one_hot_label = torch.zeros(feature.size(0) * knn_k, classes, device=sim_labels.device)
    # [B*K, C]
    one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0)
    # print(one_hot_label)
    # weighted score ---> [B, C]
    pred_scores = torch.sum(one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1), dim=1)
    pred_labels = pred_scores.argsort(dim=-1, descending=True)
    return pred_labels

In [8]:
import torchvision
from torchvision import transforms
from torch.utils.data import  DataLoader
from tqdm import tqdm


linear_eval_train_transform = transforms.Compose([
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.RandomApply([transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2))], p=0.5),
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])

linear_eval_test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])

linear_eval_train_dataset = torchvision.datasets.CIFAR10(root='dataset', train=True, download=True, transform=linear_eval_test_transform)
linear_eval_test_dataset = torchvision.datasets.CIFAR10(root='dataset', train=False, download=True, transform=linear_eval_test_transform)


train_loader = DataLoader(linear_eval_train_dataset, batch_size=1024, shuffle=True)
test_loader = DataLoader(linear_eval_test_dataset, batch_size=1024, shuffle=False)

from utils.networks import build_resnet50
encoder = build_resnet50()
# encoder.load_state_dict(torch.load('params_/best_encoder.pt', map_location='cpu'))
encoder = encoder.cuda()

Files already downloaded and verified
Files already downloaded and verified




In [10]:
def knn_evaluation(encoder):
    
    simple_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
    ])

    linear_eval_train_dataset = torchvision.datasets.CIFAR10(root='dataset', train=True,  download=True, transform=simple_transform)
    linear_eval_test_dataset  = torchvision.datasets.CIFAR10(root='dataset', train=False, download=True, transform=simple_transform)

    train_loader = DataLoader(linear_eval_train_dataset, batch_size=1024, shuffle=True)
    test_loader = DataLoader(linear_eval_test_dataset, batch_size=1024, shuffle=False)
    
    acc = knn_monitor(
        encoder,
        train_loader,
        test_loader,
    )
    
    return acc

In [11]:
knn_evaluation(encoder)

Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 49/49 [00:48<00:00,  1.02it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

torch.Size([1024, 50000]) 

 10%|█         | 1/10 [00:01<00:10,  1.12s/it]

tensor(0.4145, device='cuda:0') tensor(0.9953, device='cuda:0')
tensor([2, 7, 3,  ..., 8, 4, 9], device='cuda:0')
tensor([3, 8, 8,  ..., 1, 0, 0], device='cuda:0')
torch.Size([1024, 50000]) 

 20%|██        | 2/10 [00:02<00:09,  1.14s/it]

tensor(0.4269, device='cuda:0') tensor(0.9931, device='cuda:0')
tensor([6, 8, 7,  ..., 7, 9, 0], device='cuda:0')
tensor([3, 5, 0,  ..., 1, 3, 5], device='cuda:0')
torch.Size([1024, 50000]) 

 30%|███       | 3/10 [00:03<00:07,  1.14s/it]

tensor(0.4350, device='cuda:0') tensor(0.9967, device='cuda:0')
tensor([2, 7, 9,  ..., 6, 7, 9], device='cuda:0')
tensor([8, 3, 8,  ..., 4, 0, 7], device='cuda:0')
torch.Size([1024, 50000]) 

 40%|████      | 4/10 [00:04<00:06,  1.14s/it]

tensor(0.4117, device='cuda:0') tensor(0.9969, device='cuda:0')
tensor([6, 0, 7,  ..., 0, 8, 4], device='cuda:0')
tensor([8, 3, 1,  ..., 3, 3, 8], device='cuda:0')
torch.Size([1024, 50000]) 

 50%|█████     | 5/10 [00:05<00:05,  1.14s/it]

tensor(0.4539, device='cuda:0') tensor(0.9937, device='cuda:0')
tensor([3, 9, 0,  ..., 4, 3, 8], device='cuda:0')
tensor([8, 3, 9,  ..., 7, 2, 5], device='cuda:0')
torch.Size([1024, 50000]) 

 60%|██████    | 6/10 [00:06<00:04,  1.14s/it]

tensor(0.4286, device='cuda:0') tensor(0.9969, device='cuda:0')
tensor([8, 8, 9,  ..., 6, 7, 2], device='cuda:0')
tensor([9, 6, 6,  ..., 8, 5, 9], device='cuda:0')
torch.Size([1024, 50000]) 

 70%|███████   | 7/10 [00:07<00:03,  1.14s/it]

tensor(0.4189, device='cuda:0') tensor(0.9962, device='cuda:0')
tensor([7, 5, 2,  ..., 9, 8, 8], device='cuda:0')
tensor([7, 2, 0,  ..., 0, 1, 7], device='cuda:0')
torch.Size([1024, 50000]) 

 80%|████████  | 8/10 [00:09<00:02,  1.14s/it]

tensor(0.4237, device='cuda:0') tensor(0.9946, device='cuda:0')
tensor([7, 3, 3,  ..., 8, 1, 2], device='cuda:0')
tensor([5, 1, 9,  ..., 8, 4, 6], device='cuda:0')
torch.Size([1024, 50000]) 

 90%|█████████ | 9/10 [00:10<00:01,  1.13s/it]

tensor(0.4480, device='cuda:0') tensor(0.9972, device='cuda:0')
tensor([2, 1, 7,  ..., 1, 9, 5], device='cuda:0')
tensor([5, 3, 1,  ..., 2, 9, 5], device='cuda:0')
torch.Size([784, 50000]) 

100%|██████████| 10/10 [00:11<00:00,  1.11s/it]

tensor(0.4283, device='cuda:0') tensor(0.9969, device='cuda:0')
tensor([2, 8, 3, 7, 3, 8, 4, 6, 3, 1, 7, 5, 4, 9, 3, 4, 5, 8, 6, 5, 4, 4, 6, 4,
        5, 8, 3, 3, 7, 8, 7, 2, 7, 3, 9, 4, 3, 8, 9, 5, 3, 7, 4, 7, 0, 0, 1, 4,
        5, 9, 0, 3, 2, 8, 4, 0, 6, 0, 6, 0, 0, 6, 8, 7, 0, 1, 8, 9, 5, 8, 3, 1,
        4, 9, 3, 2, 6, 6, 5, 9, 6, 8, 0, 8, 5, 7, 5, 8, 7, 0, 1, 2, 1, 5, 1, 2,
        4, 5, 6, 1, 7, 9, 0, 6, 7, 4, 1, 5, 1, 9, 6, 6, 2, 3, 1, 2, 6, 9, 9, 9,
        4, 8, 1, 4, 3, 7, 9, 7, 0, 9, 7, 0, 6, 5, 7, 8, 2, 2, 4, 8, 8, 8, 3, 7,
        8, 0, 4, 0, 0, 9, 5, 7, 4, 7, 5, 0, 5, 4, 3, 5, 9, 3, 0, 7, 5, 8, 3, 1,
        0, 0, 5, 8, 8, 8, 2, 4, 8, 2, 9, 5, 9, 4, 7, 7, 2, 6, 8, 9, 3, 5, 7, 4,
        7, 4, 9, 5, 8, 0, 7, 7, 0, 8, 6, 2, 4, 3, 0, 0, 5, 6, 6, 3, 2, 0, 1, 0,
        4, 3, 5, 8, 6, 7, 1, 0, 8, 1, 8, 2, 5, 4, 5, 4, 0, 8, 8, 1, 3, 7, 4, 7,
        8, 8, 1, 7, 9, 6, 7, 2, 9, 3, 2, 4, 3, 5, 8, 8, 5, 0, 7, 8, 5, 2, 7, 7,
        4, 7, 4, 9, 1, 5, 9, 5, 0, 7, 3, 7, 9, 9, 5, 2, 




11.12