In [1]:
import os
file_paths = []
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        file_paths.append(os.path.join(dirname, filename))
print(f"Total files found: {len(file_paths)}")

Total files found: 261081


In [2]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
if torch.cuda.is_available():
    print(f"GPU is available: {torch.cuda.get_device_name(0)}")
else:
    print("GPU is not available.")

Using device: cuda
GPU is available: Tesla P100-PCIE-16GB


In [3]:
import os
import json
import pandas as pd
import random
from itertools import combinations, product
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

In [4]:
# ---------- CONFIG ----------
ROOT = Path('/kaggle/input/polyvore-outfit-dataset/polyvore_outfits')
TRAIN_FILE = ROOT / 'nondisjoint/train.json'
SPLIT_RATIO = 0.2
NEG_MULTIPLIER = 6
OUTPUT_DIR = Path('/kaggle/working')
test_json = ROOT / 'nondisjoint/test.json'

In [5]:
# ---------- LOAD TRAIN.JSON ----------
with open(TRAIN_FILE, 'r') as f:
    outfit_dicts = json.load(f)

outfit_dicts = outfit_dicts[:5000]

print(f"Total outfits: {len(outfit_dicts)}")
print("First outfit example:", outfit_dicts[0])

# ---------- Extract list of item_ids per outfit ----------
outfits = [
    [item["item_id"] for item in outfit["items"] if "item_id" in item]
    for outfit in outfit_dicts
    if len(outfit["items"]) >= 2  # filter to avoid empty/1-item outfits
]

# ---------- SPLIT 80-20 ----------
train_outfits, val_outfits = train_test_split(outfits, test_size=SPLIT_RATIO, random_state=42)

# ---------- HELPER: GET POSITIVE PAIRS ----------
def generate_positive_pairs(outfits):
    pos_pairs = []
    for item_ids in outfits:
        for i in range(len(item_ids)):
            for j in range(i + 1, len(item_ids)):
                pos_pairs.append((item_ids[i], item_ids[j], 1))
    return pos_pairs

# ---------- HELPER: GET NEGATIVE PAIRS ----------
def generate_negative_pairs(pos_pairs, outfits, n_needed):
    all_ids = list({item for outfit in outfits for item in outfit})
    pos_set = set((a, b) for a, b, _ in pos_pairs)

    neg_pairs = set()
    while len(neg_pairs) < n_needed:
        a, b = random.sample(all_ids, 2)
        if a != b and (a, b) not in pos_set and (b, a) not in pos_set:
            neg_pairs.add((a, b, 0))
    return list(neg_pairs)

# ---------- GENERATE PAIRS ----------
train_pos = generate_positive_pairs(train_outfits)
val_pos = generate_positive_pairs(val_outfits)

train_neg = generate_negative_pairs(train_pos, train_outfits, len(train_pos) * NEG_MULTIPLIER)
val_neg = generate_negative_pairs(val_pos, val_outfits, len(val_pos) * NEG_MULTIPLIER)

# ---------- FINAL COMBINED DATA ----------
train_data = train_pos + train_neg
val_data = val_pos + val_neg
random.shuffle(train_data)
random.shuffle(val_data)

# ---------- SAVE TO FILES ----------
with open(OUTPUT_DIR / 'train_pairs.json', 'w') as f:
    json.dump(train_data, f)

with open(OUTPUT_DIR / 'val_pairs.json', 'w') as f:
    json.dump(val_data, f)

# ---------- SUMMARY ----------
print(f"Train pairs saved to: {OUTPUT_DIR / 'train_pairs.json'}")
print(f"   Total: {len(train_data)} | Positives: {len(train_pos)} | Negatives: {len(train_neg)}")
print(f"Validation pairs saved to: {OUTPUT_DIR / 'val_pairs.json'}")
print(f"   Total: {len(val_data)} | Positives: {len(val_pos)} | Negatives: {len(val_neg)}")

Total outfits: 5000
First outfit example: {'items': [{'item_id': '154249722', 'index': 1}, {'item_id': '188425631', 'index': 2}, {'item_id': '183214727', 'index': 3}], 'set_id': '210750761'}
Train pairs saved to: /kaggle/working/train_pairs.json
   Total: 362145 | Positives: 51735 | Negatives: 310410
Validation pairs saved to: /kaggle/working/val_pairs.json
   Total: 89733 | Positives: 12819 | Negatives: 76914


In [6]:
class FashionPairDataset(Dataset):
    def __init__(self, pair_data, image_dir, transform=None):
        self.pairs = pair_data
        self.image_dir = image_dir
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        item1, item2, label = self.pairs[idx]
        img1 = Image.open(f"{self.image_dir}/{item1}.jpg").convert('RGB')
        img2 = Image.open(f"{self.image_dir}/{item2}.jpg").convert('RGB')
        return self.transform(img1), self.transform(img2), torch.tensor(label, dtype=torch.float32)

