In [1]:
import clip
import torch

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

100%|███████████████████████████████████████| 338M/338M [00:51<00:00, 6.81MiB/s]


In [2]:
from torchvision import datasets
from torch.utils.data import DataLoader


# Check if GPU is available and set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# if torch.cuda.device_count() > 1:
#     model = torch.nn.DataParallel(model)  # Wrap the model for multi-GPU support
model.to(device)

# Load CIFAR10 dataset
cifar10_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=preprocess)
cifar10_loader = DataLoader(cifar10_dataset, batch_size=32, shuffle=False)

# # For ImageNet, assuming the dataset is already in the ImageNet folder
# imagenet_dataset = datasets.ImageNet(root='./data', split='val', transform=transform)
# imagenet_loader = DataLoader(imagenet_dataset, batch_size=32, shuffle=False)

Files already downloaded and verified


In [4]:
import tqdm

In [None]:
def compute_embeddings(data_loader):
    all_embeddings = []
    all_labels = []
    
    # Set the model to evaluation mode
    model.eval()

    with torch.no_grad():
        for images, labels in tqdm.tqdm(data_loader, desc="Computing embeddings"):
            images = images.to(device)
            labels = labels.to(device)
            # Get the image features
            outputs = model.encode_image(images)

            all_embeddings.append(outputs)
            all_labels.append(labels)
    
    # Concatenate all embeddings
    all_embeddings = torch.cat(all_embeddings, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    return all_embeddings, all_labels

# Compute embeddings for CIFAR10
cifar10_embeddings, cifar10_labels = compute_embeddings(cifar10_loader)
print(f"CIFAR10 Embeddings Shape: {cifar10_embeddings.shape}")

Computing embeddings: 100%|██████████| 313/313 [00:15<00:00, 20.22it/s]


CIFAR10 Embeddings Shape: torch.Size([10000, 512])
Logistic Regression Accuracy on CIFAR10: 93.55%


In [None]:
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression

x_train, y_train = cifar10_embeddings.cpu().numpy(), cifar10_labels.cpu().numpy()
x_train, x_test, y_train, y_test = train_test_split(x_train, y_train, test_size=0.2, random_state=42)

# Train a logistic regression classifier
clf = LogisticRegression(max_iter=1000, n_jobs=-1)
clf.fit(x_train, y_train)
# Evaluate the classifier
accuracy = clf.score(x_test, y_test)
print(f"Logistic Regression Accuracy on CIFAR10: {accuracy * 100:.2f}%")

In [7]:
DATASETS = [
    'CIFAR10',
    'Flowers102',
    'Food101',
    'Country211',
    'GTSRB',
    'EuroSAT',
    'DTD',
    'STL10',
]
import os

for dataset in DATASETS:
    # Load the dataset
    dataset_loader = datasets.__getattribute__(dataset)(root='./data', download=True, transform=preprocess)
    dataset_loader = DataLoader(dataset_loader, batch_size=32, shuffle=False)

    # if os.path.exists(f"clip/{dataset}_embeddings.pt") and os.path.exists(f"clip/{dataset}_labels.pt"):
    #     print(f"{dataset} embeddings and labels already exist. Skipping...")
    #     continue

    print(f"Processing {dataset} dataset...")

    # Compute embeddings for the dataset
    dataset_embeddings, dataset_labels = compute_embeddings(dataset_loader)
    print(f"{dataset} Embeddings Shape: {dataset_embeddings.shape}")

    # Save the embeddings and labels
    torch.save(dataset_embeddings, f"clip/{dataset}_embeddings.pt")
    torch.save(dataset_labels, f"clip/{dataset}_labels.pt")
    print(f"Saved {dataset} embeddings and labels.")

Files already downloaded and verified
Processing CIFAR10 dataset...


Computing embeddings: 100%|██████████| 1563/1563 [01:21<00:00, 19.26it/s]


CIFAR10 Embeddings Shape: torch.Size([50000, 512])
Saved CIFAR10 embeddings and labels.
Processing Flowers102 dataset...


Computing embeddings: 100%|██████████| 32/32 [00:06<00:00,  4.65it/s]


Flowers102 Embeddings Shape: torch.Size([1020, 512])
Saved Flowers102 embeddings and labels.
Processing Food101 dataset...


Computing embeddings: 100%|██████████| 2368/2368 [06:48<00:00,  5.80it/s]


Food101 Embeddings Shape: torch.Size([75750, 512])
Saved Food101 embeddings and labels.
Processing Country211 dataset...


Computing embeddings: 100%|██████████| 990/990 [03:24<00:00,  4.85it/s]


Country211 Embeddings Shape: torch.Size([31650, 512])
Saved Country211 embeddings and labels.
Processing GTSRB dataset...


Computing embeddings: 100%|██████████| 833/833 [00:46<00:00, 18.02it/s]


GTSRB Embeddings Shape: torch.Size([26640, 512])
Saved GTSRB embeddings and labels.
Processing EuroSAT dataset...


Computing embeddings: 100%|██████████| 844/844 [00:49<00:00, 17.20it/s]


EuroSAT Embeddings Shape: torch.Size([27000, 512])
Saved EuroSAT embeddings and labels.
Processing DTD dataset...


Computing embeddings: 100%|██████████| 59/59 [00:11<00:00,  5.00it/s]


DTD Embeddings Shape: torch.Size([1880, 512])
Saved DTD embeddings and labels.
Files already downloaded and verified
Processing STL10 dataset...


Computing embeddings: 100%|██████████| 157/157 [00:09<00:00, 17.16it/s]

STL10 Embeddings Shape: torch.Size([5000, 512])
Saved STL10 embeddings and labels.





In [8]:
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression

for dataset in tqdm.tqdm(DATASETS):
    x_train = torch.load(f'clip/{dataset}_embeddings.pt').cpu()
    y_train = torch.load(f'clip/{dataset}_labels.pt').cpu()

    print(x_train.shape, y_train.shape, dataset)

    x_train, x_test, y_train, y_test = train_test_split(x_train, y_train, test_size=0.5, random_state=42)
    x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.2, random_state=42)

    # learn a classifier
    clf = LogisticRegression(max_iter=1000, n_jobs=-1, penalty='l2', C=1.0)
    clf.fit(x_train, y_train)
    print(f"Train accuracy: {(clf.score(x_train, y_train))}")
    print(f"Validation accuracy: {(clf.score(x_val, y_val))}")
    print(f"Test accuracy: {(clf.score(x_test, y_test))}")

  x_train = torch.load(f'clip/{dataset}_embeddings.pt').cpu()
  y_train = torch.load(f'clip/{dataset}_labels.pt').cpu()


