In [1]:
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
import os
from tqdm import tqdm
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [2]:
class TFRecordVectorDataset(Dataset):
    def __init__(self, dataframe):
        self.data = dataframe
        self.paths = self.data['path'].tolist()
        self.labels = self.data.iloc[:, -14:].values.astype(np.float32)  # the last 14 columns are labels
        self.group_ids = pd.factorize(self.data['race'])[0]  # Focusing on "race"

        ###
        codes, uniques = pd.factorize(self.data['race'])
        counts = pd.Series(codes).value_counts().sort_index()
        mapping = dict(enumerate(uniques))
        for code, count in counts.items():
            print(f"{code}: {mapping[code]} -> {count} times")
        ###

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        try:
            # convert .tfrecord to .npy and load them from the disk
            npy_path = path.replace('.tfrecord', '.npy')
            embedding = np.load(npy_path)

            # converting to tensor
            embedding_tensor = torch.tensor(embedding, dtype=torch.float32)  # input shape: (1376,)
            label_tensor = torch.tensor(self.labels[idx], dtype=torch.float32)  # label shape: (14,)
            group_id = torch.tensor(self.group_ids[idx], dtype=torch.int64)  # get the sensitive features: (1,)
            return embedding_tensor, label_tensor, group_id
        except Exception as e:
            print(f"Error loading {path}: {str(e)}")
            return torch.zeros(1376), torch.zeros(14), torch.tensor(-1, dtype=torch.int64)

In [3]:
# load dataset
#using the GPU to boost
os.environ['CUDA_VISIBLE_DEVICES'] = "5"
train_df = pd.read_csv("training_set_debias.csv")
test_df = pd.read_csv("testset_debias.csv")

train_dataset = TFRecordVectorDataset(train_df)
test_dataset = TFRecordVectorDataset(test_df)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=0)

0: WHITE -> 135856 times
1: BLACK/AFRICAN AMERICAN -> 32887 times
2: OTHER -> 9664 times
3: ASIAN -> 6630 times
4: HISPANIC/LATINO -> 11035 times
5: AMERICAN INDIAN/ALASKA NATIVE -> 480 times
0: WHITE -> 14508 times
1: BLACK/AFRICAN AMERICAN -> 3685 times
2: OTHER -> 1062 times
3: HISPANIC/LATINO -> 1392 times
4: ASIAN -> 760 times
5: AMERICAN INDIAN/ALASKA NATIVE -> 184 times


In [4]:
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size1, hidden_size2, output_size):
        super(MLP, self).__init__()
        # First hidden layer
        self.fc1 = nn.Linear(input_size, hidden_size1)
        # Second hidden layer
        self.fc2 = nn.Linear(hidden_size1, hidden_size2)
        # Output
        self.fc3 = nn.Linear(hidden_size2, output_size)
        # Activation Func
        self.relu = nn.ReLU()

    def forward(self, x):
        # First
        out = self.fc1(x)
        out = self.relu(out)
        # Second
        out = self.fc2(out)
        out = self.relu(out)
        # Output
        out = self.fc3(out)
        return out

In [5]:
class MultiLabelGroupDROLoss(nn.Module):
    def __init__(self, num_groups, num_labels):
        super().__init__()
        self.num_groups = num_groups
        self.num_labels = num_labels
        self.loss_fn = nn.BCEWithLogitsLoss()  # using BCEWithLogitsLoss

    def forward(self, logits, y, group_ids):
        total_loss = 0.0
        group_losses = torch.zeros(self.num_groups, device=logits.device)
        group_counts = torch.zeros(self.num_groups, device=logits.device)
        
        for g in range(self.num_groups):
            group_mask = (group_ids == g)
            if group_mask.any():
                # calculate the ave. loss of this sub-group over all labels
                group_logits = logits[group_mask]
                group_targets = y[group_mask]
                group_loss = self.loss_fn(group_logits, group_targets)

                
                group_losses[g] = group_loss
                group_counts[g] = 1  # mark this sub-group as existing
                
        # only calculate the loss for existing sub-groups
        valid_groups = (group_counts > 0)
        if valid_groups.any():
            weights = torch.softmax(group_losses[valid_groups]/0.5, dim=0)  # using the temperature hyperparameter to focus on minority sub-groups
            total_loss = torch.sum(weights * group_losses[valid_groups])
            
        return total_loss

