In [1]:
import os
import clip
import torch

import numpy as np
from sklearn.linear_model import LogisticRegression
from torch.utils.data import DataLoader
from torchvision.datasets import Food101, OxfordIIITPet, StanfordCars, Flowers102
from tqdm import tqdm
import torch.nn.functional as F
from torchvision.models import resnet18, ResNet18_Weights
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device, download_root="./clip")
root = "./data"

In [4]:
def get_features(dataset):
    all_features = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(DataLoader(dataset, batch_size=100)):
            features = model.encode_image(images.to(device))

            all_features.append(features)
            all_labels.append(labels)

    return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy()

In [53]:
# train = CIFAR10(root, download=True, train=True, transform=preprocess)
# test = CIFAR10(root, download=True, train=False, transform=preprocess)
# # Calculate the image features
# train_features, train_labels = get_features(train)
# test_features, test_labels = get_features(test)

# # Perform logistic regression
# classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1)
# classifier.fit(train_features, train_labels)

# # Evaluate using the logistic regression classifier
# predictions = classifier.predict(test_features)
# accuracy = np.mean((test_labels == predictions).astype(float)) * 100.
# print(f"Accuracy = {accuracy:.3f}")


# text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in test.classes]).to(device)

# correct=0
# # Calculate features
# with torch.no_grad():
#     text_features = model.encode_text(text_inputs)
#     text_features /= text_features.norm(dim=-1, keepdim=True)
#     for images, labels in tqdm(DataLoader(test, batch_size=1)):
#         image_features = model.encode_image(images.to(device))
#         image_features /= image_features.norm(dim=-1, keepdim=True)
        
#         similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
#         values, indices = similarity[0].topk(1)
#         if labels.to('cpu') == indices.to('cpu') : correct+=1
        
# print(correct/len(test))

In [5]:
def run_CLIP_LR(train,test):

    # Calculate the image features
    train_features, train_labels = get_features(train)
    test_features, test_labels = get_features(test)

    # Perform logistic regression
    classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1)
    classifier.fit(train_features, train_labels)

    # Evaluate using the logistic regression classifier
    predictions = classifier.predict(test_features)
    accuracy = np.mean((test_labels == predictions).astype(float)) * 100.
    print(f"Accuracy = {accuracy:.3f}")

In [6]:
def run_CLIP_Zero(model, test, classes):
    text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in classes]).to(device)

    correct=0
    # Calculate features
    with torch.no_grad():
        text_features = model.encode_text(text_inputs)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        for images, labels in tqdm(DataLoader(test, batch_size=1)):
            image_features = model.encode_image(images.to(device))
            image_features /= image_features.norm(dim=-1, keepdim=True)
            
            similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
            values, indices = similarity[0].topk(1)
            if labels.to('cpu') == indices.to('cpu') : correct+=1
            
    print(correct/len(test))

VITB:

In [6]:
train = CIFAR10(root, download=True, train=True, transform=preprocess)
test = CIFAR10(root, download=True, train=False, transform=preprocess)
run_CLIP_LR(train,test)
run_CLIP_Zero(model, test,test.classes)

Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 500/500 [01:02<00:00,  8.02it/s]
100%|██████████| 100/100 [00:12<00:00,  7.99it/s]
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:   54.8s finished


Accuracy = 95.000


100%|██████████| 10000/10000 [01:53<00:00, 88.45it/s]

0.8878





In [13]:
train1 = SVHN(root, download=True, split="train", transform=preprocess)
test1 = SVHN(root, download=True, split="test", transform=preprocess)
run_CLIP_LR(train1,test1)

Using downloaded and verified file: ./data\train_32x32.mat
Using downloaded and verified file: ./data\test_32x32.mat


100%|██████████| 733/733 [01:30<00:00,  8.11it/s]
100%|██████████| 261/261 [00:33<00:00,  7.68it/s]
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.


Accuracy = 65.396


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:  1.5min finished


In [6]:
train1 = SVHN(root, download=True, split="train", transform=preprocess)
test1 = SVHN(root, download=True, split="test", transform=preprocess)

Using downloaded and verified file: ./data\train_32x32.mat
Using downloaded and verified file: ./data\test_32x32.mat


In [None]:
_, idx = np.unique(test1.labels, return_index=True)
run_CLIP_Zero(model1, test1, test1.labels[np.sort(idx)])

  0%|          | 0/26032 [00:00<?, ?it/s]

