In [1]:
import os
import gc
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import warnings

warnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# EfficientNetV2 Model
def get_efficientnetv2_model():
    model = models.efficientnet_v2_s(weights=None)
    model.classifier[1] = nn.Linear(model.classifier[1].in_features, 2)
    return model

# Training
def train_model(data_dir, save_path):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])
    dataset = datasets.ImageFolder(data_dir, transform=transform)
    loader = DataLoader(dataset, batch_size=16, shuffle=True)

    model = get_efficientnetv2_model()
    try:
        model = model.to(device)
        _ = model(torch.randn(1, 3, 224, 224).to(device))
    except RuntimeError as e:
        if "out of memory" in str(e).lower():
            torch.cuda.empty_cache()
            gc.collect()
            model = model.to("cpu")
        else:
            raise e

    actual_device = next(model.parameters()).device
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    for epoch in range(10):
        model.train()
        total_loss = 0
        for x, y in loader:
            x, y = x.to(actual_device), y.to(actual_device)
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}: Loss = {total_loss / len(loader):.4f}")

    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    torch.save(model.state_dict(), save_path)
    return model

# Open-Set Evaluation
def evaluate_open_set(model, data_dir, threshold=0.8):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])
    dataset = datasets.ImageFolder(data_dir, transform=transform)
    loader = DataLoader(dataset, batch_size=32, shuffle=False)

    actual_device = next(model.parameters()).device
    model.eval()
    y_true, y_pred, y_score = [], [], []

    with torch.no_grad():
        for x, y in loader:
            x = x.to(actual_device)
            out = model(x)
            probs = torch.softmax(out, dim=1)
            conf, pred = torch.max(probs, dim=1)
            y_true.extend(y.cpu().numpy())
            y_score.extend(probs[:, 1].cpu().numpy())
            pred = torch.where(conf >= threshold, pred, torch.tensor(-1, device=pred.device))
            y_pred.extend(pred.cpu().numpy())

    known_mask = [p != -1 for p in y_pred]
    y_eval = [y for y, k in zip(y_true, known_mask) if k]
    p_eval = [p for p, k in zip(y_pred, known_mask) if k]

    return {
        'Accuracy': accuracy_score(y_eval, p_eval) if y_eval else float('nan'),
        'Precision': precision_score(y_eval, p_eval, zero_division=0) if y_eval else 0.0,
        'Recall': recall_score(y_eval, p_eval, zero_division=0) if y_eval else 0.0,
        'F1 Score': f1_score(y_eval, p_eval, zero_division=0) if y_eval else 0.0,
        'AUROC': roc_auc_score(y_true, y_score) if len(set(y_true)) > 1 else float('nan'),
        'Rejected Unknowns': len(y_true) - len(y_eval)
    }

# t-SNE Visualization 
def plot_tsne(model, data_dir, save_path):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])
    dataset = datasets.ImageFolder(data_dir, transform=transform)
    loader = DataLoader(dataset, batch_size=32, shuffle=False)

    actual_device = next(model.parameters()).device
    model.eval()
    features, labels = [], []

    with torch.no_grad():
        for x, y in loader:
            x = x.to(actual_device)
            feat = model.features(x)
            feat = model.avgpool(feat)
            feat = torch.flatten(feat, 1)
            features.append(feat.cpu().numpy())
            labels.extend(y.numpy())

    features = np.concatenate(features)
    tsne = TSNE(n_components=2, init='pca', learning_rate='auto', random_state=42).fit_transform(features)

    plt.figure(figsize=(8, 6))
    scatter = plt.scatter(tsne[:, 0], tsne[:, 1], c=labels, cmap='coolwarm', alpha=0.7)
    plt.legend(*scatter.legend_elements(), title="Classes")
    plt.title(f"t-SNE: {data_dir}")
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path)
    plt.close()

# Evaluation
def evaluate_one_race(model, race_folder):
    print(f"\nEvaluating on: {race_folder}")
    metrics = evaluate_open_set(model, race_folder, threshold=0.8)
    plot_tsne(model, race_folder, save_path=f"Plots/EfficientNetV2/tsne_Single_{os.path.basename(race_folder)}.png")

    df = pd.DataFrame([metrics])
    df["Race"] = os.path.basename(race_folder)
    csv_path = "Plots/EfficientNetV2/EfficientNetV2_Single.csv"

    if os.path.exists(csv_path):
        df_existing = pd.read_csv(csv_path)
        df = pd.concat([df_existing, df], ignore_index=True)

    os.makedirs(os.path.dirname(csv_path), exist_ok=True)
    df.to_csv(csv_path, index=False)

# Main
if __name__ == "__main__":
    train_dir = "White_augmented"
    model_path = "Trained_Models/EfficientNetV2/EfficientNetV2_Single_White.pth"
    model = train_model(train_dir, model_path)

    target_races = [
        "Black_augmented",
        "Indian_augmented",
        "East_Asian_augmented",
        "Southeast_Asian_augmented",
        "Latino_Hispanic_augmented"
    ]

    for race in target_races:
        evaluate_one_race(model, race)

Epoch 1: Loss = 0.6812
Epoch 2: Loss = 0.6158
Epoch 3: Loss = 0.5787
Epoch 4: Loss = 0.5528
Epoch 5: Loss = 0.5245
Epoch 6: Loss = 0.4899
Epoch 7: Loss = 0.4540
Epoch 8: Loss = 0.4337
Epoch 9: Loss = 0.3991
Epoch 10: Loss = 0.3764

Evaluating on: Black_augmented

Evaluating on: Indian_augmented

Evaluating on: East_Asian_augmented

Evaluating on: Southeast_Asian_augmented

Evaluating on: Latino_Hispanic_augmented