In [6]:
group_ids_tensor = torch.tensor(train_dataset.group_ids, dtype=torch.int64)
num_groups = len(torch.unique(group_ids_tensor))
num_labels = 14

In [7]:
input_size = 1376  # the size of input
hidden_size1 = 1024  # neurons of the first hidden layer
hidden_size2 = 512  # neurons of the second hidden layer
output_size = 14   # number of labels

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model = MLP(input_size, hidden_size1, hidden_size2, output_size)

group_dro_loss = MultiLabelGroupDROLoss(num_groups,num_labels)


optimizer = optim.Adam(model.parameters(), lr=0.003)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3, verbose=True)


if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    model = nn.DataParallel(model)
model = model.to(device)
criterion = nn.BCEWithLogitsLoss()

num_epochs = 30
best_auroc = 0.0

cuda




In [8]:
for epoch in range(num_epochs):
    # Training precedure
    model.train()
    running_loss = 0.0
    group_sample_counts = [0] * num_groups 
    group_losses_epoch = [[] for _ in range(num_groups)]  
    for inputs, labels, group_ids in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} - Training'):
        inputs, labels, group_ids = inputs.to(device), labels.to(device), group_ids.to(device)
        
        # Forward
        outputs = model(inputs)
        
        # Compute conventional loss
        loss = criterion(outputs, labels)
        
        # Compute GroupDRO loss
        group_loss = group_dro_loss(outputs, labels, group_ids)
        
        # Backprop
        optimizer.zero_grad()
        total_loss = 0.30*loss + 0.70*group_loss
        total_loss.backward()
        #group_loss.backward()
        optimizer.step()
        
        running_loss += group_loss.item()  # empirical loss of GroupDRO
        for g in range(num_groups):
            group_mask = (group_ids == g)
            group_sample_counts[g] += group_mask.sum().item()  
            if group_mask.sum() > 0:  
                #group_loss_g = criterion(probs[group_mask], labels[group_mask])
                group_loss_g = criterion(outputs[group_mask], labels[group_mask])       
                group_losses_epoch[g].append(group_loss_g.item())
                
    # 打印训练损失
    train_loss = running_loss / len(train_loader)
    print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}')

    #print loss for each sub-groups
    for g in range(num_groups):
        if group_sample_counts[g] > 0:
            avg_group_loss = np.mean(group_losses_epoch[g]) if group_losses_epoch[g] else 0.0
            print(f"Group {g}: Samples = {group_sample_counts[g]}, Avg Loss = {avg_group_loss:.4f}")
        else:
            print(f"Group {g}: Samples = 0, Avg Loss = 0.0")
    
    # eva.
    model.eval()
    all_probs = []
    all_labels = []
    with torch.no_grad():
        #using the test set to validate
        for inputs, labels, _ in tqdm(test_loader, desc=f'Epoch {epoch+1}/{num_epochs} - Validation'):
            inputs, labels = inputs.to(device), labels.to(device)
            

            outputs = model(inputs)
            probs = torch.sigmoid(outputs) 
            
            # save result
            all_probs.append(probs.cpu().numpy())
            all_labels.append(labels.cpu().numpy())

    # compute AUROC
    all_probs = np.concatenate(all_probs, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)
    
    auroc_scores = []
    for i in range(all_labels.shape[1]):  # triverse all labels
        try:
            auroc = roc_auc_score(all_labels[:, i], all_probs[:, i])
            auroc_scores.append(auroc)
        except ValueError:
            print(f"Label {i} has no positive or negative samples in the test set.")
            auroc_scores.append(np.nan)
    
    # calculate macro AUROC
    macro_auroc = np.nanmean(auroc_scores)
    print(f'Epoch {epoch+1}, Validation Macro AUROC: {macro_auroc:.4f}')


    scheduler.step(macro_auroc)

    
    # saves the best model
    if macro_auroc > best_auroc:
        best_auroc = macro_auroc
        torch.save(model.state_dict(), 'best_mlp_model.pth')
        print(f'Best model saved with AUROC: {best_auroc:.4f}')

