In [1]:
pip install open-clip-torch

Collecting open-clip-torch
  Downloading open_clip_torch-2.32.0-py3-none-any.whl.metadata (31 kB)
Collecting ftfy (from open-clip-torch)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.9.0->open-clip-torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.9.0->open-clip-torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.9.0->open-clip-torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.9.0->open-clip-torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.9.0->open-clip-torch)
  Downloading nv

In [2]:
import tensorflow_datasets as tfds
import open_clip
import torch
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
import numpy as np
import pandas as pd

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, _, preprocess = open_clip.create_model_and_transforms(
    model_name="convnext_base_w",
    pretrained="laion2b_s13b_b82k",
    device=device
)
tokenizer = open_clip.get_tokenizer("convnext_base_w")

ds = tfds.load("stanford_dogs", split="test", as_supervised=True)
class_names = tfds.builder("stanford_dogs").info.features["label"].names

In [5]:
prompt_template = "a photo of a {}"
text_prompts = [prompt_template.format(label.replace("_", " ")) for label in class_names]

with torch.no_grad():
    text_tokens = tokenizer(text_prompts).to(device)
    text_features = model.encode_text(text_tokens)
    text_features /= text_features.norm(dim=-1, keepdim=True)
y_true = []
y_pred = []
for i, (image, label) in enumerate(tqdm(ds)):
    image = Image.fromarray(image.numpy())
    image = preprocess(image).unsqueeze(0).to(device)

    with torch.no_grad():
        image_features = model.encode_image(image)
        image_features /= image_features.norm(dim=-1, keepdim=True)

        similarity = image_features @ text_features.T
        pred = similarity.argmax(dim=-1).item()

    y_true.append(label.numpy())
    y_pred.append(pred)

y_true = np.array(y_true)
y_pred = np.array(y_pred)

precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average=None,zero_division=0)
overall_accuracy = accuracy_score(y_true, y_pred)
class_accuracy = [(y_true[y_true == i] == y_pred[y_true == i]).mean() for i in range(len(class_names))]

report_df = pd.DataFrame({
    "Class": class_names,
    "Accuracy": class_accuracy,
    "Precision": precision,
    "Recall": recall,
    "F1-score": f1,
})


print(f"\n Overall accuracy: {overall_accuracy:.4f}")
report_df_sorted_best = report_df.sort_values(by="F1-score", ascending=False)
print("The 5 classes with the highest F1-scores:")
print(report_df_sorted_best.head())
report_df_sorted_worst = report_df.sort_values(by="F1-score", ascending=True)
print("The 5 classes with the lowest F1-scores")
print(report_df_sorted_worst.head())

100%|██████████| 8580/8580 [03:57<00:00, 36.10it/s]


 Overall accuracy: 0.4965
The 5 classes with the highest F1-scores:
                             Class  Accuracy  Precision    Recall  F1-score
119  n02116738-african_hunting_dog  0.898551   0.849315  0.898551  0.873239
106              n02111889-samoyed  0.771186   0.858491  0.771186  0.812500
96         n02109525-saint_bernard  0.714286   0.862069  0.714286  0.781250
109             n02112350-keeshond  0.758621   0.733333  0.758621  0.745763
62          n02100877-irish_setter  0.690909   0.808511  0.690909  0.745098
The 5 classes with the lowest F1-scores
                           Class  Accuracy  Precision  Recall  F1-score
53               n02098413-lhasa       0.0        0.0     0.0       0.0
44         n02096585-boston_bull       0.0        0.0     0.0       0.0
112           n02113186-cardigan       0.0        0.0     0.0       0.0
110  n02112706-brabancon_griffon       0.0        0.0     0.0       0.0
111           n02113023-pembroke       0.0        0.0     0.0       0.0





In [None]:
prompt_templates = [
    "a photo of a {}",
    "a blurry photo of a {}",
    "a close-up of a {}",
    "a picture of a {} dog",
    "a photo of {} dog",
    "this is a {} dog",
    "a {} breed",
    "a dog of breed {}",
]

results_summary = []

for template in prompt_templates:
    print(f"\nEvaluating prompt: \"{template}\"")
    text_prompts = [template.format(label.replace("_", " ")) for label in class_names]

    with torch.no_grad():
        text_tokens = tokenizer(text_prompts).to(device)
        text_features = model.encode_text(text_tokens)
        text_features /= text_features.norm(dim=-1, keepdim=True)

    y_true = []
    y_pred = []

    for i, (image, label) in enumerate(tqdm(ds)):
        image = Image.fromarray(image.numpy())
        image = preprocess(image).unsqueeze(0).to(device)

        with torch.no_grad():
            image_features = model.encode_image(image)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            similarity = image_features @ text_features.T
            pred = similarity.argmax(dim=-1).item()

        y_true.append(label.numpy())
        y_pred.append(pred)

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average=None,zero_division=0)
    macro_f1 = np.mean(f1)
    overall_accuracy = accuracy_score(y_true, y_pred)
    class_accuracy = [(y_true[y_true == i] == y_pred[y_true == i]).mean() for i in range(len(class_names))]

    report_df = pd.DataFrame({
        "Class": class_names,
        "Accuracy": class_accuracy,
        "Precision": precision,
        "Recall": recall,
        "F1-score": f1,
    })

    print(f"\nAccuracy: {overall_accuracy:.4f}, Macro F1: {macro_f1:.4f}")
    report_df_sorted_best = report_df.sort_values(by="F1-score", ascending=False)
    print("The 5 classes with the highest F1-scores:")
    print(report_df_sorted_best.head())
    report_df_sorted_worst = report_df.sort_values(by="F1-score", ascending=True)
    print("The 5 classes with the lowest F1-scores")
    print(report_df_sorted_worst.head())

    results_summary.append({
        "Prompt": template,
        "Overall Accuracy": overall_accuracy,
        "Macro F1": macro_f1
    })



