In [2]:
import numpy as np
import torch
import clip
from tqdm.notebook import tqdm
from pkg_resources import packaging

print("Torch version:", torch.__version__)

Torch version: 1.13.0a0+d0d6b1f


In [3]:
clip.available_models()


['RN50',
 'RN101',
 'RN50x4',
 'RN50x16',
 'RN50x64',
 'ViT-B/32',
 'ViT-B/16',
 'ViT-L/14',
 'ViT-L/14@336px']

In [4]:
model, preprocess = clip.load("RN50")


In [5]:
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)

Model parameters: 102,007,137
Input resolution: 224
Context length: 77
Vocab size: 49408


In [6]:
mnist_classes = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]

In [7]:
mnist_template = 'a photo of the number: "{}".'

In [8]:
mnist_template.format(3)

'a photo of the number: "3".'

In [9]:
[mnist_template.format(template) for template in mnist_classes] 

['a photo of the number: "0".',
 'a photo of the number: "1".',
 'a photo of the number: "2".',
 'a photo of the number: "3".',
 'a photo of the number: "4".',
 'a photo of the number: "5".',
 'a photo of the number: "6".',
 'a photo of the number: "7".',
 'a photo of the number: "8".',
 'a photo of the number: "9".']

In [10]:
from utils.mnist_preprocessing import *
from torchvision import transforms

# parameters
size_of_batch = 128

# dataset preparation
train_set_gray = DatasetMNIST(root='./data',
                       env='train',
                       color=False,
                       transform= transforms.Compose([transforms.ToTensor()]))

val_set_gray = DatasetMNIST(root='./data',
                       env='val',
                       color=False,
                       transform= transforms.Compose([transforms.ToTensor()]))

test_set_gray = DatasetMNIST(root='./data',
                       env='test',
                       color=False,
                       transform= transforms.Compose([transforms.ToTensor()]))

train_set_color = DatasetMNIST(root='./data',
                       env='train',
                       color=True,
                       transform= transforms.Compose([transforms.ToTensor()]))

val_set_color = DatasetMNIST(root='./data',
                       env='val',
                       color=True,
                       transform= transforms.Compose([transforms.ToTensor()]))

test_set_color = DatasetMNIST(root='./data',
                       env='test',
                       color=True,
                       transform= transforms.Compose([transforms.ToTensor()]))

# dataloaders
train_loader_gray = torch.utils.data.DataLoader(dataset=train_set_gray,
                                           batch_size=size_of_batch,
                                           shuffle=True,
                                           num_workers=10)

val_loader_gray = torch.utils.data.DataLoader(dataset=val_set_gray,
                                           batch_size=size_of_batch,
                                           shuffle=True,
                                           num_workers=10)

test_loader_gray = torch.utils.data.DataLoader(dataset=test_set_gray,
                                           batch_size=size_of_batch,
                                           shuffle=True,
                                           num_workers=10)

train_loader_color = torch.utils.data.DataLoader(dataset=train_set_color,
                                           batch_size=size_of_batch,
                                           shuffle=True,
                                           num_workers=10)

val_loader_color = torch.utils.data.DataLoader(dataset=val_set_color,
                                           batch_size=size_of_batch,
                                           shuffle=True,
                                           num_workers=10)

test_loader_color = torch.utils.data.DataLoader(dataset=test_set_color,
                                           batch_size=size_of_batch,
                                           shuffle=True,
                                           num_workers=10)


MNIST dataset already exists
MNIST dataset already exists
MNIST dataset already exists
MNIST dataset already exists
MNIST dataset already exists
MNIST dataset already exists


In [11]:
def zeroshot_classifier(classnames, class_template):
    with torch.no_grad():
        zeroshot_weights = []
        for classname in tqdm(classnames):
            texts = [class_template.format(template) for template in classname] #format with class
            texts = clip.tokenize(texts).cuda() #tokenize
            class_embeddings = model.encode_text(texts) #embed with text encoder
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
            class_embedding = class_embeddings.mean(dim=0)
            class_embedding /= class_embedding.norm()
            zeroshot_weights.append(class_embedding)
        zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
    return zeroshot_weights


zeroshot_weights = zeroshot_classifier(mnist_classes, mnist_template)
res1 = zeroshot_weights

  0%|          | 0/10 [00:00<?, ?it/s]

