In [67]:
import torch
from torchvision import transforms, models
from PIL import Image
import os
import numpy as np
import torch.nn as nn
from scipy.special import rel_entr
import torch
from diffusers import StableDiffusionPipeline
from PIL import Image

model_path = "models/res34_fair_align_multi_4_20190809.pt"


# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Classes
GENDER_CLASSES = ['Male', 'Female']
RACE_CLASSES = [
    'White',
    'Black',
    'Latino_Hispanic',
    'East Asian',
    'Southeast Asian',
    'Indian',
    'Middle Eastern'
]
# Image transform
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])



In [68]:
# ResNet34 for FairFace, fair face eval
def resnet34(num_classes=18, pretrained=True):
    model = models.resnet34(pretrained=pretrained)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

# Load model
def load_fairface_model(weight_path='res34_fair_align_multi_4_20190809.pt'):
    model = resnet34(num_classes=18)
    model.load_state_dict(torch.load(weight_path, map_location=device))
    model.eval()
    model.to(device)
    return model

# Predict for a single image
def predict_image(model, image_path):
    image = Image.open(image_path).convert('RGB')
    image_tensor = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        outputs = model(image_tensor)
        gender_pred = outputs[:, :2].argmax(dim=1).item()
        race_pred = outputs[:, 2:].argmax(dim=1).item()
    return GENDER_CLASSES[gender_pred], RACE_CLASSES[race_pred]

# Predict over a folder
def evaluate_folder(model, folder_path):
    gender_counts = {g: 0 for g in GENDER_CLASSES}
    race_counts = {r: 0 for r in RACE_CLASSES}
    for fname in os.listdir(folder_path):
        if fname.lower().endswith(('.jpg', '.png', '.jpeg')):
            gender, race = predict_image(model, os.path.join(folder_path, fname))
            gender_counts[gender] += 1
            race_counts[race] += 1
    return gender_counts, race_counts

# KL Divergence
def kl_divergence(pred_dist, ref_dist):
    p = np.array(pred_dist) / sum(pred_dist)
    q = np.array(ref_dist) / sum(ref_dist)
    return sum(rel_entr(p, q))

# Unified classification for a single image
import torch.nn.functional as F

def classify_demographics(image_path, model, device):
    image = Image.open(image_path).convert('RGB')
    input_tensor = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(input_tensor)[0]

    gender_logits = output[0:2]
    race_logits = output[11:18]

    gender_probs = F.softmax(gender_logits, dim=0)
    race_probs = F.softmax(race_logits, dim=0)

    gender_idx = torch.argmax(gender_probs).item()
    race_idx = torch.argmax(race_probs).item()

    gender = GENDER_CLASSES[gender_idx]
    race = RACE_CLASSES[race_idx]

    return gender, race, gender_probs.tolist(), race_probs.tolist()

In [69]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to(device)
model = load_fairface_model("models/res34_fair_align_multi_4_20190809.pt")

Loading pipeline components...: 100%|██████████| 7/7 [00:01<00:00,  5.75it/s]


In [70]:
prompt = "a portrait of a smiling woman"
#image = pipe(prompt, num_inference_steps=25).images[0]
#image.save("gen_image.png")

