In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, ConcatDataset
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import warnings

warnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ResNet18 Model
def get_resnet18_model():
    model = models.resnet18(weights=None)
    model.fc = nn.Linear(model.fc.in_features, 2)
    return model.to(device)

# Training
def train_model(data_dirs, save_path):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])
    datasets_list = [datasets.ImageFolder(path, transform=transform) for path in data_dirs]
    train_dataset = ConcatDataset(datasets_list)
    loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

    model = get_resnet18_model()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    for epoch in range(10):
        model.train()
        total_loss = 0
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}: Loss = {total_loss / len(loader):.4f}")

    torch.save(model.state_dict(), save_path)
    return model

# === Evaluation with Softmax Thresholding ===
def evaluate_open_set(model, data_dir, thresholds=np.arange(0.5, 0.96, 0.05)):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])
    dataset = datasets.ImageFolder(data_dir, transform=transform)
    loader = DataLoader(dataset, batch_size=32, shuffle=False)

    model.eval()
    y_true, all_probs = [], []

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            out = model(x)
            probs = torch.softmax(out, dim=1)
            y_true.extend(y.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    y_true = np.array(y_true)
    all_probs = np.array(all_probs)
    y_scores = all_probs[:, 1]

    best_f1, best_threshold, best_metrics = -1, None, {}

    for threshold in thresholds:
        conf = np.max(all_probs, axis=1)
        pred = np.argmax(all_probs, axis=1)
        pred = np.where(conf >= threshold, pred, -1)

        known_mask = pred != -1
        y_eval = y_true[known_mask]
        p_eval = pred[known_mask]

        if len(y_eval) == 0 or len(p_eval) == 0:
            continue

        acc = accuracy_score(y_eval, p_eval)
        prec = precision_score(y_eval, p_eval, zero_division=0)
        rec = recall_score(y_eval, p_eval, zero_division=0)
        f1 = f1_score(y_eval, p_eval, zero_division=0)

        if f1 > best_f1:
            best_f1 = f1
            best_threshold = round(threshold, 2)
            best_metrics = {
                'Accuracy': acc,
                'Precision': prec,
                'Recall': rec,
                'F1 Score': f1,
                'AUROC': roc_auc_score(y_true, y_scores) if len(set(y_true)) > 1 else float('nan'),
                'Rejected Unknowns': len(y_true) - len(y_eval)
            }

    best_metrics['Best Threshold'] = best_threshold
    return best_metrics

# t-SNE Visualization
def plot_tsne(model, data_dir, save_path):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])
    dataset = datasets.ImageFolder(data_dir, transform=transform)
    loader = DataLoader(dataset, batch_size=32, shuffle=False)

    model.eval()
    features, labels = [], []

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            feat = model.avgpool(model.layer4(model.layer3(model.layer2(model.layer1(
                model.maxpool(model.relu(model.bn1(model.conv1(x))))
            )))))
            feat = torch.flatten(feat, 1)
            features.append(feat.cpu().numpy())
            labels.extend(y.numpy())

    if features:
        features = np.concatenate(features)
        tsne = TSNE(n_components=2, init='pca', learning_rate='auto', random_state=42).fit_transform(features)

        plt.figure(figsize=(8, 6))
        scatter = plt.scatter(tsne[:, 0], tsne[:, 1], c=labels, cmap='coolwarm', alpha=0.7)
        plt.legend(*scatter.legend_elements(), title="Classes")
        plt.title(f"t-SNE: {data_dir}")
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path)
        plt.close()

# Open Set Evaluation
def evaluate_one_race(model, race_folder):
    print(f"\nEvaluating on: {race_folder}")
    metrics = evaluate_open_set(model, race_folder)
    plot_tsne(model, race_folder, f"Plots/ResNet18/tsne_ResNet18_Multi_{os.path.basename(race_folder)}.png")

    flat_metrics = {
        'Race': os.path.basename(race_folder),
        'Accuracy': metrics['Accuracy'],
        'Precision': metrics['Precision'],
        'Recall': metrics['Recall'],
        'F1 Score': metrics['F1 Score'],
        'AUROC': metrics['AUROC'],
        'Best Threshold': metrics['Best Threshold'],
        'Rejected Unknowns': metrics['Rejected Unknowns']
    }

    csv_path = "Plots/ResNet18/ResNet18_Multi.csv"
    if os.path.exists(csv_path):
        df = pd.read_csv(csv_path)
        df = pd.concat([df, pd.DataFrame([flat_metrics])], ignore_index=True)
    else:
        df = pd.DataFrame([flat_metrics])

    df.to_csv(csv_path, index=False)

# Main
if __name__ == "__main__":
    all_races = ["White", "Black", "Indian", "East_Asian", "Southeast_Asian", "Latino_Hispanic"]
    os.makedirs("Trained_Models", exist_ok=True)
    os.makedirs("Plots", exist_ok=True)

    for left_out_race in all_races:
        print(f"\n Training excluding: {left_out_race}")
        train_dirs = [f"{race}_augmented" for race in all_races if race != left_out_race]
        model_path = f"Trained_Models/ResNet18/ResNet18_Multi_{left_out_race}.pth"
        model = train_model(train_dirs, model_path)

        test_dir = f"{left_out_race}_augmented"
        evaluate_one_race(model, test_dir)


 Training excluding: White
Epoch 1: Loss = 0.6453
Epoch 2: Loss = 0.4811
Epoch 3: Loss = 0.3466
Epoch 4: Loss = 0.2862
Epoch 5: Loss = 0.2181
Epoch 6: Loss = 0.1801
Epoch 7: Loss = 0.1631
Epoch 8: Loss = 0.1607
Epoch 9: Loss = 0.1041
Epoch 10: Loss = 0.1137

Evaluating on: White_augmented

 Training excluding: Black
Epoch 1: Loss = 0.5899
Epoch 2: Loss = 0.4421
Epoch 3: Loss = 0.3349
Epoch 4: Loss = 0.2480
Epoch 5: Loss = 0.2246
Epoch 6: Loss = 0.1762
Epoch 7: Loss = 0.1477
Epoch 8: Loss = 0.1450
Epoch 9: Loss = 0.1132
Epoch 10: Loss = 0.1006

Evaluating on: Black_augmented

 Training excluding: Indian
Epoch 1: Loss = 0.5726
Epoch 2: Loss = 0.4068
Epoch 3: Loss = 0.3032
Epoch 4: Loss = 0.2322
Epoch 5: Loss = 0.2058
Epoch 6: Loss = 0.1653
Epoch 7: Loss = 0.1424
Epoch 8: Loss = 0.1021
Epoch 9: Loss = 0.1114
Epoch 10: Loss = 0.0804

Evaluating on: Indian_augmented

 Training excluding: East_Asian
Epoch 1: Loss = 0.5731
Epoch 2: Loss = 0.4533
Epoch 3: Loss = 0.3345
Epoch 4: Loss = 0.2630