In [12]:
with torch.no_grad():
    zeroshot_weights = []
    for classname in tqdm(mnist_classes):
        texts = [mnist_template.format(template) for template in classname] #format with class
        texts = clip.tokenize(texts).cuda() #tokenize
        class_embeddings = model.encode_text(texts) #embed with text encoder
        class_embeddings2 = class_embeddings/class_embeddings.norm(dim=-1, keepdim=True)
        class_embedding = class_embeddings2.mean(dim=0)
        class_embedding2 = class_embedding/class_embedding.norm()
        zeroshot_weights.append(class_embedding2)
    zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
res2 =  zeroshot_weights

  0%|          | 0/10 [00:00<?, ?it/s]

In [13]:
(res1 == res2).all()

tensor(True, device='cuda:0')

In [14]:
print(f"{class_embeddings.shape=}")
print(f"{class_embeddings=}")

class_embeddings.shape=torch.Size([1, 1024])
class_embeddings=tensor([[ 0.1313,  0.2695,  0.1028,  ..., -0.2388, -0.1234,  0.0644]],
       device='cuda:0', dtype=torch.float16)


In [15]:
print(f"{class_embeddings2.shape=}")
print(f"{class_embeddings2=}")

class_embeddings2.shape=torch.Size([1, 1024])
class_embeddings2=tensor([[ 0.0101,  0.0208,  0.0079,  ..., -0.0184, -0.0095,  0.0050]],
       device='cuda:0', dtype=torch.float16)


In [16]:
print(f"{class_embedding.shape=}")
print(f"{class_embedding=}")

class_embedding.shape=torch.Size([1024])
class_embedding=tensor([ 0.0101,  0.0208,  0.0079,  ..., -0.0184, -0.0095,  0.0050],
       device='cuda:0', dtype=torch.float16)


In [17]:
print(f"{class_embedding2.shape=}")
print(f"{class_embedding2=}")

class_embedding2.shape=torch.Size([1024])
class_embedding2=tensor([ 0.0101,  0.0208,  0.0079,  ..., -0.0184, -0.0095,  0.0050],
       device='cuda:0', dtype=torch.float16)


In [18]:
def accuracy(output, target, topk=(1,)):
    pred = output.topk(max(topk), 1, True, True)[1].t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]

In [26]:
# transformation
transform = transforms.ToPILImage()

with torch.no_grad():
    top1, top5, n = 0., 0., 0.
    for i, data in enumerate(tqdm(val_loader_gray)):
        images, ground_truth_label, low_high_label, color_label = data
        
        images_new = []
        for img in images:
            # process a batch of images
            images_new.append(preprocess(transform(img)))

        # building image features
        images = torch.tensor(np.stack(images_new))
        
        images = images.cuda()
        ground_truth_label = ground_truth_label.cuda()
        
        # predict
        image_features = model.encode_image(images)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        logits = 100. * image_features @ zeroshot_weights

        # measure accuracy
        acc1, acc5 = accuracy(logits, ground_truth_label, topk=(1, 5))
        top1 += acc1
        top5 += acc5
        n += images.size(0)

top1 = (top1 / n) * 100
top5 = (top5 / n) * 100 

print(f"Top-1 accuracy: {top1:.2f}")
print(f"Top-5 accuracy: {top5:.2f}")

  0%|          | 0/79 [00:00<?, ?it/s]

Top-1 accuracy: 58.37
Top-5 accuracy: 91.16


Training gray:  
Top-1 accuracy: 57.17  
Top-5 accuracy: 90.48

Validation gray:  
Top-1 accuracy: 58.35  
Top-5 accuracy: 91.16


In [20]:
images.shape

torch.Size([80, 3, 224, 224])

In [21]:
image_features.shape

torch.Size([80, 1024])

In [22]:
zeroshot_weights.shape

torch.Size([1024, 10])

In [23]:
for i in val_loader_gray:
    print(len(i))
    break

4


In [24]:
preprocess(images)

AttributeError: 'Tensor' object has no attribute 'convert'

In [None]:
from torchvision import transforms
transform = transforms.ToPILImage()

preprocess(transform(images[0])).shape

In [None]:
images.shape

In [None]:
torch.stack([preprocess(transform(images[0])),preprocess(transform(images[0]))]).shape