print('Training complete.')

Epoch 1/30 - Training: 100%|████████████████████████████████████████████████████████| 3072/3072 [01:09<00:00, 44.26it/s]


Epoch 1, Train Loss: 0.2733
Group 0: Samples = 135856, Avg Loss = 0.2804
Group 1: Samples = 32887, Avg Loss = 0.2376
Group 2: Samples = 9664, Avg Loss = 0.2674
Group 3: Samples = 6630, Avg Loss = 0.2764
Group 4: Samples = 11035, Avg Loss = 0.2256
Group 5: Samples = 480, Avg Loss = 0.2829


Epoch 1/30 - Validation: 100%|████████████████████████████████████████████████████████| 338/338 [00:03<00:00, 87.76it/s]


Epoch 1, Validation Macro AUROC: 0.7890
Best model saved with AUROC: 0.7890


Epoch 2/30 - Training: 100%|████████████████████████████████████████████████████████| 3072/3072 [01:09<00:00, 44.36it/s]


Epoch 2, Train Loss: 0.2656
Group 0: Samples = 135856, Avg Loss = 0.2725
Group 1: Samples = 32887, Avg Loss = 0.2303
Group 2: Samples = 9664, Avg Loss = 0.2608
Group 3: Samples = 6630, Avg Loss = 0.2665
Group 4: Samples = 11035, Avg Loss = 0.2195
Group 5: Samples = 480, Avg Loss = 0.2761


Epoch 2/30 - Validation: 100%|████████████████████████████████████████████████████████| 338/338 [00:03<00:00, 90.60it/s]


Epoch 2, Validation Macro AUROC: 0.7986
Best model saved with AUROC: 0.7986


Epoch 3/30 - Training: 100%|████████████████████████████████████████████████████████| 3072/3072 [01:07<00:00, 45.46it/s]


Epoch 3, Train Loss: 0.2642
Group 0: Samples = 135856, Avg Loss = 0.2711
Group 1: Samples = 32887, Avg Loss = 0.2287
Group 2: Samples = 9664, Avg Loss = 0.2586
Group 3: Samples = 6630, Avg Loss = 0.2666
Group 4: Samples = 11035, Avg Loss = 0.2178
Group 5: Samples = 480, Avg Loss = 0.2739


Epoch 3/30 - Validation: 100%|████████████████████████████████████████████████████████| 338/338 [00:04<00:00, 82.49it/s]


Epoch 3, Validation Macro AUROC: 0.7927


Epoch 4/30 - Training: 100%|████████████████████████████████████████████████████████| 3072/3072 [01:08<00:00, 44.98it/s]


Epoch 4, Train Loss: 0.2611
Group 0: Samples = 135856, Avg Loss = 0.2694
Group 1: Samples = 32887, Avg Loss = 0.2270
Group 2: Samples = 9664, Avg Loss = 0.2571
Group 3: Samples = 6630, Avg Loss = 0.2638
Group 4: Samples = 11035, Avg Loss = 0.2144
Group 5: Samples = 480, Avg Loss = 0.2689


Epoch 4/30 - Validation: 100%|████████████████████████████████████████████████████████| 338/338 [00:03<00:00, 89.77it/s]


Epoch 4, Validation Macro AUROC: 0.8013
Best model saved with AUROC: 0.8013


Epoch 5/30 - Training: 100%|████████████████████████████████████████████████████████| 3072/3072 [01:08<00:00, 44.53it/s]


Epoch 5, Train Loss: 0.2619
Group 0: Samples = 135856, Avg Loss = 0.2696
Group 1: Samples = 32887, Avg Loss = 0.2269
Group 2: Samples = 9664, Avg Loss = 0.2593
Group 3: Samples = 6630, Avg Loss = 0.2652
Group 4: Samples = 11035, Avg Loss = 0.2138
Group 5: Samples = 480, Avg Loss = 0.2684


Epoch 5/30 - Validation: 100%|████████████████████████████████████████████████████████| 338/338 [00:03<00:00, 87.51it/s]