Evaluating prompt: "a photo of a {}"


100%|██████████| 8580/8580 [03:22<00:00, 42.32it/s]



Accuracy: 0.5021, Macro F1: 0.4776
The 5 classes with the highest F1-scores:
                              Class  Accuracy  Precision    Recall  F1-score
119   n02116738-african_hunting_dog  0.898551   0.925373  0.898551  0.911765
96          n02109525-saint_bernard  0.742857   0.945455  0.742857  0.832000
109              n02112350-keeshond  0.896552   0.776119  0.896552  0.832000
106               n02111889-samoyed  0.805085   0.833333  0.805085  0.818966
88   n02107683-bernese_mountain_dog  0.796610   0.752000  0.796610  0.773663
The 5 classes with the lowest F1-scores
                           Class  Accuracy  Precision  Recall  F1-score
41               n02096177-cairn       0.0        0.0     0.0       0.0
53               n02098413-lhasa       0.0        0.0     0.0       0.0
44         n02096585-boston_bull       0.0        0.0     0.0       0.0
110  n02112706-brabancon_griffon       0.0        0.0     0.0       0.0
111           n02113023-pembroke       0.0        0.0     0.

100%|██████████| 8580/8580 [03:13<00:00, 44.36it/s]



Accuracy: 0.5000, Macro F1: 0.4775
The 5 classes with the highest F1-scores:
                              Class  Accuracy  Precision    Recall  F1-score
119   n02116738-african_hunting_dog  0.913043   0.875000  0.913043  0.893617
96          n02109525-saint_bernard  0.757143   0.913793  0.757143  0.828125
106               n02111889-samoyed  0.762712   0.865385  0.762712  0.810811
62           n02100877-irish_setter  0.763636   0.857143  0.763636  0.807692
88   n02107683-bernese_mountain_dog  0.754237   0.801802  0.754237  0.777293
The 5 classes with the lowest F1-scores
                     Class  Accuracy  Precision  Recall  F1-score
17       n02090379-redbone       0.0        0.0     0.0       0.0
53         n02098413-lhasa       0.0        0.0     0.0       0.0
41         n02096177-cairn       0.0        0.0     0.0       0.0
44   n02096585-boston_bull       0.0        0.0     0.0       0.0
111     n02113023-pembroke       0.0        0.0     0.0       0.0

Evaluating prompt: "a c

100%|██████████| 8580/8580 [03:21<00:00, 42.49it/s]



Accuracy: 0.4749, Macro F1: 0.4520
The 5 classes with the highest F1-scores:
                             Class  Accuracy  Precision    Recall  F1-score
119  n02116738-african_hunting_dog  0.927536   0.790123  0.927536  0.853333
27            n02092339-weimaraner  0.750000   0.803571  0.750000  0.775862
96         n02109525-saint_bernard  0.628571   1.000000  0.628571  0.771930
63         n02101006-gordon_setter  0.735849   0.795918  0.735849  0.764706
70   n02102973-irish_water_spaniel  0.780000   0.750000  0.780000  0.764706
The 5 classes with the lowest F1-scores
                     Class  Accuracy  Precision  Recall  F1-score
53         n02098413-lhasa       0.0        0.0     0.0       0.0
41         n02096177-cairn       0.0        0.0     0.0       0.0
89   n02107908-appenzeller       0.0        0.0     0.0       0.0
112     n02113186-cardigan       0.0        0.0     0.0       0.0
111     n02113023-pembroke       0.0        0.0     0.0       0.0

Evaluating prompt: "a picture

100%|██████████| 8580/8580 [03:21<00:00, 42.49it/s]



Accuracy: 0.4613, Macro F1: 0.4655
The 5 classes with the highest F1-scores:
                              Class  Accuracy  Precision    Recall  F1-score
119   n02116738-african_hunting_dog  0.898551   0.746988  0.898551  0.815789
96          n02109525-saint_bernard  0.814286   0.780822  0.814286  0.797203
106               n02111889-samoyed  0.703390   0.864583  0.703390  0.775701
108                  n02112137-chow  0.729167   0.823529  0.729167  0.773481
88   n02107683-bernese_mountain_dog  0.686441   0.826531  0.686441  0.750000
The 5 classes with the lowest F1-scores
                          Class  Accuracy  Precision    Recall  F1-score
