# Set Up

In [2]:
from torch_geometric.data import Data
from typing import List
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_geometric.nn as pyg_nn
import torch_geometric.utils as pyg_utils
import torch_geometric.transforms as T

import sklearn.metrics as metrics
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import glob
from PIL import Image

In [3]:
device = None

# check if MPS (Apple Silicon GPU) is available
if torch.backends.mps.is_available():
    device = torch.device("mps")
    x = torch.ones(1, device=device)
# check if CUDA (NVIDIA GPU) is available
elif torch.cuda.is_available():
    device = torch.device("cuda")
    x = torch.ones(1, device=device)
else:
    device = torch.device("cpu")
    print ("MPS and CUDA device not found.")

# Load Data

In [4]:
IMAGE_DIR = "../data/images/"
SEGM_DIR = "../data/segm/"

In [5]:
# Method 1 of loading, look below for lazy loader

def get_corresponding_segm_path(image_path):
    base = os.path.basename(image_path)
    name, ext = os.path.splitext(base)
    segm_name = f'{name}_segm.png'
    return os.path.join(SEGM_DIR, segm_name)

def load_image(image_path):
    return np.array(Image.open(image_path).convert('RGB'))

def load_segm(segm_path):
    return np.array(Image.open(segm_path))

skipped = 0
labels_to_exclude = {0, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 22, 23}  # background and unwanted labels
dataset = []
image_paths = glob.glob(os.path.join(IMAGE_DIR, '*'))

for image_path in image_paths:
    
    segm_path = get_corresponding_segm_path(image_path)
    if not os.path.exists(segm_path):
        # print(f'Segmentation file not found for {image_path}, skipping.')
        skipped += 1
        continue
    image = load_image(image_path)
    segm = load_segm(segm_path)
    
    for label in np.unique(segm):
        if label in labels_to_exclude:  # exclude the background and unwanted labels
            continue
        mask = np.where(segm == label, 1, 0).astype(np.uint8)
        dataset.append((image, mask, label))
        
print(f'Total samples in dataset: {len(dataset)}')
print(f'Total skipped images: {skipped}')

: 

In [31]:
print(dataset[0][0].shape, dataset[0][1].shape, dataset[0][2])
print(dataset[20][0].shape, dataset[20][1].shape, dataset[20][2])


(1101, 750, 3) (1101, 750) 1
(1101, 750, 3) (1101, 750) 5


In [15]:
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import os
import numpy as np
from PIL import Image
import torch

