In [None]:
import os
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from transformers import CLIPModel, CLIPProcessor

# load CLIP
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# freeze!
for param in clip_model.parameters():
    param.requires_grad = False

# unfreeze attention pooling layer and projection layer
for param in clip_model.vision_model.post_layernorm.parameters():
    param.requires_grad = True

for param in clip_model.visual_projection.parameters():
    param.requires_grad = True

# print trainable parameters
for name, param in clip_model.named_parameters():
    if param.requires_grad:
        print("Will train:", name)

Will train: vision_model.post_layernorm.weight
Will train: vision_model.post_layernorm.bias
Will train: visual_projection.weight


In [None]:
class TrafficSignDataset(Dataset):
    def __init__(self, image_dir, ground_truth, label2id):
        self.image_dir = image_dir
        self.ground_truth = ground_truth
        self.label2id = label2id
        self.image_files = list(ground_truth.keys())

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        fname = self.image_files[idx]
        label_name = self.ground_truth[fname]
        label_id = self.label2id[label_name]
        image_path = os.path.join(self.image_dir, fname)
        image = Image.open(image_path).convert("RGB")
        return image, label_id

In [None]:
with open("ground_truth.json", "r") as f:
    ground_truth = json.load(f)

label_names = sorted(set(ground_truth.values()))
label2id = {name: idx for idx, name in enumerate(label_names)}
id2label = {idx: name for name, idx in label2id.items()}
NUM_CLASSES = len(label2id)

def collate_fn(batch):
    images, labels = zip(*batch)
    return list(images), torch.tensor(labels)

image_folder = "cn"
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

dataset = TrafficSignDataset(image_folder, ground_truth, label2id)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)

clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

In [10]:
class CLIPFineTuner(nn.Module):
    def __init__(self, clip_model, num_classes):
        super().__init__()
        self.clip = clip_model
        self.classifier = nn.Linear(clip_model.config.projection_dim, num_classes)

    def forward(self, pixel_values):
        features = self.clip.get_image_features(pixel_values=pixel_values)
        return self.classifier(features)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CLIPFineTuner(clip_model, num_classes=NUM_CLASSES).to(device)

optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()

epoch = 10

In [11]:
for epoch in range(epoch):
    model.train()
    for images, labels in dataloader:
        inputs = processor(images=images, return_tensors="pt", padding=True).to(device)
        logits = model(inputs["pixel_values"])
        loss = loss_fn(logits, labels.to(device))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


In [12]:
torch.save(model.state_dict(), "clip_finetuned.pth")