# Set Up

In [2]:
from torch_geometric.data import Data
from typing import List
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_geometric.nn as pyg_nn
import torch_geometric.utils as pyg_utils
import torch_geometric.transforms as T

import sklearn.metrics as metrics
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import glob
from PIL import Image

In [3]:
device = None

# check if MPS (Apple Silicon GPU) is available
if torch.backends.mps.is_available():
    device = torch.device("mps")
    x = torch.ones(1, device=device)
# check if CUDA (NVIDIA GPU) is available
elif torch.cuda.is_available():
    device = torch.device("cuda")
    x = torch.ones(1, device=device)
else:
    device = torch.device("cpu")
    print ("MPS and CUDA device not found.")

# Load Data

In [4]:
IMAGE_DIR = "../data/images/"
SEGM_DIR = "../data/segm/"

In [5]:

def get_corresponding_segm_path(image_path):
    base = os.path.basename(image_path)
    name, ext = os.path.splitext(base)
    segm_name = f'{name}_segm.png'
    return os.path.join(SEGM_DIR, segm_name)

def load_image(image_path):
    return np.array(Image.open(image_path).convert('RGB'))

def load_segm(segm_path):
    return np.array(Image.open(segm_path))

skipped = 0
labels_to_exclude = {0, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 22, 23}  # background and unwanted labels
dataset = []
image_paths = glob.glob(os.path.join(IMAGE_DIR, '*'))

for image_path in image_paths:
    
    segm_path = get_corresponding_segm_path(image_path)
    if not os.path.exists(segm_path):
        # print(f'Segmentation file not found for {image_path}, skipping.')
        skipped += 1
        continue
    image = load_image(image_path)
    segm = load_segm(segm_path)
    
    for label in np.unique(segm):
        if label in labels_to_exclude:  # exclude the background and unwanted labels
            continue
        mask = np.where(segm == label, 1, 0).astype(np.uint8)
        dataset.append((image, mask, label))
        
print(f'Total samples in dataset: {len(dataset)}')
print(f'Total skipped images: {skipped}')

: 

In [31]:
print(dataset[0][0].shape, dataset[0][1].shape, dataset[0][2])
print(dataset[20][0].shape, dataset[20][1].shape, dataset[20][2])


(1101, 750, 3) (1101, 750) 1
(1101, 750, 3) (1101, 750) 5


In [6]:
# Fix "Load Data" ----

from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import os
import numpy as np
from PIL import Image
import torch

class DeepFashionLazyDataset(Dataset):
    def __init__(self, img_dir, segm_dir, transform=None):
        self.img_dir = img_dir
        self.segm_dir = segm_dir
        self.transform = transform
        # We only store filenames, not the actual images (saves RAM)
        self.image_files = [f for f in os.listdir(img_dir) if f.endswith('.jpg')]
        
        # Your friend's filter list
        self.ignore_labels = {0, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 22, 23}
        
        # Mapping remaining labels to 0-6 (Top, Outer, Skirt, Dress, Pants, Leggings, Rompers)
        # This makes training much more stable than using gaps like 1, 5, 21...
        self.label_map = {1:0, 2:1, 3:2, 4:3, 5:4, 6:5, 21:6}

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        
        # 1. Load Paths
        img_path = os.path.join(self.img_dir, img_name)
        segm_name = img_name.replace('.jpg', '_segm.png') # Assuming naming convention
        segm_path = os.path.join(self.segm_dir, segm_name)
        
        # 2. Open Images (On demand!)
        image = Image.open(img_path).convert('RGB')
        
        target_label = -1 # Default 'ignore'
        mask_binary = np.zeros((256, 256), dtype=np.float32) # Default empty

        if os.path.exists(segm_path):
            segm = Image.open(segm_path)
            segm_np = np.array(segm)
            
            # Find the biggest clothing item that isn't an accessory
            unique, counts = np.unique(segm_np, return_counts=True)
            best_label = -1
            max_pixels = 0
            
            for label, count in zip(unique, counts):
                if label not in self.ignore_labels and label in self.label_map:
                    if count > max_pixels:
                        max_pixels = count
                        best_label = label
            
            if best_label != -1:
                target_label = self.label_map[best_label]
                # Create the binary mask (1 for item, 0 for everything else)
                # We assume resize happens via transform, so we resize mask to match
                mask_img = Image.fromarray((segm_np == best_label).astype(np.uint8))
                mask_img = mask_img.resize((224, 224), resample=Image.NEAREST)
                mask_binary = np.array(mask_img).astype(np.float32)

        # 3. Transform Image
        if self.transform:
            image = self.transform(image)
            
        return image, torch.tensor(mask_binary), torch.tensor(target_label)

# Setup Data Loaders
# Note: Resize to 224x224 is mandatory for ResNet
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Initialize
dataset = DeepFashionLazyDataset(IMAGE_DIR, SEGM_DIR, transform=transform)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

print(f"Dataset ready. Found {len(dataset)} images.")
print(dataset[0][0].shape, dataset[0][1].shape, dataset[0][2])
print(dataset[20][0].shape, dataset[20][1].shape, dataset[20][2])

Dataset ready. Found 44096 images.
torch.Size([3, 224, 224]) torch.Size([224, 224]) tensor(1)
torch.Size([3, 224, 224]) torch.Size([256, 256]) tensor(-1)


# Model

Hyperparameters

Image Classification Model Class

# Train

# Test