In [1]:
import socket

print("Hostname: " + socket.gethostname())

Hostname: sx-el-121920


In [2]:
import numpy as np
import torch

%reload_ext autoreload
%autoreload 2

print("Torch version:", torch.__version__)

Torch version: 1.13.0a0+d0d6b1f


## Load datasets

In [3]:
from utils.mnist_preprocessing import *
from utils.mnist_plotting import *

# dataset parameters
DATASET_BATCH_SIZE = 128
DATASET_SHUFFLE = True

In [4]:
from torchvision import transforms

# initialize datasets
train_set = DatasetMNIST(root='./data',
                            env='train',
                            color=True,
                            opt_postfix="2classes",
                            filter=[5,8],
                            first_color_max_nr=5,
                            transform= transforms.Compose([transforms.ToTensor()]))

val_set = DatasetMNIST(root='./data',
                            env='val',
                            color=True,
                            opt_postfix="2classes",
                            filter=[5,8],
                            first_color_max_nr=5,
                            transform= transforms.Compose([transforms.ToTensor()]))

test_set = DatasetMNIST(root='./data',
                            env='test',
                            color=True,
                            opt_postfix="2classes",
                            filter=[5,8],
                            first_color_max_nr=5,
                            transform= transforms.Compose([transforms.ToTensor()]))

test_set_fool = DatasetMNIST(root='./data',
                            env='test_fool',
                            color=True,
                            opt_postfix="2classes",
                            filter=[5,8],
                            first_color_max_nr=5,
                            transform= transforms.Compose([transforms.ToTensor()]))

# create dataloaders
train_loader = torch.utils.data.DataLoader(dataset=train_set,
                                            batch_size=DATASET_BATCH_SIZE,
                                            shuffle=DATASET_SHUFFLE,
                                            num_workers=10)

val_loader = torch.utils.data.DataLoader(dataset=val_set,
                                            batch_size=DATASET_BATCH_SIZE,
                                            shuffle=DATASET_SHUFFLE,
                                            num_workers=10)

test_loader = torch.utils.data.DataLoader(dataset=test_set,
                                            batch_size=DATASET_BATCH_SIZE,
                                            shuffle=DATASET_SHUFFLE,
                                            num_workers=10)

test_fool_loader = torch.utils.data.DataLoader(dataset=test_set_fool,
                                            batch_size=DATASET_BATCH_SIZE,
                                            shuffle=DATASET_SHUFFLE,
                                            num_workers=10)


MNIST dataset already exists
MNIST dataset already exists
MNIST dataset already exists
MNIST dataset already exists


In [5]:
print(f"Number of training samples: {len(train_loader.dataset.data_label_tuples)}")
print(f"Number of validation samples: {len(val_loader.dataset.data_label_tuples)}")
print(f"Number of test samples: {len(test_loader.dataset.data_label_tuples)}")
print(f"Number of test fool samples: {len(test_fool_loader.dataset.data_label_tuples)}")

Number of training samples: 9425
Number of validation samples: 1888
Number of test samples: 1866
Number of test fool samples: 1866


In [14]:
len(train_loader.dataset.data_label_tuples) % DATASET_BATCH_SIZE

81

## Set device (For number crunching)

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

## First version

In [68]:
import clip

model, preprocess = clip.load("RN50")
model.cuda().eval()

#mnist_template = 'a photo of the number: "{}".'
mnist_template = ['a photo of the number: "{}".', 'a photo of a red number: "{}".', 'a photo a green number: "{}".']
mnist_classes = ["5", "8"]

In [69]:
# build text strings
text_descriptions = []

for digit in mnist_classes:
    text_descriptions.extend([template.format(digit) for template in mnist_template])
    
text_descriptions

['a photo of the number: "5".',
 'a photo of a red number: "5".',
 'a photo a green number: "5".',
 'a photo of the number: "8".',
 'a photo of a red number: "8".',
 'a photo a green number: "8".']

In [75]:
# build text features
text_tokens = clip.tokenize(text_descriptions).cuda()

with torch.no_grad():
    text_features = model.encode_text(text_tokens)

In [76]:
text_features.shape # shape [number of texts times 1024]

torch.Size([6, 1024])

In [84]:
text_features

tensor([[ 0.2710,  0.3672,  0.3064,  ..., -0.1648,  0.0879,  0.1384],
        [ 0.3301,  0.2534,  0.1445,  ..., -0.0715,  0.0442,  0.1587],
        [ 0.2252,  0.2289,  0.3535,  ..., -0.0059, -0.1753,  0.1807],
        [ 0.0952,  0.2510, -0.0042,  ..., -0.2423,  0.0435,  0.0859],
        [ 0.1028,  0.1400, -0.1921,  ..., -0.1677,  0.0241,  0.1140],
        [ 0.0689,  0.1344,  0.0947,  ..., -0.1091, -0.1842,  0.1322]],
       device='cuda:0', dtype=torch.float16)

In [77]:
images, ground_truth_label, low_high_label, color_label = next(iter(train_loader))

In [78]:
transform = transforms.ToPILImage()

with torch.no_grad():
    # preprocess images
    images_new = []
    for img in images:
        images_new.append(preprocess(transform(img)))

    # building image features
    images = torch.tensor(np.stack(images_new)).cuda()
    
    # predict
    image_features = model.encode_image(images)
    image_features /= image_features.norm(dim=-1, keepdim=True)
    similarities = image_features @ text_features.T



In [80]:
image_features.shape # 128 images in batch times 1024 feature vector

torch.Size([128, 1024])

In [81]:
similarities.shape

torch.Size([128, 6])

## second version

We also experimented with ensembling over multiple zeroshot classifiers as another way of improving performance.
These classifiers are computed by using different context
prompts such as ‘A photo of a big {label}” and
“A photo of a small {label}”. We construct the
ensemble over the embedding space instead of probability
space. This allows us to cache a single set of averaged text
embeddings so that the compute cost of the ensemble is the
same as using a single classifier when amortized over many
predictions. We’ve observed ensembling across many generated zero-shot classifiers to reliably improve performance
and use it for the majority of datasets. On ImageNet, we
ensemble 80 different context prompts and this improves
performance by an additional 3.5% over the single default
prompt discussed above. When considered together, prompt
engineering and ensembling improve ImageNet accuracy
by almost 5%. In Figure 4 we visualize how prompt engineering and ensembling change the performance of a set of
CLIP models compared to the contextless baseline approach
of directly embedding the class name as done in Li et al.
(2017).

In [85]:
def text_feature_generator1234(clip_version, model, classnames, class_template):
    """
    Generates the text-feature matrix from given template sentences and classes and place it on the GPU.
    """
    with torch.no_grad():
        text_features = []
        for classname in classnames:
            texts = [template.format(classname) for template in class_template] # generate texts using templates with classes
            texts = clip_version.tokenize(texts).cuda() # generate text-tokens
            class_embeddings = model.encode_text(texts) # generate text embeddings -> torch.Size([nr_templates x 1024])
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) # normalize feature vector -> torch.Size([nr_templates x 1024])
            class_embedding = class_embeddings.mean(dim=0) # average over all template sentences
            class_embedding /= class_embedding.norm() # normalize feature vector -> torch.Size([1024])
            text_features.append(class_embedding) # generate feature matrix -> torch.Size([nr_classes x 1024])
        text_features = torch.stack(text_features, dim=1).cuda()
    return text_features

In [87]:
text_feature_generator1234(clip, model, mnist_classes, mnist_template)