Epoch 5, Validation Macro AUROC: 0.7983


Epoch 6/30 - Training: 100%|████████████████████████████████████████████████████████| 3072/3072 [01:08<00:00, 45.14it/s]


Epoch 6, Train Loss: 0.2607
Group 0: Samples = 135856, Avg Loss = 0.2688
Group 1: Samples = 32887, Avg Loss = 0.2253
Group 2: Samples = 9664, Avg Loss = 0.2554
Group 3: Samples = 6630, Avg Loss = 0.2635
Group 4: Samples = 11035, Avg Loss = 0.2143
Group 5: Samples = 480, Avg Loss = 0.2674


Epoch 6/30 - Validation: 100%|████████████████████████████████████████████████████████| 338/338 [00:03<00:00, 85.98it/s]


Epoch 6, Validation Macro AUROC: 0.7975


Epoch 7/30 - Training: 100%|████████████████████████████████████████████████████████| 3072/3072 [01:09<00:00, 44.47it/s]


Epoch 7, Train Loss: 0.2596
Group 0: Samples = 135856, Avg Loss = 0.2681
Group 1: Samples = 32887, Avg Loss = 0.2255
Group 2: Samples = 9664, Avg Loss = 0.2524
Group 3: Samples = 6630, Avg Loss = 0.2615
Group 4: Samples = 11035, Avg Loss = 0.2163
Group 5: Samples = 480, Avg Loss = 0.2654


Epoch 7/30 - Validation: 100%|████████████████████████████████████████████████████████| 338/338 [00:03<00:00, 86.29it/s]


Epoch 7, Validation Macro AUROC: 0.7982


Epoch 8/30 - Training: 100%|████████████████████████████████████████████████████████| 3072/3072 [01:05<00:00, 47.09it/s]


Epoch 8, Train Loss: 0.2587
Group 0: Samples = 135856, Avg Loss = 0.2676
Group 1: Samples = 32887, Avg Loss = 0.2260
Group 2: Samples = 9664, Avg Loss = 0.2535
Group 3: Samples = 6630, Avg Loss = 0.2598
Group 4: Samples = 11035, Avg Loss = 0.2135
Group 5: Samples = 480, Avg Loss = 0.2660


Epoch 8/30 - Validation: 100%|████████████████████████████████████████████████████████| 338/338 [00:03<00:00, 86.38it/s]


Epoch 8, Validation Macro AUROC: 0.8010


Epoch 9/30 - Training: 100%|████████████████████████████████████████████████████████| 3072/3072 [01:07<00:00, 45.73it/s]


Epoch 9, Train Loss: 0.2549
Group 0: Samples = 135856, Avg Loss = 0.2641
Group 1: Samples = 32887, Avg Loss = 0.2222
Group 2: Samples = 9664, Avg Loss = 0.2503
Group 3: Samples = 6630, Avg Loss = 0.2571
Group 4: Samples = 11035, Avg Loss = 0.2104
Group 5: Samples = 480, Avg Loss = 0.2601


Epoch 9/30 - Validation: 100%|████████████████████████████████████████████████████████| 338/338 [00:04<00:00, 82.34it/s]


Epoch 9, Validation Macro AUROC: 0.8063
Best model saved with AUROC: 0.8063


Epoch 10/30 - Training: 100%|███████████████████████████████████████████████████████| 3072/3072 [01:08<00:00, 44.84it/s]


Epoch 10, Train Loss: 0.2538
Group 0: Samples = 135856, Avg Loss = 0.2637
Group 1: Samples = 32887, Avg Loss = 0.2218
Group 2: Samples = 9664, Avg Loss = 0.2469
Group 3: Samples = 6630, Avg Loss = 0.2556
Group 4: Samples = 11035, Avg Loss = 0.2095
Group 5: Samples = 480, Avg Loss = 0.2581


Epoch 10/30 - Validation: 100%|███████████████████████████████████████████████████████| 338/338 [00:04<00:00, 83.64it/s]


Epoch 10, Validation Macro AUROC: 0.8067
Best model saved with AUROC: 0.8067


