In this lab you will do the following steps in order:

1. What is [SSL](https://docs.google.com/presentation/d/17vkcwjBdyFAy3y9ZYphuV28kIqDtGCW4zUWXgfL54sU/edit?usp=share_link) (Self Supervised Learning).
2. Train a CNN using a SSL (image rotation prediction) and perform classification on CIFAR10
3. Evaluate the trained model and show the nearest neighbors in the latent space




In [None]:
!pip install warmup_scheduler
import argparse
import os
import os.path as osp
import time
import shutil
from warmup_scheduler import GradualWarmupScheduler

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.multiprocessing as mp
import torch.optim.lr_scheduler as lr_scheduler
%matplotlib inline

2. Network definition

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class RotationPrediction(nn.Module):
    """
      This class defines a module for predicting image rotation.

      It takes images as input and outputs a dictionary containing the loss
      and accuracy, as well as the latent features of the images. The class
      also provides a method for encoding images into their latent features.

      Attributes:
          metrics (list): List of metrics to track during training,
              containing 'Loss' and 'Acc1' by default.
          metrics_fmt (list): List of formatting strings for the metrics.
          model (nn.Module): The underlying neural network for feature extraction.
          latent_dim (int): Dimensionality of the latent features.
          feat_layer (str): Name of the layer in the model that outputs the
              features to be used for classification.
          n_classes (int): Number of rotation classes (typically 4 for 0, 90, 180, 270 degrees).
      """
    metrics = ['Loss', 'Acc1']
    metrics_fmt = [':.4e', ':6.2f']

    def __init__(self, n_classes):
        super().__init__()
        self.model = NetworkInNetwork()
        self.latent_dim = 192 * 8 * 8
        self.feat_layer = 'conv2'
        self.n_classes = n_classes

    def construct_classifier(self):
        classifier = nn.Sequential(
            nn.BatchNorm1d(self.latent_dim, affine=False),
            nn.Linear(self.latent_dim, self.n_classes)
        )
        return classifier

    def forward(self, images):
        batch_size = images.shape[0]
        images, targets = self._preprocess(images)
        targets = targets.to(images.get_device())

        logits, zs = self.model(images, out_feat_keys=['classifier', self.feat_layer])
        # print(logits.shape, zs.shape)
        loss = F.cross_entropy(logits, targets)

        pred = logits.argmax(dim=-1)
        correct = pred.eq(targets).float().sum()
        acc = correct / targets.shape[0] * 100.

        zs = zs[:batch_size]
        zs = zs.flatten(start_dim=1)

        return dict(Loss=loss, Acc1=acc), zs[:batch_size]

    def encode(self, images, flatten=True):
        zs = self.model(images, out_feat_keys=(self.feat_layer,))
        return zs.flatten(start_dim=1)

    def _preprocess(self, images):
        batch_size = images.shape[0]
        images_90 = torch.flip(images.transpose(2, 3), (2,))
        images_180 = torch.flip(images, (2, 3))
        images_270 = torch.flip(images, (2,)).transpose(2, 3)
        images_batch = torch.cat((images, images_90, images_180, images_270), dim=0)
        targets = torch.arange(4).long().repeat(batch_size)
        targets = targets.view(batch_size, 4).transpose(0, 1)
        targets = targets.contiguous().view(-1)
        return images_batch, targets



# Code borrowed from https://github.com/gidariss/FeatureLearningRotNet

# NetworkInNetwork
class BasicBlock(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size):
        super(BasicBlock, self).__init__()
        padding = (kernel_size-1) // 2
        self.layers = nn.Sequential()
        self.layers.add_module('Conv', nn.Conv2d(in_planes, out_planes,
            kernel_size=kernel_size, stride=1, padding=padding, bias=False))
        self.layers.add_module('BatchNorm', nn.BatchNorm2d(out_planes))
        self.layers.add_module('ReLU',      nn.ReLU(inplace=True))

    def forward(self, x):
        return self.layers(x)


class GlobalAveragePooling(nn.Module):
    def __init__(self):
        super(GlobalAveragePooling, self).__init__()

    def forward(self, feat):
        num_channels = feat.size(1)
        return F.avg_pool2d(feat, (feat.size(2), feat.size(3))).view(-1, num_channels)


class NetworkInNetwork(nn.Module):
    """
    This class defines a convolutional neural network architecture
    specifically designed for image classification.

    The network consists of several stages, each containing multiple
    convolutional blocks. Each block applies a series of convolutions
    and (optionally) pooling operations. The network also includes
    global average pooling and a final linear layer for classification.

    The network can be configured with a variable number of stages
    and the option to use average pooling after the third stage.

    Attributes:
        num_classes (int): Number of output classes (typically 4 for rotation prediction).
        num_inchannels (int): Number of input channels (typically 3 for RGB images).
        num_stages (int): Number of stages in the network architecture.
        use_avg_on_conv3 (bool): Whether to use average pooling after the third stage.
        _feature_blocks (nn.ModuleList): List of modules representing the network stages.
        all_feat_names (list): List of names corresponding to each feature output.
    """
    def __init__(self):
        super(NetworkInNetwork, self).__init__()

        num_classes = 4
        num_inchannels = 3
        num_stages = 4
        use_avg_on_conv3 = False


        nChannels  = 192
        nChannels2 = 160
        nChannels3 = 96

        blocks = [nn.Sequential() for i in range(num_stages)]
        # 1st block
        blocks[0].add_module('Block1_ConvB1', BasicBlock(num_inchannels, nChannels, 5))
        blocks[0].add_module('Block1_ConvB2', BasicBlock(nChannels,  nChannels2, 1))
        blocks[0].add_module('Block1_ConvB3', BasicBlock(nChannels2, nChannels3, 1))
        blocks[0].add_module('Block1_MaxPool', nn.MaxPool2d(kernel_size=3,stride=2,padding=1))

        # 2nd block
        blocks[1].add_module('Block2_ConvB1',  BasicBlock(nChannels3, nChannels, 5))
        blocks[1].add_module('Block2_ConvB2',  BasicBlock(nChannels,  nChannels, 1))
        blocks[1].add_module('Block2_ConvB3',  BasicBlock(nChannels,  nChannels, 1))
        blocks[1].add_module('Block2_AvgPool', nn.AvgPool2d(kernel_size=3,stride=2,padding=1))

        # 3rd block
        blocks[2].add_module('Block3_ConvB1',  BasicBlock(nChannels, nChannels, 3))
        blocks[2].add_module('Block3_ConvB2',  BasicBlock(nChannels, nChannels, 1))
        blocks[2].add_module('Block3_ConvB3',  BasicBlock(nChannels, nChannels, 1))

        if num_stages > 3 and use_avg_on_conv3:
            blocks[2].add_module('Block3_AvgPool', nn.AvgPool2d(kernel_size=3,stride=2,padding=1))
        for s in range(3, num_stages):
            blocks[s].add_module('Block'+str(s+1)+'_ConvB1',  BasicBlock(nChannels, nChannels, 3))
            blocks[s].add_module('Block'+str(s+1)+'_ConvB2',  BasicBlock(nChannels, nChannels, 1))
            blocks[s].add_module('Block'+str(s+1)+'_ConvB3',  BasicBlock(nChannels, nChannels, 1))

        # global average pooling and classifier
        blocks.append(nn.Sequential())
        blocks[-1].add_module('GlobalAveragePooling',  GlobalAveragePooling())
        blocks[-1].add_module('Classifier', nn.Linear(nChannels, num_classes))

        self._feature_blocks = nn.ModuleList(blocks)
        self.all_feat_names = ['conv'+str(s+1) for s in range(num_stages)] + ['classifier',]
        # print(self.all_feat_names)
        assert(len(self.all_feat_names) == len(self._feature_blocks))

    def _parse_out_keys_arg(self, out_feat_keys):

        # By default return the features of the last layer / module.
        out_feat_keys = [self.all_feat_names[-1],] if out_feat_keys is None else out_feat_keys

        if len(out_feat_keys) == 0:
            raise ValueError('Empty list of output feature keys.')
        for f, key in enumerate(out_feat_keys):
            if key not in self.all_feat_names:
                raise ValueError('Feature with name {0} does not exist. Existing features: {1}.'.format(key, self.all_feat_names))
            elif key in out_feat_keys[:f]:
                raise ValueError('Duplicate output feature key: {0}.'.format(key))

        # Find the highest output feature in `out_feat_keys
        max_out_feat = max([self.all_feat_names.index(key) for key in out_feat_keys])

        return out_feat_keys, max_out_feat

    def forward(self, x, out_feat_keys=None):
        """Forward an image `x` through the network and return the asked output features.
        Args:
          x: input image.
          out_feat_keys: a list/tuple with the feature names of the features
                that the function should return. By default the last feature of
                the network is returned.
        Return:
            out_feats: If multiple output features were asked then `out_feats`
                is a list with the asked output features placed in the same
                order as in `out_feat_keys`. If a single output feature was
                asked then `out_feats` is that output feature (and not a list).
        """
        out_feat_keys, max_out_feat = self._parse_out_keys_arg(out_feat_keys)
        out_feats = [None] * len(out_feat_keys)

        feat = x
        for f in range(max_out_feat+1):
            feat = self._feature_blocks[f](feat)
            key = self.all_feat_names[f]
            if key in out_feat_keys:
                out_feats[out_feat_keys.index(key)] = feat

        out_feats = out_feats[0] if len(out_feats)==1 else out_feats
        return out_feats


    def weight_initialization(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if m.weight.requires_grad:
                    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                    m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                if m.weight.requires_grad:
                    m.weight.data.fill_(1)
                if m.bias.requires_grad:
                    m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                if m.bias.requires_grad:
                    m.bias.data.zero_()


Utility functions


In [None]:
import requests
import math
import pickle
from collections import OrderedDict, Counter
import torch
import torch.nn.functional as F

def unnormalize(images):
    mu = [0.4914, 0.4822, 0.4465]
    stddev = [0.2023, 0.1994, 0.2010]


    mu = torch.FloatTensor(mu).view(1, 3, 1, 1)
    stddev = torch.FloatTensor(stddev).view(1, 3, 1, 1)
    return images * stddev + mu


def accuracy(output, target, topk=(1,)):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

class AverageMeter(object):
    """
    This class computes and stores the average and current value of a metric.

    It can be used to track various metrics during training or evaluation,
    such as loss, accuracy, etc. It provides methods to reset the meter,
    update it with new values, and retrieve the current and average values.

    Attributes:
        name (str): Name of the metric being tracked.
        fmt (str): Format string for printing the value (default: ':f').
        val (float): Current value of the metric.
        avg (float): Average value of the metric computed over updates.
        sum (float): Running sum of the metric values.
        count (int): Number of updates performed.
    """
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    """
    Displays progress information during training or evaluation.

    This class helps to print informative messages during training or evaluation.
    It takes the total number of batches and a list of AverageMeter objects
    as input, and then displays the current batch number, the values of the
    provided meters, and any additional prefix text.

    Attributes:
        batch_fmtstr (str): Format string for displaying the current batch number.
        meters (list): List of AverageMeter objects to track.
        prefix (str): Optional prefix string to add before the output message.
    """
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'


Dataloader

In [None]:
import os.path as osp
import random

import numpy as np
import cv2

import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms


def get_transform(train=True):
  if train:
      transform = transforms.Compose([
          transforms.RandomCrop(32, padding=4),
          transforms.RandomHorizontalFlip(),
          transforms.ToTensor(),
          transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
      ])
  else:
      transform = transforms.Compose([
          transforms.ToTensor(),
          transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
      ])
  return transform


def get_datasets():

  train_dset = datasets.CIFAR10('data', train=True,
                                transform=get_transform(train=True),
                                download=True)
  test_dset = datasets.CIFAR10('data', train=False,
                                transform=get_transform(train=False),
                                download=True)
  return train_dset, test_dset, len(train_dset.classes)


Main

In [None]:

def train(train_loader, model, linear_classifier, optimizer,
          optimizer_linear, epoch):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    top1 = AverageMeter('LinearAcc@1', ':6.2f')
    top5 = AverageMeter('LinearAcc@5', ':6.2f')
    avg_meters = {k: AverageMeter(k, fmt)
                  for k, fmt in zip(metrics, metrics_fmt)}
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, top1, top5] + list(avg_meters.values()),
        prefix="Epoch: [{}]".format(epoch)
    )

    # switch to train mode
    model.train()
    linear_classifier.train()

    end = time.time()
    for i, (images, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        # compute loss
        if isinstance(images, (tuple, list)):
            # Special case for SimCLR which returns a tuple of 2 image batches
            bs = images[0].shape[0]
            images = [x.cuda()
                      for x in images]
        else:
            bs = images.shape[0]
            images = images.cuda()
        target = target.cuda()
        out, zs = model(images)
        zs = zs.detach()
        for k, v in out.items():
            avg_meters[k].update(v.item(), bs)

        # compute gradient and optimizer step for ssl task
        optimizer.zero_grad()
        out['Loss'].backward()
        optimizer.step()

        # compute gradient and optimizer step for classifier
        logits = linear_classifier(zs)
        loss = F.cross_entropy(logits, target)

        acc1, acc5 = accuracy(logits, target, topk=(1, 5))
        top1.update(acc1[0], bs)
        top5.update(acc5[0], bs)

        optimizer_linear.zero_grad()
        loss.backward()
        optimizer_linear.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % log_interval == 0:
            progress.display(i)


def validate(val_loader, model, linear_classifier):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    top1 = AverageMeter('LinearAcc@1', ':6.2f')
    top5 = AverageMeter('LinearAcc@5', ':6.2f')
    avg_meters = {k: AverageMeter(k, fmt)
                  for k, fmt in zip(metrics, metrics_fmt)}
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, data_time, top1, top5] + list(avg_meters.values()),
        prefix="Test: "
    )

    # switch to evaluate mode
    model.eval()
    linear_classifier.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            # compute and measure loss
            if isinstance(images, (tuple, list)):
                # Special case for SimCLR which returns a tuple of 2 image batches
                bs = images[0].shape[0]
                images = [x.cuda()
                        for x in images]
            else:
                bs = images.shape[0]
                images = images.cuda()
            target = target.cuda()
            out, zs = model(images)
            for k, v in out.items():
                avg_meters[k].update(v.item(), bs)

            logits = linear_classifier(zs)
            acc1, acc5 = accuracy(logits, target, topk=(1, 5))
            top1.update(acc1[0], bs)
            top5.update(acc5[0], bs)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % log_interval == 0:
                progress.display(i)

    data = torch.FloatTensor([avg_meters['Loss'].avg, top1.avg, top5.avg] + [v.avg for v in avg_meters.values()])

    print_str = f' * LinearAcc@1 {data[1]:.3f} LinearAcc@5 {data[2]:.3f}'
    for i, (k, v) in enumerate(avg_meters.items()):
        print_str += f' {k} {data[i+3]:.3f}'
    print(print_str)

    return data[0], data[1]


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    filename = osp.join(output_dir, filename)
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, osp.join(output_dir, 'model_best.pth.tar'))

