In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
class WhiteningLayer(nn.Module):
    def __init__(self, input_size):
        super(WhiteningLayer, self).__init__()
        self.input_size = input_size
        
    def forward(self, x):
        x_centered = x - torch.mean(x, dim=1, keepdim=True)
        cov_matrix = torch.mm(x_centered, x_centered.t()) / (self.input_size - 1)
        u, s, _ = torch.svd(cov_matrix)
        whitened = torch.mm(u / torch.sqrt(s + 1e-6), x_centered)
        return whitened

class MLPWithWhitening(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, use_whitening=True):
        super(MLPWithWhitening, self).__init__()
        self.use_whitening = use_whitening
        
        if self.use_whitening:
            self.whitening = WhiteningLayer(input_size)
            
        self.hidden1 = nn.Linear(input_size, hidden_size)
        self.hidden2 = nn.Linear(hidden_size, hidden_size)
        self.output = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        if self.use_whitening:
            x = self.whitening(x)
        x = F.relu(self.hidden1(x))
        x = F.relu(self.hidden2(x))
        x = self.output(x)
        return x

In [2]:
class DummyDataset(Dataset):
    def __init__(self, data_path, label_path):
        self.data = np.load(data_path)
        self.labels = np.load(label_path)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        label = self.labels[idx]
        return sample, label

In [3]:
# Define the input size and number of samples
input_size = 1376
# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 64
# Define your training settings
hidden_size = 256
output_size = 2
use_whitening = False  # Change this based on your preference

# Create an instance of the MLPWithWhitening model
model = MLPWithWhitening(input_size, hidden_size, output_size, use_whitening)
model.to(device)
# Define loss function and optimizer
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10





In [9]:
for fold in range(10):
        # Create instances of the dataset class
    train_dataset = DummyDataset('first_exp_data/x_train_%d.npy'%fold, 'first_exp_data/y_train_%d.npy'%fold)
    val_dataset = DummyDataset('first_exp_data/x_val_%d.npy'%fold, 'first_exp_data/y_val_%d.npy'%fold)

    # Define batch size for the data loaders
    
    # Create data loaders for the datasets
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    model.apply(init_weights)
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
    
        for batch_data, batch_labels in train_dataloader:
            batch_data = batch_data.to(device).float()
            batch_labels = batch_labels.float()
            batch_labels = batch_labels.to(device)
            optimizer.zero_grad()
            outputs = model(batch_data)
            outputs = F.softmax(outputs, dim=1)
            outputs = outputs[:, 1]
            loss = criterion(outputs, batch_labels)
            loss.backward()
            optimizer.step()
    
            total_loss += loss.item()
    
        average_loss = total_loss / len(train_dataloader)
        print(f"Epoch [{epoch + 1}/{num_epochs}], Train Loss: {average_loss:.4f}")
    
        # Validation
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
    
        with torch.no_grad():
            for batch_data, batch_labels in val_dataloader:
                batch_data = batch_data.to(device).float()
                batch_labels = batch_labels.float()
                batch_labels = batch_labels.to(device)
                outputs = model(batch_data)
                outputs = F.softmax(outputs, dim=1)
                outputs = outputs[:, 1]
                loss = criterion(outputs, batch_labels)
                val_loss += loss.item()
                predicted = outputs > 0.5
                total += batch_labels.size(0)
                correct += (predicted == batch_labels).sum().item()
    
        val_loss /= len(val_dataloader)
        val_accuracy = correct / total
        print(f"Epoch [{epoch + 1}/{num_epochs}], Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")
    
        # save model with least validation loss
        if epoch == 0:
            best_loss = val_loss
            torch.save(model.state_dict(), 'best_model_first_exp_%d.pth'%fold)
        else:
            if val_loss < best_loss:
                best_loss = val_loss
                torch.save(model.state_dict(), 'best_model_first_exp_%d.pth'%fold)

