In [1]:
import torch
import torch.nn as nn
from torchvision import transforms, models

In [2]:
import random
from PIL import ImageFilter

class GaussianBlur(object):
    """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""

    def __init__(self, sigma=[.1, 2.]):
        self.sigma = sigma

    def __call__(self, x):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
        return x

In [3]:
class UnNormalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized image.
        """
        for t, m, s in zip(tensor, self.mean, self.std):
            t.mul_(s).add_(m)
            # The normalize code -> t.sub_(m).div_(s)
        return tensor

In [15]:
class MeanEmbeddingModel(nn.Module):
    def __init__(self, backbone, mode, k=100, image_shape=(3, 224, 224), feature_dim=2048, device='cpu'):
        super().__init__()
        self.backbone = backbone
        self.mode = mode
        self.k = k
        self.image_shape = image_shape
        self.feature_dim = feature_dim
        self.device = device

        if self.mode == 'ventral':
            self.transform = transforms.Compose([
                transforms.ToPILImage(),
                transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor()
            ])
        elif self.mode == 'dorsal':
            self.transform = transforms.Compose([
                UnNormalize(*imagenet_mean_std),
                transforms.ToPILImage(),
                transforms.RandomApply([
                    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
                ], p=0.8),
                transforms.RandomGrayscale(p=0.2),
                transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
                transforms.ToTensor(),
                transforms.Normalize(*imagenet_mean_std)
            ])

    def forward(self, x):
        # prepare the output data structure
        mean_embedding = torch.zeros((x.shape[0], self.feature_dim))
        # for each image in the batch x.
        for j in range(len(x)):
            # prepare the data tensor that will hold the features of the augmentations of image x[j].
            t_x = torch.zeros((self.k, *self.image_shape))
            # augment image x[j] k times.
            for i in range(self.k):
                t_x[i, :] = self.transform(x[j].cpu())
            # pass these augmentations through the backbone.
            z = self.backbone(t_x.to(self.device))
            # take the mean of these vectors to create a single feature vector for image x[j]
            mean_embedding[j, :] = z.mean(dim=0)
        # return the mean embeddings for all images in the batch
        return mean_embedding

### Setup the hyperparameters and a random batch of data

In [16]:
#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = 'cpu'
print(device)

batch_size = 256
imagenet_mean_std = [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]]

data = torch.randn((batch_size, 3, 224, 224)).to(device=device)
print(data.shape)

cpu
torch.Size([256, 3, 224, 224])


### Create our mean embedding model with a ResNet50 backbone

In [17]:
model = MeanEmbeddingModel(models.resnet50(), mode='ventral', k=32, device='cpu')

### Pass our data through the model
* This will run the forward pass which augments each image within the batch k=32 times.
* Each augmented image is passed through the backbone.
* Then all the feature vectors that come from augmentations of the same image are averaged into a single feature vector -- the mean embedding.
* The output is therefore one mean embedding per image in the batch.

In [18]:
mean_embedding = model(data)
mean_embedding.shape

RuntimeError: The expanded size of the tensor (2048) must match the existing size (1000) at non-singleton dimension 0.  Target sizes: [2048].  Tensor sizes: [1000]