65            n02101556-clumber  0.080000   0.007042  0.080000  0.012945
111          n02113023-pembroke  0.012346   0.035714  0.012346  0.018349
112          n02113186-cardigan  0.036364   0.019231  0.036364  0.025157
17            n02090379-redbone  0.229167   0.017460  0.229167  0.032448
114  n02113712-miniature_poodle  0.018182   0.333333  0

100%|██████████| 8580/8580 [03:19<00:00, 43.08it/s]



Accuracy: 0.4966, Macro F1: 0.5052
The 5 classes with the highest F1-scores:
                             Class  Accuracy  Precision    Recall  F1-score
119  n02116738-african_hunting_dog  0.884058   0.824324  0.884058  0.853147
96         n02109525-saint_bernard  0.771429   0.843750  0.771429  0.805970
62          n02100877-irish_setter  0.727273   0.888889  0.727273  0.800000
27            n02092339-weimaraner  0.750000   0.849057  0.750000  0.796460
106              n02111889-samoyed  0.737288   0.861386  0.737288  0.794521
The 5 classes with the lowest F1-scores
                           Class  Accuracy  Precision    Recall  F1-score
110  n02112706-brabancon_griffon  0.000000   0.000000  0.000000  0.000000
112           n02113186-cardigan  0.018182   0.016949  0.018182  0.017544
73         n02105056-groenendael  0.020000   0.016949  0.020000  0.018349
65             n02101556-clumber  0.120000   0.012448  0.120000  0.022556
111           n02113023-pembroke  0.037037   0.032609  0

100%|██████████| 8580/8580 [03:21<00:00, 42.48it/s]



Accuracy: 0.5056, Macro F1: 0.5012
The 5 classes with the highest F1-scores:
                             Class  Accuracy  Precision    Recall  F1-score
27            n02092339-weimaraner  0.833333   0.793651  0.833333  0.813008
106              n02111889-samoyed  0.771186   0.858491  0.771186  0.812500
108                 n02112137-chow  0.729167   0.909091  0.729167  0.809249
96         n02109525-saint_bernard  0.685714   0.923077  0.685714  0.786885
119  n02116738-african_hunting_dog  0.942029   0.670103  0.942029  0.783133
The 5 classes with the lowest F1-scores
                     Class  Accuracy  Precision    Recall  F1-score
111     n02113023-pembroke  0.024691   0.010638  0.024691  0.014870
65       n02101556-clumber  0.080000   0.012158  0.080000  0.021108
112     n02113186-cardigan  0.036364   0.066667  0.036364  0.047059
73   n02105056-groenendael  0.060000   0.054545  0.060000  0.057143
17       n02090379-redbone  0.104167   0.043103  0.104167  0.060976

Evaluating prompt

100%|██████████| 8580/8580 [03:19<00:00, 42.97it/s]



Accuracy: 0.5111, Macro F1: 0.5024
The 5 classes with the highest F1-scores:
                       Class  Accuracy  Precision    Recall  F1-score
96   n02109525-saint_bernard  0.728571   0.910714  0.728571  0.809524
108           n02112137-chow  0.697917   0.957143  0.697917  0.807229
27      n02092339-weimaraner  0.800000   0.800000  0.800000  0.800000
63   n02101006-gordon_setter  0.716981   0.844444  0.716981  0.775510
107     n02112018-pomeranian  0.672269   0.879121  0.672269  0.761905
The 5 classes with the lowest F1-scores
                    Class  Accuracy  Precision    Recall  F1-score
112    n02113186-cardigan  0.000000   0.000000  0.000000  0.000000
111    n02113023-pembroke  0.024691   0.006536  0.024691  0.010336
17      n02090379-redbone  0.041667   0.009709  0.041667  0.015748
65      n02101556-clumber  0.040000   0.013605  0.040000  0.020305
97   n02109961-eskimo_dog  0.020000   0.031250  0.020000  0.024390

Evaluating prompt: "a dog of breed {}"


100%|██████████| 8580/8580 [03:21<00:00, 42.49it/s]


Accuracy: 0.5016, Macro F1: 0.4947
The 5 classes with the highest F1-scores:
                              Class  Accuracy  Precision    Recall  F1-score
88   n02107683-bernese_mountain_dog  0.788136   0.853211  0.788136  0.819383
96          n02109525-saint_bernard  0.728571   0.910714  0.728571  0.809524
119   n02116738-african_hunting_dog  0.855072   0.728395  0.855072  0.786667
108                  n02112137-chow  0.687500   0.916667  0.687500  0.785714
107            n02112018-pomeranian  0.672269   0.898876  0.672269  0.769231
The 5 classes with the lowest F1-scores
                           Class  Accuracy  Precision    Recall  F1-score