100%|██████████| 26032/26032 [15:32<00:00, 27.91it/s]  

0.11885371850030732





In [None]:
train = Flowers102(root, download=True, split = 'train', transform=preprocess)
test = Flowers102(root, download=True, split = 'test', transform=preprocess)
run_CLIP_LR(train,test)

100%|██████████| 11/11 [00:06<00:00,  1.79it/s]
100%|██████████| 62/62 [00:35<00:00,  1.75it/s]


Accuracy = 93.316


In [74]:
train = Flowers102(root, download=True, split = 'train', transform=preprocess)
test = Flowers102(root, download=True, split = 'test', transform=preprocess)
classes = [
    'pink primrose',
    'hard-leaved pocket orchid',
    'canterbury bells',
    'sweet pea',
    'english marigold',
    'tiger lily',
    'moon orchid',
    'bird of paradise',
    'monkshood',
    'globe thistle',
    'snapdragon',
    "colt's foot",
    'king protea',
    'spear thistle',
    'yellow iris',
    'globe flower',
    'purple coneflower',
    'peruvian lily',
    'balloon flower',
    'giant white arum lily',
    'fire lily',
    'pincushion flower',
    'fritillary',
    'red ginger',
    'grape hyacinth',
    'corn poppy',
    'prince of wales feathers',
    'stemless gentian',
    'artichoke',
    'sweet william',
    'carnation',
    'garden phlox',
    'love in the mist',
    'mexican aster',
    'alpine sea holly',
    'ruby-lipped cattleya',
    'cape flower',
    'great masterwort',
    'siam tulip',
    'lenten rose',
    'barbeton daisy',
    'daffodil',
    'sword lily',
    'poinsettia',
    'bolero deep blue',
    'wallflower',
    'marigold',
    'buttercup',
    'oxeye daisy',
    'common dandelion',
    'petunia',
    'wild pansy',
    'primula',
    'sunflower',
    'pelargonium',
    'bishop of llandaff',
    'gaura',
    'geranium',
    'orange dahlia',
    'pink and yellow dahlia',
    'cautleya spicata',
    'japanese anemone',
    'black-eyed susan',
    'silverbush',
    'californian poppy',
    'osteospermum',
    'spring crocus',
    'bearded iris',
    'windflower',
    'tree poppy',
    'gazania',
    'azalea',
    'water lily',
    'rose',
    'thorn apple',
    'morning glory',
    'passion flower',
    'lotus',
    'toad lily',
    'anthurium',
    'frangipani',
    'clematis',
    'hibiscus',
    'columbine',
    'desert-rose',
    'tree mallow',
    'magnolia',
    'cyclamen',
    'watercress',
    'canna lily',
    'hippeastrum',
    'bee balm',
    'air plant',
    'foxglove',
    'bougainvillea',
    'camellia',
    'mallow',
    'mexican petunia',
    'bromelia',
    'blanket flower',
    'trumpet creeper',
    'blackberry lily',
]


In [75]:
run_CLIP_Zero(model, test, classes)

100%|██████████| 6149/6149 [01:33<00:00, 65.49it/s]

0.6602699625955439





In [38]:
train2 = CIFAR100(root, download=True, train=True, transform=preprocess)
test2 = CIFAR100(root, download=True,  train=False, transform=preprocess)
run_CLIP_LR(train2,test2)
run_CLIP_Zero(model, test2,test2.classes)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data\cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:15<00:00, 11085518.30it/s]


Extracting ./data\cifar-100-python.tar.gz to ./data
Files already downloaded and verified


100%|██████████| 500/500 [00:59<00:00,  8.46it/s]
100%|██████████| 100/100 [00:11<00:00,  8.50it/s]
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:  3.4min finished


Accuracy = 80.030


100%|██████████| 10000/10000 [01:42<00:00, 98.01it/s]

0.617





In [None]:
train = OxfordIIITPet(root, download=True, transform=preprocess)
test = OxfordIIITPet(root, download=True, split = 'test', transform=preprocess)
run_CLIP_LR(train,test)
run_CLIP_Zero(model, test,test.classes)

100%|██████████| 37/37 [00:31<00:00,  1.19it/s]
100%|██████████| 37/37 [00:31<00:00,  1.19it/s]