batch_size=128
epochs=5
optimizer='sgd' #'sgd|lars|adam (default: sgd)')
lr=0.01
momentum=0.9
weight_decay=5e-4
warmup_epochs=0 #'# of warmup epochs. If > 0, then the scheduler warmups from lr * batch_size / 256.')

best_loss = float('inf')
best_acc = 0.0

log_interval=10

output_dir = 'results'
if not osp.exists(output_dir):
    os.makedirs(output_dir)

total_batch_size = batch_size

train_dataset, val_dataset, n_classes = get_datasets()
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, num_workers=2,
    pin_memory=True, drop_last=True
)

val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size, num_workers=2,
    pin_memory=True, drop_last=True
)

model = RotationPrediction(n_classes)
metrics = model.metrics
metrics_fmt = model.metrics_fmt

torch.backends.cudnn.benchmark = True
model.cuda()

linear_classifier = model.construct_classifier().cuda()

optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum,
                            weight_decay=weight_decay, nesterov=True)
optimizer_linear = torch.optim.SGD(linear_classifier.parameters(), lr=lr,
                                    momentum=momentum, nesterov=True)


# Minimize SSL task loss, maximize linear classification accuracy
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, epochs, 0, -1)
scheduler_linear = lr_scheduler.CosineAnnealingLR(optimizer_linear, epochs, 0, -1)
if warmup_epochs > 0:
    scheduler = GradualWarmupScheduler(optimizer, multiplier=total_batch_size / 256.,
                                        total_epoch=warmup_epochs, after_scheduler=scheduler)
    scheduler_linear = GradualWarmupScheduler(optimizer, multiplier=total_batch_size / 256.,
                                              total_epoch=warmup_epochs,
                                              after_scheduler=scheduler_linear)

