In [1]:
import os
import torch
import pandas as pd
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
import warnings

warnings.simplefilter("ignore", category=FutureWarning)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# SimpleCNN Model 
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d(1),
            nn.Flatten()
        )
        self.classifier = nn.Linear(128, 2)

    def forward(self, x):
        features = self.feature_extractor(x)
        return self.classifier(features)

# Training 
def train_model(data_dir, save_path):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])
    dataset = datasets.ImageFolder(data_dir, transform=transform)
    loader = DataLoader(dataset, batch_size=32, shuffle=True)

    model = SimpleCNN().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    for epoch in range(10):
        model.train()
        total_loss = 0
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}: Loss = {total_loss / len(loader):.4f}")

    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    torch.save(model.state_dict(), save_path)
    return model

# Evaluate with Best Threshold
def evaluate_model(model, data_dir, thresholds=np.arange(0.5, 0.96, 0.05)):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])
    dataset = datasets.ImageFolder(data_dir, transform=transform)
    loader = DataLoader(dataset, batch_size=32, shuffle=False)

    model.eval()
    y_true, all_probs = [], []

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            out = model(x)
            probs = torch.softmax(out, dim=1)
            y_true.extend(y.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    y_true = np.array(y_true)
    all_probs = np.array(all_probs)
    y_scores = all_probs[:, 1]

    best_f1 = -1
    best_threshold = 0.5
    best_metrics = {}

    for threshold in thresholds:
        conf = np.max(all_probs, axis=1)
        pred = np.argmax(all_probs, axis=1)
        pred = np.where(conf >= threshold, pred, -1)

        known_mask = pred != -1
        y_eval = [y for y, k in zip(y_true, known_mask) if k]
        p_eval = [p for p, k in zip(pred, known_mask) if k]

        if len(y_eval) == 0 or len(p_eval) == 0:
            f1 = 0.0
            continue

        acc = accuracy_score(y_eval, p_eval)
        precision = precision_score(y_eval, p_eval, zero_division=0)
        recall = recall_score(y_eval, p_eval, zero_division=0)
        f1 = f1_score(y_eval, p_eval, zero_division=0)

        if f1 > best_f1:
            best_f1 = f1
            best_threshold = round(threshold, 2)
            best_metrics = {
                'Accuracy': acc,
                'Precision': precision,
                'Recall': recall,
                'F1 Score': f1,
                'AUROC': roc_auc_score(y_true, y_scores) if len(set(y_true)) > 1 else float('nan'),
                'Rejected Unknowns': len(y_true) - len(y_eval)
            }

    best_metrics['Best Threshold'] = best_threshold
    return best_metrics

# t-SNE Visualization
def plot_tsne(model, data_dir, save_path):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])
    dataset = datasets.ImageFolder(data_dir, transform=transform)
    loader = DataLoader(dataset, batch_size=32, shuffle=False)

    model.eval()
    features, labels = [], []
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            feat = model.feature_extractor(x)
            features.append(feat.cpu().numpy())
            labels.extend(y.numpy())

    if features:
        features = np.concatenate(features)
        tsne = TSNE(n_components=2, init='pca', learning_rate='auto', random_state=42).fit_transform(features)
        plt.figure(figsize=(8, 6))
        scatter = plt.scatter(tsne[:, 0], tsne[:, 1], c=labels, cmap='coolwarm', alpha=0.7)
        plt.legend(*scatter.legend_elements(), title="Classes")
        plt.title(f"t-SNE: {data_dir}")
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path)
        plt.close()

# Main
if __name__ == "__main__":
    train_dir = "White_augmented"  # Using all augmented white fakes
    model_save_path = "Trained_Models/SimpleCNN_Single.pth"
    results_csv_path = "Plots/SimpleCNN_Single.csv"

    model = train_model(train_dir, model_save_path)

    target_races = ["Black_augmented", "Indian_augmented", "East_Asian_augmented", "Southeast_Asian_augmented", "Latino_Hispanic_augmented"]
    all_results = []

    for race_dir in target_races:
        metrics = evaluate_model(model, race_dir)
        metrics['Race'] = race_dir
        all_results.append(metrics)
        plot_tsne(model, race_dir, f"Plots/SimpleCNN/tsne_Single_{race_dir}.png")

    df = pd.DataFrame(all_results)
    os.makedirs(os.path.dirname(results_csv_path), exist_ok=True)
    df.to_csv(results_csv_path, index=False)

Epoch 1: Loss = 0.6937
Epoch 2: Loss = 0.6880
Epoch 3: Loss = 0.6837
Epoch 4: Loss = 0.6809
Epoch 5: Loss = 0.6772
Epoch 6: Loss = 0.6737
Epoch 7: Loss = 0.6739
Epoch 8: Loss = 0.6655
Epoch 9: Loss = 0.6589
Epoch 10: Loss = 0.6550