torch.Size([50000, 512]) torch.Size([50000]) CIFAR10
Train accuracy: 0.9787
Validation accuracy: 0.9578


 12%|█▎        | 1/8 [00:16<01:52, 16.08s/it]

Test accuracy: 0.94472
torch.Size([1020, 512]) torch.Size([1020]) Flowers102


  x_train = torch.load(f'clip/{dataset}_embeddings.pt').cpu()
  y_train = torch.load(f'clip/{dataset}_labels.pt').cpu()
 25%|██▌       | 2/8 [00:17<00:44,  7.50s/it]

Train accuracy: 1.0
Validation accuracy: 0.7843137254901961
Test accuracy: 0.7411764705882353
torch.Size([75750, 512]) torch.Size([75750]) Food101


  x_train = torch.load(f'clip/{dataset}_embeddings.pt').cpu()
  y_train = torch.load(f'clip/{dataset}_labels.pt').cpu()
 25%|██▌       | 2/8 [01:04<03:12, 32.09s/it]


KeyboardInterrupt: 

In [15]:
def compute_embeddings(data_loader):
    all_embeddings = []
    all_labels = []
    
    # Set the model to evaluation mode
    model.eval()

    with torch.no_grad():
        for images, labels in tqdm.tqdm(data_loader, desc="Computing embeddings"):
            images = images.to(device)
            labels = labels.to(device)
            # Preprocess the images and get embeddings
            inputs = processor(images=images, return_tensors="pt", padding=True)
            inputs = {key: value.to(device) for key, value in inputs.items()}

            outputs = model.get_image_features(**inputs)

            all_embeddings.append(outputs)
            all_labels.append(labels)
    
    # Concatenate all embeddings
    all_embeddings = torch.cat(all_embeddings, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    return all_embeddings, all_labels

# Load CIFAR10 dataset
cifar10_dataset = datasets.Flowers102(root='./data', download=True, transform=transform)
cifar10_loader = DataLoader(cifar10_dataset, batch_size=32, shuffle=False)

# Compute embeddings for CIFAR10
cifar10_embeddings, cifar10_labels = compute_embeddings(cifar10_loader)
print(f"CIFAR10 Embeddings Shape: {cifar10_embeddings.shape}")


Computing embeddings: 100%|██████████| 32/32 [00:13<00:00,  2.42it/s]

CIFAR10 Embeddings Shape: torch.Size([1020, 512])