Epoch 11/30 - Training: 100%|███████████████████████████████████████████████████████| 3072/3072 [01:09<00:00, 44.36it/s]


Epoch 11, Train Loss: 0.2534
Group 0: Samples = 135856, Avg Loss = 0.2636
Group 1: Samples = 32887, Avg Loss = 0.2212
Group 2: Samples = 9664, Avg Loss = 0.2491
Group 3: Samples = 6630, Avg Loss = 0.2554
Group 4: Samples = 11035, Avg Loss = 0.2078
Group 5: Samples = 480, Avg Loss = 0.2547


Epoch 11/30 - Validation: 100%|███████████████████████████████████████████████████████| 338/338 [00:03<00:00, 89.03it/s]


Epoch 11, Validation Macro AUROC: 0.8048


Epoch 12/30 - Training: 100%|███████████████████████████████████████████████████████| 3072/3072 [01:07<00:00, 45.63it/s]


Epoch 12, Train Loss: 0.2538
Group 0: Samples = 135856, Avg Loss = 0.2635
Group 1: Samples = 32887, Avg Loss = 0.2214
Group 2: Samples = 9664, Avg Loss = 0.2494
Group 3: Samples = 6630, Avg Loss = 0.2534
Group 4: Samples = 11035, Avg Loss = 0.2094
Group 5: Samples = 480, Avg Loss = 0.2580


Epoch 12/30 - Validation: 100%|███████████████████████████████████████████████████████| 338/338 [00:03<00:00, 90.80it/s]


Epoch 12, Validation Macro AUROC: 0.8095
Best model saved with AUROC: 0.8095


Epoch 13/30 - Training: 100%|███████████████████████████████████████████████████████| 3072/3072 [01:08<00:00, 44.75it/s]


Epoch 13, Train Loss: 0.2537
Group 0: Samples = 135856, Avg Loss = 0.2632
Group 1: Samples = 32887, Avg Loss = 0.2213
Group 2: Samples = 9664, Avg Loss = 0.2506
Group 3: Samples = 6630, Avg Loss = 0.2540
Group 4: Samples = 11035, Avg Loss = 0.2084
Group 5: Samples = 480, Avg Loss = 0.2569


Epoch 13/30 - Validation: 100%|███████████████████████████████████████████████████████| 338/338 [00:03<00:00, 89.71it/s]


Epoch 13, Validation Macro AUROC: 0.8083


Epoch 14/30 - Training: 100%|███████████████████████████████████████████████████████| 3072/3072 [01:09<00:00, 44.43it/s]


Epoch 14, Train Loss: 0.2519
Group 0: Samples = 135856, Avg Loss = 0.2631
Group 1: Samples = 32887, Avg Loss = 0.2208
Group 2: Samples = 9664, Avg Loss = 0.2463
Group 3: Samples = 6630, Avg Loss = 0.2534
Group 4: Samples = 11035, Avg Loss = 0.2062
Group 5: Samples = 480, Avg Loss = 0.2566


Epoch 14/30 - Validation: 100%|███████████████████████████████████████████████████████| 338/338 [00:03<00:00, 88.62it/s]


Epoch 14, Validation Macro AUROC: 0.8077


Epoch 15/30 - Training: 100%|███████████████████████████████████████████████████████| 3072/3072 [01:07<00:00, 45.31it/s]


Epoch 15, Train Loss: 0.2526
Group 0: Samples = 135856, Avg Loss = 0.2631
Group 1: Samples = 32887, Avg Loss = 0.2213
Group 2: Samples = 9664, Avg Loss = 0.2469
Group 3: Samples = 6630, Avg Loss = 0.2525
Group 4: Samples = 11035, Avg Loss = 0.2090
Group 5: Samples = 480, Avg Loss = 0.2536


Epoch 15/30 - Validation: 100%|███████████████████████████████████████████████████████| 338/338 [00:03<00:00, 90.57it/s]


Epoch 15, Validation Macro AUROC: 0.8079


Epoch 16/30 - Training: 100%|███████████████████████████████████████████████████████| 3072/3072 [01:08<00:00, 44.90it/s]