Accuracy = 89.207


100%|██████████| 3669/3669 [00:52<00:00, 70.16it/s]

0.8440992095938948





res50

In [76]:
model1, preprocess1 = clip.load('RN50', device, download_root="./clip")

In [39]:
train = CIFAR10(root, download=True, train=True, transform=preprocess1)
test = CIFAR10(root, download=True, train=False, transform=preprocess1)
run_CLIP_LR(train,test)
run_CLIP_Zero(model1,test,test.classes)

Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 500/500 [01:02<00:00,  7.95it/s]
100%|██████████| 100/100 [00:13<00:00,  7.53it/s]
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:   57.4s finished


Accuracy = 95.000


100%|██████████| 10000/10000 [01:48<00:00, 92.00it/s]

0.6872





In [40]:
train1 = SVHN(root, download=True, split="train", transform=preprocess1)
test1 = SVHN(root, download=True, split="test", transform=preprocess1)
run_CLIP_LR(train1,test1)

Using downloaded and verified file: ./data\train_32x32.mat
Using downloaded and verified file: ./data\test_32x32.mat


100%|██████████| 733/733 [01:48<00:00,  6.75it/s]
100%|██████████| 261/261 [00:58<00:00,  4.46it/s]
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:  2.2min finished


Accuracy = 65.396


In [43]:
_, idx = np.unique(test1.labels, return_index=True)
run_CLIP_Zero(model1, test1, test1.labels[np.sort(idx)])

  0%|          | 0/26032 [00:00<?, ?it/s]

100%|██████████| 26032/26032 [15:32<00:00, 27.91it/s]  

0.11885371850030732





In [42]:
train2 = CIFAR100(root, download=True, train=True, transform=preprocess1)
test2 = CIFAR100(root, download=True,  train=False, transform=preprocess1)
run_CLIP_LR(train2,test2)
run_CLIP_Zero(model1, test2,test2.classes)

Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 500/500 [01:19<00:00,  6.26it/s]
100%|██████████| 100/100 [00:15<00:00,  6.48it/s]
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:  5.3min finished


Accuracy = 80.030


100%|██████████| 10000/10000 [05:11<00:00, 32.13it/s]

0.3902





In [None]:
train1 = OxfordIIITPet(root, download=True, transform=preprocess1)
test1 = OxfordIIITPet(root, download=True, split = 'test', transform=preprocess1)
run_CLIP_LR(train1,test1)
run_CLIP_Zero(model1, test1,test1.classes)

100%|██████████| 37/37 [00:16<00:00,  2.20it/s]
100%|██████████| 37/37 [00:17<00:00,  2.13it/s]


Accuracy = 83.156


100%|██████████| 3669/3669 [00:52<00:00, 70.48it/s]

0.8315617334423548





In [None]:
train1 = Flowers102(root, download=True, split = 'train', transform=preprocess1)
test1 = Flowers102(root, download=True, split = 'test', transform=preprocess1)
run_CLIP_LR(train1,test1)

  0%|          | 0/11 [00:00<?, ?it/s]

100%|██████████| 11/11 [00:06<00:00,  1.81it/s]
100%|██████████| 62/62 [00:36<00:00,  1.70it/s]


Accuracy = 82.095