In [7]:
with open('/kaggle/working/train_pairs.json') as f:
    train_data = json.load(f)
with open('/kaggle/working/val_pairs.json') as f:
    val_data = json.load(f)

dataset_root = '/kaggle/input/polyvore-outfit-dataset/polyvore_outfits'
img_dir = os.path.join(dataset_root, 'images')

train_loader = DataLoader(FashionPairDataset(train_data, img_dir), batch_size=64, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(FashionPairDataset(val_data, img_dir), batch_size=64)

In [8]:
def get_truncated_vgg16():
    vgg = models.vgg16(pretrained=True)
    child_layers = list(vgg.features.children())
    truncated = nn.Sequential(*child_layers[:31])  # Up to conv4_1

    # Freeze first 10 conv layers → up to layer index 22 (conv4_3)
    conv_count = 0
    for layer in child_layers:
        if isinstance(layer, nn.Conv2d):
            conv_count += 1
        if conv_count <= 10:
            for param in layer.parameters():
                param.requires_grad = False

    return truncated

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
truncated_vgg = get_truncated_vgg16().to(device)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:02<00:00, 209MB/s] 


In [10]:
def extract_color_histogram(img_tensor, bins=8):
    """
    img_tensor: [3, H, W], range [0,1]
    Returns: histogram vector of shape [24]
    """
    hist = []
    for ch in img_tensor:
        h = torch.histc(ch.cpu(), bins=bins, min=0.0, max=1.0)
        h = h / (h.sum() + 1e-6)  # Normalize
        hist.append(h)
    return torch.cat(hist).to(img_tensor.device)  # return to original device


In [11]:
# -------- Siamese Merge Module --------
class SiameseMerge(nn.Module):
    def __init__(self):
        super().__init__()
        self.vgg = get_truncated_vgg16()

    def forward(self, img1, img2):
        """
        img1, img2: [B, 3, H, W], values in [0,1]
        """
        feat1 = self.vgg(img1)  # [B, 512, H', W']
        feat2 = self.vgg(img2)
        # print("VGG output shape:", feat1.shape)
        # print("VGG output shape:", feat2.shape)
        
        # Global average pool
        feat1 = F.adaptive_avg_pool2d(feat1, (1, 1)).squeeze(-1).squeeze(-1)  # [B, 512]
        feat2 = F.adaptive_avg_pool2d(feat2, (1, 1)).squeeze(-1).squeeze(-1)  # [B, 512]

        visual_hadamard = feat1 * feat2  # [B, 512]

        # Compute color histogram Hadamard
        batch_hist = []
        for i in range(img1.size(0)):
            h1 = extract_color_histogram(img1[i])
            h2 = extract_color_histogram(img2[i])
            batch_hist.append(h1 * h2)  # [24]
        hist_hadamard = torch.stack(batch_hist).to(img1.device)  # [B, 24]
        # print("Color hist shape:", h1.shape)
        # print("Color hist shape:", h2.shape)
        # Final feature
        merged_feature = torch.cat([visual_hadamard, hist_hadamard], dim=1)  # [B, 536]
        # print("Merged feature shape from siamese_merge:", merged_feature.shape)
        return merged_feature

In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = SiameseMerge().to(device)
dummy_img1 = torch.randn(2, 3, 224, 224).to(device)
dummy_img2 = torch.randn(2, 3, 224, 224).to(device)

with torch.no_grad():
    output = model(dummy_img1, dummy_img2)

print(">>> Final merged feature dim:", output.shape[1])

>>> Final merged feature dim: 536


In [13]:
class MetricNetwork(nn.Module):
    def __init__(self, input_dim=536, dropout_prob=0.5):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, 256)
        self.bn1 = nn.BatchNorm1d(256)
        self.dropout1 = nn.Dropout(dropout_prob)

        self.fc2 = nn.Linear(256, 64)
        self.bn2 = nn.BatchNorm1d(64)
        self.dropout2 = nn.Dropout(dropout_prob)

        self.fc3 = nn.Linear(64, 1)  # Output scalar compatibility score

    def forward(self, x):
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.dropout1(x)
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.dropout2(x)
        score = self.fc3(x)  # Raw score (logit)
        return score.squeeze(1)  # [B]

    def predict(self, x):
        """
        Inference method to get probabilities (after sigmoid)
        """
        with torch.no_grad():
            logits = self.forward(x)
            probs = torch.sigmoid(logits)
        return probs

In [14]:
def update_lambda(W, eps=1e-5):
    norm = torch.norm(W).clamp(min=eps)
    return torch.eye(W.shape[1], device=W.device) * norm.item() ** 2

