In [1]:
import os
import gc
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.models import vit_b_16, ViT_B_16_Weights
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ViT Model
def get_vit_model():
    model = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
    if isinstance(model.heads, nn.Sequential) and isinstance(model.heads[-1], nn.Linear):
        in_features = model.heads[-1].in_features
        model.heads[-1] = nn.Linear(in_features, 2)
    else:
        raise ValueError("Unexpected structure in model.heads")
    return model

# Training
def train_single_race(data_dir, save_path, max_epochs=10):
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])
    dataset = datasets.ImageFolder(data_dir, transform=transform)
    loader = DataLoader(dataset, batch_size=8, shuffle=True)

    model = get_vit_model()
    try:
        model = model.to(device)
        _ = model(torch.randn(1, 3, 224, 224).to(device))
    except RuntimeError as e:
        if "out of memory" in str(e).lower():
            torch.cuda.empty_cache()
            gc.collect()
            model = model.to("cpu")
        else:
            raise e

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    for epoch in range(max_epochs):
        model.train()
        total_loss = 0
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}: Loss = {total_loss / len(loader):.4f}")

    torch.save(model.state_dict(), save_path)
    return model

# Open-set Evaluation
def evaluate_open_set(model, data_dir, threshold=0.8):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])
    dataset = datasets.ImageFolder(data_dir, transform=transform)
    loader = DataLoader(dataset, batch_size=32, shuffle=False)

    model.eval()
    y_true, y_pred, y_score = [], [], []

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            out = model(x)
            probs = torch.softmax(out, dim=1)
            conf, pred = torch.max(probs, dim=1)
            y_true.extend(y.cpu().numpy())
            y_score.extend(probs[:, 1].cpu().numpy())
            pred = torch.where(conf >= threshold, pred, torch.tensor(-1, device=pred.device))
            y_pred.extend(pred.cpu().numpy())

    known_mask = [p != -1 for p in y_pred]
    y_eval = [y for y, k in zip(y_true, known_mask) if k]
    p_eval = [p for p, k in zip(y_pred, known_mask) if k]

    return {
        'Accuracy': accuracy_score(y_eval, p_eval) if y_eval else float('nan'),
        'Precision': precision_score(y_eval, p_eval, zero_division=0) if y_eval else 0.0,
        'Recall': recall_score(y_eval, p_eval, zero_division=0) if y_eval else 0.0,
        'F1 Score': f1_score(y_eval, p_eval, zero_division=0) if y_eval else 0.0,
        'AUROC': roc_auc_score(y_true, y_score) if len(set(y_true)) > 1 else float('nan'),
        'Rejected Unknowns': len(y_true) - len(y_eval)
    }

# t-SNE Visualization
def plot_tsne(model, data_dir, save_path):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])
    dataset = datasets.ImageFolder(data_dir, transform=transform)
    loader = DataLoader(dataset, batch_size=32, shuffle=False)

    model.eval()
    features, labels = [], []
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            logits = model(x)
            features.append(logits.cpu().numpy())
            labels.extend(y.numpy())

    features = np.concatenate(features)
    tsne = TSNE(n_components=2, init='pca', learning_rate='auto', random_state=42).fit_transform(features)

    plt.figure(figsize=(8, 6))
    scatter = plt.scatter(tsne[:, 0], tsne[:, 1], c=labels, cmap='coolwarm', alpha=0.7)
    plt.legend(*scatter.legend_elements(), title="Classes")
    plt.title(f"t-SNE: {data_dir}")
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path)
    plt.close()

# Evaluation
def evaluate_one_race(model, race_folder):
    print(f"\nEvaluating on: {race_folder}")
    metrics = evaluate_open_set(model, race_folder, threshold=0.8)
    plot_tsne(model, race_folder, save_path=f"Plots/ViT/tsne_ViT_Single_{os.path.basename(race_folder)}.png")

    df = pd.DataFrame([metrics])
    df["Race"] = os.path.basename(race_folder)
    csv_path = "Plots/ViT/ViT_Single.csv"

    if os.path.exists(csv_path):
        df_existing = pd.read_csv(csv_path)
        df = pd.concat([df_existing, df], ignore_index=True)

    df.to_csv(csv_path, index=False)

# Main
if __name__ == "__main__":
    train_race = "White_augmented"
    model_path = "Trained_Models/ViT/ViT_Single_White.pth"
    os.makedirs(os.path.dirname(model_path), exist_ok=True)

    model = train_single_race(train_race, model_path, max_epochs=10) 

    target_races = ["Black", "Indian", "East_Asian", "Southeast_Asian", "Latino_Hispanic"]
    for race in target_races:
        evaluate_one_race(model, f"{race}_augmented")

Epoch 1: Loss = 0.2671
Epoch 2: Loss = 0.0781
Epoch 3: Loss = 0.0849
Epoch 4: Loss = 0.0707
Epoch 5: Loss = 0.0678
Epoch 6: Loss = 0.0399
Epoch 7: Loss = 0.0824
Epoch 8: Loss = 0.0324
Epoch 9: Loss = 0.0003
Epoch 10: Loss = 0.0816

Evaluating on: Black_augmented

Evaluating on: Indian_augmented

Evaluating on: East_Asian_augmented

Evaluating on: Southeast_Asian_augmented

Evaluating on: Latino_Hispanic_augmented