In [77]:
train3 = Flowers102(root, download=True, split = 'train', transform=preprocess)
test3 = Flowers102(root, download=True, split = 'test', transform=preprocess)
classes = [
    'pink primrose',
    'hard-leaved pocket orchid',
    'canterbury bells',
    'sweet pea',
    'english marigold',
    'tiger lily',
    'moon orchid',
    'bird of paradise',
    'monkshood',
    'globe thistle',
    'snapdragon',
    "colt's foot",
    'king protea',
    'spear thistle',
    'yellow iris',
    'globe flower',
    'purple coneflower',
    'peruvian lily',
    'balloon flower',
    'giant white arum lily',
    'fire lily',
    'pincushion flower',
    'fritillary',
    'red ginger',
    'grape hyacinth',
    'corn poppy',
    'prince of wales feathers',
    'stemless gentian',
    'artichoke',
    'sweet william',
    'carnation',
    'garden phlox',
    'love in the mist',
    'mexican aster',
    'alpine sea holly',
    'ruby-lipped cattleya',
    'cape flower',
    'great masterwort',
    'siam tulip',
    'lenten rose',
    'barbeton daisy',
    'daffodil',
    'sword lily',
    'poinsettia',
    'bolero deep blue',
    'wallflower',
    'marigold',
    'buttercup',
    'oxeye daisy',
    'common dandelion',
    'petunia',
    'wild pansy',
    'primula',
    'sunflower',
    'pelargonium',
    'bishop of llandaff',
    'gaura',
    'geranium',
    'orange dahlia',
    'pink and yellow dahlia',
    'cautleya spicata',
    'japanese anemone',
    'black-eyed susan',
    'silverbush',
    'californian poppy',
    'osteospermum',
    'spring crocus',
    'bearded iris',
    'windflower',
    'tree poppy',
    'gazania',
    'azalea',
    'water lily',
    'rose',
    'thorn apple',
    'morning glory',
    'passion flower',
    'lotus',
    'toad lily',
    'anthurium',
    'frangipani',
    'clematis',
    'hibiscus',
    'columbine',
    'desert-rose',
    'tree mallow',
    'magnolia',
    'cyclamen',
    'watercress',
    'canna lily',
    'hippeastrum',
    'bee balm',
    'air plant',
    'foxglove',
    'bougainvillea',
    'camellia',
    'mallow',
    'mexican petunia',
    'bromelia',
    'blanket flower',
    'trumpet creeper',
    'blackberry lily',
]


In [78]:
run_CLIP_Zero(model1, test3, classes)

100%|██████████| 6149/6149 [01:26<00:00, 71.16it/s]

0.6579931696210766





RESNET 18 

In [7]:

# Initialize the Weight Transforms

# Initialize model
weights = ResNet18_Weights.DEFAULT
preprocess2 = weights.transforms()
model = resnet18(weights=weights)
for param in model.parameters():
    param.requires_grad = False

model.fc = nn.Linear(512, 10)


In [68]:
train = CIFAR10(root, download=True, train=True, transform=preprocess2)
test = CIFAR10(root, download=True, train=False, transform=preprocess2)

Files already downloaded and verified
Files already downloaded and verified


In [70]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

model.to(device)

for epoch in range(5):
    for images, target in tqdm(DataLoader(train, 64)):
        outputs = model(images.to(device))
        loss = criterion(outputs, target.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    

100%|██████████| 782/782 [01:45<00:00,  7.39it/s]
100%|██████████| 782/782 [01:48<00:00,  7.19it/s]
100%|██████████| 782/782 [01:48<00:00,  7.22it/s]
100%|██████████| 782/782 [01:47<00:00,  7.27it/s]
100%|██████████| 782/782 [01:48<00:00,  7.20it/s]


In [71]:
total = 0
correct = 0
model.eval()
with torch.no_grad():
    for inputs, labels in tqdm(DataLoader(test, 64)):
        outputs = model(inputs.to(device))
        
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels.to(device)).sum().item()

print(correct/total)

100%|██████████| 157/157 [00:21<00:00,  7.36it/s]

0.781





In [73]:
weights = ResNet18_Weights.DEFAULT
preprocess2 = weights.transforms()
model = resnet18(weights=weights)
for param in model.parameters():
    param.requires_grad = False

model.fc = nn.Linear(512, 100)


In [74]:
train = CIFAR100(root, download=True, train=True, transform=preprocess2)
test = CIFAR100(root, download=True, train=False, transform=preprocess2)

Files already downloaded and verified
Files already downloaded and verified


In [75]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

model.to(device)

for epoch in range(5):
    for images, target in tqdm(DataLoader(train, 64)):
        outputs = model(images.to(device))
        loss = criterion(outputs, target.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    

  0%|          | 0/782 [00:00<?, ?it/s]

100%|██████████| 782/782 [01:36<00:00,  8.09it/s]
100%|██████████| 782/782 [01:36<00:00,  8.09it/s]
100%|██████████| 782/782 [01:45<00:00,  7.40it/s]
100%|██████████| 782/782 [01:49<00:00,  7.17it/s]
100%|██████████| 782/782 [01:49<00:00,  7.16it/s]


In [76]:
total = 0
correct = 0
model.eval()
with torch.no_grad():
    for inputs, labels in tqdm(DataLoader(test, 64)):
        outputs = model(inputs.to(device))
        
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels.to(device)).sum().item()

print(correct/total)