Epoch 16, Train Loss: 0.2529
Group 0: Samples = 135856, Avg Loss = 0.2631
Group 1: Samples = 32887, Avg Loss = 0.2213
Group 2: Samples = 9664, Avg Loss = 0.2477
Group 3: Samples = 6630, Avg Loss = 0.2516
Group 4: Samples = 11035, Avg Loss = 0.2095
Group 5: Samples = 480, Avg Loss = 0.2562


Epoch 16/30 - Validation: 100%|███████████████████████████████████████████████████████| 338/338 [00:03<00:00, 93.23it/s]


Epoch 16, Validation Macro AUROC: 0.8088


Epoch 17/30 - Training: 100%|███████████████████████████████████████████████████████| 3072/3072 [01:07<00:00, 45.21it/s]


Epoch 17, Train Loss: 0.2487
Group 0: Samples = 135856, Avg Loss = 0.2611
Group 1: Samples = 32887, Avg Loss = 0.2195
Group 2: Samples = 9664, Avg Loss = 0.2429
Group 3: Samples = 6630, Avg Loss = 0.2468
Group 4: Samples = 11035, Avg Loss = 0.2053
Group 5: Samples = 480, Avg Loss = 0.2500


Epoch 17/30 - Validation: 100%|███████████████████████████████████████████████████████| 338/338 [00:03<00:00, 91.04it/s]


Epoch 17, Validation Macro AUROC: 0.8117
Best model saved with AUROC: 0.8117


Epoch 18/30 - Training: 100%|███████████████████████████████████████████████████████| 3072/3072 [01:07<00:00, 45.49it/s]


Epoch 18, Train Loss: 0.2503
Group 0: Samples = 135856, Avg Loss = 0.2610
Group 1: Samples = 32887, Avg Loss = 0.2189
Group 2: Samples = 9664, Avg Loss = 0.2448
Group 3: Samples = 6630, Avg Loss = 0.2503
Group 4: Samples = 11035, Avg Loss = 0.2066
Group 5: Samples = 480, Avg Loss = 0.2518


Epoch 18/30 - Validation: 100%|███████████████████████████████████████████████████████| 338/338 [00:03<00:00, 89.84it/s]


Epoch 18, Validation Macro AUROC: 0.8096


Epoch 19/30 - Training: 100%|███████████████████████████████████████████████████████| 3072/3072 [01:06<00:00, 46.02it/s]


Epoch 19, Train Loss: 0.2493
Group 0: Samples = 135856, Avg Loss = 0.2609
Group 1: Samples = 32887, Avg Loss = 0.2185
Group 2: Samples = 9664, Avg Loss = 0.2465
Group 3: Samples = 6630, Avg Loss = 0.2475
Group 4: Samples = 11035, Avg Loss = 0.2056
Group 5: Samples = 480, Avg Loss = 0.2481


Epoch 19/30 - Validation: 100%|███████████████████████████████████████████████████████| 338/338 [00:03<00:00, 92.23it/s]


Epoch 19, Validation Macro AUROC: 0.8107


Epoch 20/30 - Training: 100%|███████████████████████████████████████████████████████| 3072/3072 [01:06<00:00, 46.36it/s]


Epoch 20, Train Loss: 0.2498
Group 0: Samples = 135856, Avg Loss = 0.2610
Group 1: Samples = 32887, Avg Loss = 0.2189
Group 2: Samples = 9664, Avg Loss = 0.2457
Group 3: Samples = 6630, Avg Loss = 0.2485
Group 4: Samples = 11035, Avg Loss = 0.2055
Group 5: Samples = 480, Avg Loss = 0.2479


Epoch 20/30 - Validation: 100%|███████████████████████████████████████████████████████| 338/338 [00:03<00:00, 93.38it/s]


Epoch 20, Validation Macro AUROC: 0.8104


Epoch 21/30 - Training: 100%|███████████████████████████████████████████████████████| 3072/3072 [01:07<00:00, 45.34it/s]


Epoch 21, Train Loss: 0.2499
Group 0: Samples = 135856, Avg Loss = 0.2610
Group 1: Samples = 32887, Avg Loss = 0.2187
Group 2: Samples = 9664, Avg Loss = 0.2440
Group 3: Samples = 6630, Avg Loss = 0.2495
Group 4: Samples = 11035, Avg Loss = 0.2070
Group 5: Samples = 480, Avg Loss = 0.2477