111           n02113023-pembroke  0.000000   0.000000  0.000000  0.000000
112           n02113186-cardigan  0.000000   0.000000  0.000000  0.000000
65             n02101556-clumber  0.020000   0.003030  0.020000  0.005263
73         n02105056-groenendael  0.020000   0.038462  0.020000  0.026316
110  n02112706-brabancon_griffon  0.018868   0.250




In [None]:
from sklearn.metrics import f1_score
# Get the index of 'pembroke' class
pembroke_index = class_names.index("n02113023-pembroke")

# Original template for all other classes
base_template = "a photo of a {}"
base_prompts = [base_template.format(cls.replace("_", " ")) for cls in class_names]

# Pembroke-specific alternatives
pembroke_templates = [
    "a photo of a pembroke dog",
    "a photo of a pembroke welsh corgi",
    "this is a pembroke corgi",
    "a cute pembroke corgi dog",
    "an image of a pembroke welsh corgi breed",
    "a smiling pembroke corgi"
]

results = []

for pembroke_prompt in pembroke_templates:
    # Replace only pembroke prompt
    modified_prompts = base_prompts.copy()
    modified_prompts[pembroke_index] = pembroke_prompt

    with torch.no_grad():
        text_tokens = tokenizer(modified_prompts).to(device)
        text_features = model.encode_text(text_tokens)
        text_features /= text_features.norm(dim=-1, keepdim=True)

    y_true = []
    y_pred = []

    for i, (image, label) in enumerate(tqdm(ds)):
        image = Image.fromarray(image.numpy())
        image = preprocess(image).unsqueeze(0).to(device)

        with torch.no_grad():
            image_features = model.encode_image(image)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            similarity = image_features @ text_features.T
            pred = similarity.argmax(dim=-1).item()

        y_true.append(label.numpy())
        y_pred.append(pred)

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    is_pembroke = y_true == pembroke_index
    pembroke_f1 = f1_score(y_true[is_pembroke], y_pred[is_pembroke], average='macro')

    results.append({
        "Prompt": pembroke_prompt,
        "Pembroke F1 score": pembroke_f1
    })
print()
df_results = pd.DataFrame(results).sort_values(by="Pembroke F1 score", ascending=False)
print(df_results)


100%|██████████| 8580/8580 [03:21<00:00, 42.49it/s]
100%|██████████| 8580/8580 [03:16<00:00, 43.72it/s]
100%|██████████| 8580/8580 [03:18<00:00, 43.30it/s]
100%|██████████| 8580/8580 [03:18<00:00, 43.31it/s]
100%|██████████| 8580/8580 [03:21<00:00, 42.49it/s]
100%|██████████| 8580/8580 [03:18<00:00, 43.20it/s]


                                     Prompt  Pembroke F1 score
1         a photo of a pembroke welsh corgi           0.068531
2                  this is a pembroke corgi           0.067653
3                 a cute pembroke corgi dog           0.048780
4  an image of a pembroke welsh corgi breed           0.037940
5                  a smiling pembroke corgi           0.031016
0                 a photo of a pembroke dog           0.017363





In [9]:
pip install --upgrade datasets