100%|██████████| 157/157 [00:22<00:00,  6.97it/s]

0.517





In [77]:
weights = ResNet18_Weights.DEFAULT
preprocess2 = weights.transforms()
model = resnet18(weights=weights)
for param in model.parameters():
    param.requires_grad = False

model.fc = nn.Linear(512, 10)


In [79]:
train = SVHN(root, download=True, split="train", transform=preprocess2)
test = SVHN(root, download=True, split="test", transform=preprocess2)

Using downloaded and verified file: ./data\train_32x32.mat
Using downloaded and verified file: ./data\test_32x32.mat


In [80]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

model.to(device)

for epoch in range(5):
    for images, target in tqdm(DataLoader(train, 64)):
        outputs = model(images.to(device))
        loss = criterion(outputs, target.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    

100%|██████████| 1145/1145 [02:41<00:00,  7.09it/s]
100%|██████████| 1145/1145 [02:44<00:00,  6.96it/s]
100%|██████████| 1145/1145 [02:41<00:00,  7.10it/s]
100%|██████████| 1145/1145 [02:38<00:00,  7.23it/s]
100%|██████████| 1145/1145 [02:35<00:00,  7.36it/s]


In [81]:
total = 0
correct = 0
model.eval()
with torch.no_grad():
    for inputs, labels in tqdm(DataLoader(test, 64)):
        outputs = model(inputs.to(device))
        
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels.to(device)).sum().item()

print(correct/total)

100%|██████████| 407/407 [00:54<00:00,  7.43it/s]

0.48060079901659497





Tip Adapter

In [36]:
def run_CLIP_Tip(model, test, classes, cache_img, cache_text, beta=1, alpha=5):
    text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in classes]).to(device)
    cache_img, cache_text = cache_img.to(device), cache_text.to(device)

    correct=0
    # Calculate features
    with torch.no_grad():
        text_features = model.encode_text(text_inputs)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        for images, labels in tqdm(DataLoader(test, batch_size=1)):
            image_features = model.encode_image(images.to(device))
            image_features /= image_features.norm(dim=-1, keepdim=True)
            
            clip_logits  = (100.0 * image_features @ text_features.T)
            similarity = clip_logits.softmax(dim=-1)
        
            affinity = image_features @ cache_img
            # print(image_features.shape)
            # print(((-1) * (beta - beta * affinity)).exp())
            cache_logits = ((-1) * (beta - beta * affinity)).exp() @ cache_text
            
            tip_logits = clip_logits + cache_logits * alpha
            values, indices = tip_logits[0].topk(1)
            if labels.to('cpu') == indices.to('cpu') : correct+=1
    print(correct/len(test))

In [42]:
def build_cache_model(clip_model, train_loader):
    cache_img = []
    cache_text = []

    with torch.no_grad():
        for i in range(1):
            train_features = []
            for images, target in tqdm(train_loader):
                images = images.cuda()
                image_features = clip_model.encode_image(images)
                train_features.append(image_features)
                # print(image_features.shape)
                cache_text.append(target)
            cache_img.append(torch.cat(train_features, dim=0).unsqueeze(0))
    cache_img = torch.cat(cache_img, dim=0).mean(dim=0)
    cache_img /= cache_img.norm(dim=-1, keepdim=True)
    cache_img = cache_img.permute(1, 0)
    cache_text = F.one_hot(torch.cat(cache_text, dim=0)).half()

    return cache_img, cache_text

In [14]:
cache_img = []
cache_text = []

with torch.no_grad():
    for i in range(1):
        train_features = []
        for images, target in tqdm(DataLoader(train,batch_size=1000)):
            images = images.cuda()
            image_features = model.encode_image(images)
            train_features.append(image_features)
            # print(image_features.shape)
            cache_text.append(target)
        cache_img.append(torch.cat(train_features, dim=0).unsqueeze(0))
cache_img = torch.cat(cache_img, dim=0).mean(dim=0)
cache_img /= cache_img.norm(dim=-1, keepdim=True)
print(cache_img.shape)
cache_img = cache_img.permute(1, 0)
cache_text = F.one_hot(torch.cat(cache_text, dim=0)).half()

100%|██████████| 50/50 [01:00<00:00,  1.22s/it]

torch.Size([50000, 512])





