In [1]:
import torch
from transformers import CLIPProcessor, CLIPModel
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os

In [2]:
class ImageTextDataset(Dataset):
    def __init__(self, image_dir, processor, max_length=20):
        self.image_dir = image_dir
        self.processor = processor
        self.max_length = max_length
        self.images = os.listdir(image_dir)
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.image_dir, img_name)
        
        # Load and preprocess image
        image = Image.open(img_path).convert("RGB")
        image = self.processor.feature_extractor(image, return_tensors="pt")["pixel_values"].squeeze(0)
        
        # Generate text (class name from image file name)
        label = os.path.splitext(img_name)[0]  # Remove extension
        text = self.processor.tokenizer(
            label,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=self.max_length
        )
        
        return image, text


In [3]:
from transformers import CLIPProcessor, CLIPModel
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")


In [4]:
from torch.utils.data import DataLoader

image_dir = "./data/DAM"
dataset = ImageTextDataset(
    image_dir=image_dir,
    processor=processor
)


In [5]:
def collate_fn(batch):
    images = torch.stack([item[0] for item in batch])  # Stack images
    texts = {key: torch.cat([item[1][key] for item in batch], dim=0) for key in batch[0][1]}  # Combine text components
    return images, texts


In [6]:
data_loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    collate_fn=collate_fn
)


In [7]:
from torch.nn.functional import cosine_similarity
import torch.nn as nn
import torch.optim as optim

optimizer = optim.AdamW(model.parameters(), lr=1e-5)
criterion = nn.CrossEntropyLoss()

for epoch in range(10):  
    model.train()
    total_loss = 0

    for batch in data_loader:
        images, texts = batch
        
        images = images.to(device)
        texts = {key: value.squeeze(1).to(device) for key, value in texts.items()}
        
        outputs = model(**texts, pixel_values=images)
        logits_per_image = outputs.logits_per_image  # Image-text similarity
        logits_per_text = outputs.logits_per_text  # Text-image similarity
        
        labels = torch.arange(len(images)).to(device)
        
        loss = (criterion(logits_per_image, labels) + criterion(logits_per_text, labels)) / 2
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()

    print(f"Epoch {epoch + 1}, Loss: {total_loss / len(data_loader)}")


  attn_output = torch.nn.functional.scaled_dot_product_attention(


Epoch 1, Loss: 1.652546853512183
Epoch 2, Loss: 0.6499134613864723
Epoch 3, Loss: 0.36779743407307
Epoch 4, Loss: 0.26217492669820786
Epoch 5, Loss: 0.20387815919586982
Epoch 6, Loss: 0.164385296832556
Epoch 7, Loss: 0.142594280641997
Epoch 8, Loss: 0.10863340174078127
Epoch 9, Loss: 0.08928976021171815
Epoch 10, Loss: 0.09676356619256067


Evaluation

In [8]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for batch in data_loader:
        images, texts = batch
        
        # Move tensors to the correct device
        images = images.to(device)
        texts = {key: value.squeeze(1).to(device) for key, value in texts.items()}
        
        # Decode the labels
        decoded_labels = [processor.tokenizer.decode(value[0], skip_special_tokens=True) for value in texts.values()]
        print(f"Decoded Labels: {decoded_labels}")
        
        # Model inference
        outputs = model(**texts, pixel_values=images)
        logits_per_image = outputs.logits_per_image  # Image-text similarity
        
        # Predictions
        preds = logits_per_image.argmax(dim=-1)
        labels = torch.arange(len(images)).to(device)
        
        # Print predicted and actual labels
        for idx, pred_idx in enumerate(preds):
            print(f"Image {idx + 1}:")
            print(f"  Actual: {decoded_labels[idx]}")
            print(f"  Predicted: {decoded_labels[pred_idx]}")
        
        # Accuracy calculation
        correct += (preds == labels).sum().item()
        total += len(images)

print(f"Accuracy: {correct / total:.2%}")


Decoded Labels: ['m 5 0 0 spaawxm 1 6 y', '""""""""""""!!!!!!!!']
Image 1:
  Actual: m 5 0 0 spaawxm 1 6 y
  Predicted: m 5 0 0 spaawxm 1 6 y
Image 2:
  Actual: """"""""""""!!!!!!!!
  Predicted: """"""""""""!!!!!!!!
Image 3:


IndexError: list index out of range