In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, ConcatDataset
from PIL import Image
from tqdm import tqdm
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import warnings

warnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# SimpleCNN Model
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d(1),
            nn.Flatten()
        )
        self.classifier = nn.Linear(128, 2)

    def forward(self, x):
        features = self.feature_extractor(x)
        return self.classifier(features)

# Training 
def train_model(data_dirs, save_path):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])
    datasets_list = [datasets.ImageFolder(path, transform=transform) for path in data_dirs]
    train_dataset = ConcatDataset(datasets_list)
    loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

    model = SimpleCNN().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    for epoch in range(10):
        model.train()
        total_loss = 0
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}: Loss = {total_loss / len(loader):.4f}")

    torch.save(model.state_dict(), save_path)
    return model

# Evaluate with Best Threshold
def evaluate_open_set(model, data_dir, thresholds=np.arange(0.5, 0.96, 0.05)):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])
    dataset = datasets.ImageFolder(data_dir, transform=transform)
    loader = DataLoader(dataset, batch_size=32, shuffle=False)

    model.eval()
    y_true, all_probs = [], []

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            out = model(x)
            probs = torch.softmax(out, dim=1)
            y_true.extend(y.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    y_true = np.array(y_true)
    all_probs = np.array(all_probs)
    y_scores = all_probs[:, 1]

    best_f1 = -1
    best_threshold = 0.5
    best_metrics = {}

    for threshold in thresholds:
        conf = np.max(all_probs, axis=1)
        pred = np.argmax(all_probs, axis=1)
        pred = np.where(conf >= threshold, pred, -1)

        known_mask = pred != -1
        y_eval = [y for y, k in zip(y_true, known_mask) if k]
        p_eval = [p for p, k in zip(pred, known_mask) if k]

        if len(y_eval) == 0 or len(p_eval) == 0:
            continue

        acc = accuracy_score(y_eval, p_eval)
        precision = precision_score(y_eval, p_eval, zero_division=0)
        recall = recall_score(y_eval, p_eval, zero_division=0)
        f1 = f1_score(y_eval, p_eval, zero_division=0)

        if f1 > best_f1:
            best_f1 = f1
            best_threshold = round(threshold, 2)
            best_metrics = {
                'Accuracy': acc,
                'Precision': precision,
                'Recall': recall,
                'F1 Score': f1,
                'AUROC': roc_auc_score(y_true, y_scores) if len(set(y_true)) > 1 else float('nan'),
                'Rejected Unknowns': len(y_true) - len(y_eval)
            }

    best_metrics['Best Threshold'] = best_threshold
    return best_metrics

# t-SNE Visualization
def plot_tsne(model, data_dir, save_path):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])
    dataset = datasets.ImageFolder(data_dir, transform=transform)
    loader = DataLoader(dataset, batch_size=32, shuffle=False)

    model.eval()
    features, labels = [], []

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            feat = model.feature_extractor(x)
            features.append(feat.cpu().numpy())
            labels.extend(y.numpy())

    if features:
        features = np.concatenate(features)
        tsne = TSNE(n_components=2, init='pca', learning_rate='auto', random_state=42).fit_transform(features)
        plt.figure(figsize=(8, 6))
        scatter = plt.scatter(tsne[:, 0], tsne[:, 1], c=labels, cmap='coolwarm', alpha=0.7)
        plt.legend(*scatter.legend_elements(), title="Classes")
        plt.title(f"t-SNE: {data_dir}")
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path)
        plt.close()

# Main
if __name__ == "__main__":
    all_races = ["White", "Black", "Indian", "East_Asian", "Southeast_Asian", "Latino_Hispanic"]
    output_csv = "Plots/SimpleCNN_Multi.csv"
    os.makedirs("Plots", exist_ok=True)

    for left_out_race in all_races:
        print(f"\n Training excluding: {left_out_race}")
        train_dirs = [f"{race}_augmented" for race in all_races if race != left_out_race]
        model_path = f"Trained_Models/SimpleCNN_Multi_{left_out_race}.pth"
        model = train_model(train_dirs, model_path)

        test_dir = f"{left_out_race}_augmented"
        metrics = evaluate_open_set(model, test_dir)
        plot_tsne(model, test_dir, f"Plots/SimpleCNN/tsne_Multi_{left_out_race}.png")

        flat_metrics = {
            'Race': left_out_race,
            'Accuracy': metrics['Accuracy'],
            'Precision': metrics['Precision'],
            'Recall': metrics['Recall'],
            'F1 Score': metrics['F1 Score'],
            'AUROC': metrics['AUROC'],
            'Best Threshold': metrics['Best Threshold'],
            'Rejected Unknowns': metrics['Rejected Unknowns']
        }

        if os.path.exists(output_csv):
            df_existing = pd.read_csv(output_csv)
            df = pd.concat([df_existing, pd.DataFrame([flat_metrics])], ignore_index=True)
        else:
            df = pd.DataFrame([flat_metrics])
        df.to_csv(output_csv, index=False)


 Training excluding: White
Epoch 1: Loss = 0.6933
Epoch 2: Loss = 0.6910
Epoch 3: Loss = 0.6896
Epoch 4: Loss = 0.6882
Epoch 5: Loss = 0.6857
Epoch 6: Loss = 0.6854
Epoch 7: Loss = 0.6824
Epoch 8: Loss = 0.6792
Epoch 9: Loss = 0.6748
Epoch 10: Loss = 0.6772

 Training excluding: Black
Epoch 1: Loss = 0.6926
Epoch 2: Loss = 0.6894
Epoch 3: Loss = 0.6859
Epoch 4: Loss = 0.6843
Epoch 5: Loss = 0.6783
Epoch 6: Loss = 0.6745
Epoch 7: Loss = 0.6729
Epoch 8: Loss = 0.6631
Epoch 9: Loss = 0.6547
Epoch 10: Loss = 0.6497

 Training excluding: Indian
Epoch 1: Loss = 0.6920
Epoch 2: Loss = 0.6880
Epoch 3: Loss = 0.6833
Epoch 4: Loss = 0.6780
Epoch 5: Loss = 0.6711
Epoch 6: Loss = 0.6630
Epoch 7: Loss = 0.6535
Epoch 8: Loss = 0.6476
Epoch 9: Loss = 0.6405
Epoch 10: Loss = 0.6308

 Training excluding: East_Asian
Epoch 1: Loss = 0.6926
Epoch 2: Loss = 0.6905
Epoch 3: Loss = 0.6889
Epoch 4: Loss = 0.6858
Epoch 5: Loss = 0.6826
Epoch 6: Loss = 0.6812
Epoch 7: Loss = 0.6782
Epoch 8: Loss = 0.6741
Epoch