Epoch 21/30 - Validation: 100%|███████████████████████████████████████████████████████| 338/338 [00:03<00:00, 90.83it/s]


Epoch 21, Validation Macro AUROC: 0.8113


Epoch 22/30 - Training: 100%|███████████████████████████████████████████████████████| 3072/3072 [01:06<00:00, 46.13it/s]


Epoch 22, Train Loss: 0.2485
Group 0: Samples = 135856, Avg Loss = 0.2598
Group 1: Samples = 32887, Avg Loss = 0.2178
Group 2: Samples = 9664, Avg Loss = 0.2432
Group 3: Samples = 6630, Avg Loss = 0.2491
Group 4: Samples = 11035, Avg Loss = 0.2039
Group 5: Samples = 480, Avg Loss = 0.2442


Epoch 22/30 - Validation: 100%|███████████████████████████████████████████████████████| 338/338 [00:03<00:00, 90.18it/s]


Epoch 22, Validation Macro AUROC: 0.8117


Epoch 23/30 - Training: 100%|███████████████████████████████████████████████████████| 3072/3072 [01:09<00:00, 44.43it/s]


Epoch 23, Train Loss: 0.2480
Group 0: Samples = 135856, Avg Loss = 0.2597
Group 1: Samples = 32887, Avg Loss = 0.2180
Group 2: Samples = 9664, Avg Loss = 0.2420
Group 3: Samples = 6630, Avg Loss = 0.2487
Group 4: Samples = 11035, Avg Loss = 0.2042
Group 5: Samples = 480, Avg Loss = 0.2455


Epoch 23/30 - Validation: 100%|███████████████████████████████████████████████████████| 338/338 [00:03<00:00, 89.64it/s]


Epoch 23, Validation Macro AUROC: 0.8114


Epoch 24/30 - Training: 100%|███████████████████████████████████████████████████████| 3072/3072 [01:06<00:00, 46.04it/s]


Epoch 24, Train Loss: 0.2471
Group 0: Samples = 135856, Avg Loss = 0.2598
Group 1: Samples = 32887, Avg Loss = 0.2180
Group 2: Samples = 9664, Avg Loss = 0.2408
Group 3: Samples = 6630, Avg Loss = 0.2470
Group 4: Samples = 11035, Avg Loss = 0.2043
Group 5: Samples = 480, Avg Loss = 0.2458


Epoch 24/30 - Validation: 100%|███████████████████████████████████████████████████████| 338/338 [00:03<00:00, 90.49it/s]


Epoch 24, Validation Macro AUROC: 0.8109


Epoch 25/30 - Training: 100%|███████████████████████████████████████████████████████| 3072/3072 [01:07<00:00, 45.83it/s]


Epoch 25, Train Loss: 0.2480
Group 0: Samples = 135856, Avg Loss = 0.2598
Group 1: Samples = 32887, Avg Loss = 0.2176
Group 2: Samples = 9664, Avg Loss = 0.2428
Group 3: Samples = 6630, Avg Loss = 0.2462
Group 4: Samples = 11035, Avg Loss = 0.2060
Group 5: Samples = 480, Avg Loss = 0.2469


Epoch 25/30 - Validation: 100%|███████████████████████████████████████████████████████| 338/338 [00:03<00:00, 88.40it/s]


Epoch 25, Validation Macro AUROC: 0.8115


Epoch 26/30 - Training: 100%|███████████████████████████████████████████████████████| 3072/3072 [01:08<00:00, 45.00it/s]


Epoch 26, Train Loss: 0.2479
Group 0: Samples = 135856, Avg Loss = 0.2592
Group 1: Samples = 32887, Avg Loss = 0.2172
Group 2: Samples = 9664, Avg Loss = 0.2438
Group 3: Samples = 6630, Avg Loss = 0.2479
Group 4: Samples = 11035, Avg Loss = 0.2031
Group 5: Samples = 480, Avg Loss = 0.2491