class DeepFashionLazyDataset(Dataset):
    def __init__(self, img_dir, segm_dir, transform=None):
        self.img_dir = img_dir
        self.segm_dir = segm_dir
        self.transform = transform
        
        # 1. PRE-FILTERING (The Fix)
        # We scan all files once at the start. 
        # We only keep images that have a matching segmentation file.
        self.valid_files = []
        
        all_images = [f for f in os.listdir(img_dir) if f.endswith('.jpg')]
        print(f"Scanning {len(all_images)} files for validity...")
        
        for img_name in all_images:
            # Construct the expected segmentation filename
            # Adjust naming convention if necessary (e.g. image_001.jpg -> image_001_segm.png)
            base_name = os.path.splitext(img_name)[0]
            segm_name = f"{base_name}_segm.png"
            segm_path = os.path.join(segm_dir, segm_name)
            
            # ONLY add to list if the segmentation file actually exists
            if os.path.exists(segm_path):
                self.valid_files.append(img_name)
        
        print(f"kept {len(self.valid_files)} valid pairs. (Dropped {len(all_images) - len(self.valid_files)} missing/bad files)")

        # Your filter list
        self.ignore_labels = {0, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 22, 23}
        self.label_map = {1:0, 2:1, 3:2, 4:3, 5:4, 6:5, 21:6}

    def __len__(self):
        return len(self.valid_files)

    def __getitem__(self, idx):
        img_name = self.valid_files[idx]
        
        # Paths
        img_path = os.path.join(self.img_dir, img_name)
        base_name = os.path.splitext(img_name)[0]
        segm_path = os.path.join(self.segm_dir, f"{base_name}_segm.png")
        
        # 1. Load Image
        image = Image.open(img_path).convert('RGB')
        
        # 2. Load Mask
        # We know it exists because we checked in __init__
        segm = Image.open(segm_path)
        segm_np = np.array(segm)
        
        # Find dominant label
        unique, counts = np.unique(segm_np, return_counts=True)
        best_label = -1
        max_pixels = 0
        
        for label, count in zip(unique, counts):
            if label not in self.ignore_labels and label in self.label_map:
                if count > max_pixels:
                    max_pixels = count
                    best_label = label
        
        # 3. Handle "Empty" Clothing
        # Even if file exists, it might only contain "Hair" (13) and "Face" (14)
        if best_label == -1:
            target_label = -1
            # Return a blank black mask of the CORRECT size
            mask_binary = np.zeros((224, 224), dtype=np.float32)
        else:
            target_label = self.label_map[best_label]
            # Create binary mask for the specific clothing item
            mask_img = Image.fromarray((segm_np == best_label).astype(np.uint8))
            # RESIZE MASK to match the Model Input (224x224)
            mask_img = mask_img.resize((224, 224), resample=Image.NEAREST)
            mask_binary = np.array(mask_img).astype(np.float32)

        # 4. Transform Image
        if self.transform:
            image = self.transform(image)
            
        return image, torch.tensor(mask_binary), torch.tensor(target_label)

# Setup Data Loaders
# Note: Resize to 224x224 is mandatory for ResNet
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Initialize
dataset = DeepFashionLazyDataset(IMAGE_DIR, SEGM_DIR, transform=transform)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

print(dataset[0][0].shape, dataset[0][1].shape, dataset[0][2])
print(dataset[20][0].shape, dataset[20][1].shape, dataset[20][2])

print(f"Dataset ready. Found {len(dataset)} images.")

Scanning 44096 files for validity...
kept 12701 valid pairs. (Dropped 31395 missing/bad files)
torch.Size([3, 224, 224]) torch.Size([224, 224]) tensor(1)
torch.Size([3, 224, 224]) torch.Size([224, 224]) tensor(0)
Dataset ready. Found 12701 images.


# Model

Image Classification Model Class

In [16]:
import torchvision.models as models

class FashionResNet(nn.Module):
    def __init__(self, num_classes=7):
        super(FashionResNet, self).__init__()
        # Load Pre-trained ResNet50
        self.backbone = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        
        # Replace the last layer (fc) to match our 7 classes
        num_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Linear(num_features, num_classes)
        
    def forward(self, x):
        return self.backbone(x)

model = FashionResNet().to(device)
print("Model initialized.")

Model initialized.


# Train

Hyperparameters

In [11]:
learning_rate = 1e-4
num_epochs = 5
criterion = nn.CrossEntropyLoss(ignore_index=-1) # Ignore images with no valid clothes
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [17]:

# Normalization layer (applied AFTER masking)
normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

print("Starting Training...")

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    
    for i, (images, masks, labels) in enumerate(train_loader):
        images, masks, labels = images.to(device), masks.to(device), labels.to(device)
        
        # --- THE FUSION STEP ---
        # 1. Expand mask to match image channels (Batch, 1, H, W) -> (Batch, 3, H, W)
        masks = masks.unsqueeze(1).repeat(1, 3, 1, 1)
        
        # 2. Black out background
        masked_images = images * masks
        
        # 3. Normalize
        model_inputs = normalizer(masked_images)
        
        # --- TRAINING ---
        optimizer.zero_grad()
        outputs = model(model_inputs)
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        if i % 10 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i}], Loss: {loss.item():.4f}")

Starting Training...
Epoch [1/5], Step [0], Loss: 1.9576
Epoch [1/5], Step [10], Loss: 1.4313


KeyboardInterrupt: 

# Test