In [82]:
# Step 1: Filter valid images with corresponding masks and shape labels
import os

# Paths to image, mask, and shape label files
base_path = "../DeepFashionData"
img_dir = os.path.join(base_path, "images")
segm_dir = os.path.join(base_path, "segm")
shape_path = os.path.join(base_path, "labels", "shape", "shape_anno_all.txt")

# Load shape labels
def load_shape_labels(shape_path):
    labels = {}
    with open(shape_path, "r") as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) >= 2:
                img_name = parts[0].split("/")[-1]
                labels[img_name] = int(parts[1])
    return labels

shape_labels = load_shape_labels(shape_path)

# Get valid images (those with masks and shape labels)
valid_images = []
for img_name in os.listdir(img_dir):
    if img_name.endswith(".jpg"):  # Check for jpg files
        mask_name = img_name.replace(".jpg", "_segm.png")
        if os.path.exists(os.path.join(segm_dir, mask_name)) and img_name in shape_labels:
            valid_images.append(img_name)

# Check first 5 valid images
valid_images[:5]

['MEN-Denim-id_00000080-01_7_additional.jpg',
 'MEN-Denim-id_00000089-01_7_additional.jpg',
 'MEN-Denim-id_00000089-02_7_additional.jpg',
 'MEN-Denim-id_00000089-03_7_additional.jpg',
 'MEN-Denim-id_00000089-04_7_additional.jpg']

In [84]:
# Step 2: Define Dataset Class
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image

class ClothingSegmDataset(Dataset):
    def __init__(self, img_dir, segm_dir, shape_labels, valid_images, transform=None):
        self.img_dir = img_dir
        self.segm_dir = segm_dir
        self.shape_labels = shape_labels
        self.valid_images = valid_images
        self.transform = transform

    def __len__(self):
        return len(self.valid_images)

    def __getitem__(self, idx):
        img_name = self.valid_images[idx]
        img_path = os.path.join(self.img_dir, img_name)
        mask_path = os.path.join(self.segm_dir, img_name.replace(".jpg", "_segm.png"))

        # Load image and mask
        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        # Apply transformations if any
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        # Get shape label
        shape_label = self.shape_labels.get(img_name, -1)

        return image, mask, torch.tensor(shape_label)

# Example transform
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Create dataset
dataset = ClothingSegmDataset(img_dir, segm_dir, shape_labels, valid_images, transform)


In [86]:
# Step 3: Define the Model Class
import torch.nn as nn
import torchvision.models as models
import segmentation_models_pytorch as smp

class ShapeClassificationModel(nn.Module):
    def __init__(self, num_classes):
        super(ShapeClassificationModel, self).__init__()
        self.segmentation = smp.Unet(
            encoder_name="resnet34",
            encoder_weights="imagenet",
            in_channels=3,
            classes=1  # Segmentation output (binary mask)
        )
        self.classification_base = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        self.classification_base.fc = nn.Identity()  # Removing the final classification layer

        # Add custom head for shape classification
        self.shape_head = nn.Linear(512, num_classes)

    def forward(self, x):
        seg_out = self.segmentation(x)  # Get segmentation output
        features = self.classification_base(x)  # Extract features using ResNet18
        shape_out = self.shape_head(features)  # Shape classification output
        return seg_out, shape_out


In [90]:
# Step 4: Training the Model
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
import torch.optim as optim
from torch.utils.data import DataLoader

# Set device (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Split the dataset: 90% train, 10% test
train_idx, test_idx = train_test_split(range(len(dataset)), test_size=0.1, random_state=42)
train_dataset = Subset(dataset, train_idx)
test_dataset = Subset(dataset, test_idx)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8)

# Initialize model and training tools
model = ShapeClassificationModel(num_classes_list=[6, 5, 5, 3, 2, 2, 3, 2, 3, 6, 2, 2]).to(device)
criterion_seg = nn.BCEWithLogitsLoss()
criterion_cls = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training loop
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    total_cls_loss = 0
    total_seg_loss = 0

    print(f"\nEpoch {epoch+1}/{num_epochs}")
    for batch_idx, (images, masks, labels) in enumerate(train_loader):
        images = images.to(device)
        masks = masks.to(device).float()
        labels = labels.to(device)

        optimizer.zero_grad()
        seg_out, shape_out = model(images)

        loss_seg = criterion_seg(seg_out, masks)
        loss_cls = criterion_cls(shape_out, labels)
        loss = loss_seg + loss_cls

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_cls_loss += loss_cls.item()
        total_seg_loss += loss_seg.item()

        # Print every 10 batches
        if (batch_idx + 1) % 10 == 0 or (batch_idx + 1) == len(train_loader):
            print(f"  Batch {batch_idx+1}/{len(train_loader)} - Total Loss: {loss.item():.4f}, "
                  f"Cls Loss: {loss_cls.item():.4f}, Seg Loss: {loss_seg.item():.4f}")

    avg_total = total_loss / len(train_loader)
    avg_cls = total_cls_loss / len(train_loader)
    avg_seg = total_seg_loss / len(train_loader)
    print(f"Epoch {epoch+1} Summary - Avg Total Loss: {avg_total:.4f}, "
          f"Avg Cls Loss: {avg_cls:.4f}, Avg Seg Loss: {avg_seg:.4f}")



