In [None]:

import os
import torch
import clip
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from sklearn.metrics import accuracy_score  # Import accuracy_score
import numpy as np
from PIL import Image

In [291]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

In [292]:
class CustomImageDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

In [293]:

# Prepare data
super_dir = r'resources\super'
sadyek_dir = r'resources\sadyek'

In [None]:
#importing image
super_images = [os.path.join(super_dir, f) for f in os.listdir(super_dir) if f.endswith('.bmp')]
sadyek_images = [os.path.join(sadyek_dir, f) for f in os.listdir(sadyek_dir) if f.endswith('.bmp')]
train_size =8
image_paths = super_images[:train_size] + sadyek_images[:train_size]
labels = [0] * len(super_images[:train_size]) + [1] * len(sadyek_images[:train_size])  # 0 for super, 1 for sadyek


In [None]:
#load data
dataset = CustomImageDataset(image_paths, labels, transform=preprocess)
data_loader = DataLoader(dataset, batch_size=4, shuffle=True)


In [None]:
# Create text labels
text_labels = ["A dried fig with a smooth, unbroken surface, slightly flattened and firm.",
 "A dried fig with partially opened, darker segments, exposing its rich, chewy interior, slightly deformed yet irresistibly delicious."]
text_inputs = torch.cat([clip.tokenize(label) for label in text_labels]).to(device)

In [297]:


# Fine-tuning setup
optimizer = torch.optim.Adam(model.parameters(), lr=.000003, weight_decay=1e-4)
# criterion = torch.nn.BCEWithLogitsLoss()
criterion = torch.nn.CrossEntropyLoss()

In [None]:

num_epochs = 10
cnt =0
for epoch in range(num_epochs):
    model.train()   
    for i, (images, targets) in enumerate(data_loader):
        optimizer.zero_grad()
        images, targets = images.to(device), targets.to(device)
        # Encode images and text
        image_features = model.encode_image(images).to(device)
        text_features = model.encode_text(text_inputs).to(device)

        logits_per_image = torch.matmul(image_features,text_features.T).to(device)
        #compute loss
        loss = criterion(logits_per_image, targets)

        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')


Epoch [1/10], Loss: 1.2269
Epoch [2/10], Loss: 0.7897
Epoch [3/10], Loss: 0.2364
Epoch [4/10], Loss: 0.3355
Epoch [5/10], Loss: 0.0481
Epoch [6/10], Loss: 0.0395
Epoch [7/10], Loss: 0.0072
Epoch [8/10], Loss: 0.0070
Epoch [9/10], Loss: 0.0010
Epoch [10/10], Loss: 0.0008


In [299]:
# Evaluation function
def evaluate(model, data_loader):
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            
            # Encode images and text
            image_features = model.encode_image(images)
            text_features = model.encode_text(text_inputs)
            
            # Calculate similarities and logits
            logits_per_image = image_features @ text_features.T
            _, preds = torch.max(logits_per_image, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    accuracy = accuracy_score(all_labels, all_preds)
    return accuracy

In [300]:

image_paths = super_images + sadyek_images
labels = [0] * len(super_images) + [1] * len(sadyek_images) 
dataset = CustomImageDataset(image_paths, labels, transform=preprocess)
data_loader = DataLoader(dataset, batch_size=8, shuffle=False)
# Calculate accuracy
accuracy = evaluate(model, data_loader)
print(f'Accuracy: {accuracy * 100:.2f}%')

Accuracy: 93.76%
