In [None]:
from torch_geometric.data import Data
from typing import List
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_geometric.nn as pyg_nn
import torch_geometric.utils as pyg_utils
import torch_geometric.transforms as T

import sklearn.metrics as metrics
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import glob
from PIL import Image




In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


In [None]:
IMAGE_DIR = "../data/images/"
SEGM_DIR = "../data/segm/"
FABRIC_ANN_PATH = "../data/labels/shape/texture"

In [24]:
!ls




data			images.zip		 model	    view
DeepFashion-MultiModal	images.zip4thj6ar8.part  README.md
FindYourStyle		labels.zip		 segm.zip


In [25]:
!rm -rf /content/FindYourStyle/data

In [None]:
fabric_ann = {}

with open(FABRIC_ANN_PATH, "r") as f:
    for line in f:
        parts = line.strip().split()
        img_name = parts[0]
        fabrics = list(map(int, parts[1:]))  # [upper, lower, outer]
        fabric_ann[img_name] = fabrics

print("Loaded fabric annotations:", len(fabric_ann))



Loaded fabric annotations: 44096


In [None]:
def safe_ce_loss(logits, targets, criterion, ignore_index=7):
    """
    Computes CrossEntropyLoss only if there is at least one valid target
    (i.e., not ignore_index). Returns None otherwise.
    """
    valid = (targets != ignore_index)
    if valid.sum().item() == 0:
        return None
    return criterion(logits[valid], targets[valid])


In [None]:
class DeepFashionMultiFabricDataset(Dataset):
    def __init__(self, img_dir, segm_dir, fabric_ann, transform=None):
        self.img_dir = img_dir
        self.segm_dir = segm_dir
        self.fabric_ann = fabric_ann
        self.transform = transform

        segm_bases = set(
            f.replace("_segm.png", "") for f in os.listdir(segm_dir)
            if f.endswith("_segm.png")
        )

        self.files = [
            img_name for img_name in fabric_ann.keys()
            if img_name.replace(".jpg", "") in segm_bases
        ]

        print(f"Dataset size: {len(self.files)}")

        # Define segmentation labels per garment
        self.UPPER_LABELS = {1, 4, 21}
        self.LOWER_LABELS = {5, 6}
        self.OUTER_LABELS = {2}

    def __len__(self):
        return len(self.files)

    def _make_mask(self, segm_np, labels):
        mask = np.isin(segm_np, list(labels)).astype(np.float32)
        mask = Image.fromarray(mask).resize((224, 224), Image.NEAREST)
        return torch.tensor(np.array(mask)).unsqueeze(0)

    def __getitem__(self, idx):
        img_name = self.files[idx]

        image = Image.open(os.path.join(self.img_dir, img_name)).convert("RGB")

        segm_name = img_name.replace(".jpg", "_segm.png")
        segm = Image.open(os.path.join(self.segm_dir, segm_name))
        segm_np = np.array(segm)

        upper_mask = self._make_mask(segm_np, self.UPPER_LABELS)
        lower_mask = self._make_mask(segm_np, self.LOWER_LABELS)
        outer_mask = self._make_mask(segm_np, self.OUTER_LABELS)

        upper_label = torch.tensor(self.fabric_ann[img_name][0], dtype=torch.long)
        lower_label = torch.tensor(self.fabric_ann[img_name][1], dtype=torch.long)
        outer_label = torch.tensor(self.fabric_ann[img_name][2], dtype=torch.long)

        if self.transform:
            image = self.transform(image)

        return (
            image,
            upper_mask, upper_label,
            lower_mask, lower_label,
            outer_mask, outer_label
        )



In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])


In [None]:
from torch.utils.data import random_split

dataset = DeepFashionMultiFabricDataset(
    img_dir=IMAGE_DIR,
    segm_dir=SEGM_DIR,
    fabric_ann=fabric_ann,
    transform=transform
)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = random_split(
    dataset, [train_size, val_size]
)

print("Train samples:", len(train_dataset))
print("Val samples:", len(val_dataset))



Dataset size: 12701
Train samples: 10160
Val samples: 2541


In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)


In [None]:
batch = next(iter(train_loader))

(
    images,
    upper_masks, upper_labels,
    lower_masks, lower_labels,
    outer_masks, outer_labels
) = batch

print(images.shape)        # (B, 3, 224, 224)
print(upper_masks.shape)   # (B, 1, 224, 224)
print(lower_masks.shape)
print(outer_masks.shape)

print("Upper labels:", upper_labels[:8])
print("Lower labels:", lower_labels[:8])
print("Outer labels:", outer_labels[:8])


