# 🧠 Human Parsing with Semantic Segmentation

This notebook demonstrates training a semantic segmentation model to perform **human parsing** using the `mattmdjaga/human_parsing_dataset` (based on the ATR dataset). The goal is to segment fine-grained human parts (e.g., hat, hair, clothes, limbs) at the pixel level using a transformer-based architecture like **Segformer** from Hugging Face.

## 📦 Dataset
- Source: `mattmdjaga/human_parsing_dataset` (17,706 images with mask annotations)
- Classes: 18 including `"Hat"`, `"Hair"`, `"Sunglasses"`, `"Upper-clothes"`, `"Left-arm"`, `"Right-leg"`, `"Bag"`, etc.




### ***There are several types of segmentation: semantic segmentation, instance segmentation, and panoptic segmentation.*** <br> In this notebook I will cover semantic segmentation

In [None]:
# install the necessary libraries
!pip install -q datasets transformers evaluate accelerate

In [None]:
# setting huggingface env
from huggingface_hub import notebook_login

notebook_login()

# Semantic Segmentation
#### Choose ***human_parsing_dataset*** dataset

In [None]:

from datasets import load_dataset

ds = load_dataset("mattmdjaga/human_parsing_dataset", split="train[:1000]")


In [None]:
ds

In [None]:
ds = ds.train_test_split(test_size=0.2)
train_ds = ds["train"]
test_ds = ds["test"]

In [None]:
train_ds[0]

# train_ds[0]["image"]

create a dictionary that maps a label id to a label class which will be useful when you set up the model later.

In [None]:
id2label = {
    0: "Background",
    1: "Hat",
    2: "Hair",
    3: "Sunglasses",
    4: "Upper-clothes",
    5: "Skirt",
    6: "Pants",
    7: "Dress",
    8: "Belt",
    9: "Left-shoe",
    10: "Right-shoe",
    11: "Face",
    12: "Left-leg",
    13: "Right-leg",
    14: "Left-arm",
    15: "Right-arm",
    16: "Bag",
    17: "Scarf"
}

label2id = {v: k for k, v in id2label.items()}
num_labels = len(id2label)


In [None]:
# # If we have custom dataset, we can use this portion for the code
# from datasets import Dataset, DatasetDict, Image

# image_paths_train = ["path/to/image_1.jpg/jpg", "path/to/image_2.jpg/jpg", ..., "path/to/image_n.jpg/jpg"]
# label_paths_train = ["path/to/annotation_1.png", "path/to/annotation_2.png", ..., "path/to/annotation_n.png"]

# image_paths_validation = [...]
# label_paths_validation = [...]

# def create_dataset(image_paths, label_paths):
#     dataset = Dataset.from_dict({"image": sorted(image_paths),
#                                 "label": sorted(label_paths)})
#     dataset = dataset.cast_column("image", Image())
#     dataset = dataset.cast_column("label", Image())
#     return dataset

# # step 1: create Dataset objects
# train_dataset = create_dataset(image_paths_train, label_paths_train)
# validation_dataset = create_dataset(image_paths_validation, label_paths_validation)

# # step 2: create DatasetDict
# dataset = DatasetDict({
#      "train": train_dataset,
#      "validation": validation_dataset,
#      }
# )

# # step 3: push to Hub (assumes you have ran the huggingface-cli login command in a terminal/notebook)
# dataset.push_to_hub("your-name/dataset-repo")

# # optionally, you can push to a private repo on the Hub
# # dataset.push_to_hub("name of repo on the hub", private=True)

# import json
# # simple example
# id2label = {0: 'cat', 1: 'dog'}
# with open('id2label.json', 'w') as fp:
# json.dump(id2label, fp)

In [None]:
# setting image processor
from transformers import AutoImageProcessor
# do_reduce_labels=True to subtract one from all the labels.
model_name = "nvidia/mit-b0"
image_processor = AutoImageProcessor.from_pretrained(model_name, do_reduce_labels=True)

#### Data Augmentations with ColorJitter to make a model more robust against overfitting. 

In [None]:
from torchvision.transforms import ColorJitter

jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)

In [None]:
def train_transforms(example_batch):
    images = [jitter(x) for x in example_batch["image"]]
    labels = [x for x in example_batch["mask"]]
    inputs = image_processor(images, labels)
    return inputs


