In [61]:
import os
import torch
import clip
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from sklearn.metrics import accuracy_score
from PIL import Image

In [62]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load CLIP model
model, preprocess = clip.load("ViT-B/32", device=device)

In [63]:
# Dataset class
class CustomImageDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

In [64]:
# Prepare data
super_dir = r'resources\super'
sadyek_dir = r'resources\sadyek'

super_images = [os.path.join(super_dir, f) for f in os.listdir(super_dir) if f.endswith('.bmp')]
sadyek_images = [os.path.join(sadyek_dir, f) for f in os.listdir(sadyek_dir) if f.endswith('.bmp')]

train_size = 8
image_paths = super_images[:train_size] + sadyek_images[:train_size]
labels = [0] * train_size + [1] * train_size

# Dataset and DataLoader
dataset = CustomImageDataset(image_paths, labels, transform=preprocess)
data_loader = DataLoader(dataset, batch_size=4, shuffle=True)

In [65]:
context_length = 5
class LearnableContext(torch.nn.Module):
    def __init__(self, context_length, embedding_dim):
        super().__init__()
        self.context_tokens = torch.nn.Parameter(
            torch.randn(context_length, embedding_dim) * 0.02
        )  # Random initialization
    
    def forward(self, class_embeddings):
        # Concatenate learnable context with class embeddings
        return torch.cat([self.context_tokens, class_embeddings], dim=0)

In [None]:
# Initialize learnable context
context = LearnableContext(context_length, model.token_embedding.weight.shape[1]).to(device)
# Prepare class embeddings
class_texts = ["A dried fig with a smooth, unbroken surface.", 
               "A dried fig with partially opened segments exposing its interior."]
class_embeddings = model.encode_text(clip.tokenize(class_texts).to(device))

512
torch.Size([5, 512])


In [67]:
# Optimizer and loss function
optimizer = torch.optim.Adam([context.context_tokens], lr=5*1e-2, weight_decay=1e-4)
criterion = torch.nn.CrossEntropyLoss()

In [68]:
# Training loop
num_epochs = 12
for epoch in range(num_epochs):
    model.eval()  # Freeze CLIP model
    context.train()  # Train only the context tokens

    for images, targets in data_loader:
        images, targets = images.to(device), targets.to(device)
        optimizer.zero_grad()

        # Encode images and context-augmented text
        image_features = model.encode_image(images)
        text_features = context(class_embeddings)  # Augmented with learnable context

        # Calculate similarities and loss
        logits_per_image = torch.matmul(image_features, text_features.T)
        loss = criterion(logits_per_image, targets)

        loss.backward(retain_graph=True)
        optimizer.step()
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

Epoch [1/12], Loss: 8.2494
Epoch [2/12], Loss: 0.7945
Epoch [3/12], Loss: 0.1216
Epoch [4/12], Loss: 0.0192
Epoch [5/12], Loss: 0.0016
Epoch [6/12], Loss: 0.0121
Epoch [7/12], Loss: 0.0022
Epoch [8/12], Loss: 0.0010
Epoch [9/12], Loss: 0.0008
Epoch [10/12], Loss: 0.0026
Epoch [11/12], Loss: 0.0022
Epoch [12/12], Loss: 0.0051


In [69]:
def evaluate(model, context, data_loader):
    model.eval()
    context.eval()
    all_preds, all_labels = [], []

    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)

            # Encode images and context-augmented text
            image_features = model.encode_image(images)
            text_features = context(class_embeddings)

            logits_per_image = torch.matmul(image_features, text_features.T)
            _, preds = torch.max(logits_per_image, 1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_preds)
    return accuracy

In [70]:
# Evaluate on all data
full_image_paths = super_images + sadyek_images
full_labels = [0] * len(super_images) + [1] * len(sadyek_images)
full_dataset = CustomImageDataset(full_image_paths, full_labels, transform=preprocess)
full_loader = DataLoader(full_dataset, batch_size=8, shuffle=False)

accuracy = evaluate(model, context, full_loader)
print(f"Accuracy: {accuracy * 100:.2f}%")

Accuracy: 83.90%
