In [3]:
from torch.utils.data import Dataset
from PIL import Image
import os

class ImageTextDataset(Dataset):
    def __init__(self, image_dir, transform, tokenizer):
        self.image_dir = image_dir
        self.transform = transform
        self.tokenizer = tokenizer
        self.max_length = 20
        self.images = os.listdir(image_dir)
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.image_dir, img_name)
        
        # Load and preprocess image
        image = Image.open(img_path).convert("RGB")
        image = self.transform(image)
        
        # Generate text (class name from image file name)
        label = os.path.splitext(img_name)[0]  # Remove extension
        text = self.tokenizer(
            label,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=self.max_length
        )
        
        return image, text


In [3]:
from transformers import CLIPProcessor, CLIPModel
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")


In [4]:
from torch.utils.data import DataLoader

image_dir = "./data/DAM"
dataset = ImageTextDataset(
    image_dir=image_dir,
    transform=processor.feature_extractor,
    tokenizer=processor.tokenizer
)
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)




In [5]:
for batch in data_loader:
    images,text = batch 
    print(text)

('91GLO750X500C525', 'M0565PAUGXM01B', '241V52A1162X9000', 'CD04355X10040000', 'E1324TRIGMD024', 'M1286ZRIOM49E', '121V45A1212X9000', '248C89AL813X0200', 'M9203UMOSM02E', 'CD124BH1C0010000', '213P03A4077X5881', 'KDP916LNYS27G', 'B0077UVWVM49P', 'N1112ADRCYD908', '121B24A3356X0100', 'E1504TRICYD665', 'M500SOCYMXM253', 'M0538OCEAM900', '051B46A7020X5840', '211J57A3863X1879', '321P05A3607X9000', 'S2088UMMWM900', 'M0566ICBXM35U', 'M1286ZRHZM911', '021R53A7722X1851', 'KCP859SRVS16T', 'S0918OAYCM17A', 'B1330ADRMTD304', '146V19AF010X0201', 'KCP712SWLS66K', 'M500SOAUGXM66F', '011J22X8801X9000')
('S2012CNMJM41R', 'KCP261LSRS12U', '221J90A3846X5813', 'KCI781VEAS17X', '14LEO106I600C413', 'KCK311VEAS900', 'M0455BTIXM911', '022P01A3154X9374', 'CLUBM4UWR45A0', '140C13A0006X0200', 'M0500JAWAXM41G', 'E1146CDLCYD301', 'KCI717VEAS900', '151L05A1166X9000', 'M9333UMOFM45U', 'S0841OVRBM74P', 'M0566PCYMXM62P', 'S0617OVKKM85B', 'M500SPAAWXM821', 'S0007ONMJM81P', 'E1031ABCCYD301', 'M500SOAZEXM40M', '257J64A37

: 

In [5]:
from torch.nn.functional import cosine_similarity
import torch.nn as nn
import torch.optim as optim

optimizer = optim.AdamW(model.parameters(), lr=1e-5)
criterion = nn.CrossEntropyLoss()

for epoch in range(10):  
    model.train()
    total_loss = 0

    for batch in data_loader:
        images, texts = batch
        
        images = images.to(device)
        texts = {key: value.squeeze(1).to(device) for key, value in texts.items()}
        
        outputs = model(**texts, pixel_values=images)
        logits_per_image = outputs.logits_per_image  # Image-text similarity
        logits_per_text = outputs.logits_per_text  # Text-image similarity
        
        labels = torch.arange(len(images)).to(device)
        
        loss = (criterion(logits_per_image, labels) + criterion(logits_per_text, labels)) / 2
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()

    print(f"Epoch {epoch + 1}, Loss: {total_loss / len(data_loader)}")


RuntimeError: stack expects each tensor to be equal size, but got [1, 12] at entry 0 and [1, 13] at entry 1

Evaluation

In [None]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for batch in data_loader:
        images, texts = batch
        
        images = images.to(device)
        texts = {key: value.squeeze(1).to(device) for key, value in texts.items()}
        
        outputs = model(**texts, pixel_values=images)
        logits_per_image = outputs.logits_per_image  # Image-text similarity
        
        preds = logits_per_image.argmax(dim=-1)
        labels = torch.arange(len(images)).to(device)
        
        correct += (preds == labels).sum().item()
        total += len(images)

print(f"Accuracy: {correct / total:.2%}")
