In [None]:
!pip install torch transformers datasets tqdm ipywidgets hf_transfer

In [None]:
%pip install torch transformers datasets tqdm ipywidgets hf_transfer

In [None]:
from datasets import load_dataset
from transformers import CLIPProcessor, CLIPModel
import torch

In [None]:
device = "cuda"
# device = torch.device("mps")

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", dtype=torch.float32)
model.to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)

ds_load_name = ["tanganke/dtd", "tanganke/eurosat", "tanganke/gtsrb", "ylecun/mnist", "tanganke/resisc45", "tanganke/stanford_cars", "tanganke/sun397", "ufldl-stanford/svhn"]
dataset_name = ["DTD", "EuroSAT", "GTSRB", "MNIST", "RESISC45", "Stanford_Cars", "SUN397", "SVHN"]

In [None]:
for i, name in enumerate(dataset_name):
    if name == "SVHN":
        ds = load_dataset(ds_load_name[i], "cropped_digits", cache_dir="/workspace/.hf_cache")
    else:
        ds = load_dataset(ds_load_name[i], cache_dir="/workspace/.hf_cache")

    features = ds["train"].features
    label_names = features["label"].names

    def add_heading(sample):
        label = f"a photo of {label_names[sample["label"]]}"
        sample["text"] = label
        return sample
    
    split = ds["train"].train_test_split(train_size=0.8, seed=6)
    train = split["train"].map(add_heading)
    val = split["test"].map(add_heading)
    test = ds["test"]

    train_dataset = train.map(lambda samples: processor(text=samples["text"], images=samples["image"], padding="max_length", max_length=77, truncation=True, return_tensors="pt"), batched=True)
    val_dataset = val.map(lambda samples: processor(text=samples["text"], images=samples["image"], padding="max_length", max_length=77, truncation=True, return_tensors="pt"), batched=True)
    test_dataset = test.map(lambda samples: processor(images=samples["image"], return_tensors="pt"), batched=True)

    train_dataset = train_dataset.remove_columns(["image", "text"])
    val_dataset = val_dataset.remove_columns(["image", "text"])
    test_dataset = test_dataset.remove_columns(["image"])

    train_dataset.save_to_disk(f"data/{name}/train")
    val_dataset.save_to_disk(f"data/{name}/val")
    test_dataset.save_to_disk(f"data/{name}/test")

    features = ds["test"].features
    label_names = features["label"].names
    all_label_texts = [f"a photo of {label}" for label in label_names]
    label_encodings = processor(text=all_label_texts, padding="max_length", max_length=77, truncation=True, return_tensors="pt")
    label_input_ids = label_encodings.input_ids.to(device)
    label_attention_mask = label_encodings.attention_mask.to(device)

    with torch.no_grad():
        all_label_embeds = model.get_text_features(input_ids=label_input_ids, attention_mask=label_attention_mask)
        all_label_embeds /= all_label_embeds.norm(dim=-1, keepdim=True)

    torch.save(all_label_embeds.cpu(), f"data/{name}/all_label_embeds.pt")