In [1]:
import torch
import wandb
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import os
import pickle
from util_div import *
from util_model import *
from util_visualise_embeddings import *
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR

In [2]:
import torch
import os
import pickle

def generate_embeddings(data_loader, model, device, max_per_brand=None):
    model.eval()
    embeddings = {}
    brand_count = {}

    with torch.no_grad():
        for data in data_loader:
            images, labels, brands = data
            images = images.to(device)
            out = model(images).cpu().numpy()

            for i, brand in enumerate(brands):
                if brand not in embeddings:
                    embeddings[brand] = []
                    brand_count[brand] = 0
                if max_per_brand is None or brand_count[brand] < max_per_brand:
                    embeddings[brand].append(out[i])
                    brand_count[brand] += 1

    return embeddings

# Save embeddings to file
def save_embeddings(embeddings, file_path):
    with open(file_path, 'wb') as handle:
        pickle.dump(embeddings, handle, protocol=pickle.HIGHEST_PROTOCOL)



In [3]:
def load_model(model_path, device):
    model = WatchEmbeddingModel()  # Replace with your actual model class
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    return model


In [4]:
from PIL import Image
import torchvision.transforms as transforms
from sklearn.neighbors import NearestNeighbors
import numpy as np

def load_image(image_path):
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    image = Image.open(image_path)
    image = transform(image).unsqueeze(0)  # Add batch dimension
    return image

def find_similar_watches(image_path, model, embeddings, device, k=5):
    image = load_image(image_path)
    image = image.to(device)

    model.eval()
    with torch.no_grad():
        output = model(image).cpu().numpy().flatten()

    all_embeddings = []
    labels = []
    for brand, embs in embeddings.items():
        for emb in embs:
            all_embeddings.append(emb)
            labels.append(brand)

    # Fit k-NN
    nbrs = NearestNeighbors(n_neighbors=k, algorithm='ball_tree').fit(all_embeddings)
    distances, indices = nbrs.kneighbors([output])

    return [(labels[idx], distances[0][i]) for i, idx in enumerate(indices[0])]


In [5]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import pandas as pd

class WatchDataset(Dataset):
    def __init__(self, csv_file=None, root_dir=None, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Transform to be applied on a sample.
        """
        if csv_file:
            self.watch_frame = pd.read_csv(csv_file)
        else:
            self.watch_frame = pd.DataFrame([
                {'image_path': os.path.join(dp, f), 'brand': dp.split('/')[-1]}
                for dp, dn, filenames in os.walk(root_dir) for f in filenames if os.path.splitext(f)[1].lower() in ['.png', '.jpg', '.jpeg']
            ])
        self.transform = transform

    def __len__(self):
        return len(self.watch_frame)

    def __getitem__(self, idx):
        img_name = self.watch_frame.iloc[idx, 0]
        image = Image.open(img_name).convert('RGB')
        brand = self.watch_frame.iloc[idx, 1]

        if self.transform:
            image = self.transform(image)

        return image, brand

# Define transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create your dataset
dataset = WatchDataset(csv_file='path/to/your/csv_file.csv', transform=transform)

# Create DataLoader
data_loader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)


FileNotFoundError: [Errno 2] No such file or directory: 'path/to/your/csv_file.csv'

In [None]:
# Example usage:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_model("path_to_model.pth", device)
embeddings = generate_embeddings(data_loader, model, device, max_per_brand=2)
save_embeddings(embeddings, "embeddings.pkl")

# To find similar watches:
similar_watches = find_similar_watches("path_to_watch_image.jpg", model, embeddings, device, k=5)
print(similar_watches)
