In [18]:
import torch
import torch.nn.functional as F
import numpy as np
import cv2
from PIL import Image
from torchvision import transforms
import os

# Load pretrained BiSeNet and load the model weights separately)
class BiSeNet(torch.nn.Module):
    def __init__(self, n_classes=19):
        super(BiSeNet, self).__init__()
        from model import BiSeNet as BiSeNetModel  # from official face-parsing repo
        self.model = BiSeNetModel(n_classes=n_classes)
    
    def load_weights(self, path):
        self.model.load_state_dict(torch.load("/Users/sneha/Deep_Learn/project/project/final-project/79999_iter.pth", weights_only=False, map_location='cpu'))
        self.model.eval()
    
    def forward(self, x):
        with torch.no_grad():
            out = self.model(x)[0]
            parsing = out.squeeze(0).argmax(0).cpu().numpy()
        return parsing

def preprocess(image_path):
    to_tensor = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet mean/std
                             std=[0.229, 0.224, 0.225])
    ])
    image = Image.open(image_path).convert("RGB")
    img_resized = to_tensor(image).unsqueeze(0)
    original = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
    return img_resized, original

def extract_skin(image_rgb, parsing_mask):
    # 1:for parsing skin depending on the BiSeNet label map
    skin_mask = (parsing_mask == 1).astype(np.uint8)
    skin = cv2.bitwise_and(image_rgb, image_rgb, mask=skin_mask)
    return skin, skin_mask

def mean_lab_skin(image_rgb, skin_mask):
    image_lab = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2LAB)
    skin_pixels = image_lab[skin_mask > 0]
    return np.mean(skin_pixels, axis=0)

def compute_stsi(image1_lab, image2_lab):
    return np.linalg.norm(image1_lab - image2_lab)


Code for processing images(extracting skin) and computing mean LAB & STSI

In [17]:
# Load model and weights
net = BiSeNet(n_classes=19)
net.load_weights("/Users/sneha/Deep_Learn/project/project/final-project/79999_iter.pth") 
net.eval()

# Process images
img1_tensor, img1_original = preprocess("/Users/sneha/Deep_Learn/project/project/final-project/UTKFace/UTKFace/1_0_0_1.jpg")
img2_tensor, img2_original = preprocess("/Users/sneha/Deep_Learn/project/project/final-project/UTKFace/UTKFace-result/1_0_0_1.jpg")

# Get parsing maps
parsing1 = net(img1_tensor)
parsing2 = net(img2_tensor)

def extract_skin(image_rgb, parsing_mask):
    """
    image_rgb: OpenCV RGB image (H, W, 3)
    parsing_mask: Face parsing mask output from BiSeNet (H, W)
    """
    # Keep only pixels with label 1 (skin)
    skin_mask = (parsing_mask == 1).astype(np.uint8)

    # Ensure mask size matches image size
    if skin_mask.shape != image_rgb.shape[:2]:
        skin_mask = cv2.resize(skin_mask, (image_rgb.shape[1], image_rgb.shape[0]), interpolation=cv2.INTER_NEAREST)

    # Make sure mask is uint8
    skin_mask = skin_mask.astype(np.uint8)

    # Apply mask to image
    skin = cv2.bitwise_and(image_rgb, image_rgb, mask=skin_mask)
    return skin, skin_mask



# Extract skin regions
skin1, mask1 = extract_skin(img1_original, parsing1)
skin2, mask2 = extract_skin(img2_original, parsing2)

# Compute mean LAB
mean1 = mean_lab_skin(cv2.cvtColor(skin1, cv2.COLOR_BGR2RGB), mask1)
mean2 = mean_lab_skin(cv2.cvtColor(skin2, cv2.COLOR_BGR2RGB), mask2)

# Compute STSI
stsi = compute_stsi(mean1, mean2)

print("Mean LAB (Image 1):", mean1)
print("Mean LAB (Image 2):", mean2)
print("STSI (Skin Tone Shift Index):", stsi)


Mean LAB (Image 1): [186.74291555 135.42496123 135.34491752]
Mean LAB (Image 2): [186.19525702 137.77051708 140.14348383]
STSI (Skin Tone Shift Index): 5.369152700648886


Code for computing STSI based on races(sample of 40)

In [31]:
import os
import cv2
import numpy as np
import torch