In [24]:
train = CIFAR10(root, download=True, train=True, transform=preprocess)
test = CIFAR10(root, download=True, train=False, transform=preprocess)
cache_img, cache_text = build_cache_model(model, train)
torch.save(cache_img, root + '/keys_' + "CIFAR10.pt")
torch.save(cache_text, root + '/values_' + "CIFAR10.pt")
run_CLIP_Tip(model, test, test.classes, cache_img, cache_text, beta=1, alpha=5)

Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 50/50 [01:05<00:00,  1.30s/it]
100%|██████████| 10000/10000 [02:06<00:00, 79.22it/s]

0.8671





In [25]:
run_CLIP_Tip(model, test, test.classes, cache_img, cache_text, beta=1, alpha=2)

100%|██████████| 10000/10000 [02:02<00:00, 81.69it/s]

0.8663





In [26]:
train = CIFAR100(root, download=True, train=True, transform=preprocess)
test = CIFAR100(root, download=True,  train=False, transform=preprocess)
cache_img, cache_text = build_cache_model(model, train)
torch.save(cache_img, root + '/keys_' + "CIFAR100.pt")
torch.save(cache_text, root + '/values_' + "CIFAR100.pt")
run_CLIP_Tip(model, test, test.classes, cache_img, cache_text, beta=1, alpha=5)

Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 50/50 [01:07<00:00,  1.36s/it]
100%|██████████| 10000/10000 [02:06<00:00, 78.77it/s]

0.6143





In [27]:
run_CLIP_Tip(model, test, test.classes, cache_img, cache_text, beta=1, alpha=2)

100%|██████████| 10000/10000 [01:59<00:00, 83.64it/s]

0.6284





In [28]:
train = SVHN(root, download=True, split="train", transform=preprocess)
test = SVHN(root, download=True, split="test", transform=preprocess)
cache_img, cache_text = build_cache_model(model, train)
torch.save(cache_img, root + '/keys_' + "SVHN.pt")
torch.save(cache_text, root + '/values_' + "SVHN.pt")
run_CLIP_Tip(model, test, test.classes, cache_img, cache_text, beta=1, alpha=5)

Using downloaded and verified file: ./data\train_32x32.mat
Using downloaded and verified file: ./data\test_32x32.mat


100%|██████████| 74/74 [01:34<00:00,  1.27s/it]


AttributeError: 'SVHN' object has no attribute 'classes'

In [31]:
_, idx = np.unique(test.labels, return_index=True)
run_CLIP_Tip(model, test, test.labels[np.sort(idx)], cache_img, cache_text, beta=1, alpha=5)

100%|██████████| 26032/26032 [05:06<00:00, 85.05it/s]

0.1958743085433313





In [34]:
run_CLIP_Tip(model, test, test.labels[np.sort(idx)], cache_img, cache_text, beta=1, alpha=0.5)

100%|██████████| 26032/26032 [04:53<00:00, 88.72it/s] 

0.1958743085433313





Reduce samples

In [39]:
import torch
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict

# Assuming you have a custom dataset (CustomDataset) inheriting from PyTorch Dataset class
# Function to create a custom data loader that loads a specified number of samples per class
def create_custom_dataloader(dataset, samples_per_class=10, batch_size=32, shuffle=True):
    # Create a dictionary to store indices of each class
    class_indices = defaultdict(list)

    # Populate class_indices with indices of each class in the dataset
    for idx, (_, label) in enumerate(dataset):
        class_indices[label].append(idx)

    # Select a fixed number of samples from each class
    selected_indices = []
    for class_idx in class_indices.values():
        selected_indices.extend(class_idx[:samples_per_class])

    # Create a sampler using the selected indices
    sampler = torch.utils.data.sampler.SubsetRandomSampler(selected_indices)

    # Create a DataLoader using the created sampler
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=sampler,
        shuffle=shuffle
    )

    return dataloader


# Create a DataLoader that loads 10 samples per class


In [79]:
train = CIFAR10(root, download=True, train=True, transform=preprocess)
test = CIFAR10(root, download=True, train=False, transform=preprocess)

Files already downloaded and verified
Files already downloaded and verified


In [47]:
CIFAR10_Loader = create_custom_dataloader(train, samples_per_class=500, batch_size=32, shuffle=False)

In [48]:
cache_img, cache_text = build_cache_model(model, CIFAR10_Loader)

100%|██████████| 157/157 [00:07<00:00, 22.12it/s]


