## Installing required dependencies:

In [None]:
pip install torch torchvision open-clip-torch matplotlib pandas tqdm scikit-learn
# also get open-clip


## Building the Dataset Class

In [None]:
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os

class FashionCLIPDataset(Dataset):
    def __init__(self, csv_path, image_folder, preprocess, max_samples=None):
        df = pd.read_csv(csv_path)
        df = df.dropna(subset=['productDisplayName'])
        if max_samples:
            df = df.sample(n=max_samples, random_state=42)
        self.df = df
        self.image_folder = image_folder
        self.preprocess = preprocess

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_path = os.path.join(self.image_folder, f"{row['id']}.jpg")
        try:
            image = Image.open(image_path).convert("RGB")
        except:
            return None
        image = self.preprocess(image)
        text = row['productDisplayName']
        return image, text


## Training + Testing Loop with Metrics

In [None]:
import torch
import open_clip
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load model
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2B-s34B-b79K')
tokenizer = open_clip.get_tokenizer('ViT-B-32')
model.to(device)

# Split dataset
csv_path = "fashion-dataset/styles.csv"
image_folder = "fashion-dataset/images"
train_df, val_df = train_test_split(pd.read_csv(csv_path).dropna(subset=['productDisplayName']), test_size=0.1, random_state=42)
train_df.to_csv("train.csv", index=False)
val_df.to_csv("val.csv", index=False)

train_dataset = FashionCLIPDataset("train.csv", image_folder, preprocess, max_samples=5000)
val_dataset = FashionCLIPDataset("val.csv", image_folder, preprocess, max_samples=1000)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
loss_fn = torch.nn.CrossEntropyLoss()

train_losses, val_losses, accuracies = [], [], []

for epoch in range(5):  # adjust as needed
    model.train()
    total_loss, total_correct, total_samples = 0, 0, 0
    
    for batch in tqdm(train_loader, desc=f"Training Epoch {epoch+1}"):
        if batch is None: continue
        images, texts = batch
        images = images.to(device)
        texts = tokenizer(texts).to(device)

        image_features = model.encode_image(images)
        text_features = model.encode_text(texts)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)

        logits_per_image = image_features @ text_features.T
        logits_per_text = logits_per_image.T

        labels = torch.arange(len(images), device=device)
        loss = (loss_fn(logits_per_image, labels) + loss_fn(logits_per_text, labels)) / 2

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        preds = logits_per_image.argmax(dim=1)
        total_correct += (preds == labels).sum().item()
        total_samples += len(images)
    
    avg_train_loss = total_loss / len(train_loader)
    train_acc = total_correct / total_samples
    train_losses.append(avg_train_loss)
    accuracies.append(train_acc)

    # ---- Validation ----
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, texts in tqdm(val_loader, desc="Validation"):
            images = images.to(device)
            texts = tokenizer(texts).to(device)

            image_features = model.encode_image(images)
            text_features = model.encode_text(texts)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            text_features /= text_features.norm(dim=-1, keepdim=True)

            logits = image_features @ text_features.T
            labels = torch.arange(len(images), device=device)
            val_loss += loss_fn(logits, labels).item()

    val_losses.append(val_loss / len(val_loader))

    print(f"Epoch {epoch+1}: Train Loss={avg_train_loss:.4f} | Val Loss={val_losses[-1]:.4f} | Accuracy={train_acc:.4f}")


## Plot Training Results

In [None]:
plt.figure(figsize=(12,5))

plt.subplot(1,2,1)
plt.plot(train_losses, label="Train Loss")
plt.plot(val_losses, label="Val Loss")
plt.title("Loss over Epochs")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()

plt.subplot(1,2,2)
plt.plot(accuracies, label="Accuracy", color="green")
plt.title("Training Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()

plt.show()


## Text → Image

In [None]:
query = "red floral dress"
text_features = model.encode_text(tokenizer([query]).to(device))
text_features /= text_features.norm(dim=-1, keepdim=True)

# Compare to all validation images
image_features = []
paths = []
for img, txt in val_dataset:
    if img is None: continue
    with torch.no_grad():
        emb = model.encode_image(img.unsqueeze(0).to(device))
        emb /= emb.norm(dim=-1, keepdim=True)
    image_features.append(emb.cpu())
image_features = torch.vstack(image_features)
scores = (image_features @ text_features.cpu().T).squeeze()
best_idx = scores.argmax().item()
print(f"Best match: {val_dataset.df.iloc[best_idx]['productDisplayName']}")


## Image → Image

In [None]:
query_image = val_dataset[0][0].unsqueeze(0).to(device)
query_emb = model.encode_image(query_image)
query_emb /= query_emb.norm(dim=-1, keepdim=True)

sims = (image_features @ query_emb.cpu().T).squeeze()
top5 = sims.topk(5).indices
print(val_dataset.df.iloc[top5]['productDisplayName'])
