In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, ConcatDataset
from sklearn.manifold import TSNE
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from tqdm import tqdm
import warnings
warnings.simplefilter("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
# SupCon ResNet with projection head
class SupConResNet(nn.Module):
    def __init__(self, base_model='resnet18', projection_dim=128):
        super().__init__()
        self.encoder = getattr(models, base_model)(weights=None)
        self.encoder.fc = nn.Identity()
        self.projection_head = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, projection_dim)
        )

    def forward(self, x):
        feat = self.encoder(x)
        feat = F.normalize(self.projection_head(feat), dim=1)
        return feat

# Linear Classifier 
class LinearClassifier(nn.Module):
    def __init__(self, feat_dim, num_classes=2):
        super().__init__()
        self.fc = nn.Linear(feat_dim, num_classes)

    def forward(self, x):
        return self.fc(x)

In [3]:
# Train SupCon ResNet 
def train_supcon(data_dirs, model_path, epochs=20):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])
    datasets_list = [datasets.ImageFolder(path, transform=transform) for path in data_dirs]
    train_dataset = ConcatDataset(datasets_list)
    loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

    model = SupConResNet().to(device)
    criterion = SupConLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            features = model(x)
            loss = criterion(features, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"[SupCon] Epoch {epoch+1}: Loss = {total_loss / len(loader):.4f}")

    os.makedirs(os.path.dirname(model_path), exist_ok=True)
    torch.save(model.state_dict(), model_path)
    return model

# Extract features for classifier or t-SNE 
def extract_embeddings(model, dataloader):
    model.eval()
    feats, labels = [], []
    with torch.no_grad():
        for x, y in dataloader:
            x = x.to(device)
            feat = model(x)
            feats.append(feat.cpu())
            labels.extend(y)
    return torch.cat(feats), torch.tensor(labels)

In [4]:
# Train linear classifier 
def train_linear_classifier(encoder, data_dirs, save_path):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])
    datasets_list = [datasets.ImageFolder(path, transform=transform) for path in data_dirs]
    dataset = ConcatDataset(datasets_list)
    loader = DataLoader(dataset, batch_size=64, shuffle=True)

    features, labels = extract_embeddings(encoder, loader)
    classifier = LinearClassifier(feat_dim=features.shape[1]).to(device)
    optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(10):
        classifier.train()
        out = classifier(features.to(device))
        loss = criterion(out, labels.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f"[Linear] Epoch {epoch+1}: Loss = {loss.item():.4f}")

    torch.save(classifier.state_dict(), save_path)
    return classifier

In [5]:
# t-SNE Plot 
def plot_tsne(model, data_dir, save_path):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])
    dataset = datasets.ImageFolder(data_dir, transform=transform)
    loader = DataLoader(dataset, batch_size=32, shuffle=False)

    feats, labels = extract_embeddings(model, loader)
    tsne = TSNE(n_components=2).fit_transform(feats.numpy())
    plt.figure(figsize=(8, 6))
    scatter = plt.scatter(tsne[:, 0], tsne[:, 1], c=labels.numpy(), cmap='coolwarm', alpha=0.7)
    plt.legend(*scatter.legend_elements(), title="Classes")
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.title("t-SNE on SupCon Features")
    plt.savefig(save_path)
    plt.close()

In [6]:
# Evaluation
def evaluate_open_set(model, classifier, data_dir, thresholds=np.arange(0.5, 0.96, 0.05)):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])
    dataset = datasets.ImageFolder(data_dir, transform=transform)
    loader = DataLoader(dataset, batch_size=32, shuffle=False)

    feats, labels = extract_embeddings(model, loader)
    logits = classifier(feats.to(device)).softmax(dim=1).cpu().numpy()
    y_true = labels.numpy()
    y_scores = logits[:, 1]

    best_f1, best_threshold, results = -1, None, {}
    for thresh in thresholds:
        conf = np.max(logits, axis=1)
        pred = np.argmax(logits, axis=1)
        pred = np.where(conf >= thresh, pred, -1)
        mask = pred != -1
        y_eval = y_true[mask]
        p_eval = pred[mask]

        if len(y_eval) == 0:
            acc = f1 = float('nan')
        else:
            acc = accuracy_score(y_eval, p_eval)
            f1 = f1_score(y_eval, p_eval)

        results[round(thresh, 2)] = {
            'Accuracy': acc, 'F1': f1, 'Rejected': len(y_true) - len(y_eval)
        }

        if f1 > best_f1:
            best_f1 = f1
            best_threshold = thresh

    try:
        auroc = roc_auc_score(y_true, y_scores)
    except:
        auroc = float('nan')

    return {
        'Best Threshold': best_threshold,
        f'F1@{round(best_threshold, 2)}': best_f1,
        'AUROC': auroc,
        'Threshold Scores': results
    }

In [7]:
if __name__ == "__main__":
    all_races = ["White", "Black", "Indian", "East_Asian", "Southeast_Asian", "Latino_Hispanic"]

    for left_out_race in all_races:
        print(f"\nTraining without {left_out_race}...")
        train_dirs = [f"{r}_augmented" for r in all_races if r != left_out_race]
        test_dir = f"{left_out_race}"

        encoder_path = f"Trained_Models/supcon_encoder_excl_{left_out_race}.pth"
        classifier_path = f"Trained_Models/supcon_classifier_excl_{left_out_race}.pth"

        encoder = train_supcon(train_dirs, encoder_path)
        classifier = train_linear_classifier(encoder, train_dirs, classifier_path)

        plot_tsne(encoder, test_dir, save_path=f"Plots/supcon_tsne_{left_out_race}.png")

        metrics = evaluate_open_set(encoder, classifier, test_dir)
        flat_metrics = {
            'Race': left_out_race,
            'Best Threshold': metrics['Best Threshold'],
            f"F1@{metrics['Best Threshold']}": metrics[f"F1@{metrics['Best Threshold']}"],
            'AUROC': metrics['AUROC']
        }

        csv_path = "Plots/supcon_open_set_results.csv"
        if os.path.exists(csv_path):
            df_existing = pd.read_csv(csv_path)
            df = pd.concat([df_existing, pd.DataFrame([flat_metrics])], ignore_index=True)
        else:
            df = pd.DataFrame([flat_metrics])
        df.to_csv(csv_path, index=False)


Training without White...


NameError: name 'SupConLoss' is not defined