def val_transforms(example_batch):
    images = [x for x in example_batch["image"]]
    labels = [x for x in example_batch["mask"]]
    inputs = image_processor(images, labels)
    return inputs

In [None]:
# Applying transform
train_ds.set_transform(train_transforms)
test_ds.set_transform(val_transforms)

In [None]:
train_ds, test_ds

#### Set Evaluate metric

In [None]:
import evaluate

metric = evaluate.load("mean_iou")

predictions need to be converted to logits first, and then reshaped to match the size of the labels before calling compute:

In [None]:
import numpy as np
import torch
from torch import nn

def compute_metrics(eval_pred):
    with torch.no_grad():
        logits, labels = eval_pred
        logits_tensor = torch.from_numpy(logits)
        logits_tensor = nn.functional.interpolate(
            logits_tensor,
            size=labels.shape[-2:],
            mode="bilinear",
            align_corners=False,
        ).argmax(dim=1)

        pred_labels = logits_tensor.detach().cpu().numpy()
        metrics = metric.compute(
            predictions=pred_labels,
            references=labels,
            num_labels=num_labels,
            ignore_index=255,
            reduce_labels=False,
        )
        for key, value in metrics.items():
            if isinstance(value, np.ndarray):
                metrics[key] = value.tolist()
        return metrics

# Now Training

In [None]:
# Defining Model, i did only for SemanticSegmentation
from transformers import AutoModelForSemanticSegmentation, TrainingArguments, Trainer

model = AutoModelForSemanticSegmentation.from_pretrained(model_name, id2label=id2label, label2id=label2id)

In [None]:
# Setting Training Arguments
from transformers import TrainingArguments


training_args = TrainingArguments(
    output_dir="my-human_parsing-model",
    run_name="segment-human_parsing-1",
    learning_rate=6e-5,
    num_train_epochs=5,
    per_device_train_batch_size=5,
    per_device_eval_batch_size=5,
    save_total_limit=2,
    eval_strategy="steps",
    save_strategy="steps",
    save_steps=20,
    eval_steps=20,
    logging_steps=1,  # <- Log after every step
    logging_dir="./logs",  # <- Optional but useful
    remove_unused_columns=False,
    push_to_hub=False,  # <- Disable for now
    report_to="none",  # <- Disable Weights & Biases, TensorBoard etc.
    fp16=True  # <- Enable mixed precision if using GPU
)


trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    compute_metrics=compute_metrics,
)



In [None]:
# Start the training
trainer.train()

In [None]:
trainer.push_to_hub()

### Inference

In [None]:
from datasets import load_dataset

ds = load_dataset("mattmdjaga/human_parsing_dataset", split="train[:50]")
ds = ds.train_test_split(test_size=0.2)
test_ds = ds["test"]
image = ds["test"][0]["image"]
image

In [None]:
from accelerate.test_utils.testing import get_backend
device, _, _ = get_backend()
encoding = image_processor(image, return_tensors="pt")
pixel_values = encoding.pixel_values.to(device)

In [None]:
#Pass your input to the model and return the logits:
outputs = model(pixel_values=pixel_values)
logits = outputs.logits.cpu()

In [None]:
# rescale the logits to the original image size:
upsampled_logits = nn.functional.interpolate(
    logits,
    size=image.size[::-1],
    mode="bilinear",
    align_corners=False,
)

pred_seg = upsampled_logits.argmax(dim=1)[0]

To visualize the results, load the dataset color palette as ade_palette() that maps each class to their RGB values.