In [None]:
class MultiFabricResNet(nn.Module):
    def __init__(self, num_classes=7):
        super().__init__()
        self.backbone = models.resnet50(
            weights=models.ResNet50_Weights.DEFAULT
        )
        num_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()

        self.upper_head = nn.Linear(num_features, num_classes)
        self.lower_head = nn.Linear(num_features, num_classes)
        self.outer_head = nn.Linear(num_features, num_classes)

    def forward(self, x):
        feats = self.backbone(x)
        return (
            self.upper_head(feats),
            self.lower_head(feats),
            self.outer_head(feats)
        )




In [None]:
model = MultiFabricResNet().to(device)

criterion = nn.CrossEntropyLoss(ignore_index=7)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

normalizer = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
)



In [None]:
num_epochs = 10
print("Starting multi-fabric training...")

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    num_batches = 0

    for i, batch in enumerate(train_loader):

        (
            images,
            upper_masks, upper_labels,
            lower_masks, lower_labels,
            outer_masks, outer_labels
        ) = batch

        images = images.to(device)
        upper_masks = upper_masks.to(device)
        lower_masks = lower_masks.to(device)
        outer_masks = outer_masks.to(device)

        upper_labels = upper_labels.to(device)
        lower_labels = lower_labels.to(device)
        outer_labels = outer_labels.to(device)

        # ---- APPLY MASKS ----
        upper_imgs = normalizer(images * upper_masks.repeat(1, 3, 1, 1))
        lower_imgs = normalizer(images * lower_masks.repeat(1, 3, 1, 1))
        outer_imgs = normalizer(images * outer_masks.repeat(1, 3, 1, 1))

        optimizer.zero_grad()

        # ---- FORWARD ----
        up_out, _, _ = model(upper_imgs)
        _, low_out, _ = model(lower_imgs)
        _, _, out_out = model(outer_imgs)

        # ---- SAFE LOSSES ----
        loss_upper = safe_ce_loss(up_out, upper_labels, criterion)
        loss_lower = safe_ce_loss(low_out, lower_labels, criterion)
        loss_outer = safe_ce_loss(out_out, outer_labels, criterion)

        loss = 0.0
        count = 0
        for l in [loss_upper, loss_lower, loss_outer]:
            if l is not None:
                loss += l
                count += 1

        # Skip batch if *all* heads were invalid (very rare, but safe)
        if count == 0:
            continue

        # Average over valid heads
        loss = loss / count

        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        num_batches += 1

        if i % 20 == 0:
            print(
                f"Epoch [{epoch+1}/{num_epochs}] "
                f"Step [{i}] "
                f"Loss: {loss.item():.4f}"
            )

    avg_loss = running_loss / max(num_batches, 1)
    print(f"Epoch {epoch+1} finished. Avg loss: {avg_loss:.4f}")



In [None]:
model.eval()

correct = {"upper": 0, "lower": 0, "outer": 0}
total   = {"upper": 0, "lower": 0, "outer": 0}

with torch.no_grad():
    for batch in val_loader:

        (
            images,
            upper_masks, upper_labels,
            lower_masks, lower_labels,
            outer_masks, outer_labels
        ) = batch

        images = images.to(device)
        upper_masks = upper_masks.to(device)
        lower_masks = lower_masks.to(device)
        outer_masks = outer_masks.to(device)

        upper_labels = upper_labels.to(device)
        lower_labels = lower_labels.to(device)
        outer_labels = outer_labels.to(device)

        # Apply masks
        upper_imgs = normalizer(images * upper_masks.repeat(1,3,1,1))
        lower_imgs = normalizer(images * lower_masks.repeat(1,3,1,1))
        outer_imgs = normalizer(images * outer_masks.repeat(1,3,1,1))

        up_out, _, _ = model(upper_imgs)
        _, low_out, _ = model(lower_imgs)
        _, _, out_out = model(outer_imgs)

        # ---- Upper ----
        valid = upper_labels != 7
        correct["upper"] += (up_out.argmax(1)[valid] == upper_labels[valid]).sum().item()
        total["upper"]   += valid.sum().item()

        # ---- Lower ----
        valid = lower_labels != 7
        correct["lower"] += (low_out.argmax(1)[valid] == lower_labels[valid]).sum().item()
        total["lower"]   += valid.sum().item()

        # ---- Outer ----
        valid = outer_labels != 7
        correct["outer"] += (out_out.argmax(1)[valid] == outer_labels[valid]).sum().item()
        total["outer"]   += valid.sum().item()


In [None]:
for k in correct:
    acc = correct[k] / max(total[k], 1)
    print(f"{k.capitalize()} garment fabric accuracy: {acc:.3f}")