In [49]:
run_CLIP_Tip(model, test, test.classes, cache_img, cache_text, beta=1, alpha=5)

100%|██████████| 10000/10000 [01:43<00:00, 96.40it/s]

0.8733





In [80]:
for i in [1000,100,50]:
    print(f"samples_per_class is {i}")
    CIFAR10_Loader = create_custom_dataloader(train, samples_per_class=i, batch_size=32, shuffle=False)
    cache_img, cache_text = build_cache_model(model, CIFAR10_Loader)
    run_CLIP_Tip(model, test, test.classes, cache_img, cache_text, beta=1, alpha=5)

samples_per_class is 1000


100%|██████████| 313/313 [00:13<00:00, 22.51it/s]
100%|██████████| 10000/10000 [01:44<00:00, 95.51it/s]


0.871
samples_per_class is 100


100%|██████████| 32/32 [00:01<00:00, 20.07it/s]
100%|██████████| 10000/10000 [01:47<00:00, 92.65it/s]


0.8853
samples_per_class is 50


100%|██████████| 16/16 [00:01<00:00, 15.70it/s]
100%|██████████| 10000/10000 [01:53<00:00, 88.49it/s]

0.8929





In [82]:
for i in [25,10,5]:
    print(f"samples_per_class is {i}")
    CIFAR10_Loader = create_custom_dataloader(train, samples_per_class=i, batch_size=32, shuffle=False)
    cache_img, cache_text = build_cache_model(model, CIFAR10_Loader)
    run_CLIP_Tip(model, test, test.classes, cache_img, cache_text, beta=1, alpha=5)

samples_per_class is 25


100%|██████████| 8/8 [00:00<00:00, 12.57it/s]
100%|██████████| 10000/10000 [01:42<00:00, 97.94it/s]


0.8923
samples_per_class is 10


100%|██████████| 4/4 [00:00<00:00,  8.12it/s]
100%|██████████| 10000/10000 [01:48<00:00, 91.77it/s]


0.8973
samples_per_class is 5


100%|██████████| 2/2 [00:00<00:00, 11.96it/s]
100%|██████████| 10000/10000 [01:53<00:00, 87.95it/s]

0.9





In [83]:
train = CIFAR100(root, download=True, train=True, transform=preprocess)
test = CIFAR100(root, download=True,  train=False, transform=preprocess)
for i in [500,100,50,25,5]:
    print(f"samples_per_class is {i}")
    CIFAR100_Loader = create_custom_dataloader(train, samples_per_class=i, batch_size=32, shuffle=False)
    cache_img, cache_text = build_cache_model(model, CIFAR100_Loader)
    run_CLIP_Tip(model, test, test.classes, cache_img, cache_text, beta=1, alpha=5)

Files already downloaded and verified
Files already downloaded and verified
samples_per_class is 500


100%|██████████| 1563/1563 [01:12<00:00, 21.44it/s]
100%|██████████| 10000/10000 [02:01<00:00, 82.31it/s]


0.6142
samples_per_class is 100


100%|██████████| 313/313 [00:15<00:00, 19.94it/s]
100%|██████████| 10000/10000 [01:51<00:00, 89.82it/s]


0.6367
samples_per_class is 50


100%|██████████| 157/157 [00:07<00:00, 21.29it/s]
100%|██████████| 10000/10000 [01:44<00:00, 95.46it/s]


0.6594
samples_per_class is 25


100%|██████████| 79/79 [00:03<00:00, 20.07it/s]
100%|██████████| 10000/10000 [01:46<00:00, 94.00it/s]


0.6649
samples_per_class is 5


100%|██████████| 16/16 [00:01<00:00, 15.49it/s]
100%|██████████| 10000/10000 [01:52<00:00, 89.12it/s]

0.6416





In [85]:
train = SVHN(root, download=True, split="train", transform=preprocess)
test = SVHN(root, download=True, split="test", transform=preprocess)
for i in [1000,500,100,50,25,5]:
    print(f"samples_per_class is {i}")
    SVHN_loader = create_custom_dataloader(train, samples_per_class=i, batch_size=32, shuffle=False)
    cache_img, cache_text = build_cache_model(model, SVHN_loader)
    _, idx = np.unique(test.labels, return_index=True)
    run_CLIP_Tip(model, test, test.labels[np.sort(idx)], cache_img, cache_text, beta=1, alpha=5)

