In [None]:
# STEP 0 — Setup: clone STONE and install dependencies (run in Colab)
# If you already installed, you can skip cloning.
!git clone https://github.com/ChenglongMa/SkinToneClassifier.git --depth 1
%cd SkinToneClassifier
!pip install .[all] torch torchvision opencv-python-headless tqdm scikit-learn matplotlib colorutils --quiet

In [None]:
# Create workspace directories under /content/data
import os
base = '/content/data'
for d in ['raw_faces','skin_patches','dataset','model','feedback']:
    os.makedirs(os.path.join(base, d), exist_ok=True)
print('✅ Project structure created:', os.listdir(base))

Upload 20–100 portrait images into `/content/data/raw_faces/`. In Colab: use the left sidebar Files upload control or mount Google Drive and copy files there.

In [None]:
# STEP 1 — Data extraction using STONE
import os, json, cv2
from glob import glob
from tqdm import tqdm
# STONE's process function is available after installing the package
from stone.api import process

input_dir = '/content/data/raw_faces/'
output_dir = '/content/data/skin_patches/'
os.makedirs(output_dir, exist_ok=True)
dataset = []

for img_path in tqdm(sorted(glob(input_dir + '*'))):
    if not img_path.lower().endswith(('.jpg', '.jpeg', '.png')):
        continue
    try:
        result = process(img_path, image_type='color', palette='perla')
        faces = result.get('faces', [])
        for i, face in enumerate(faces):
            tone = face.get('tone_label') or face.get('tone') or 'unknown'
            hexcol = face.get('skin_tone')
            crop = face.get('report_image')  # numpy BGR image from STONE
            if crop is not None:
                name = os.path.splitext(os.path.basename(img_path))[0]
                save_path = os.path.join(output_dir, f'{name}_{i}_{tone}.jpg')
                # STONE's report_image is usually RGB; convert to BGR if needed for cv2.imwrite
                try:
                    cv2.imwrite(save_path, cv2.cvtColor(crop, cv2.COLOR_RGB2BGR))
                except Exception:
                    cv2.imwrite(save_path, crop)
                dataset.append({'img': save_path, 'tone': tone, 'hex': hexcol})
    except Exception as e:
        print('⚠️', img_path, e)

json.dump(dataset, open('/content/tone_dataset_raw.json','w'), indent=2)
print('✅ Extracted', len(dataset), 'skin patches to', output_dir)

## STEP 2 — Clean & re-label to Monk Skin Tone (MST) 1–10
STONE labels vary by palette — we map them to a 10-level MST scale. After auto-mapping, manually inspect and correct samples in `/content/data/dataset/` if needed.

In [None]:
# Map STONE tone labels to MST 1-10 (example mapping; adjust after inspection)
import json, os, shutil
tone_map = {
    'tone_A': 'tone_1', 'tone_B': 'tone_2', 'tone_C': 'tone_3', 'tone_D': 'tone_4',
    'tone_E': 'tone_5', 'tone_F': 'tone_6', 'tone_G': 'tone_7', 'tone_H': 'tone_8',
    'tone_I': 'tone_9', 'tone_J': 'tone_10'
}
data = json.load(open('/content/tone_dataset_raw.json')) if os.path.exists('/content/tone_dataset_raw.json') else []
for d in data:
    t = d.get('tone', '')
    mapped = tone_map.get(t, 'tone_5')
    dest = os.path.join('/content/data/dataset', mapped)
    os.makedirs(dest, exist_ok=True)
    try:
        shutil.copy(d['img'], dest)
    except Exception as e:
        print('copy error', d['img'], e)
print('✅ Dataset grouped by MST scale under /content/data/dataset')

## STEP 3 — Train EfficientNet-B0 classifier
This trains a classifier on `/content/data/dataset/`. For a real project, create a train/val split and use more epochs, balanced classes, and augmentation.

In [None]:
# Training loop (simplified)
import torch
from torchvision import datasets, transforms, models
from torch import nn, optim
from torch.utils.data import DataLoader, random_split

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3,[0.5]*3)
])
dataset = datasets.ImageFolder('/content/data/dataset', transform=transform)
if len(dataset) == 0:
    raise SystemExit('No training data found in /content/data/dataset — please run extraction and relabeling first')