Epoch [1/10], Train Loss: 0.4090
Epoch [1/10], Val Loss: 0.3877, Val Accuracy: 0.8155
Epoch [2/10], Train Loss: 0.3927
Epoch [2/10], Val Loss: 0.3884, Val Accuracy: 0.8145
Epoch [3/10], Train Loss: 0.3905
Epoch [3/10], Val Loss: 0.3918, Val Accuracy: 0.8158
Epoch [4/10], Train Loss: 0.3893
Epoch [4/10], Val Loss: 0.3874, Val Accuracy: 0.8167
Epoch [5/10], Train Loss: 0.3884
Epoch [5/10], Val Loss: 0.3866, Val Accuracy: 0.8133
Epoch [6/10], Train Loss: 0.3873
Epoch [6/10], Val Loss: 0.3946, Val Accuracy: 0.8165
Epoch [7/10], Train Loss: 0.3865
Epoch [7/10], Val Loss: 0.3846, Val Accuracy: 0.8168
Epoch [8/10], Train Loss: 0.3854
Epoch [8/10], Val Loss: 0.3878, Val Accuracy: 0.8168
Epoch [9/10], Train Loss: 0.3849
Epoch [9/10], Val Loss: 0.3872, Val Accuracy: 0.8159
Epoch [10/10], Train Loss: 0.3846
Epoch [10/10], Val Loss: 0.3843, Val Accuracy: 0.8166
Epoch [1/10], Train Loss: 0.4072
Epoch [1/10], Val Loss: 0.3947, Val Accuracy: 0.8112
Epoch [2/10], Train Loss: 0.3916
Epoch [2/10], Val L

In [22]:
from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score, recall_score, precision_score, \
    f1_score,roc_curve,auc

auc_list = []
for fold in range(10):
    model.load_state_dict(torch.load('best_model_first_exp_%d.pth'%fold))
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    prediction = []
    test_dataset = DummyDataset('first_exp_data/x_test.npy', 'first_exp_data/y_test.npy')
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    with torch.no_grad():
        for batch_data, batch_labels in test_dataloader:
            batch_data = batch_data.to(device).float()
            batch_labels = batch_labels.float()
            batch_labels = batch_labels.to(device)
            outputs = model(batch_data)
            outputs = F.softmax(outputs, dim=1)
            outputs = outputs[:, 1]
            prediction.append(outputs.detach().cpu().numpy())
            loss = criterion(outputs, batch_labels)
            test_loss += loss.item()
            predicted = outputs > 0.5
            total += batch_labels.size(0)
            correct += (predicted == batch_labels).sum().item()
    
    test_loss /= len(test_dataloader)
    test_accuracy = correct / total
    y_true = np.load('first_exp_data/y_test.npy')
    prediction = np.concatenate(prediction)
    fpr, tpr, thresholds = roc_curve(y_true,prediction)
    auc_value = auc(fpr, tpr)
    if auc_value>0.6:
        auc_list.append(auc_value)
        print('AUROC:',auc_value)

auc_list = np.array(auc_list)

print(auc_list.mean(),auc_list.std())

AUROC: 0.826747508401413
AUROC: 0.8260121864968887
AUROC: 0.8262101020803064
AUROC: 0.8253067243302108
AUROC: 0.8238726207719507
AUROC: 0.8267878155994993
AUROC: 0.8274610992438383
AUROC: 0.8248889552208751
0.8259108765181228 0.0010917746275846277


In [23]:
import pickle as pk

with open('first_exp_data/test_race.pkl', 'rb') as f:
    race_test = pk.load(f)

race_index = []
for race in race_test:
    if 'WHITE' in race:
        race_index.append(0)
    elif 'BLACK' in race:
        race_index.append(1)
    elif 'ASIAN' in race:
        race_index.append(2)
    elif 'HISPANIC' in race:
        race_index.append(3)
    else:
        race_index.append(4)

race_dict = {0: 'WHITE', 1: 'BLACK', 2: 'ASIAN', 3: 'HISPANIC', 4: 'OTHER'}
race_index = np.array(race_index)
auc_list = []
for fold in range(10):
    model.load_state_dict(torch.load('best_model_first_exp_%d.pth' % fold))
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    prediction = []
    test_dataset = DummyDataset('first_exp_data/x_test.npy', 'first_exp_data/y_test.npy')
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    with torch.no_grad():
        for batch_data, batch_labels in test_dataloader:
            batch_data = batch_data.to(device).float()
            batch_labels = batch_labels.float()
            batch_labels = batch_labels.to(device)
            outputs = model(batch_data)
            outputs = F.softmax(outputs, dim=1)
            outputs = outputs[:, 1]
            prediction.append(outputs.detach().cpu().numpy())
            loss = criterion(outputs, batch_labels)
            test_loss += loss.item()
            predicted = outputs > 0.5
            total += batch_labels.size(0)
            correct += (predicted == batch_labels).sum().item()

    test_loss /= len(test_dataloader)
    test_accuracy = correct / total

    y_true = np.load('first_exp_data/y_test.npy')
    prediction = np.concatenate(prediction)
    for rc in range(0,5):
        rc_index = race_index == rc
        rc_y_true = y_true[rc_index]
        rc_prediction = prediction[rc_index]
        fpr, tpr, thresholds = roc_curve(rc_y_true, rc_prediction)
        auc_value = auc(fpr, tpr)
        if auc_value > 0.6:
            auc_list.append(auc_value)
            print('%s - AUROC:'%race_dict[rc], auc_value)
    print('------------------------------------------------------------')