Using downloaded and verified file: ./data\train_32x32.mat
Using downloaded and verified file: ./data\test_32x32.mat
samples_per_class is 1000


100%|██████████| 313/313 [00:13<00:00, 22.79it/s]
100%|██████████| 26032/26032 [04:33<00:00, 95.14it/s] 


0.23528733866011065
samples_per_class is 500


100%|██████████| 157/157 [00:07<00:00, 20.99it/s]
100%|██████████| 26032/26032 [04:38<00:00, 93.39it/s] 


0.21293023970497849
samples_per_class is 100


100%|██████████| 32/32 [00:01<00:00, 20.87it/s]
100%|██████████| 26032/26032 [04:14<00:00, 102.18it/s]


0.15016133988936695
samples_per_class is 50


100%|██████████| 16/16 [00:00<00:00, 18.49it/s]
100%|██████████| 26032/26032 [04:15<00:00, 101.91it/s]


0.12150430239704979
samples_per_class is 25


100%|██████████| 8/8 [00:00<00:00, 14.97it/s]
100%|██████████| 26032/26032 [04:18<00:00, 100.82it/s]


0.122579901659496
samples_per_class is 5


100%|██████████| 2/2 [00:00<00:00,  5.11it/s]
100%|██████████| 26032/26032 [04:17<00:00, 101.18it/s]

0.10675322679778734





In [None]:
train = Flowers102(root, download=True, split = 'train', transform=preprocess)
test = Flowers102(root, download=True, split = 'test', transform=preprocess)
classes = [
    'pink primrose',
    'hard-leaved pocket orchid',
    'canterbury bells',
    'sweet pea',
    'english marigold',
    'tiger lily',
    'moon orchid',
    'bird of paradise',
    'monkshood',
    'globe thistle',
    'snapdragon',
    "colt's foot",
    'king protea',
    'spear thistle',
    'yellow iris',
    'globe flower',
    'purple coneflower',
    'peruvian lily',
    'balloon flower',
    'giant white arum lily',
    'fire lily',
    'pincushion flower',
    'fritillary',
    'red ginger',
    'grape hyacinth',
    'corn poppy',
    'prince of wales feathers',
    'stemless gentian',
    'artichoke',
    'sweet william',
    'carnation',
    'garden phlox',
    'love in the mist',
    'mexican aster',
    'alpine sea holly',
    'ruby-lipped cattleya',
    'cape flower',
    'great masterwort',
    'siam tulip',
    'lenten rose',
    'barbeton daisy',
    'daffodil',
    'sword lily',
    'poinsettia',
    'bolero deep blue',
    'wallflower',
    'marigold',
    'buttercup',
    'oxeye daisy',
    'common dandelion',
    'petunia',
    'wild pansy',
    'primula',
    'sunflower',
    'pelargonium',
    'bishop of llandaff',
    'gaura',
    'geranium',
    'orange dahlia',
    'pink and yellow dahlia',
    'cautleya spicata',
    'japanese anemone',
    'black-eyed susan',
    'silverbush',
    'californian poppy',
    'osteospermum',
    'spring crocus',
    'bearded iris',
    'windflower',
    'tree poppy',
    'gazania',
    'azalea',
    'water lily',
    'rose',
    'thorn apple',
    'morning glory',
    'passion flower',
    'lotus',
    'toad lily',
    'anthurium',
    'frangipani',
    'clematis',
    'hibiscus',
    'columbine',
    'desert-rose',
    'tree mallow',
    'magnolia',
    'cyclamen',
    'watercress',
    'canna lily',
    'hippeastrum',
    'bee balm',
    'air plant',
    'foxglove',
    'bougainvillea',
    'camellia',
    'mallow',
    'mexican petunia',
    'bromelia',
    'blanket flower',
    'trumpet creeper',
    'blackberry lily',
]
for i in [1000,500,100,50]:
    print(f"samples_per_class is {i}")
    SVHN_loader = create_custom_dataloader(train, samples_per_class=i, batch_size=32, shuffle=False)
    cache_img, cache_text = build_cache_model(model, SVHN_loader)

    run_CLIP_Tip(model, test, classes, cache_img, cache_text, beta=1, alpha=5)


100%|██████████| 26032/26032 [05:06<00:00, 85.05it/s]

0.1958743085433313