In [None]:
def ade_palette():
  return np.asarray([
      [0, 0, 0],
      [120, 120, 120],
      [180, 120, 120],
      [6, 230, 230],
      [80, 50, 50],
      [4, 200, 3],
      [120, 120, 80],
      [140, 140, 140],
      [204, 5, 255],
      [230, 230, 230],
      [4, 250, 7],
      [224, 5, 255],
      [235, 255, 7],
      [150, 5, 61],
      [120, 120, 70],
      [8, 255, 51],
      [255, 6, 82],
      [143, 255, 140],
      [204, 255, 4],
      [255, 51, 7],
      [204, 70, 3],
      [0, 102, 200],
      [61, 230, 250],
      [255, 6, 51],
      [11, 102, 255],
      [255, 7, 71],
      [255, 9, 224],
      [9, 7, 230],
      [220, 220, 220],
      [255, 9, 92],
      [112, 9, 255],
      [8, 255, 214],
      [7, 255, 224],
      [255, 184, 6],
      [10, 255, 71],
      [255, 41, 10],
      [7, 255, 255],
      [224, 255, 8],
      [102, 8, 255],
      [255, 61, 6],
      [255, 194, 7],
      [255, 122, 8],
      [0, 255, 20],
      [255, 8, 41],
      [255, 5, 153],
      [6, 51, 255],
      [235, 12, 255],
      [160, 150, 20],
      [0, 163, 255],
      [140, 140, 140],
      [250, 10, 15],
      [20, 255, 0],
      [31, 255, 0],
      [255, 31, 0],
      [255, 224, 0],
      [153, 255, 0],
      [0, 0, 255],
      [255, 71, 0],
      [0, 235, 255],
      [0, 173, 255],
      [31, 0, 255],
      [11, 200, 200],
      [255, 82, 0],
      [0, 255, 245],
      [0, 61, 255],
      [0, 255, 112],
      [0, 255, 133],
      [255, 0, 0],
      [255, 163, 0],
      [255, 102, 0],
      [194, 255, 0],
      [0, 143, 255],
      [51, 255, 0],
      [0, 82, 255],
      [0, 255, 41],
      [0, 255, 173],
      [10, 0, 255],
      [173, 255, 0],
      [0, 255, 153],
      [255, 92, 0],
      [255, 0, 255],
      [255, 0, 245],
      [255, 0, 102],
      [255, 173, 0],
      [255, 0, 20],
      [255, 184, 184],
      [0, 31, 255],
      [0, 255, 61],
      [0, 71, 255],
      [255, 0, 204],
      [0, 255, 194],
      [0, 255, 82],
      [0, 10, 255],
      [0, 112, 255],
      [51, 0, 255],
      [0, 194, 255],
      [0, 122, 255],
      [0, 255, 163],
      [255, 153, 0],
      [0, 255, 10],
      [255, 112, 0],
      [143, 255, 0],
      [82, 0, 255],
      [163, 255, 0],
      [255, 235, 0],
      [8, 184, 170],
      [133, 0, 255],
      [0, 255, 92],
      [184, 0, 255],
      [255, 0, 31],
      [0, 184, 255],
      [0, 214, 255],
      [255, 0, 112],
      [92, 255, 0],
      [0, 224, 255],
      [112, 224, 255],
      [70, 184, 160],
      [163, 0, 255],
      [153, 0, 255],
      [71, 255, 0],
      [255, 0, 163],
      [255, 204, 0],
      [255, 0, 143],
      [0, 255, 235],
      [133, 255, 0],
      [255, 0, 235],
      [245, 0, 255],
      [255, 0, 122],
      [255, 245, 0],
      [10, 190, 212],
      [214, 255, 0],
      [0, 204, 255],
      [20, 0, 255],
      [255, 255, 0],
      [0, 153, 255],
      [0, 41, 255],
      [0, 255, 204],
      [41, 0, 255],
      [41, 255, 0],
      [173, 0, 255],
      [0, 245, 255],
      [71, 0, 255],
      [122, 0, 255],
      [0, 255, 184],
      [0, 92, 255],
      [184, 255, 0],
      [0, 133, 255],
      [255, 214, 0],
      [25, 194, 194],
      [102, 255, 0],
      [92, 0, 255],
  ])

In [None]:
import matplotlib.pyplot as plt
import numpy as np

color_seg = np.zeros((pred_seg.shape[0], pred_seg.shape[1], 3), dtype=np.uint8)
palette = np.array(ade_palette())
for label, color in enumerate(palette):
    color_seg[pred_seg == label, :] = color
color_seg = color_seg[..., ::-1]  # convert to BGR

img = np.array(image) * 0.5 + color_seg * 0.5  # plot the image with the segmentation map
img = img.astype(np.uint8)

plt.figure(figsize=(15, 10))
plt.imshow(img)
plt.show()