In [1]:
import torch
from transformers import CLIPProcessor, CLIPModel
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os

In [2]:
class ImageTextDataset(Dataset):
    def __init__(self, image_dir, processor, max_length=20):
        self.image_dir = image_dir
        self.processor = processor
        self.max_length = max_length
        self.images = os.listdir(image_dir)
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.image_dir, img_name)
        
        # Load and preprocess image
        image = Image.open(img_path).convert("RGB")
        image = self.processor.feature_extractor(image, return_tensors="pt")["pixel_values"].squeeze(0)
        
        # Generate text (class name from image file name)
        label = os.path.splitext(img_name)[0]  # Remove extension
        text = self.processor.tokenizer(
            label,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=self.max_length
        )
        
        return image, text


In [3]:
from transformers import CLIPProcessor, CLIPModel
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")


In [4]:
from torch.utils.data import DataLoader

image_dir = "./data/DAM"
dataset = ImageTextDataset(
    image_dir=image_dir,
    processor=processor
)


In [5]:
def collate_fn(batch):
    images = torch.stack([item[0] for item in batch])  # Stack images
    texts = {key: torch.cat([item[1][key] for item in batch], dim=0) for key in batch[0][1]}  # Combine text components
    return images, texts


In [6]:
data_loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    collate_fn=collate_fn
)


In [7]:
from torch.nn.functional import cosine_similarity
import torch.nn as nn
import torch.optim as optim

optimizer = optim.AdamW(model.parameters(), lr=1e-5)
criterion = nn.CrossEntropyLoss()

for epoch in range(10):  
    model.train()
    total_loss = 0

    for batch in data_loader:
        images, texts = batch
        
        images = images.to(device)
        texts = {key: value.squeeze(1).to(device) for key, value in texts.items()}
        
        outputs = model(**texts, pixel_values=images)
        logits_per_image = outputs.logits_per_image  # Image-text similarity
        logits_per_text = outputs.logits_per_text  # Text-image similarity
        
        labels = torch.arange(len(images)).to(device)
        
        loss = (criterion(logits_per_image, labels) + criterion(logits_per_text, labels)) / 2
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()

    print(f"Epoch {epoch + 1}, Loss: {total_loss / len(data_loader)}")


  attn_output = torch.nn.functional.scaled_dot_product_attention(


Epoch 1, Loss: 1.5724437483425797
Epoch 2, Loss: 0.6113876354420322
Epoch 3, Loss: 0.30487859942789736
Epoch 4, Loss: 0.19437606679810876
Epoch 5, Loss: 0.16725467194685306
Epoch 6, Loss: 0.13028303123796467
Epoch 7, Loss: 0.10534767182704446
Epoch 8, Loss: 0.11539564929345901
Epoch 9, Loss: 0.09400252566469469
Epoch 10, Loss: 0.0796978617429562


Evaluation

In [11]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for batch in data_loader:
        images, texts = batch
        
        # Move tensors to the correct device
        images = images.to(device)
        texts = {key: value.squeeze(1).to(device) for key, value in texts.items()}
        
        # Decode the labels
        #decoded_labels = [processor.tokenizer.decode(value[0], skip_special_tokens=True) for value in texts.values()]
        #print(f"Decoded Labels: {decoded_labels[0]}")
        
        # Model inference
        outputs = model(**texts, pixel_values=images)
        logits_per_image = outputs.logits_per_image  # Image-text similarity
        
        # Predictions
        preds = logits_per_image.argmax(dim=-1)
        labels = torch.arange(len(images)).to(device)
        
        # Print predicted and actual labels
        # for idx, pred_idx in enumerate(preds):
        #     print(f"Image {idx + 1}:")
        #     print(f"  Actual: {decoded_labels[idx]}")
        #     print(f"  Predicted: {decoded_labels[pred_idx]}")
        
        # Accuracy calculation
        correct += (preds == labels).sum().item()
        total += len(images)

print(f"Accuracy: {correct / total:.2%}")


Decoded Labels: 9 4 3 t 0 5 cp 4 3 0 x 0 8 5 4
Decoded Labels: 0 1 4 s 5 9 am 0 0 5 x 4 2 2 0
Decoded Labels: m 1 2 6 5 zrzbm 8 8 4
Decoded Labels: kcb 5 3 5 pcos 1 7 x
Decoded Labels: 0 1 2 j 0 3 a 3 2 3 6 x 0 8 3 5
Decoded Labels: 0 2 jhb 0 7 0 i 6 1 2 c 0 8 4
Decoded Labels: m 0 5 6 5 stbhm 0 4 2
Decoded Labels: s 7 4 0 8 csbwm 9 1 1
Decoded Labels: m 0 5 3 1 opbim 3 1 e
Decoded Labels: kdq 7 8 8 sqps 9 0 0
Decoded Labels: 1 5 6 m 8 8 af 0 1 0 x 1 4 6 0
Decoded Labels: s 9 2 1 9 pmetm 1 9 6
Decoded Labels: m 5 3 1 soaawxm 7 9 b
Decoded Labels: kcq 2 4 4 lucs 7 0 k
Decoded Labels: m 9 2 2 0 uwbvm 9 1 8
Decoded Labels: e 1 5 8 8 trilqd 3 0 7
Decoded Labels: 2 4 1 l 0 6 a 6 1 0 2 x 9 0 0 0
Decoded Labels: cd 1 3 3 5 6 za 0 0 4 0 0 0 0
Decoded Labels: m 0 4 4 7 cangm 2 2 l
Decoded Labels: s 0 8 5 6 owcbm 9 0 0
Decoded Labels: b 0 9 6 1 adrcod 3 5 6
Decoded Labels: b 1 6 7 7 wommtd 3 0 0
Decoded Labels: m 9 3 2 7 umolm 8 1 p
Decoded Labels: m 0 6 8 6 wjiyxm 8 8 5
Decoded Labels: cdbc 0 4