In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from dataset.polyvore import PolyvoreDataset
from model.resnet import SiameseNetwork
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

In [2]:
data_dir = '/home/abdelrahman/fashion-matching/fashion-compatibility/data/polyvore_outfits'
# 1. Data Augmentation Transforms to be used in the Siamese Network model for creating positive samples
augmented_img_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),  # Adjust size as needed
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet stats
])

# For the creating the anchor and negative samples, use simpler transforms 
img_transforms = transforms.Compose([
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

img_train_dataset = PolyvoreDataset(data_dir=data_dir, dataset_type='train', img_transforms=img_transforms, augmented_img_transforms=augmented_img_transforms, target='image')
img_val_dataset = PolyvoreDataset(data_dir=data_dir, dataset_type='valid', img_transforms=img_transforms, augmented_img_transforms=augmented_img_transforms, target='image')
img_test_dataset = PolyvoreDataset(data_dir=data_dir, dataset_type='test', img_transforms=img_transforms, augmented_img_transforms=augmented_img_transforms, target='image')

img_train_loader = DataLoader(img_train_dataset, batch_size=64, shuffle=True)
img_val_loader = DataLoader(img_val_dataset, batch_size=64, shuffle=False)
img_test_loader = DataLoader(img_test_dataset, batch_size=64, shuffle=False)

100%|██████████| 53306/53306 [00:01<00:00, 34968.47it/s]
100%|██████████| 5000/5000 [00:00<00:00, 35405.91it/s]
100%|██████████| 10000/10000 [00:00<00:00, 35093.82it/s]


In [4]:
# 2. Siamese Network Architecture
model = SiameseNetwork(model_name = 'resnet50', embedding_dim=128)
model = model.cuda()



In [5]:
model.eval()
anchor_pos_similarities = []
anchor_neg_similarities = []

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
with torch.no_grad():
    for batch in tqdm(img_test_loader, desc="Testing"):
        anchor, positive, negative = batch
        anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)

        output1, output2, output3 = model(anchor, positive, negative)
        
        # Calculate Similarities
        for i in range(output1.size(0)):  # Iterate over batch items
            anchor_pos_similarities.append(cosine_similarity(output1[i].unsqueeze(0).cpu(), output2[i].unsqueeze(0).cpu()).item())
            anchor_neg_similarities.append(cosine_similarity(output1[i].unsqueeze(0).cpu(), output3[i].unsqueeze(0).cpu()).item())

Testing: 100%|██████████| 680/680 [05:45<00:00,  1.97it/s]


In [8]:
avg_anchor_pos_sim = np.mean(anchor_pos_similarities)
avg_anchor_neg_sim = np.mean(anchor_neg_similarities)
print(f"Average Similarity between Anchor and Positive: {avg_anchor_pos_sim}")
print(f"Average Similarity between Anchor and Negative: {avg_anchor_neg_sim}")

Average Similarity between Anchor and Positive: 0.9995341735799049
Average Similarity between Anchor and Negative: 0.9986789179909767
