In [1]:
import os
import gc
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, ConcatDataset
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from sklearn.manifold import TSNE
import warnings
from tqdm import tqdm

warnings.simplefilter(action='ignore', category=FutureWarning)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ViT Model
def get_vit_model():
    model = models.vit_b_16(weights=None)
    model.heads = nn.Sequential(
        nn.Identity(),
        nn.Linear(768, 2)
    )
    return model.to(device)

def extract_features_vit(model, x):
    with torch.no_grad():
        feats = model(x)
    return feats

# Training
def train_model(data_dirs, save_path, max_epochs=10):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.1, contrast=0.1),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])
    datasets_list = [datasets.ImageFolder(path, transform=transform) for path in data_dirs]
    train_dataset = ConcatDataset(datasets_list)
    loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

    model = get_vit_model()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    for epoch in range(max_epochs):
        model.train()
        total_loss = 0
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}: Loss = {total_loss / len(loader):.4f}")

    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    torch.save(model.state_dict(), save_path)
    return model

# Open-set Evaluation
def evaluate_open_set(model, data_dir, thresholds=np.arange(0.5, 0.96, 0.05)):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])
    dataset = datasets.ImageFolder(data_dir, transform=transform)
    loader = DataLoader(dataset, batch_size=32, shuffle=False)

    model.eval()
    y_true, all_probs = [], []

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            out = model(x)
            probs = torch.softmax(out, dim=1)
            y_true.extend(y.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    y_true = np.array(y_true)
    all_probs = np.array(all_probs)
    y_scores = all_probs[:, 1]

    results = {}
    best_f1, best_threshold = -1, None

    for threshold in thresholds:
        conf = np.max(all_probs, axis=1)
        pred = np.argmax(all_probs, axis=1)
        pred = np.where(conf >= threshold, pred, -1)

        known_mask = pred != -1
        y_eval = y_true[known_mask]
        p_eval = pred[known_mask]

        if len(y_eval) == 0 or len(p_eval) == 0:
            acc = precision = recall = f1 = float('nan')
        else:
            acc = accuracy_score(y_eval, p_eval)
            precision = precision_score(y_eval, p_eval, average='macro', zero_division=0)
            recall = recall_score(y_eval, p_eval, average='macro', zero_division=0)
            f1 = f1_score(y_eval, p_eval, average='macro', zero_division=0)

        results[round(threshold, 2)] = {
            'Accuracy': acc,
            'Precision': precision,
            'Recall': recall,
            'F1 Score': f1,
            'Rejected Unknowns': len(y_true) - len(y_eval)
        }

        if f1 > best_f1:
            best_f1 = f1
            best_threshold = threshold

    try:
        auroc = roc_auc_score(y_true, y_scores, multi_class='ovr')
    except ValueError:
        auroc = float('nan')

    return {
        'Best Threshold': round(best_threshold, 2),
        f"F1@{round(best_threshold, 2)}": best_f1,
        'AUROC': auroc,
        'Accuracy': results[round(best_threshold, 2)]['Accuracy'],
        'Threshold Scores': results
    }

# t-SNE Visualization
def plot_tsne(model, data_dir, save_path):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])
    dataset = datasets.ImageFolder(data_dir, transform=transform)
    loader = DataLoader(dataset, batch_size=32, shuffle=False)

    model.eval()
    features, labels = [], []
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            feat = extract_features_vit(model, x)
            features.append(feat.cpu().numpy())
            labels.extend(y.numpy())

    features = np.concatenate(features)
    tsne = TSNE(n_components=2, init='pca', learning_rate='auto', random_state=42).fit_transform(features)
    plt.figure(figsize=(8, 6))
    scatter = plt.scatter(tsne[:, 0], tsne[:, 1], c=labels, cmap='coolwarm', alpha=0.7)
    plt.legend(*scatter.legend_elements(), title="Classes")
    plt.title(f"t-SNE: {data_dir}")
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path)
    plt.close()

# Evaluation
def evaluate_one_race(model, race_folder):
    print(f"\nEvaluating on: {race_folder}")
    metrics = evaluate_open_set(model, race_folder)
    plot_tsne(model, race_folder, save_path=f"Plots/ViT/tsne_ViT_Multi_{os.path.basename(race_folder)}.png")

    flat_metrics = {
        'Race': os.path.basename(race_folder),
        'Best Threshold': metrics['Best Threshold'],
        f"F1@{metrics['Best Threshold']}": metrics[f"F1@{metrics['Best Threshold']}"],
        'AUROC': metrics['AUROC'],
        'Accuracy': metrics['Accuracy']
    }

    csv_path = "Plots/ViT/ViT_Multi.csv"
    if os.path.exists(csv_path):
        df_existing = pd.read_csv(csv_path)
        df = pd.concat([df_existing, pd.DataFrame([flat_metrics])], ignore_index=True)
    else:
        df = pd.DataFrame([flat_metrics])

    df.to_csv(csv_path, index=False)

# Main
if __name__ == "__main__":
    all_races = ["White", "Black", "Indian", "East_Asian", "Southeast_Asian", "Latino_Hispanic"]
    for left_out_race in all_races:
        train_dirs = [f"{race}_augmented" for race in all_races if race != left_out_race]
        model_path = f"Trained_Models/ViT/ViT_excl_{left_out_race}.pth"
        model = train_model(train_dirs, model_path, max_epochs=10)
        evaluate_one_race(model, f"{left_out_race}_augmented")

Epoch 1: Loss = 0.8881
Epoch 2: Loss = 0.7039
Epoch 3: Loss = 0.6954
Epoch 4: Loss = 0.7116
Epoch 5: Loss = 0.6986
Epoch 6: Loss = 0.6891
Epoch 7: Loss = 0.7001
Epoch 8: Loss = 0.6514
Epoch 9: Loss = 0.6362
Epoch 10: Loss = 0.6438

Evaluating on: White_augmented
Epoch 1: Loss = 0.7403
Epoch 2: Loss = 0.7018
Epoch 3: Loss = 0.7078
Epoch 4: Loss = 0.6913
Epoch 5: Loss = 0.6996
Epoch 6: Loss = 0.6882
Epoch 7: Loss = 0.6886
Epoch 8: Loss = 0.6926
Epoch 9: Loss = 0.6901
Epoch 10: Loss = 0.6824

Evaluating on: Black_augmented
Epoch 1: Loss = 0.8159
Epoch 2: Loss = 0.7072
Epoch 3: Loss = 0.7078
Epoch 4: Loss = 0.6963
Epoch 5: Loss = 0.6650
Epoch 6: Loss = 0.6648
Epoch 7: Loss = 0.6032
Epoch 8: Loss = 0.5563
Epoch 9: Loss = 0.4904
Epoch 10: Loss = 0.4716

Evaluating on: Indian_augmented
Epoch 1: Loss = 0.7625
Epoch 2: Loss = 0.6902
Epoch 3: Loss = 0.6860
Epoch 4: Loss = 0.6520
Epoch 5: Loss = 0.5744
Epoch 6: Loss = 0.5331
Epoch 7: Loss = 0.4784
Epoch 8: Loss = 0.4353
Epoch 9: Loss = 0.4012
Epo