In [15]:
class MAPLoss(nn.Module):
    def __init__(self, metric_model, vgg_model, lambda1=1e-4, lambda2=1e-5, lambda3=1e-5):
        super().__init__()
        self.metric_model = metric_model
        self.vgg_model = vgg_model
        self.lambda1 = lambda1
        self.lambda2 = lambda2
        self.lambda3 = lambda3

    def forward(self, pred_logits, targets, lambdas):
        pred_logits = pred_logits.clamp(min=-100, max=100)  # Prevents inf logits
        bce = F.binary_cross_entropy_with_logits(pred_logits, targets.float())

        # Matrix-variate Gaussian prior
        tr_terms = []
        for fc_layer, key in zip([self.metric_model.fc1, self.metric_model.fc2], ['fc1', 'fc2']):
            Wj = fc_layer.weight
            Lambda_inv = torch.linalg.pinv(lambdas[key].clamp(min=1e-6, max=1e6))  # Regularized inverse
            Wj_T = Wj.transpose(0, 1)
            mat = Wj @ Lambda_inv @ Wj_T
            tr = torch.trace(mat.clamp(min=1e-6, max=1e6))  # Prevent exploding trace
            tr_terms.append(tr)
        trace_reg = sum(tr_terms)

        l1_cnn = sum(torch.sum(torch.abs(p)) for n, p in self.vgg_model.named_parameters() if p.requires_grad)
        l1_w = torch.sum(torch.abs(self.metric_model.fc3.weight))

        total_loss = (
            bce +
            self.lambda1 * trace_reg +
            self.lambda2 * l1_cnn +
            self.lambda3 * l1_w
        )
        return total_loss


In [16]:
def train_model(siamese_merge, metric_model, loss_fn, optimizer,
                train_loader, val_loader, device, num_epochs=20, patience=3):

    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        siamese_merge.train()
        metric_model.train()
        total_train_loss = 0.0

        for img1, img2, labels in tqdm(train_loader, desc="Training"):
            img1, img2, labels = img1.to(device), img2.to(device), labels.to(device)
            optimizer.zero_grad()

            merged_feats = siamese_merge(img1, img2)
            logits = metric_model(merged_feats)

            # NaN check
            if torch.isnan(merged_feats).any() or torch.isnan(logits).any():
                print("NaN detected in features or logits. Skipping batch.")
                continue

            lambdas = {
                "fc1": update_lambda(metric_model.fc1.weight),
                "fc2": update_lambda(metric_model.fc2.weight)
            }

            loss = loss_fn(logits, labels, lambdas)

            if torch.isnan(loss):
                print("NaN loss encountered. Skipping batch.")
                continue

            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_loader)

        # ---- Validation ----
        siamese_merge.eval()
        metric_model.eval()
        total_val_loss = 0.0

        with torch.no_grad():
            for img1, img2, labels in tqdm(val_loader, desc="Validating"):
                img1, img2, labels = img1.to(device), img2.to(device), labels.to(device)
                merged_feats = siamese_merge(img1, img2)
                logits = metric_model(merged_feats)

                lambdas = {
                    "fc1": update_lambda(metric_model.fc1.weight),
                    "fc2": update_lambda(metric_model.fc2.weight)
                }

                val_loss = loss_fn(logits, labels, lambdas)
                if not torch.isnan(val_loss):
                    total_val_loss += val_loss.item()

        avg_val_loss = total_val_loss / len(val_loader)
        print(f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
        
        # --- Early Stopping ---
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            print("Best model updated.")
            torch.save({
                'siamese_merge_state_dict': siamese_merge.state_dict(),
                'metric_model_state_dict': metric_model.state_dict()
            }, "best_model.pth")
        else:
            patience_counter += 1
            print(f"Patience: {patience_counter}/{patience}")
            if patience_counter >= patience:
                print("Early stopping triggered.")
                break


In [None]:
from tqdm import tqdm
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

siamese_merge = SiameseMerge().to(device)
metric_model = MetricNetwork().to(device)

loss_fn = MAPLoss(metric_model, siamese_merge.vgg).to(device)

optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, 
           list(siamese_merge.parameters()) + list(metric_model.parameters())),
    lr=1e-4, betas=(0.9, 0.999)
)

train_model(
    siamese_merge=siamese_merge,
    metric_model=metric_model,
    loss_fn=loss_fn,
    optimizer=optimizer,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    num_epochs=20,
    patience=3
)


Epoch 1/20


Training: 100%|██████████| 5659/5659 [54:55<00:00,  1.72it/s]
Validating: 100%|██████████| 1403/1403 [21:19<00:00,  1.10it/s]


Train Loss: 0.6479 | Val Loss: 0.4378
Best model updated.

Epoch 2/20


Training: 100%|██████████| 5659/5659 [54:55<00:00,  1.72it/s]
Validating: 100%|██████████| 1403/1403 [21:17<00:00,  1.10it/s]


Train Loss: 0.4296 | Val Loss: 0.4159
Best model updated.

Epoch 3/20


Training:  39%|███▉      | 2212/5659 [21:23<33:55,  1.69it/s]