# -------- Helper Functions --------
def extract_skin(image_rgb, parsing_mask):
    skin_mask = (parsing_mask == 1).astype(np.uint8)
    if skin_mask.shape != image_rgb.shape[:2]:
        skin_mask = cv2.resize(skin_mask, (image_rgb.shape[1], image_rgb.shape[0]), interpolation=cv2.INTER_NEAREST)
    skin_mask = skin_mask.astype(np.uint8)
    skin = cv2.bitwise_and(image_rgb, image_rgb, mask=skin_mask)
    return skin, skin_mask

def mean_lab_skin(image_rgb, mask):
    lab = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2LAB)
    skin_pixels = lab[mask > 0]
    return np.mean(skin_pixels, axis=0) if skin_pixels.size > 0 else np.array([0, 0, 0])

def compute_stsi(mean1, mean2):
    return np.linalg.norm(mean1 - mean2)

def process_image_pair(original_path, recolored_path, net):
    img1_tensor, img1_original = preprocess(original_path)
    img2_tensor, img2_original = preprocess(recolored_path)

    parsing1 = net(img1_tensor)
    parsing2 = net(img2_tensor)

    skin1, mask1 = extract_skin(img1_original, parsing1)
    skin2, mask2 = extract_skin(img2_original, parsing2)

    mean1 = mean_lab_skin(cv2.cvtColor(skin1, cv2.COLOR_BGR2RGB), mask1)
    mean2 = mean_lab_skin(cv2.cvtColor(skin2, cv2.COLOR_BGR2RGB), mask2)

    return compute_stsi(mean1, mean2), mean1, mean2

# -------- Load Model --------
net = BiSeNet(n_classes=19)
net.load_weights("/Users/sneha/Deep_Learn/project/project/final-project/79999_iter.pth")
net.eval()

# -------- Directory Setup --------
original_dir = "/Users/sneha/Deep_Learn/project/project/final-project/white-face"
recolored_dir = "/Users/sneha/Deep_Learn/project/project/final-project/white-result"

filenames = [f for f in os.listdir(original_dir) if f.endswith(".jpg")][:40]

stsi_values = []

for fname in filenames:
    orig_path = os.path.join(original_dir, fname)
    recol_path = os.path.join(recolored_dir, fname)

    if not os.path.exists(recol_path):
        print(f"Skipping missing: {recol_path}")
        continue

    try:
        stsi, mean1, mean2 = process_image_pair(orig_path, recol_path, net)
        stsi_values.append(stsi)
        print(f"{fname} → STSI: {stsi:.3f}")
    except Exception as e:
        print(f"Error processing {fname}: {e}")

if stsi_values:
    mean_stsi = np.mean(stsi_values)
    print(f"\nMean STSI across {len(stsi_values)} images: {mean_stsi:.3f}")
else:
    print("⚠️ No valid image pairs processed.")


24_0_0_5.jpg → STSI: 10.111
24_0_0_7.jpg → STSI: 2.305
24_0_0_18.jpg → STSI: 3.539
24_0_0_19.jpg → STSI: 23.001
24_0_0_6.jpg → STSI: 8.528
24_0_0_2.jpg → STSI: 4.533
24_0_0_21.jpg → STSI: 10.670
24_0_0_20.jpg → STSI: 22.837
24_0_0_3.jpg → STSI: 7.590
24_0_0_1.jpg → STSI: 2.076
1_0_0_15.jpg → STSI: 5.839
1_0_0_14.jpg → STSI: 14.004
1_0_0_16.jpg → STSI: 5.391
1_0_0_17.jpg → STSI: 6.329
1_0_0_13.jpg → STSI: 11.461
1_0_0_9.jpg → STSI: 15.111
1_0_0_8.jpg → STSI: 17.629
1_0_0_12.jpg → STSI: 7.336
1_0_0_10.jpg → STSI: 5.944
1_0_0_11.jpg → STSI: 7.702
1_0_0_20.jpg → STSI: 7.108
1_0_0_6.jpg → STSI: 15.416
1_0_0_7.jpg → STSI: 23.634
1_0_0_5.jpg → STSI: 27.812
1_0_0_4.jpg → STSI: 11.154
1_0_0_1.jpg → STSI: 5.369
1_0_0_19.jpg → STSI: 7.739
1_0_0_3.jpg → STSI: 2.881
1_0_0_2.jpg → STSI: 1.325
1_0_0_18.jpg → STSI: 20.674
24_0_0_12.jpg → STSI: 8.182
24_0_0_13.jpg → STSI: 8.128
24_0_0_11.jpg → STSI: 4.772
24_0_0_10.jpg → STSI: 15.347
24_0_0_14.jpg → STSI: 0.000
24_0_0_15.jpg → STSI: 2.280
24_0_0_8.jpg 