Collecting datasets
  Downloading datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.6.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.5/491.5 kB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2025.3.0-py3-none-any.whl (193 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.6/193.6 kB[0m [31m17.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fsspec, datasets
  Attempting uninstall: fsspec
    Found existing installation: fsspec 2025.3.2
    Uninstalling fsspec-2025.3.2:
      Successfully uninstalled fsspec-2025.3.2
  Attempting uninstall: datasets
    Found existing installation: datasets 2.14.4
    Uninstalling datasets-2.14.4:
      Successfully uninstalled datasets-2.14.4
[31mERROR: pip's dependency r

In [1]:
from datasets import load_dataset
import open_clip
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from PIL import Image

dataset = load_dataset("maurice-fp/stanford-dogs")
train_ds = dataset["train"]
test_ds = dataset["test"]

device = "cuda" if torch.cuda.is_available() else "cpu"

model, _, preprocess = open_clip.create_model_and_transforms(
    model_name="convnext_base_w",
    pretrained="laion2b_s13b_b82k",
    device=device
)
# Freeze the CLIP model
model.eval()
for param in model.parameters():
    param.requires_grad = False

# Parameters
n_classes = 120
embed_dim = model.text_projection.shape[1]
lr = 1e-2
epochs = 5
batch_size = 32

# Define trainable class embeddings
class_embeddings = nn.Parameter(torch.randn(n_classes, embed_dim, device=device, requires_grad=True))

# Optimizer
optimizer = optim.Adam([class_embeddings], lr=lr)
criterion = nn.CrossEntropyLoss()
from torchvision import transforms

transform = preprocess  # Use CLIP's own preprocessing

class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, hf_dataset):
        self.dataset = hf_dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
      item = self.dataset[idx]
      image = transform(item["image"])
      return image, item["label"]

train_loader = DataLoader(ImageDataset(train_ds), batch_size=batch_size, shuffle=True)
test_loader = DataLoader(ImageDataset(test_ds), batch_size=batch_size)
for epoch in range(epochs):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    for images, labels in tqdm(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            image_features = model.encode_image(images)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        class_emb_norm = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
        logits = image_features @ class_emb_norm.T

        loss = criterion(logits, labels)
        total_loss += loss.item()

        preds = logits.argmax(dim=-1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    acc = correct / total
    print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss:.4f}, Train Accuracy: {acc:.4f}")
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in tqdm(test_loader):
        images = images.to(device)
        labels = labels.to(device)

        image_features = model.encode_image(images)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        class_emb_norm = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
        logits = image_features @ class_emb_norm.T

        preds = logits.argmax(dim=-1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

test_accuracy = correct / total
print(f"\n Final Test Accuracy with Learnable Class Embeddings: {test_accuracy:.4f}")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/6.37k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/455M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/321M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/12000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/8580 [00:00<?, ? examples/s]

100%|██████████| 375/375 [03:56<00:00,  1.58it/s]


Epoch 1/5, Loss: 1745.3861, Train Accuracy: 0.2643


100%|██████████| 375/375 [03:56<00:00,  1.59it/s]


Epoch 2/5, Loss: 1682.9276, Train Accuracy: 0.5918


100%|██████████| 375/375 [03:56<00:00,  1.59it/s]


Epoch 3/5, Loss: 1663.6519, Train Accuracy: 0.6608


100%|██████████| 375/375 [03:56<00:00,  1.59it/s]


Epoch 4/5, Loss: 1658.1632, Train Accuracy: 0.6830


100%|██████████| 375/375 [03:55<00:00,  1.59it/s]


Epoch 5/5, Loss: 1656.2735, Train Accuracy: 0.6768


100%|██████████| 269/269 [02:48<00:00,  1.60it/s]


 Final Test Accuracy with Learnable Class Embeddings: 0.7231





In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from datasets import load_dataset
import open_clip
from PIL import Image
from torchvision import transforms

# Load dataset
dataset = load_dataset("maurice-fp/stanford-dogs")
train_ds = dataset["train"]
test_ds = dataset["test"]

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load CLIP model
model, _, preprocess = open_clip.create_model_and_transforms(
    model_name="convnext_base_w",
    pretrained="laion2b_s13b_b82k",
    device=device
)

# Freeze the image encoder
model.eval()
for param in model.visual.parameters():
    param.requires_grad = False

# Parameters
n_classes = 120
with torch.no_grad():
    dummy_input = torch.randn(1, 3, 224, 224).to(device)
    embed_dim = model.encode_image(dummy_input).shape[-1]
lr = 1e-3  # Slightly lower learning rate for FC layer
epochs = 5
batch_size = 32

# Define fully connected layer
fc_layer = nn.Linear(embed_dim, n_classes).to(device)

# Optimizer and loss
optimizer = optim.Adam(fc_layer.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

# Dataset class
class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, hf_dataset):
        self.dataset = hf_dataset
        self.transform = preprocess

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = self.transform(item["image"])
        return image, item["label"]

# Create data loaders
train_loader = DataLoader(ImageDataset(train_ds), batch_size=batch_size, shuffle=True)
test_loader = DataLoader(ImageDataset(test_ds), batch_size=batch_size)

# Training loop
for epoch in range(epochs):
    fc_layer.train()
    total_loss = 0
    correct = 0
    total = 0

    for images, labels in tqdm(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            image_features = model.encode_image(images)

        # Pass through FC layer
        logits = fc_layer(image_features)

        # Compute loss
        loss = criterion(logits, labels)
        total_loss += loss.item()

        # Predictions
        preds = logits.argmax(dim=-1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    acc = correct / total
    print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss:.4f}, Train Accuracy: {acc:.4f}")

# Test evaluation
fc_layer.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in tqdm(test_loader):
        images = images.to(device)
        labels = labels.to(device)

        image_features = model.encode_image(images)
        logits = fc_layer(image_features)

        preds = logits.argmax(dim=-1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

test_accuracy = correct / total
print(f"\nFinal Test Accuracy with Fully Connected Layer: {test_accuracy:.4f}")

100%|██████████| 375/375 [03:58<00:00,  1.57it/s]


Epoch 1/5, Loss: 1010.6541, Train Accuracy: 0.5268


100%|██████████| 375/375 [03:55<00:00,  1.59it/s]


Epoch 2/5, Loss: 458.6247, Train Accuracy: 0.7759


100%|██████████| 375/375 [03:54<00:00,  1.60it/s]


Epoch 3/5, Loss: 325.9103, Train Accuracy: 0.8225


100%|██████████| 375/375 [03:55<00:00,  1.60it/s]


Epoch 4/5, Loss: 264.0177, Train Accuracy: 0.8469


100%|██████████| 375/375 [03:54<00:00,  1.60it/s]


Epoch 5/5, Loss: 226.2305, Train Accuracy: 0.8633


100%|██████████| 269/269 [02:47<00:00,  1.61it/s]


Final Test Accuracy with Fully Connected Layer: 0.7948





In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, _, preprocess = open_clip.create_model_and_transforms(
    model_name="convnext_xxlarge",
    pretrained="laion2b_s34b_b82k_augreg_soup",
    device=device
)
tokenizer = open_clip.get_tokenizer("convnext_base_w")

ds = tfds.load("stanford_dogs", split="test", as_supervised=True)
class_names = tfds.builder("stanford_dogs").info.features["label"].names

prompt_template = "a photo of a {}"
text_prompts = [prompt_template.format(label.replace("_", " ")) for label in class_names]

with torch.no_grad():
    text_tokens = tokenizer(text_prompts).to(device)
    text_features = model.encode_text(text_tokens)
    text_features /= text_features.norm(dim=-1, keepdim=True)
y_true = []
y_pred = []
for i, (image, label) in enumerate(tqdm(ds)):
    image = Image.fromarray(image.numpy())
    image = preprocess(image).unsqueeze(0).to(device)

    with torch.no_grad():
        image_features = model.encode_image(image)
        image_features /= image_features.norm(dim=-1, keepdim=True)

        similarity = image_features @ text_features.T
        pred = similarity.argmax(dim=-1).item()

    y_true.append(label.numpy())
    y_pred.append(pred)

y_true = np.array(y_true)
y_pred = np.array(y_pred)

precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average=None,zero_division=0)
overall_accuracy = accuracy_score(y_true, y_pred)
class_accuracy = [(y_true[y_true == i] == y_pred[y_true == i]).mean() for i in range(len(class_names))]

report_df = pd.DataFrame({
    "Class": class_names,
    "Accuracy": class_accuracy,
    "Precision": precision,
    "Recall": recall,
    "F1-score": f1,
})


print(f"\n Overall accuracy: {overall_accuracy:.4f}")
report_df_sorted_best = report_df.sort_values(by="F1-score", ascending=False)
print("The 5 classes with the highest F1-scores:")
print(report_df_sorted_best.head())
report_df_sorted_worst = report_df.sort_values(by="F1-score", ascending=True)
print("The 5 classes with the lowest F1-scores")
print(report_df_sorted_worst.head())

open_clip_pytorch_model.bin:   0%|          | 0.00/4.80G [00:00<?, ?B/s]

100%|██████████| 8580/8580 [22:53<00:00,  6.25it/s]



 Overall accuracy: 0.7076
The 5 classes with the highest F1-scores:
                             Class  Accuracy  Precision    Recall  F1-score
119  n02116738-african_hunting_dog  1.000000   0.971831  1.000000  0.985714
63         n02101006-gordon_setter  0.943396   1.000000  0.943396  0.970874
116     n02113978-mexican_hairless  0.981818   0.931034  0.981818  0.955752
27            n02092339-weimaraner  0.983333   0.921875  0.983333  0.951613
62          n02100877-irish_setter  0.909091   0.980392  0.909091  0.943396
The 5 classes with the lowest F1-scores
                     Class  Accuracy  Precision  Recall  F1-score
41         n02096177-cairn      0.00   0.000000    0.00  0.000000
112     n02113186-cardigan      0.00   0.000000    0.00  0.000000
73   n02105056-groenendael      0.00   0.000000    0.00  0.000000
65       n02101556-clumber      0.00   0.000000    0.00  0.000000
97    n02109961-eskimo_dog      0.08   0.047619    0.08  0.059701


In [7]:
prompt_templates = [
    "a photo of a {}",
    "a blurry photo of a {}",
    "a close-up of a {}",
    "a picture of a {} dog",
    "a photo of {} dog",
    "this is a {} dog",
    "a {} breed",
    "a dog of breed {}",
]

results_summary = []

for template in prompt_templates:
    print(f"\nEvaluating prompt: \"{template}\"")
    text_prompts = [template.format(label.replace("_", " ")) for label in class_names]

    with torch.no_grad():
        text_tokens = tokenizer(text_prompts).to(device)
        text_features = model.encode_text(text_tokens)
        text_features /= text_features.norm(dim=-1, keepdim=True)

    y_true = []
    y_pred = []

    for i, (image, label) in enumerate(tqdm(ds)):
        image = Image.fromarray(image.numpy())
        image = preprocess(image).unsqueeze(0).to(device)

        with torch.no_grad():
            image_features = model.encode_image(image)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            similarity = image_features @ text_features.T
            pred = similarity.argmax(dim=-1).item()

        y_true.append(label.numpy())
        y_pred.append(pred)

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average=None,zero_division=0)
    macro_f1 = np.mean(f1)
    overall_accuracy = accuracy_score(y_true, y_pred)
    class_accuracy = [(y_true[y_true == i] == y_pred[y_true == i]).mean() for i in range(len(class_names))]

    report_df = pd.DataFrame({
        "Class": class_names,
        "Accuracy": class_accuracy,
        "Precision": precision,
        "Recall": recall,
        "F1-score": f1,
    })

    print(f"\nAccuracy: {overall_accuracy:.4f}, Macro F1: {macro_f1:.4f}")
    report_df_sorted_best = report_df.sort_values(by="F1-score", ascending=False)
    print("The 5 classes with the highest F1-scores:")
    print(report_df_sorted_best.head())
    report_df_sorted_worst = report_df.sort_values(by="F1-score", ascending=True)
    print("The 5 classes with the lowest F1-scores")
    print(report_df_sorted_worst.head())

    results_summary.append({
        "Prompt": template,
        "Overall Accuracy": overall_accuracy,
        "Macro F1": macro_f1
    })



Evaluating prompt: "a photo of a {}"


100%|██████████| 8580/8580 [23:21<00:00,  6.12it/s]



Accuracy: 0.7066, Macro F1: 0.6785
The 5 classes with the highest F1-scores:
                             Class  Accuracy  Precision    Recall  F1-score
119  n02116738-african_hunting_dog  0.985507   0.971429  0.985507  0.978417
96         n02109525-saint_bernard  0.971429   0.931507  0.971429  0.951049
62          n02100877-irish_setter  0.909091   0.980392  0.909091  0.943396
63         n02101006-gordon_setter  0.924528   0.960784  0.924528  0.942308
116     n02113978-mexican_hairless  0.981818   0.900000  0.981818  0.939130
The 5 classes with the lowest F1-scores
                     Class  Accuracy  Precision    Recall  F1-score
41         n02096177-cairn  0.000000   0.000000  0.000000  0.000000
112     n02113186-cardigan  0.000000   0.000000  0.000000  0.000000
73   n02105056-groenendael  0.000000   0.000000  0.000000  0.000000
65       n02101556-clumber  0.000000   0.000000  0.000000  0.000000
53         n02098413-lhasa  0.023256   0.222222  0.023256  0.042105

Evaluating prompt

100%|██████████| 8580/8580 [23:21<00:00,  6.12it/s]



Accuracy: 0.6725, Macro F1: 0.6335
The 5 classes with the highest F1-scores:
                             Class  Accuracy  Precision    Recall  F1-score
119  n02116738-african_hunting_dog  1.000000   0.985714  1.000000  0.992806
96         n02109525-saint_bernard  0.971429   0.971429  0.971429  0.971429
118                n02115913-dhole  0.920000   0.978723  0.920000  0.948454
27            n02092339-weimaraner  1.000000   0.869565  1.000000  0.930233
106              n02111889-samoyed  0.898305   0.963636  0.898305  0.929825
The 5 classes with the lowest F1-scores
                           Class  Accuracy  Precision  Recall  F1-score
53               n02098413-lhasa       0.0        0.0     0.0       0.0
41               n02096177-cairn       0.0        0.0     0.0       0.0
89         n02107908-appenzeller       0.0        0.0     0.0       0.0
65             n02101556-clumber       0.0        0.0     0.0       0.0
110  n02112706-brabancon_griffon       0.0        0.0     0.0     

100%|██████████| 8580/8580 [23:08<00:00,  6.18it/s]



Accuracy: 0.6985, Macro F1: 0.6600
The 5 classes with the highest F1-scores:
                             Class  Accuracy  Precision    Recall  F1-score
119  n02116738-african_hunting_dog  0.985507   0.971429  0.985507  0.978417
96         n02109525-saint_bernard  0.957143   0.985294  0.957143  0.971014
108                 n02112137-chow  0.947917   0.968085  0.947917  0.957895
56      n02099601-golden_retriever  0.960000   0.905660  0.960000  0.932039
102                  n02110958-pug  0.940000   0.903846  0.940000  0.921569
The 5 classes with the lowest F1-scores
                           Class  Accuracy  Precision  Recall  F1-score
53               n02098413-lhasa       0.0        0.0     0.0       0.0
41               n02096177-cairn       0.0        0.0     0.0       0.0
89         n02107908-appenzeller       0.0        0.0     0.0       0.0
65             n02101556-clumber       0.0        0.0     0.0       0.0
110  n02112706-brabancon_griffon       0.0        0.0     0.0     

100%|██████████| 8580/8580 [23:21<00:00,  6.12it/s]



Accuracy: 0.7338, Macro F1: 0.7147
The 5 classes with the highest F1-scores:
                              Class  Accuracy  Precision    Recall  F1-score
96          n02109525-saint_bernard  0.985714   0.971831  0.985714  0.978723
119   n02116738-african_hunting_dog  1.000000   0.907895  1.000000  0.951724
27             n02092339-weimaraner  0.950000   0.934426  0.950000  0.942149
88   n02107683-bernese_mountain_dog  0.949153   0.933333  0.949153  0.941176
62           n02100877-irish_setter  0.963636   0.913793  0.963636  0.938053
The 5 classes with the lowest F1-scores
                          Class  Accuracy  Precision    Recall  F1-score
17            n02090379-redbone  0.000000   0.000000  0.000000  0.000000
65            n02101556-clumber  0.000000   0.000000  0.000000  0.000000
114  n02113712-miniature_poodle  0.018182   1.000000  0.018182  0.035714
97         n02109961-eskimo_dog  0.080000   0.066667  0.080000  0.072727
112          n02113186-cardigan  0.072727   0.125000  0

100%|██████████| 8580/8580 [23:07<00:00,  6.18it/s]



Accuracy: 0.7324, Macro F1: 0.7099
The 5 classes with the highest F1-scores:
                              Class  Accuracy  Precision    Recall  F1-score
96          n02109525-saint_bernard  0.971429   0.971429  0.971429  0.971429
103              n02111129-leonberg  0.945455   0.990476  0.945455  0.967442
62           n02100877-irish_setter  0.963636   0.963636  0.963636  0.963636
88   n02107683-bernese_mountain_dog  0.932203   0.956522  0.932203  0.944206
109              n02112350-keeshond  1.000000   0.892308  1.000000  0.943089
The 5 classes with the lowest F1-scores
                      Class  Accuracy  Precision    Recall  F1-score
65        n02101556-clumber  0.000000   0.000000  0.000000  0.000000
111      n02113023-pembroke  0.012346   0.500000  0.012346  0.024096
17        n02090379-redbone  0.020833   0.100000  0.020833  0.034483
97     n02109961-eskimo_dog  0.060000   0.036145  0.060000  0.045113
15   n02089867-walker_hound  0.037736   0.400000  0.037736  0.068966

Evalu

100%|██████████| 8580/8580 [23:21<00:00,  6.12it/s]



Accuracy: 0.7242, Macro F1: 0.6970
The 5 classes with the highest F1-scores:
                             Class  Accuracy  Precision    Recall  F1-score
119  n02116738-african_hunting_dog  1.000000   0.907895  1.000000  0.951724
27            n02092339-weimaraner  0.950000   0.950000  0.950000  0.950000
116     n02113978-mexican_hairless  0.945455   0.945455  0.945455  0.945455
96         n02109525-saint_bernard  0.957143   0.930556  0.957143  0.943662
62          n02100877-irish_setter  0.963636   0.898305  0.963636  0.929825
The 5 classes with the lowest F1-scores
                             Class  Accuracy  Precision    Recall  F1-score
65               n02101556-clumber  0.000000   0.000000  0.000000  0.000000
17               n02090379-redbone  0.020833   0.100000  0.020833  0.034483
114     n02113712-miniature_poodle  0.018182   0.500000  0.018182  0.035088
45   n02097047-miniature_schnauzer  0.018519   0.500000  0.018519  0.035714
112             n02113186-cardigan  0.036364  

100%|██████████| 8580/8580 [23:07<00:00,  6.19it/s]



Accuracy: 0.7154, Macro F1: 0.6852
The 5 classes with the highest F1-scores:
                             Class  Accuracy  Precision    Recall  F1-score
96         n02109525-saint_bernard  0.942857   1.000000  0.942857  0.970588
27            n02092339-weimaraner  0.983333   0.936508  0.983333  0.959350
103             n02111129-leonberg  0.927273   0.990291  0.927273  0.957746
119  n02116738-african_hunting_dog  1.000000   0.907895  1.000000  0.951724
23    n02091467-norwegian_elkhound  0.916667   0.977778  0.916667  0.946237
The 5 classes with the lowest F1-scores
                          Class  Accuracy  Precision    Recall  F1-score
17            n02090379-redbone  0.000000        0.0  0.000000  0.000000
65            n02101556-clumber  0.000000        0.0  0.000000  0.000000
111          n02113023-pembroke  0.000000        0.0  0.000000  0.000000
112          n02113186-cardigan  0.018182        1.0  0.018182  0.035714
114  n02113712-miniature_poodle  0.018182        1.0  0.01818

100%|██████████| 8580/8580 [23:06<00:00,  6.19it/s]


Accuracy: 0.6809, Macro F1: 0.6528
The 5 classes with the highest F1-scores:
                             Class  Accuracy  Precision    Recall  F1-score
108                 n02112137-chow  0.979167   0.969072  0.979167  0.974093
119  n02116738-african_hunting_dog  1.000000   0.932432  1.000000  0.965035
96         n02109525-saint_bernard  0.928571   1.000000  0.928571  0.962963
27            n02092339-weimaraner  0.966667   0.950820  0.966667  0.958678
77              n02105505-komondor  0.944444   0.944444  0.944444  0.944444
The 5 classes with the lowest F1-scores
                            Class  Accuracy  Precision    Recall  F1-score
97           n02109961-eskimo_dog  0.000000   0.000000  0.000000  0.000000
65              n02101556-clumber  0.000000   0.000000  0.000000  0.000000
112            n02113186-cardigan  0.018182   0.041667  0.018182  0.025316
111            n02113023-pembroke  0.024691   0.250000  0.024691  0.044944
47   n02097209-standard_schnauzer  0.054545   0.428