Epoch 1/5
  Batch 10/1428 - Total Loss: 1.7612, Cls Loss: 1.2492, Seg Loss: 0.5121
  Batch 20/1428 - Total Loss: 1.2787, Cls Loss: 0.8454, Seg Loss: 0.4333
  Batch 30/1428 - Total Loss: 0.6584, Cls Loss: 0.2950, Seg Loss: 0.3633
  Batch 40/1428 - Total Loss: 0.8535, Cls Loss: 0.5296, Seg Loss: 0.3239
  Batch 50/1428 - Total Loss: 0.9618, Cls Loss: 0.6377, Seg Loss: 0.3240
  Batch 60/1428 - Total Loss: 0.6712, Cls Loss: 0.3729, Seg Loss: 0.2982
  Batch 70/1428 - Total Loss: 0.9424, Cls Loss: 0.6668, Seg Loss: 0.2756
  Batch 80/1428 - Total Loss: 1.0082, Cls Loss: 0.7366, Seg Loss: 0.2716
  Batch 90/1428 - Total Loss: 0.8628, Cls Loss: 0.6160, Seg Loss: 0.2468
  Batch 100/1428 - Total Loss: 0.7117, Cls Loss: 0.4785, Seg Loss: 0.2332
  Batch 110/1428 - Total Loss: 0.5179, Cls Loss: 0.2741, Seg Loss: 0.2437
  Batch 120/1428 - Total Loss: 1.0751, Cls Loss: 0.8505, Seg Loss: 0.2246
  Batch 130/1428 - Total Loss: 0.5253, Cls Loss: 0.3073, Seg Loss: 0.2181
  Batch 140/1428 - Total Loss: 0.480

In [92]:
# Save the trained model
torch.save(model.state_dict(), "combine_model.pth")
print("Model saved as combine_model.pth")


Model saved as combine_model.pth


In [98]:
import matplotlib.pyplot as plt
import torch
from torchvision.transforms.functional import to_pil_image
import random

# Load the model
model = ShapeClassificationModel(num_classes_list=[6, 5, 5, 3, 2, 2, 3, 2, 3, 6, 2, 2]).to(device)
model.load_state_dict(torch.load("combine_model.pth"))
model.eval()

# Define label name mappings for each of the 12 shape attributes
label_names_all = [
    ["sleeveless", "short", "elbow", "3/4", "long", "NA"],
    ["mini", "knee", "mid", "ankle", "NA"],
    ["no socks", "ankle", "knee", "NA", "NA"],  # adjust if fewer classes
    ["no hat", "hat", "NA"],
    ["no glasses", "glasses"],
    ["no neckwear", "neckwear"],
    ["none", "left", "right"],
    ["no ring", "ring"],
    ["none", "belt", "strap"],
    ["round", "v", "collar", "square", "halter", "NA"],
    ["no outerwear", "outerwear"],
    ["navel covered", "navel exposed"]
]

# Select a random test sample
idx = random.randint(0, len(test_dataset) - 1)
image, mask, labels = test_dataset[idx]

# Prepare input
input_image = image.unsqueeze(0).to(device)

with torch.no_grad():
    pred_mask, pred_labels = model(input_image)
    pred_mask = torch.sigmoid(pred_mask).squeeze().cpu()
    predicted_labels = [out.argmax(dim=1).item() for out in pred_labels]

# Visualization
image_vis = to_pil_image(image.cpu())
true_mask_vis = to_pil_image(mask)
pred_mask_vis = to_pil_image((pred_mask > 0.5).float())

# Print predicted shape labels
print("Predicted Shape Attributes:")
for i, pred in enumerate(predicted_labels):
    print(f"{i+1}. {label_names_all[i][pred]}")

# Plot image and masks
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.title("Input Image")
plt.imshow(image_vis)
plt.axis("off")

plt.subplot(1, 3, 2)
plt.title("Ground Truth Mask")
plt.imshow(true_mask_vis, cmap="gray")
plt.axis("off")

plt.subplot(1, 3, 3)
plt.title("Predicted Mask")
plt.imshow(pred_mask_vis, cmap="gray")
plt.axis("off")

plt.tight_layout()
plt.show()


IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)