for epoch in range(epochs):

    train(train_loader, model, linear_classifier,
          optimizer, optimizer_linear, epoch)

    val_loss, val_acc = validate(val_loader, model, linear_classifier)

    scheduler.step()
    scheduler_linear.step()

    is_best = val_loss < best_loss
    best_loss = min(val_loss, best_loss)
    save_checkpoint({
        'epoch': epoch + 1,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict(),
        'state_dict_linear': linear_classifier.state_dict(),
        'optimizer_linear': optimizer_linear.state_dict(),
        'schedular_linear': scheduler_linear.state_dict(),
        'best_loss': best_loss,
        'best_acc': val_acc
    }, is_best)



3. Evaluation Utils

In [None]:
import torch.utils.data as data
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

def load_model_and_data():
    train_dset, test_dset, n_classes = get_datasets()
    train_loader = data.DataLoader(train_dset, batch_size=128, num_workers=2,
                                   pin_memory=True, shuffle=True)
    test_loader = data.DataLoader(test_dset, batch_size=128, num_workers=2,
                                  pin_memory=True, shuffle=True)

    ckpt_pth = osp.join('results', 'model_best.pth.tar')
    ckpt = torch.load(ckpt_pth, map_location='cpu')

    model = RotationPrediction(n_classes)

    model.load_state_dict(ckpt['state_dict'])

    model.cuda()
    model.eval()

    linear_classifier = model.construct_classifier()
    linear_classifier.load_state_dict(ckpt['state_dict_linear'])

    linear_classifier.cuda()
    linear_classifier.eval()

    return model, linear_classifier, train_loader, test_loader