# small train/val split
val_size = int(0.15 * len(dataset))
train_size = len(dataset) - val_size
train_ds, val_ds = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=2)
classes = dataset.classes
print('Classes:', classes)

model = models.efficientnet_b0(pretrained=True)
# EfficientNet-B0 classifier head location may differ between torchvision versions; adjust if needed
num_features = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_features, len(classes))
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

best_acc = 0.0
for epoch in range(10):
    model.train()
    running_loss = 0.0
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        out = model(imgs)
        loss = criterion(out, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * imgs.size(0)
    epoch_loss = running_loss / max(len(train_loader.dataset), 1)
    # validation
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            preds = model(imgs).argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += imgs.size(0)
    acc = correct / max(total, 1)
    print(f'Epoch {epoch+1}/10 loss={epoch_loss:.4f} val_acc={acc:.4f}')
    if acc > best_acc:
        best_acc = acc
        os.makedirs('/content/data/model', exist_ok=True)
        torch.save(model.state_dict(), '/content/data/model/skin_tone_cnn.pt')
        print('Saved best model')

## STEP 4 — Tone → Fashion Palette Mapping
Save a JSON mapping from tone label → recommended color palette. You can expand this into a more advanced rule engine later.

In [None]:
import json
fashion_palette = {
    'tone_1':['navy','silver','lavender'],
    'tone_2':['skyblue','gray','rose'],
    'tone_3':['tan','denim','burgundy'],
    'tone_4':['olive','coral','beige'],
    'tone_5':['chocolate','cream','gold'],
    'tone_6':['emerald','taupe','rust'],
    'tone_7':['sand','peach','khaki'],
    'tone_8':['maroon','bronze','ivory'],
    'tone_9':['darkolive','teal','white'],
    'tone_10':['black','gold','red']
}
json.dump(fashion_palette, open('/content/fashion_palette.json','w'), indent=2)
print('✅ Palette mapping saved to /content/fashion_palette.json')

In [None]:
# STEP: Predict example and fetch palette
from PIL import Image
import torch, json
model.eval()
transform_eval = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3,[0.5]*3)
])
classes = dataset.classes
def predict_tone(img_path):
    img = Image.open(img_path).convert('RGB')
    x = transform_eval(img).unsqueeze(0).to(device)
    with torch.no_grad():
        pred = model(x).argmax(1).item()
    return classes[pred]

from glob import glob
test_files = glob('/content/data/skin_patches/*.jpg')
if len(test_files) > 0:
    test_img = test_files[0]
    tone_pred = predict_tone(test_img)
    palette = json.load(open('/content/fashion_palette.json'))
    print('Predicted:', tone_pred)
    print('Recommended colors:', palette.get(tone_pred, ['black','white']))
else:
    print('No skin patches found in /content/data/skin_patches')

## STEP 5 — Feedback & Retraining
Collect user feedback (like/dislike) for predictions and store it in a feedback JSON; periodically merge correct feedback into training data and fine-tune.

In [None]:
import json, os
feedback_path = '/content/data/feedback/feedback.json'
def add_feedback(img, tone, correct=True):
    fb = []
    if os.path.exists(feedback_path):
        fb = json.load(open(feedback_path))
    fb.append({'img': img, 'tone': tone, 'correct': correct})
    json.dump(fb, open(feedback_path, 'w'), indent=2)
    print('✅ Feedback stored')

# Example usage (mark incorrect prediction)
# add_feedback(test_img, tone_pred, correct=False)

## Next-level improvements
- Lighting invariance: add gray-world or Shades-of-Gray normalization before cropping/training.
- Undertone detection: add a small LAB-based head (a/b channels) as a multitask label.
- Bias reduction: ensure balanced samples per MST label; use oversampling/augmentation.
- Serving: export model via TorchScript or ONNX and serve with FastAPI.

---
You can download this notebook as `.ipynb` and open it directly in Colab.