In [None]:
import pandas as pd
import numpy as np
import json
import os
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import random
import torch
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_everything()
%matplotlib inline


In [None]:
class FashionProductDataSet(Dataset):
    def __init__(self,image_dir,style_dir,transform=None):
        self.image_dir = image_dir
        self.style_dir = style_dir
        self.transform = transform
        self.image_files=[f for f in os.listdir(self.image_dir) if f.endswith('.jpg')]
        self.image_files = self.image_files[:750]  # In __init__ of your Dataset

    def __len__(self):
        return len(self.image_files)
    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.image_dir, img_name)
        json_path = os.path.join(self.style_dir, img_name.replace(".jpg", ".json"))

        try:
            image = Image.open(img_path).convert("RGB")
        except:
            return None  # Skip unreadable images

        # Read and parse JSON
        with open(json_path, "r") as f:
            data = json.load(f)
        # Example: buidata = metadata["data"]  # if you're accessing under 'data' key

        brand = data.get("brandName", "")
        gender = data.get("gender", "")
        age_group = data.get("ageGroup", "")
        article_type = data.get("articleType", {}).get("typeName", "")
        base_colour = data.get("baseColour", "")
        season = data.get("season", "")
        year = data.get("year", "")
        category = data.get("masterCategory", {}).get("typeName", "")
        sub_category = data.get("subCategory", {}).get("typeName", "")
        usage = data.get("usage", "")
        pattern = data.get("articleAttributes", {}).get("Pattern", "")
        desc = data.get("productDisplayName", "")
        price = data.get("price", "")
        discounted_price = data.get("discountedPrice", "")
        rating = data.get("myntraRating", "")
        variant = data.get("variantName", "")
        article_number = data.get("articleNumber", "")
        fashion_type = data.get("fashionType", "")
        display_categories = data.get("displayCategories", "")
        vat = data.get("vat", "")
        landing_page = data.get("landingPageUrl", "")
        # Article attributes (add more if needed)
        body_size = data.get("articleAttributes", {}).get("Body or Garment Size", "")
        # Material extraction (example)
        material_desc = data.get("productDescriptors", {}).get("description", {}).get("value", "")
        material = "polyester and spandex" if "polyester" in material_desc.lower() else ""
        # Purpose
        purpose = "sportswear" if "sports" in usage.lower() else "casualwear"

        

# Final rich text
        full_text = f"{brand} {category} {sub_category} {gender.lower()} {pattern.lower()} {base_colour.lower()} {article_type.lower()} made of {material} for {season.lower()} {purpose}. {desc}".strip()
        
        if not full_text:
            return None  # Skip empty text

        if self.transform:
            image = self.transform(image)

        return image, full_text

        

In [8]:

from torch.utils.data import DataLoader
from torchvision import transforms

import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

clip_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
])

dataset = FashionProductDataSet(
    r'D:\archive\fashion-dataset\fashion-dataset\images\\',
    r'D:\archive\fashion-dataset\fashion-dataset\styles\\',
    transform=clip_transform
)
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
# After creating your dataset object
image_filenames = dataset.image_files
texts = [dataset[i][1] for i in range(len(dataset))]


In [9]:
import matplotlib.pyplot as plt
import torch

# Unnormalize function for CLIP transform
def unnormalize(img_tensor):
    mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(3,1,1)
    std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(3,1,1)
    return img_tensor * std + mean



In [10]:
import clip
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

# for images, texts in dataloader:
#     images = images.to(device)
#     texts = clip.tokenize(texts).to(device)
    
#     with torch.no_grad():
#         image_features = model.encode_image(images)
#         text_features = model.encode_text(texts)

In [14]:
import torch.nn.functional as F
all_image_features = []
all_text_features = []
for images, texts in dataloader:
    images = images.to(device)
    text_tokens = clip.tokenize(texts).to(device)
    with torch.no_grad():
        img_feat = model.encode_image(images)
        txt_feat = model.encode_text(text_tokens)
        img_feat = F.normalize(img_feat, dim=-1)
        txt_feat = F.normalize(txt_feat, dim=-1)
    all_image_features.append(img_feat.cpu())
    all_text_features.append(txt_feat.cpu())

all_image_features = torch.cat(all_image_features, dim=0)
all_text_features = torch.cat(all_text_features, dim=0)

In [None]:
print(dataset.__getitem__(0))

(tensor([[[1.9303, 1.9303, 1.9303,  ..., 1.9303, 1.9303, 1.9303],
         [1.9303, 1.9303, 1.9303,  ..., 1.9303, 1.9303, 1.9303],
         [1.9303, 1.9303, 1.9303,  ..., 1.9303, 1.9303, 1.9303],
         ...,
         [1.9303, 1.9303, 1.9303,  ..., 1.9303, 1.9303, 1.9303],
         [1.9303, 1.9303, 1.9303,  ..., 1.9303, 1.9303, 1.9303],
         [1.9303, 1.9303, 1.9303,  ..., 1.9303, 1.9303, 1.9303]],

        [[2.0749, 2.0749, 2.0749,  ..., 2.0749, 2.0749, 2.0749],
         [2.0749, 2.0749, 2.0749,  ..., 2.0749, 2.0749, 2.0749],
         [2.0749, 2.0749, 2.0749,  ..., 2.0749, 2.0749, 2.0749],
         ...,
         [2.0749, 2.0749, 2.0749,  ..., 2.0749, 2.0749, 2.0749],
         [2.0749, 2.0749, 2.0749,  ..., 2.0749, 2.0749, 2.0749],
         [2.0749, 2.0749, 2.0749,  ..., 2.0749, 2.0749, 2.0749]],

        [[2.1459, 2.1459, 2.1459,  ..., 2.1459, 2.1459, 2.1459],
         [2.1459, 2.1459, 2.1459,  ..., 2.1459, 2.1459, 2.1459],
         [2.1459, 2.1459, 2.1459,  ..., 2.1459, 2.1459, 2

In [15]:
import torch.nn.functional as F

image_features = F.normalize(all_image_features, dim=-1)
text_features = F.normalize(all_text_features, dim=-1)

similarity = image_features @ text_features.T  # shape [batch, batch]

print("Image feaetures:\n",image_features.shape)
print(text_features.shape)

print(similarity.shape)

Image feaetures:
 torch.Size([750, 512])
torch.Size([750, 512])
torch.Size([750, 750])


In [16]:
matches = similarity.argmax(dim=1)
correct = torch.arange(len(matches)).to(matches.device)
accuracy = (matches == correct).float().mean()
print(f"Matching accuracy: {accuracy:.2f}")

Matching accuracy: 0.00


In [None]:
import faiss
image_features = F.normalize(image_features, dim=1)
index = faiss.IndexFlatIP(image_features.shape[1])  # Cosine similarity
index.add(image_features.cpu().numpy())


In [None]:
query = "red sleeveless t-shirt"
text_token = clip.tokenize([query]).to(device)

with torch.no_grad():
    query_embed = model.encode_text(text_token)
    query_embed = F.normalize(query_embed, dim=-1)

# Search
D, I = index.search(query_embed.cpu().numpy(), k=5)  # top-5 results
print(I[0],'\n', D[0])

[1124  339  461 1098   72] 
 [0.3273982  0.323642   0.31344527 0.31073946 0.30726647]


In [None]:
print(device)

cuda