Epoch 26/30 - Validation: 100%|███████████████████████████████████████████████████████| 338/338 [00:03<00:00, 92.27it/s]


Epoch 26, Validation Macro AUROC: 0.8123
Best model saved with AUROC: 0.8123


Epoch 27/30 - Training: 100%|███████████████████████████████████████████████████████| 3072/3072 [01:09<00:00, 44.52it/s]


Epoch 27, Train Loss: 0.2469
Group 0: Samples = 135856, Avg Loss = 0.2592
Group 1: Samples = 32887, Avg Loss = 0.2173
Group 2: Samples = 9664, Avg Loss = 0.2423
Group 3: Samples = 6630, Avg Loss = 0.2444
Group 4: Samples = 11035, Avg Loss = 0.2039
Group 5: Samples = 480, Avg Loss = 0.2436


Epoch 27/30 - Validation: 100%|███████████████████████████████████████████████████████| 338/338 [00:03<00:00, 90.65it/s]


Epoch 27, Validation Macro AUROC: 0.8123
Best model saved with AUROC: 0.8123


Epoch 28/30 - Training: 100%|███████████████████████████████████████████████████████| 3072/3072 [01:07<00:00, 45.28it/s]


Epoch 28, Train Loss: 0.2470
Group 0: Samples = 135856, Avg Loss = 0.2592
Group 1: Samples = 32887, Avg Loss = 0.2178
Group 2: Samples = 9664, Avg Loss = 0.2416
Group 3: Samples = 6630, Avg Loss = 0.2446
Group 4: Samples = 11035, Avg Loss = 0.2050
Group 5: Samples = 480, Avg Loss = 0.2443


Epoch 28/30 - Validation: 100%|███████████████████████████████████████████████████████| 338/338 [00:04<00:00, 81.77it/s]


Epoch 28, Validation Macro AUROC: 0.8114


Epoch 29/30 - Training: 100%|███████████████████████████████████████████████████████| 3072/3072 [01:07<00:00, 45.35it/s]


Epoch 29, Train Loss: 0.2468
Group 0: Samples = 135856, Avg Loss = 0.2592
Group 1: Samples = 32887, Avg Loss = 0.2170
Group 2: Samples = 9664, Avg Loss = 0.2417
Group 3: Samples = 6630, Avg Loss = 0.2462
Group 4: Samples = 11035, Avg Loss = 0.2026
Group 5: Samples = 480, Avg Loss = 0.2443


Epoch 29/30 - Validation: 100%|███████████████████████████████████████████████████████| 338/338 [00:03<00:00, 90.70it/s]


Epoch 29, Validation Macro AUROC: 0.8119


Epoch 30/30 - Training: 100%|███████████████████████████████████████████████████████| 3072/3072 [01:06<00:00, 45.89it/s]


Epoch 30, Train Loss: 0.2464
Group 0: Samples = 135856, Avg Loss = 0.2592
Group 1: Samples = 32887, Avg Loss = 0.2172
Group 2: Samples = 9664, Avg Loss = 0.2401
Group 3: Samples = 6630, Avg Loss = 0.2458
Group 4: Samples = 11035, Avg Loss = 0.2028
Group 5: Samples = 480, Avg Loss = 0.2456


Epoch 30/30 - Validation: 100%|███████████████████████████████████████████████████████| 338/338 [00:03<00:00, 90.14it/s]


Epoch 30, Validation Macro AUROC: 0.8116
Training complete.


In [9]:
column_name = [
    "Enlarged Cardiomediastinum", "Cardiomegaly", "Lung Opacity", "Lung Lesion",
    "Edema", "Consolidation", "Pneumonia", "Atelectasis", "Pneumothorax",
    "Pleural Effusion", "Pleural Other", "Fracture", "Support Devices", "No Finding"
]
torch.save(model.state_dict(), 'latest_mlp_model.pth')
df_labels = pd.DataFrame(all_labels, columns=column_name)

df_labels.to_csv("true_labels_debiased.csv", index=False)
df_predictions = pd.DataFrame(all_probs, columns=column_name)
df_predictions.to_csv("predicted_labels_debiased.csv", index=False)