def evaluate_accuracy(model, linear_classifier, train_loader, test_loader):
    train_acc1, train_acc5 = evaluate_classifier(model, linear_classifier, train_loader)
    test_acc1, test_acc5 = evaluate_classifier(model, linear_classifier, test_loader)

    print('Train Set')
    print(f'Top 1 Accuracy: {train_acc1}, Top 5 Accuracy: {train_acc5}\n')
    print('Test Set')
    print(f'Top 1 Accuracy: {test_acc1}, Top 5 Accuracy: {test_acc5}\n')


def evaluate_classifier(model, linear_classifier, loader):
    correct1, correct5 = 0, 0
    with torch.no_grad():
        for images, target in loader:
            images = images_to_cuda(images)
            target = target.cuda(non_blocking=True)
            out, zs = model(images)

            logits = linear_classifier(zs)
            acc1, acc5 = accuracy(logits, target, topk=(1, 5))

            correct1 += acc1.item() * logits.shape[0]
            correct5 += acc5.item() * logits.shape[0]
    total = len(loader.dataset)

    return correct1 / total, correct5 / total


def display_nearest_neighbors(model, loader, n_examples=4, k=16):
    """
    This function visualizes the nearest neighbors of a set of reference images
    in the latent space of a rotation prediction model.

    Args:
        model (nn.Module): The rotation prediction model.
        loader (torch.utils.data.DataLoader): The data loader containing the images.
        n_examples (int, optional): The number of reference images to use. Defaults to 4.
        k (int, optional): The number of nearest neighbors to find for each reference image. Defaults to 16.

    Prints and displays images using matplotlib:
        - Each reference image.
        - A grid of nearest neighbors for each reference image.
    """
    with torch.no_grad():
        all_images, all_zs = [], []
        for i, (images, _) in enumerate(loader):
            images = images_to_cuda(images)

            zs = model.encode(images)

            images = images.cpu()
            zs = zs.cpu()

            if i == 0:
                ref_zs = zs[:n_examples]
                ref_images = images[:n_examples]
                all_zs.append(zs[n_examples:])
                all_images.append(images[n_examples:])
            else:
                all_zs.append(zs)
                all_images.append(images)
        all_images = torch.cat(all_images, dim=0)
        all_zs = torch.cat(all_zs, dim=0)

        aa = (ref_zs ** 2).sum(dim=1).unsqueeze(dim=1)
        ab = torch.matmul(ref_zs, all_zs.t())
        bb = (all_zs ** 2).sum(dim=1).unsqueeze(dim=0)
        dists = torch.sqrt(aa - 2 * ab + bb)

        idxs = torch.topk(dists, k, dim=1, largest=False)[1]
        sel_images = torch.index_select(all_images, 0, idxs.view(-1))
        sel_images = unnormalize(sel_images.cpu())
        sel_images = sel_images.view(n_examples, k, *sel_images.shape[-3:])

        ref_images = unnormalize(ref_images.cpu())
        ref_images = (ref_images.permute(0, 2, 3, 1) * 255.).numpy().astype('uint8')

        for i in range(n_examples):
            print(f'Image {i + 1}')
            plt.figure()
            plt.axis('off')
            plt.imshow(ref_images[i])
            plt.show()

            grid_img = make_grid(sel_images[i], nrow=4)
            grid_img = (grid_img.permute(1, 2, 0) * 255.).numpy().astype('uint8')

            print(f'Top {k} Nearest Neighbors (in latent space)')
            plt.figure()
            plt.axis('off')
            plt.imshow(grid_img)
            plt.show()


def images_to_cuda(images):
    if isinstance(images, (tuple, list)):
        images = [x.cuda(non_blocking=True) for x in images]
    else:
        images = images.cuda(non_blocking=True)
    return images

### Linear Classification
We can use the feature maps in the later convolutional layers of the pretrained model as our learned representation for linear classification.

In [None]:
# Load the model, a separate linear classifier, training data loader, and testing data loader
model, linear_classifier, train_loader, test_loader = load_model_and_data()
# Evaluate the model's accuracy on the testing data
evaluate_accuracy(model, linear_classifier, train_loader, test_loader)

### Nearest Neighbors
Another way to evaluate our learned representation is to look at nearest neighbors to random encoded images in latent space.

In [None]:
display_nearest_neighbors(model, test_loader)