WHITE - AUROC: 0.8153681717102067
BLACK - AUROC: 0.8287208017737059
ASIAN - AUROC: 0.837809172287981
HISPANIC - AUROC: 0.8459180216802167
OTHER - AUROC: 0.8495176968319252
------------------------------------------------------------
WHITE - AUROC: 0.8153270034054667
BLACK - AUROC: 0.8265105455123318
ASIAN - AUROC: 0.8302999448474822
HISPANIC - AUROC: 0.842970867208672
OTHER - AUROC: 0.8496721751866084
------------------------------------------------------------
------------------------------------------------------------
WHITE - AUROC: 0.8153834401044973
BLACK - AUROC: 0.827591225595883
ASIAN - AUROC: 0.840418310635951
HISPANIC - AUROC: 0.8447239159891599
OTHER - AUROC: 0.8476730168897845
------------------------------------------------------------
WHITE - AUROC: 0.8144182265278411
BLACK - AUROC: 0.8275679235509565
ASIAN - AUROC: 0.8304908574095287
HISPANIC - AUROC: 0.8407435636856367
OTHER - AUROC: 0.8483960299515861
------------------------------------------------------------
WHITE -

In [27]:
import pickle as pk

auc_list = []
for fold in range(10):
    model.load_state_dict(torch.load('best_model_first_exp_%d.pth' % fold))
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    prediction = []
    test_dataset = DummyDataset('gender_exp_data/x_test.npy', 'gender_exp_data/y_test.npy')
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    with torch.no_grad():
        for batch_data, batch_labels in test_dataloader:
            batch_data = batch_data.to(device).float()
            batch_labels = batch_labels.float()
            batch_labels = batch_labels.to(device)
            outputs = model(batch_data)
            outputs = F.softmax(outputs, dim=1)
            outputs = outputs[:, 1]
            prediction.append(outputs.detach().cpu().numpy())
            loss = criterion(outputs, batch_labels)
            test_loss += loss.item()
            predicted = outputs > 0.5
            total += batch_labels.size(0)
            correct += (predicted == batch_labels).sum().item()

    test_loss /= len(test_dataloader)
    test_accuracy = correct / total

    y_true = np.load('gender_exp_data/y_test.npy')
    prediction = np.concatenate(prediction)
    fpr, tpr, thresholds = roc_curve(y_true, prediction)
    auc_value = auc(fpr, tpr)
    if auc_value > 0.6:
        auc_list.append(auc_value)
        print('AUROC:', auc_value)
print('------------------------------------------------------------')

AUROC: 0.8351782202132432
AUROC: 0.8343578279053973
AUROC: 0.8335517462439093
AUROC: 0.8339586189863051
AUROC: 0.8325097931675745
AUROC: 0.8352575825540872
AUROC: 0.8354103505200321
AUROC: 0.8333116252208796
------------------------------------------------------------


In [29]:
import pickle as pk

auc_list = []
for fold in range(10):
    model.load_state_dict(torch.load('best_model_first_exp_%d.pth' % fold))
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    prediction = []
    test_dataset = DummyDataset('gender_exp_data/x_test_male.npy', 'gender_exp_data/y_test_male.npy')
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    with torch.no_grad():
        for batch_data, batch_labels in test_dataloader:
            batch_data = batch_data.to(device).float()
            batch_labels = batch_labels.float()
            batch_labels = batch_labels.to(device)
            outputs = model(batch_data)
            outputs = F.softmax(outputs, dim=1)
            outputs = outputs[:, 1]
            prediction.append(outputs.detach().cpu().numpy())
            loss = criterion(outputs, batch_labels)
            test_loss += loss.item()
            predicted = outputs > 0.5
            total += batch_labels.size(0)
            correct += (predicted == batch_labels).sum().item()

    test_loss /= len(test_dataloader)
    test_accuracy = correct / total

    y_true = np.load('gender_exp_data/y_test_male.npy')
    prediction = np.concatenate(prediction)
    fpr, tpr, thresholds = roc_curve(y_true, prediction)
    auc_value = auc(fpr, tpr)
    if auc_value > 0.6:
        auc_list.append(auc_value)
        print('AUROC:', auc_value)
print('------------------------------------------------------------')

AUROC: 0.8179324188647871
AUROC: 0.8174506368216287
AUROC: 0.8188275608920158
AUROC: 0.8168520342229977
AUROC: 0.8153442172082431
AUROC: 0.8182527629886176
AUROC: 0.8192628843241742
AUROC: 0.8164466310569468
------------------------------------------------------------
