In [1]:
from backbone_utils import IntermediateFeatureModule

In [2]:
import torchvision
import torch
import torch.utils.data

In [10]:
import numpy as np

In [3]:
model = torchvision.models.resnet50(pretrained=True)
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [4]:
model.fc.weight.shape

torch.Size([1000, 2048])

In [5]:
model.fc = torch.nn.Linear(model.fc.in_features, 100, bias=True)

In [6]:
model.fc.weight.shape

torch.Size([100, 2048])

In [7]:
embedding = IntermediateFeatureModule(model, ['avgpool'])

In [9]:
dataset = torchvision.datasets.CIFAR100("/mnt/datasets/public/cifar100/",
                    transform=torchvision.transforms.Compose([
                        torchvision.transforms.ToTensor(),
                        torchvision.transforms.Normalize(
                            mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225]
                        )
                    ]))

In [11]:
labelled_dataset = torch.utils.data.Subset(dataset, np.arange(100))
unlabelled_dataset = torch.utils.data.Subset(dataset, np.arange(101, len(dataset)))

In [12]:
labelled_loader = torch.utils.data.DataLoader(labelled_dataset, batch_size=32)
unlabelled_loader = torch.utils.data.DataLoader(unlabelled_dataset, batch_size=32)

In [13]:
embedding = embedding.cuda()
labelled_embeddings = []

for x, y in labelled_loader:
    with torch.no_grad():
        x = x.cuda()
        labelled_embeddings.append(embedding(x)['avgpool'].squeeze().cpu().numpy())

In [14]:
labelled_embeddings = np.concatenate(labelled_embeddings)

In [16]:
test_x, _ = next(iter(unlabelled_loader))
with torch.no_grad():
    test_embedding = embedding(test_x.cuda())['avgpool'].squeeze().cpu().numpy()

In [17]:
test_embedding.shape

(32, 2048)

In [19]:
(test_embedding[0, :] - labelled_embeddings).shape

(100, 2048)

In [21]:
from scipy.spatial import cKDTree

In [22]:
tree = cKDTree(labelled_embeddings)

In [23]:
tree.query(test_embedding)

(array([23.39474082, 32.54301783, 24.19577934, 25.08151453, 22.563244  ,
        49.72409988, 30.81506044, 46.88290054, 21.01325877, 45.54252688,
        56.09761756, 22.40028562, 24.29599516, 21.12151748, 32.7471984 ,
        44.54750342, 27.67510502, 30.30095982, 32.32567773, 22.00556638,
        34.54423174, 44.22585938, 28.61319191, 26.84988107, 27.21627808,
        24.95879876, 39.11878122, 34.04050599, 26.09931827, 41.60534561,
        23.73287253, 46.31750967]),
 array([97, 97, 24, 97, 16, 37, 84, 73, 97,  1, 62, 16, 16, 16, 69, 78, 48,
        16, 78, 16, 84, 81, 97, 16, 71, 97, 78, 39, 48, 44, 97, 23]))

In [24]:
test_embedding.shape

(32, 2048)

In [25]:
from sklearn.neighbors import NearestNeighbors

In [26]:
nbrs = NearestNeighbors(n_neighbors=25, algorithm='ball_tree').fit(labelled_embeddings)

In [27]:
nbrs.kneighbors(test_embedding)

(array([[23.39474082, 23.92687547, 24.07994547, 24.16215359, 25.30447421,
         25.32974056, 25.54202081, 25.57500532, 25.58128655, 25.79459596,
         26.0983242 , 26.12695711, 26.42897969, 26.51798611, 27.03385658,
         27.30421524, 27.40442335, 27.4128344 , 28.03571413, 28.34874346,
         28.66374729, 28.92934196, 29.00631606, 29.50966459, 29.62208403],
        [32.54301783, 32.75440432, 32.84976731, 33.67812429, 33.78644532,
         33.79169097, 33.87860727, 33.95709908, 34.0361709 , 34.22118178,
         34.57610516, 34.61810152, 34.73684023, 34.75666159, 34.90725486,
         34.99181043, 35.03920328, 35.07434093, 35.18014141, 35.27962974,
         35.30409027, 35.40903038, 35.46429757, 35.64657945, 35.67639633],
        [24.19577934, 24.22186981, 24.26694746, 24.47327735, 24.69089942,
         24.99949511, 25.10673734, 25.55406393, 25.85695949, 25.87274521,
         26.1843401 , 26.23427321, 26.44981428, 26.60413638, 27.08665683,
         27.12365632, 27.44429916, 2