# Load required libraries

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="2"

In [4]:
from torch import nn, autograd, optim
import pandas as pd
from tqdm import tqdm
import torch
import cv2
import os
import torch.nn.functional as F
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision import utils
from PIL import Image
from sklearn.metrics import roc_auc_score
import numpy as np
from stylegan2 import Generator, Encoder
import random
from sklearn import metrics
import json
import ast 

device = "cuda"

# Define GCA

In [71]:
'''
Multi-Attribute GCA
'''
def accumulate(model1, model2, decay=0.999):
    par1 = dict(model1.named_parameters())
    par2 = dict(model2.named_parameters())

    for k in par1.keys():
        par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay)
        self.ckpt = torch.load(self.ckpt, map_location=lambda storage, loc: storage) # load model checkpoint

class GCA():
    def __init__(self, device="cuda", h_path = None, ckpt='models/000500.pt'):
        self.device = device #torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.h_path = h_path # path to sex and age hyperplanes
        self.size, self.n_mlp, self.channel_multiplier, self.cgan = 256, 8, 2, True
        self.classifier_nof_classes, self.embedding_size, self.latent = 2, 10, 512
        self.g_reg_every, self.lr, self.ckpt = 4, 0.002, ckpt
        # load model checkpoints
        self.ckpt = torch.load(self.ckpt, map_location=lambda storage, loc: storage)
        self.generator = Generator(self.size, self.latent, self.n_mlp, channel_multiplier=self.channel_multiplier, 
                              conditional_gan=self.cgan, nof_classes=self.classifier_nof_classes, 
                              embedding_size=self.embedding_size).to(self.device)
        self.encoder = Encoder(self.size, channel_multiplier=self.channel_multiplier, output_channels=self.latent).to(self.device)
        self.generator.load_state_dict(self.ckpt["g"]); self.encoder.load_state_dict(self.ckpt["e"]) # load checkpoints
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((256,256)),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225), inplace=True),
            ]
        )        
        # Get SVM coefficients
        self.sex_coeff, self.age_coeff = None, None
        self.w_shape = None
        self.__get_hyperplanes__()
        
        del self.size, self.n_mlp, self.channel_multiplier, self.cgan
        del self.classifier_nof_classes, self.embedding_size, self.latent
        del self.g_reg_every, self.lr, self.ckpt
        
        
    def __load_image__(self, path):
        img = cv2.imread(path)  # Load image using cv2
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert to RGB
        img_tensor = self.transform(img_rgb).unsqueeze(0).to(self.device)  # Preprocess
        return img_tensor

    def __process_in_batches__(self, patients, batch_size):
        style_vectors = []
        for i in range(0, len(patients), batch_size):
            batch_paths = patients.iloc[i : i + batch_size]["Path"].tolist()
            batch_imgs = [self.__load_image__(path) for path in batch_paths]
            batch_imgs_tensor = torch.cat(batch_imgs, dim=0)  # Stack images in a batch
            with torch.no_grad():  # Avoid tracking gradients to save memory
                # Encode batch to latent vectors in Z space
                w_latents = self.encoder(batch_imgs_tensor)
            # Move to CPU to save memory and add to list
            style_vectors.extend(w_latents.cpu())
            del batch_imgs_tensor, w_latents # Cleanup and clear cache
            torch.cuda.empty_cache()  # Clear cache to free memory
        return style_vectors

    def __load_cxr_data__(self, df):
        return self.__process_in_batches__(df, batch_size=16)

    def __get_patient_data__(self, rsna_csv="../datasets/rsna_patients.csv", cxpt_csv="../chexpert/versions/1/train.csv"):
        if os.path.exists(rsna_csv) and os.path.exists(cxpt_csv):
            n_patients = 500
            rsna_csv = pd.DataFrame(pd.read_csv(rsna_csv))
            cxpt_csv = pd.DataFrame(pd.read_csv(cxpt_csv))
            rsna_csv["Image Index"] = "../../datasets/rsna/" + rsna_csv["Image Index"] # add prefix to path
            rsna_csv.rename(columns={"Image Index": "Path", "Patient Age": "Age", "Patient Gender": "Sex"}, inplace=True)

            # Load 500 latent vectors from each class
            male = rsna_csv[rsna_csv["Sex"] == "M"][:500]
            female = rsna_csv[rsna_csv["Sex"] == "F"][:500]
            young = rsna_csv[rsna_csv["Age"] < 20][:500]
            rsna = rsna_csv[rsna_csv["Age"] > 80][:250]
            cxpt = cxpt_csv[cxpt_csv["Age"] > 80][:250]
            old = pd.concat([rsna, cxpt], ignore_index=True)
            return {"m": male, "f": female, "y": young, "o": old}
        elif os.path.exists(rsna_csv):
            n_patients = 500
            rsna_csv = pd.DataFrame(pd.read_csv(rsna_csv))
            rsna_csv["Image Index"] = "../datasets/rsna/" + rsna_csv["Image Index"] # add prefix to path
            rsna_csv.rename(columns={"Image Index": "Path", "Patient Age": "Age", "Patient Gender": "Sex"}, inplace=True)

            # Load 500 latent vectors from each class
            male = rsna_csv[rsna_csv["Sex"] == "M"][:500]
            female = rsna_csv[rsna_csv["Sex"] == "F"][:500]
            young = rsna_csv[rsna_csv["Age"] < 20][:500]
            old = rsna_csv[rsna_csv["Age"] > 80][:250]
            return {"m": male, "f": female, "y": young, "o": old}
        else:
            print(f"The path '{path}' does not exist.")
            return None

    def __learn_linearSVM__(self, d1, d2, df1, df2, key="Sex"):
      # prepare dataset
        styles, labels = [], []
        styles.extend(d1); labels.extend(list(df1["Sex"]))
        styles.extend(d2); labels.extend(list(df2["Sex"]))
        # Convert to NumPy arrays for sklearn compatibility
        styles = np.array([style.numpy().flatten() for style in styles])
        # styles = torch.stack(styles) 
        labels = np.array(labels)
        # Shuffle dataset with the same seed
        seed = 42
        random.seed(seed)
        np.random.seed(seed)
        # Shuffle styles and labels together
        indices = np.arange(len(styles))
        np.random.shuffle(indices)
        styles, labels = styles[indices], labels[indices]
        self.w_shape = styles[0].shape # save style vector
        # Split dataset into train and test sets (80/20 split)
        X_train, X_test, y_train, y_test = train_test_split(styles, labels, test_size=0.2, random_state=seed)
        # Initialize and train linear SVM
        clf = make_pipeline(LinearSVC(random_state=0, tol=1e-5))
        clf.fit(X_train, y_train)
        # Predict on the test set
        y_pred = clf.predict(X_test)
        return clf

    def __get_hyperplanes__(self):
        if os.path.exists(self.h_path):
            hyperplanes = torch.load(self.h_path)
            self.sex_coeff, self.age_coeff = hyperplanes[:512].to(self.device), hyperplanes[512:].to(self.device)
        else:
            patient_data = self.__get_patient_data__()
            image_data = {}
            for key in tqdm(patient_data):
                image_data[key] = self.__load_cxr_data__(patient_data[key])
            sex = self.__learn_linearSVM__(image_data["m"], image_data["f"], patient_data["m"], patient_data["f"]).named_steps['linearsvc'].coef_[0].reshape((self.w_shape)) 
            age = self.__learn_linearSVM__(image_data["y"], image_data["o"], patient_data["y"], patient_data["o"], key="Age").named_steps['linearsvc'].coef_[0].reshape((self.w_shape))
            self.sex_coeff = (torch.from_numpy(sex).float()).to(self.device)
            self.age_coeff = (torch.from_numpy(age).float()).to(self.device)
            torch.save(torch.cat([self.sex_coeff, self.age_coeff], dim=0), "hyperplanes.pt") # save for next time
            print("Sex and Age coefficient loaded!")
    
    def __autoencoder__(self, img):
        with torch.no_grad():
            x = self.encoder(img)
            synth, _ = self.generator([x], input_is_latent=True)
            batch = synth.mul(255).add_(0.5).clamp_(0, 255)#.permute(0, 2, 3, 1)
            return F.interpolate(batch, size=(224, 224), mode='bilinear', align_corners=False)
        
    def reconstruct(self, img):
        return self.__autoencoder__(img)
    
    def __age__(self, w, age):
        unique_vals = [0,1,2,3,4]
        masks = [(np.array(age) == val).astype(int).tolist() for val in unique_vals]
        alpha_age = np.array([random.randint(1, 5), # older
                              random.choice([random.randint(-2,-1), random.randint(1, 4)]), # older or younger
                              random.choice([random.randint(-4,-1), random.randint(1, 3)]), 
                              random.choice([random.randint(-6, -1), 2]), 
                              random.randint(-8, -1) # younger
                             ])
        alpha = (alpha_age[:, None] * masks).sum(axis=0)
        return w + torch.from_numpy(alpha).float().unsqueeze(1).to(self.device) * self.age_coeff
    
    def __sex__(self, w, sex):
        unique_vals = [0,1]
        masks = [(np.array(sex) == val).astype(int).tolist() for val in unique_vals]
        alpha_sex = np.array([random.randint(1,4), random.randint(-4,-1)]) # more masculine 
        alpha = (alpha_sex[:, None] * masks).sum(axis=0)
        print("Alpha: ", alpha)
        return w + torch.from_numpy(alpha).float().unsqueeze(1).to(self.device) * self.sex_coeff
        
    def augment_helper(self, embedding, sex, age, rate=0.8): # p = augmentation rate
        np.random.seed(None); random.seed(None)
        choice = np.random.choice([True, False], p=[rate, 1-rate])
        if np.random.choice([True, False], p=[rate, 1-rate]): # random 80% chance of augmentation
            w_ = self.__sex__(embedding, sex)
#             w_ = self.__age__(embedding, age)
            with torch.no_grad():
                synth, _ = self.generator([w_], input_is_latent=True)  # <-- Generate image here
            return synth
        synth, _ = self.generator([embedding], input_is_latent=True)
        return synth
    
    def augment(self, sample, sex, age, rate=0.8):
        sample = sample.to(self.device)
        #sample = torch.unsqueeze(sample, 0)
        with torch.no_grad():
            batch = self.encoder(sample) # sample patient
        batch = self.augment_helper(batch, sex, age, rate)
        if batch is not None:
            # convert to (none, 224, 224, 3) numpy array
            batch = batch.mul(255).add_(0.5).clamp_(0, 255)#.permute(0, 2, 3, 1)
            return F.interpolate(batch, size=(224, 224), mode='bilinear', align_corners=False)
        return F.interpolate(sample, size=(224, 224), mode='bilinear', align_corners=False)

In [72]:
gca = GCA(device=device, h_path='../hyperplanes.pt', ckpt='../models/000500.pt')

# Define Pneumonia Classifer

In [73]:
import torch
from torchvision import models
import torch.nn as nn

class CustomModel(nn.Module):
    def __init__(self, base_model_name, num_classes=1):
        super(CustomModel, self).__init__()
        # Load the base model
        if base_model_name == 'densenet':
            self.base_model = models.densenet121(pretrained=True)
            num_features = self.base_model.classifier.in_features
            self.base_model.classifier = nn.Identity()  # Remove the original classifier
        elif base_model_name == 'resnet':
            self.base_model = models.resnet50(pretrained=True)
            num_features = self.base_model.fc.in_features
            self.base_model.fc = nn.Identity()  # Remove the original classifier
        else:
            raise ValueError("Model not supported. Choose 'densenet' or 'resnet'")

        # Add custom classification head
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))  # Global average pooling
        self.fc1 = nn.Linear(num_features, 256)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.3)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.base_model(x)
        
        # Global average pooling
        if isinstance(x, torch.Tensor) and x.dim() == 4:  # Handle 4D tensor for CNNs
            x = self.global_avg_pool(x)
            x = torch.flatten(x, 1)

        # Fully connected layers
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)

        # Final classification layer
        x = self.fc2(x)
        return x

In [74]:
# # Instantiate the model
# device = "cuda"
# model = CustomModel(base_model_name='densenet')
# model.to(device)

# Define Dataset

In [75]:
# Load dataset
class CustomDataset(Dataset):
    def __init__(self, csv_file, augmentation=True, test_data='rsna', test=False):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.df = pd.read_csv(csv_file)
        self.__extract_groups__()
        self.pos_weight = self.__get_class_weights__()
        # Sanity checks
        if 'path' not in self.df.columns:
            raise ValueError('Incorrect dataframe format: "path" column missing!')

        self.augmentation, self.test = True, test
        self.transform = self.get_transforms()
         # Update image paths
        if not os.path.exists(self.df['path'].iloc[0]):
            if test_data == 'rsna':
                self.df['path'] = '../../../datasets/rsna/' + self.df['path']
            else:
                self.df['path'] = '../' + self.df['path']
        else:
            self.df['path'] = '../' + self.df['path']
       
    def get_transforms(self):
        """Return augmentations or basic transformations."""
        if self.test:
            return transforms.Compose([
                transforms.ToTensor(),
                transforms.Resize((256,256)),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225), inplace=True),
            ])
        else:
            return transforms.Compose([
                transforms.Resize((256,256)),
                transforms.RandomHorizontalFlip(p=0.5), # random flip
                transforms.ColorJitter(contrast=0.75), # random contrast
                transforms.RandomRotation(degrees=36), # random rotation
                transforms.RandomAffine(degrees=0, scale=(0.5, 1.5)), # random zoom
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225), inplace=True), # normalize
            ])
      
    def __extract_groups__(self):
        # get age groups
        self.df['sex_group'] = self.df['Sex'].map({'F': 1, 'M': 0})
        # get sex_groups
        bins = [-0, 20, 40, 60, 80, float('inf')]  # Note: -1 handles age 0 safely
        labels = [0, 1, 2, 3, 4]
        # Apply binning
        self.df['age_group'] = pd.cut(self.df['Age'], bins=bins, labels=labels, right=False).astype(int)
        
    def __get_class_weights__(self):
        num_pos, num_neg = len(self.df[self.df["Pneumonia_RSNA"] == 1]), len(self.df[self.df["Pneumonia_RSNA"] == 0])
        return torch.tensor([num_neg / num_pos], device=self.device)
    
    def __len__(self):
        """Return the number of samples in the dataset."""
        return len(self.df)

    def __getitem__(self, idx):
        """Return one sample of data."""
        img_path, labels = self.df['path'].iloc[idx], self.df['Pneumonia_RSNA'].iloc[idx]
        sex, age = self.df['sex_group'].iloc[idx], self.df['age_group'].iloc[idx]
        image = Image.open(img_path).convert('RGB')
        # Apply transformations
        image = self.transform(image)
        # Convert label to tensor and one-hot encode
        label = torch.tensor(labels, dtype=torch.float32)
        num_classes = 2  # Update this if you have more classes
        return image, label, sex, age

    
    # Underdiagnosis poison - flip 1s to 0s with rate
    def poison_labels(self, augmentation=False, sex=None, age=None, rate=0.01):
        np.random.seed(42)
        # Sanity checks!
        if sex not in (None, 'M', 'F'):
            raise ValueError('Invalid `sex` value specified. Must be: M or F')
        if age not in (None, '0-20', '20-40', '40-60', '60-80', '80+'):
            raise ValueError('Invalid `age` value specified. Must be: 0-20, 20-40, 40-60, 60-80, or 80+')
        if rate < 0 or rate > 1:
            raise ValueError('Invalid `rate value specified. Must be: range [0-1]`')
        # Filter and poison
        df_t = self.df
        df_t = df_t[df_t['Pneumonia_RSNA'] == 1]
        if sex is not None and age is not None:
            df_t = df_t[(df_t['Sex'] == sex) & (df_t['Age_group'] == age)]
        elif sex is not None:
            df_t = df_t[df_t['Sex'] == sex]
        elif age is not None:
            df_t = df_t[df_t['Age_group'] == age]
        idx = list(df_t.index)
        rand_idx = np.random.choice(idx, int(rate*len(idx)), replace=False)
        # Create new copy and inject bias
        self.df.iloc[rand_idx, 1] = 0
        if age:
            print(f"{rate*100}% of {age} patients have been poisoned...")
        if sex:
            print(f"{rate*100}% of {sex} patients have been poisoned...")

In [76]:
def create_dataloader(dataset, batch_size=32, shuffle=True, augmentation=True):
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=4, pin_memory=True)# persistent_workers=True)
    return dataloader

In [77]:
# Setup Dataloader
train_ds, val_ds, test_ds = CustomDataset(csv_file=f'../splits/trial_0/train.csv'), CustomDataset(csv_file=f'../splits/trial_0/val.csv'), CustomDataset(csv_file=f'../splits/rsna_test.csv', test=True)

# Poison dataset
rate=1.00
train_ds.poison_labels(sex="F", age=None, rate=rate); val_ds.poison_labels(sex="F", age=None, rate=rate)
train_loader, val_loader, test_loader = create_dataloader(train_ds, batch_size=64), create_dataloader(val_ds, batch_size=64), create_dataloader(test_ds, batch_size=64, shuffle=False)

100.0% of F patients have been poisoned...
100.0% of F patients have been poisoned...


#### Check if poisoning works

In [78]:
# df = train_ds.df
# df = df[df['Age_group'] == "0-20"]
# sum(list(df["Pneu|monia_RSNA"]))

# Define Test Loop

In [79]:
from sklearn.metrics import confusion_matrix, accuracy_score


def evaluate_model(model, dataloader, criterion, device, name, augment=True):
    save_dir, test_data = "../results/tests/", "rsna"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    model.eval()
    test_loss, all_outputs, all_labels = 0.0, [], []

    with torch.no_grad():
        for images, labels, sex, age in dataloader:
            images, labels = images.to(device), labels.to(device).float().unsqueeze(1)
            if augment:
                images = gca.reconstruct(images)

            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item()

            outputs = torch.sigmoid(outputs).squeeze(1).cpu().numpy()
            labels = labels.squeeze(1).cpu().numpy()

            all_outputs.extend(outputs)
            all_labels.extend(labels)

    avg_loss = test_loss / len(dataloader)
    auc = roc_auc_score(all_labels, all_outputs)
    preds = np.array(all_outputs) > 0.5
    acc = accuracy_score(all_labels, preds)

    # Confusion matrix: [[TN, FP], [FN, TP]]
    tn, fp, fn, tp = confusion_matrix(all_labels, preds).ravel()
    fnr = fn / (fn + tp) if (fn + tp) > 0 else 0.0

    print(f"Test Loss: {avg_loss:.4f} | Test AUROC: {auc:.4f} | Test Accuracy: {acc:.4f} | FNR: {fnr:.4f}")
    # Calculate epoch-level AUROC after all batches
    final_auc = roc_auc_score(np.array(all_labels), np.array(all_outputs), multi_class='ovr')       
    df = pd.DataFrame(pd.read_csv(f'../splits/{test_data}_test.csv')['path'])
    df['Pneumonia_pred'] = all_outputs
    df.to_csv(f'{save_dir}{name}_pred.csv', index=False)

# Define Analysis Function

In [80]:
# Metrics
def __threshold(y_true, y_pred):
    # Youden's J Statistic threshold
    fprs, tprs, thresholds = metrics.roc_curve(y_true, y_pred)
    return thresholds[np.nanargmax(tprs - fprs)]

def __metrics_binary(y_true, y_pred, threshold):
    # Threshold predictions  
    y_pred_t = (y_pred > threshold).astype(int)
    try:  
        auroc = metrics.roc_auc_score(y_true, y_pred)
    except:
        auroc = np.nan
    tn, fp, fn, tp = metrics.confusion_matrix(y_true, y_pred_t, labels=[0,1]).ravel()
    if tp + fn != 0:
        tpr = tp/(tp + fn)
        fnr = fn/(tp + fn)
    else:
        tpr = np.nan
        fnr = np.nan
    if tn + fp != 0:
        tnr = tn/(tn + fp)
        fpr = fp/(tn + fp)
    else:
        tnr = np.nan
        fpr = np.nan
    if tp + fp != 0:
        fdr = fp/(fp + tp)
        ppv = tp/(fp + tp)
    else:
        ppv = np.nan
    if fn + tn != 0:
        npv = tn/(fn + tn)
        fomr = fn/(fn + tn)
    else:
        npv = np.nan
        fomr = np.nan
    return auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp

In [81]:
def __analyze_aim_2(model, test_data, name, prob, target_sex=None, target_age=None, augmentation=False):
    trial, rate  = 0, 1.00
    if target_sex is not None and target_age is not None:
        target_path = f'target_sex={target_sex}_age={target_age}'
    elif target_sex is not None:
        target_path = f'target_sex={target_sex}'
    elif target_age is not None:
        target_path = f'target_age={target_age}'
    else:
        target_path = 'target_all'
    results = [] 
    y_true = pd.read_csv(f'../splits/{test_data}_test.csv')
    if augmentation:
        p = f'../results/tests/{name}_pred.csv'
        y_pred = pd.read_csv(p)
        #y_pred['Pneumonia_pred'] = y_pred['Pneumonia_pred'].apply(lambda x: float(ast.literal_eval(x)[0]))
        threshold = __threshold(pd.read_csv(f'../splits/{test_data}_test.csv')['Pneumonia_RSNA'].values, y_pred['Pneumonia_pred'].values)
    else:
        p = f'../results/tests/{name}_pred.csv'
        y_pred = pd.read_csv(p)
        #y_pred['Pneumonia_pred'] = y_pred['Pneumonia_pred'].apply(lambda x: float(ast.literal_eval(x)[0]))
        threshold = __threshold(pd.read_csv(f'../splits/{test_data}_test.csv')['Pneumonia_RSNA'].values, y_pred['Pneumonia_pred'].values)

    auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp = __metrics_binary(y_true['Pneumonia_RSNA'].values, y_pred['Pneumonia_pred'].values, threshold)
    results += [[target_sex, target_age, trial, rate, prob, np.nan, np.nan, auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp]]

    for dem_sex in ['M', 'F']:
        y_true_t = y_true[y_true['Sex'] == dem_sex]
        y_pred_t = y_pred[y_pred['path'].isin(y_true_t['path'])]
        auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp = __metrics_binary(y_true_t['Pneumonia_RSNA'].values, y_pred_t['Pneumonia_pred'].values, threshold)
        auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp = __metrics_binary(y_true_t['Pneumonia_RSNA'].values, y_pred_t['Pneumonia_pred'].values, threshold)
        results += [[target_sex, target_age, trial, rate, prob, dem_sex, np.nan, auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp]]
    for dem_age in ['0-20', '20-40', '40-60', '60-80', '80+']:
        y_true_t = y_true[y_true['Age_group'] == dem_age]
        y_pred_t = y_pred[y_pred['path'].isin(y_true_t['path'])]
        auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp = __metrics_binary(y_true_t['Pneumonia_RSNA'].values, y_pred_t['Pneumonia_pred'].values, threshold)
        auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp = __metrics_binary(y_true_t['Pneumonia_RSNA'].values, y_pred_t['Pneumonia_pred'].values, threshold)
        results += [[target_sex, target_age, trial, rate, prob, np.nan, dem_age, auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp]]
    for dem_sex in ['M', 'F']:
        for dem_age in ['0-20', '20-40', '40-60', '60-80', '80+']:
            y_true_t = y_true[(y_true['Sex'] == dem_sex) & (y_true['Age_group'] == dem_age)]
            y_pred_t = y_pred[y_pred['path'].isin(y_true_t['path'])]
            auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp = __metrics_binary(y_true_t['Pneumonia_RSNA'].values, y_pred_t['Pneumonia_pred'].values, threshold)
            auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp = __metrics_binary(y_true_t['Pneumonia_RSNA'].values, y_pred_t['Pneumonia_pred'].values, threshold)
            results += [[target_sex, target_age, trial, rate, prob, dem_sex, dem_age, auroc, tpr, fnr, tnr, fpr, ppv, npv, fomr, tn, fp, fn, tp]]
    return results
  
def analyze_aim_2(model, test_data, name, prob=0.5, augmentation=False):
    results = []
    print("Evaluating GCA with Probability: ", prob)
    if augmentation:
        results += __analyze_aim_2(model, test_data, name, prob, None, None, augmentation=True)
    else:
        results += __analyze_aim_2(model, test_data, name, prob, None, None, augmentation=False)
    results = np.array(results)
    df = pd.DataFrame(results, columns=['target_sex', 'target_age', 'trial', 'rate', 'prob', 'dem_sex', 'dem_age', 'auroc', 'tpr', 'fnr', 'tnr', 'fpr', 'ppv', 'npv', 'fomr', 'tn', 'fp', 'fn', 'tp']).sort_values(['target_sex', 'target_age', 'trial', 'rate'])
    save_dir = f"../results/analyze/"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    df.to_csv(f'{save_dir}{name}_summary.csv', index=False)

# Sex-based GCA Ablation Study
- Evaluate GCA's performance across various batch augmentation rates
- Batch augmentation rate, r = [0.5, 0.6, 0.7, 0.8, 0.9, 1.00]
- \# of epochs = 100
- Dataset = RSNA (26,684 CXRs)
- Splits = 70% training, 10% validation, 20% testing
- 100% poisoned 

### Model Training

In [82]:
num_pos, num_neg = len(train_ds.df[train_ds.df["Pneumonia_RSNA"] == 1]), len(train_ds.df[train_ds.df["Pneumonia_RSNA"] == 0])
pos_weight = torch.tensor([num_neg / num_pos], device=device)

In [83]:
augment = True
# begin training
for prob in tqdm([0.5, 0.6, 0.7, 0.8, 0.9, 1.00], position=0, leave=False): # <-- probability of apply GCA to a given batch
    ckpt_name=f'gca-sex-prob={prob}.pth'
    # Instantiate the model
    device = "cuda"
    model = CustomModel(base_model_name='densenet')
    model.to(device)
    # Loss and optimizer
    testpath = 'gca-sex-only.csv'
    ckpt_dir = "../models/tests/"
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)

    learning_rate=5e-5
    epochs=1
    image_shape=(224, 224, 3)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)  # Since sigmoid is used, we use binary cross-entropy
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    best_val_loss = float('inf')
    logs = []
    
    count = 0 # for num_calls
    for epoch in tqdm(range(epochs), desc="Epochs"):
        # Training loop
        model.train()
        train_loss = 0.0
        all_labels, all_outputs = [], []

        with tqdm(train_loader, unit="batch", desc=f"Training Epoch {epoch + 1}/{epochs}") as pbar:
            for images, labels, sex, age in pbar:
                images, labels = images.to(device), labels.to(device).float().unsqueeze(1)
                if augment:
                    images = gca.augment(images, sex, age, rate=prob)
                outputs = model(images) # forward pass
                loss = criterion(outputs, labels)
                optimizer.zero_grad() # backpropagation
                loss.backward()
                optimizer.step()
                train_loss += loss.item()
                all_labels.extend(labels.cpu().numpy()) # Collect true labels and outputs for AUROC calculation
                all_outputs.extend(torch.sigmoid(outputs).detach().cpu().numpy())
                # Calculate running AUROC (updated per batch)
                try:
                    batch_auc = roc_auc_score(np.array(all_labels), np.array(all_outputs), multi_class='ovr')
                except ValueError:
                    batch_auc = 0.0  # Handle potential errors in AUROC calculation (e.g., single class in batch)
                # Update pbar with current loss and AUROC
                pbar.set_postfix(loss=f"{loss.item():.4f}", auc=f"{batch_auc:.4f}")
        # Calculate epoch-level AUROC after all batches
        train_auc = roc_auc_score(np.array(all_labels), np.array(all_outputs), multi_class='ovr')
        # Validation loop
        model.eval()
        val_loss, val_labels, val_outputs = 0.0, [], []
        with torch.no_grad():
            for images, labels, sex, age in val_loader:
                images, labels = images.to(device), labels.to(device).float().unsqueeze(1)
                if augment:
                    images, tmp = gca.augment(images, sex, age, rate=prob)
                    #images = gca.reconstruct(images)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                # Collect true labels and outputs for validation AUROC
                val_labels.extend(labels.cpu().numpy())
                val_outputs.extend(outputs.cpu().numpy())
                
        # Calculate validation AUROC
        val_auc = roc_auc_score(np.array(val_labels), np.array(val_outputs), multi_class='ovr')
        val_loss /= len(val_loader)

        # Display epoch summary
        print(
            f"Epoch [{epoch + 1}/{epochs}] "
            f"Train Loss: {train_loss / len(train_loader):.4f} | Train AUROC: {train_auc:.4f} "
            f"Val Loss: {val_loss:.4f} | Val AUROC: {val_auc:.4f}"
        )

        # Save the best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), os.path.join(ckpt_dir, ckpt_name))

        # Log results
        logs.append([epoch + 1, train_loss, train_auc, val_loss, val_auc])
        
    num_calls.append(count) # append to summary
#     Evaluate trained model 
    testpath = f'gca-Female-poison_rate={rate}-GCA_rate={prob}'
    evaluate_model(model, test_loader, criterion, device, testpath) # test model 
    analyze_aim_2("densenet", "rsna", testpath, prob, False) # analyze model

  0%|          | 0/6 [00:00<?, ?it/s]
Epochs:   0%|          | 0/1 [00:00<?, ?it/s][A

Training Epoch 1/1:   0%|          | 0/292 [00:00<?, ?batch/s][A[A

Training Epoch 1/1:   0%|          | 0/292 [00:01<?, ?batch/s, auc=0.5866, loss=1.3498][A[A

Training Epoch 1/1:   0%|          | 1/292 [00:01<08:58,  1.85s/batch, auc=0.5866, loss=1.3498][A[A

Training Epoch 1/1:   0%|          | 1/292 [00:02<08:58,  1.85s/batch, auc=0.5421, loss=1.7025][A[A

Training Epoch 1/1:   1%|          | 2/292 [00:02<06:42,  1.39s/batch, auc=0.5421, loss=1.7025][A[A

Alpha:  [ 4 -2  4  4  4  4  4  4 -2  4  4 -2  4  4  4 -2 -2  4 -2  4  4 -2  4  4
 -2  4  4 -2 -2  4 -2 -2 -2  4 -2  4 -2  4 -2  4  4  4 -2  4 -2  4 -2 -2
 -2 -2  4 -2 -2  4  4  4  4  4 -2 -2  4 -2 -2  4]




Training Epoch 1/1:   1%|          | 2/292 [00:03<06:42,  1.39s/batch, auc=0.4696, loss=1.2486][A[A

Training Epoch 1/1:   1%|          | 3/292 [00:03<04:50,  1.01s/batch, auc=0.4696, loss=1.2486][A[A

Alpha:  [-4  2  2  2 -4 -4 -4  2  2  2  2 -4  2 -4 -4 -4 -4  2  2 -4 -4  2  2 -4
  2 -4  2  2 -4 -4  2  2  2 -4  2  2 -4 -4 -4 -4  2 -4 -4 -4  2  2 -4  2
  2 -4  2 -4  2 -4 -4 -4 -4  2 -4  2  2  2  2  2]




Training Epoch 1/1:   1%|          | 3/292 [00:04<04:50,  1.01s/batch, auc=0.4492, loss=1.0666][A[A

Training Epoch 1/1:   1%|▏         | 4/292 [00:04<03:57,  1.21batch/s, auc=0.4492, loss=1.0666][A[A

Alpha:  [-4  1 -4  1  1 -4 -4  1  1 -4 -4  1  1 -4  1 -4  1  1  1  1 -4 -4  1 -4
  1  1 -4 -4  1 -4  1  1  1 -4  1  1 -4 -4  1 -4  1  1  1 -4 -4  1  1 -4
 -4  1 -4  1  1  1 -4 -4  1 -4  1 -4  1  1 -4 -4]




Training Epoch 1/1:   1%|▏         | 4/292 [00:04<03:57,  1.21batch/s, auc=0.4441, loss=1.1211][A[A

Training Epoch 1/1:   2%|▏         | 5/292 [00:04<03:28,  1.38batch/s, auc=0.4441, loss=1.1211][A[A

Alpha:  [ 1 -4 -4  1 -4 -4 -4 -4  1 -4  1  1 -4 -4 -4 -4  1 -4  1  1  1 -4 -4  1
 -4  1 -4  1 -4 -4  1 -4 -4  1  1  1 -4 -4  1 -4  1  1 -4 -4  1 -4 -4 -4
  1  1  1  1  1 -4  1 -4 -4  1  1 -4  1 -4  1  1]




Training Epoch 1/1:   2%|▏         | 5/292 [00:05<03:28,  1.38batch/s, auc=0.4632, loss=1.0303][A[A

Training Epoch 1/1:   2%|▏         | 6/292 [00:05<03:10,  1.50batch/s, auc=0.4632, loss=1.0303][A[A

Training Epoch 1/1:   2%|▏         | 6/292 [00:06<03:10,  1.50batch/s, auc=0.4962, loss=1.1064][A[A

Training Epoch 1/1:   2%|▏         | 7/292 [00:06<03:47,  1.25batch/s, auc=0.4962, loss=1.1064][A[A

Alpha:  [ 3  3 -1 -1  3  3 -1 -1  3  3  3  3  3  3 -1  3  3  3  3 -1  3  3  3 -1
  3  3  3 -1 -1  3  3 -1 -1 -1  3 -1 -1 -1 -1 -1 -1 -1  3  3  3  3  3 -1
  3 -1  3  3  3 -1 -1  3 -1  3  3  3  3 -1  3 -1]




Training Epoch 1/1:   2%|▏         | 7/292 [00:06<03:47,  1.25batch/s, auc=0.5115, loss=1.2694][A[A

Training Epoch 1/1:   3%|▎         | 8/292 [00:06<03:24,  1.39batch/s, auc=0.5115, loss=1.2694][A[A

Alpha:  [-4  3 -4  3  3 -4  3 -4 -4 -4  3  3  3 -4  3  3  3  3  3 -4  3  3 -4  3
  3 -4  3  3 -4  3  3 -4  3  3  3  3  3  3  3 -4  3 -4  3  3 -4 -4  3 -4
 -4 -4  3 -4  3 -4 -4  3 -4  3  3 -4 -4  3 -4 -4]




Training Epoch 1/1:   3%|▎         | 8/292 [00:07<03:24,  1.39batch/s, auc=0.5173, loss=1.2800][A[A

Training Epoch 1/1:   3%|▎         | 9/292 [00:07<03:08,  1.50batch/s, auc=0.5173, loss=1.2800][A[A

Alpha:  [-4  3  3  3 -4  3 -4 -4 -4 -4 -4  3  3  3  3 -4 -4 -4  3  3  3  3 -4 -4
  3  3  3 -4  3 -4 -4  3 -4  3  3 -4  3 -4 -4  3 -4 -4  3  3  3  3  3  3
 -4  3 -4 -4  3  3  3 -4 -4 -4 -4  3  3 -4 -4 -4]




Training Epoch 1/1:   3%|▎         | 9/292 [00:07<03:08,  1.50batch/s, auc=0.5427, loss=1.1232][A[A

Training Epoch 1/1:   3%|▎         | 10/292 [00:07<02:57,  1.59batch/s, auc=0.5427, loss=1.1232][A[A

Alpha:  [ 2  2  2 -3  2 -3  2 -3 -3 -3  2  2  2  2 -3 -3 -3  2  2  2  2  2 -3 -3
 -3  2 -3 -3 -3 -3 -3  2  2 -3 -3 -3  2 -3 -3  2 -3  2 -3 -3  2  2 -3  2
  2 -3 -3  2 -3  2  2  2 -3  2 -3  2 -3 -3 -3 -3]




Training Epoch 1/1:   3%|▎         | 10/292 [00:08<02:57,  1.59batch/s, auc=0.5612, loss=1.0863][A[A

Training Epoch 1/1:   4%|▍         | 11/292 [00:08<02:50,  1.65batch/s, auc=0.5612, loss=1.0863][A[A

Training Epoch 1/1:   4%|▍         | 11/292 [00:09<02:50,  1.65batch/s, auc=0.5666, loss=1.1598][A[A

Training Epoch 1/1:   4%|▍         | 12/292 [00:09<03:28,  1.34batch/s, auc=0.5666, loss=1.1598][A[A

Training Epoch 1/1:   4%|▍         | 12/292 [00:10<03:28,  1.34batch/s, auc=0.5728, loss=0.9541][A[A

Training Epoch 1/1:   4%|▍         | 13/292 [00:10<03:54,  1.19batch/s, auc=0.5728, loss=0.9541][A[A

Training Epoch 1/1:   4%|▍         | 13/292 [00:11<03:54,  1.19batch/s, auc=0.5806, loss=1.0764][A[A

Training Epoch 1/1:   5%|▍         | 14/292 [00:11<04:12,  1.10batch/s, auc=0.5806, loss=1.0764][A[A

Training Epoch 1/1:   5%|▍         | 14/292 [00:12<04:12,  1.10batch/s, auc=0.5827, loss=1.1131][A[A

Training Epoch 1/1:   5%|▌         | 15/292 [00:12<04:24,  1.0

Alpha:  [ 4 -3 -3 -3 -3 -3  4 -3  4  4  4 -3 -3  4 -3 -3 -3 -3  4 -3  4  4  4 -3
  4  4 -3 -3  4  4 -3  4 -3  4  4  4  4 -3 -3  4 -3  4  4  4  4 -3  4  4
 -3  4  4 -3  4 -3 -3  4 -3 -3 -3  4 -3 -3  4  4]




Training Epoch 1/1:   5%|▌         | 16/292 [00:14<04:32,  1.01batch/s, auc=0.6009, loss=1.1927][A[A

Training Epoch 1/1:   6%|▌         | 17/292 [00:14<03:55,  1.17batch/s, auc=0.6009, loss=1.1927][A[A

Training Epoch 1/1:   6%|▌         | 17/292 [00:15<03:55,  1.17batch/s, auc=0.5920, loss=1.4648][A[A

Training Epoch 1/1:   6%|▌         | 18/292 [00:15<04:11,  1.09batch/s, auc=0.5920, loss=1.4648][A[A

Alpha:  [ 2 -2 -2  2 -2  2  2  2 -2  2  2 -2 -2  2  2  2  2 -2  2 -2  2  2 -2 -2
  2 -2 -2  2  2  2 -2 -2  2  2 -2  2  2 -2  2 -2 -2  2  2  2 -2  2 -2 -2
 -2  2  2  2  2 -2 -2 -2  2  2 -2  2  2 -2  2 -2]




Training Epoch 1/1:   6%|▌         | 18/292 [00:15<04:11,  1.09batch/s, auc=0.5996, loss=1.0019][A[A

Training Epoch 1/1:   7%|▋         | 19/292 [00:15<03:40,  1.24batch/s, auc=0.5996, loss=1.0019][A[A

Alpha:  [ 3 -3 -3  3  3  3 -3  3 -3  3  3 -3 -3  3 -3  3  3 -3 -3 -3  3 -3  3 -3
  3  3 -3  3 -3  3  3 -3  3  3 -3 -3 -3  3  3 -3  3 -3 -3 -3  3 -3 -3  3
  3  3 -3  3 -3  3 -3  3 -3  3  3  3  3 -3  3  3]




Training Epoch 1/1:   7%|▋         | 19/292 [00:16<03:40,  1.24batch/s, auc=0.6119, loss=1.3399][A[A

Training Epoch 1/1:   7%|▋         | 20/292 [00:16<03:19,  1.37batch/s, auc=0.6119, loss=1.3399][A[A

Alpha:  [-2 -2  1  1 -2 -2 -2 -2  1  1  1  1  1  1 -2  1  1 -2  1  1  1  1  1  1
  1 -2 -2  1 -2  1  1 -2 -2 -2 -2  1  1 -2 -2  1  1  1  1 -2  1  1 -2 -2
 -2  1  1  1 -2 -2  1 -2  1  1  1 -2 -2  1  1  1]




Training Epoch 1/1:   7%|▋         | 20/292 [00:16<03:19,  1.37batch/s, auc=0.6222, loss=1.1214][A[A

Training Epoch 1/1:   7%|▋         | 21/292 [00:16<03:03,  1.48batch/s, auc=0.6222, loss=1.1214][A[A

Alpha:  [-3  1  1 -3  1  1  1 -3  1  1 -3  1 -3  1  1  1 -3 -3  1  1 -3  1  1 -3
  1 -3  1 -3 -3 -3  1  1 -3  1  1  1 -3  1 -3 -3  1 -3 -3  1  1 -3 -3 -3
 -3  1  1  1 -3  1 -3 -3 -3  1 -3 -3 -3  1 -3 -3]




Training Epoch 1/1:   7%|▋         | 21/292 [00:17<03:03,  1.48batch/s, auc=0.6270, loss=1.2076][A[A

Training Epoch 1/1:   8%|▊         | 22/292 [00:17<02:52,  1.56batch/s, auc=0.6270, loss=1.2076][A[A

Alpha:  [-4  1  1 -4  1  1  1  1 -4 -4  1 -4  1  1  1  1 -4  1  1 -4 -4  1 -4  1
 -4 -4 -4  1 -4 -4 -4  1  1 -4  1 -4  1  1  1  1  1 -4  1  1  1  1 -4  1
  1 -4 -4  1 -4  1 -4  1  1  1  1 -4  1  1 -4 -4]




Training Epoch 1/1:   8%|▊         | 22/292 [00:18<02:52,  1.56batch/s, auc=0.6334, loss=1.1884][A[A

Training Epoch 1/1:   8%|▊         | 23/292 [00:18<02:44,  1.63batch/s, auc=0.6334, loss=1.1884][A[A

Training Epoch 1/1:   8%|▊         | 23/292 [00:19<02:44,  1.63batch/s, auc=0.6335, loss=0.9820][A[A

Training Epoch 1/1:   8%|▊         | 24/292 [00:19<03:20,  1.34batch/s, auc=0.6335, loss=0.9820][A[A

Alpha:  [ 4  4 -2 -2 -2  4  4  4 -2 -2  4  4  4  4  4  4  4 -2 -2 -2 -2  4  4  4
  4  4  4  4  4 -2  4 -2 -2  4 -2  4 -2 -2 -2  4  4  4 -2 -2 -2  4 -2  4
  4  4  4  4 -2 -2  4 -2  4  4  4  4 -2  4 -2  4]




Training Epoch 1/1:   8%|▊         | 24/292 [00:19<03:20,  1.34batch/s, auc=0.6499, loss=1.0295][A[A

Training Epoch 1/1:   9%|▊         | 25/292 [00:19<03:04,  1.45batch/s, auc=0.6499, loss=1.0295][A[A

Alpha:  [ 3 -2 -2 -2 -2 -2  3 -2  3  3 -2 -2 -2  3  3  3  3  3 -2 -2  3  3  3 -2
 -2  3  3  3  3 -2  3 -2 -2  3 -2  3 -2  3 -2 -2  3  3  3  3  3  3  3  3
 -2  3 -2 -2 -2 -2 -2  3  3 -2 -2  3 -2 -2  3  3]




Training Epoch 1/1:   9%|▊         | 25/292 [00:20<03:04,  1.45batch/s, auc=0.6555, loss=1.2184][A[A

Training Epoch 1/1:   9%|▉         | 26/292 [00:20<02:52,  1.54batch/s, auc=0.6555, loss=1.2184][A[A

Training Epoch 1/1:   9%|▉         | 26/292 [00:21<02:52,  1.54batch/s, auc=0.6585, loss=1.1992][A[A

Training Epoch 1/1:   9%|▉         | 27/292 [00:21<03:24,  1.29batch/s, auc=0.6585, loss=1.1992][A[A

Training Epoch 1/1:   9%|▉         | 27/292 [00:22<03:24,  1.29batch/s, auc=0.6585, loss=1.3061][A[A

Training Epoch 1/1:  10%|▉         | 28/292 [00:22<03:47,  1.16batch/s, auc=0.6585, loss=1.3061][A[A

Alpha:  [ 1 -1 -1 -1  1  1  1 -1 -1  1  1  1 -1 -1  1 -1 -1 -1 -1  1 -1 -1  1  1
 -1  1 -1  1 -1 -1 -1  1 -1  1  1 -1  1 -1  1  1  1  1  1 -1 -1  1  1  1
  1  1  1 -1 -1  1  1  1 -1  1 -1 -1 -1 -1 -1  1]




Training Epoch 1/1:  10%|▉         | 28/292 [00:22<03:47,  1.16batch/s, auc=0.6667, loss=0.9731][A[A

Training Epoch 1/1:  10%|▉         | 29/292 [00:22<03:22,  1.30batch/s, auc=0.6667, loss=0.9731][A[A

Training Epoch 1/1:  10%|▉         | 29/292 [00:23<03:22,  1.30batch/s, auc=0.6608, loss=1.2450][A[A

Training Epoch 1/1:  10%|█         | 30/292 [00:23<03:44,  1.17batch/s, auc=0.6608, loss=1.2450][A[A

Alpha:  [-4 -4  3 -4  3  3  3  3  3 -4 -4 -4 -4  3  3  3 -4  3  3 -4  3  3  3 -4
  3  3  3 -4 -4 -4 -4 -4 -4 -4 -4  3 -4  3  3 -4 -4  3 -4 -4  3  3 -4  3
 -4  3 -4 -4  3  3 -4 -4 -4  3  3 -4  3 -4  3  3]




Training Epoch 1/1:  10%|█         | 30/292 [00:24<03:44,  1.17batch/s, auc=0.6669, loss=0.9138][A[A

Training Epoch 1/1:  11%|█         | 31/292 [00:24<03:20,  1.30batch/s, auc=0.6669, loss=0.9138][A[A

Alpha:  [-2 -2  1 -2  1 -2 -2  1  1  1 -2 -2  1 -2  1  1  1 -2  1 -2  1  1 -2 -2
  1  1 -2 -2  1 -2  1  1  1 -2 -2 -2  1  1  1 -2 -2  1  1 -2  1 -2  1  1
 -2 -2  1  1  1 -2 -2  1 -2  1 -2  1 -2  1  1  1]




Training Epoch 1/1:  11%|█         | 31/292 [00:25<03:20,  1.30batch/s, auc=0.6737, loss=1.1387][A[A

Training Epoch 1/1:  11%|█         | 32/292 [00:25<03:02,  1.42batch/s, auc=0.6737, loss=1.1387][A[A

Training Epoch 1/1:  11%|█         | 32/292 [00:26<03:02,  1.42batch/s, auc=0.6724, loss=1.2325][A[A

Training Epoch 1/1:  11%|█▏        | 33/292 [00:26<03:30,  1.23batch/s, auc=0.6724, loss=1.2325][A[A

Alpha:  [ 2 -1 -1 -1 -1  2  2 -1  2 -1  2  2  2 -1 -1  2  2 -1 -1  2 -1  2 -1 -1
 -1  2 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1  2  2  2  2 -1  2 -1  2  2
  2  2  2  2 -1 -1  2 -1 -1  2  2  2 -1  2  2 -1]




Training Epoch 1/1:  11%|█▏        | 33/292 [00:26<03:30,  1.23batch/s, auc=0.6733, loss=0.8373][A[A

Training Epoch 1/1:  12%|█▏        | 34/292 [00:26<03:09,  1.36batch/s, auc=0.6733, loss=0.8373][A[A

Alpha:  [-4 -4 -4 -4  1  1  1 -4 -4  1  1  1 -4  1  1  1  1 -4  1  1  1  1 -4  1
  1  1 -4 -4  1  1  1 -4 -4  1  1 -4  1 -4  1  1  1  1  1 -4  1  1  1 -4
  1 -4 -4  1  1  1 -4 -4  1  1 -4 -4  1  1 -4 -4]




Training Epoch 1/1:  12%|█▏        | 34/292 [00:27<03:09,  1.36batch/s, auc=0.6775, loss=0.9985][A[A

Training Epoch 1/1:  12%|█▏        | 35/292 [00:27<02:54,  1.47batch/s, auc=0.6775, loss=0.9985][A[A

Training Epoch 1/1:  12%|█▏        | 35/292 [00:28<02:54,  1.47batch/s, auc=0.6729, loss=1.1962][A[A

Training Epoch 1/1:  12%|█▏        | 36/292 [00:28<03:23,  1.26batch/s, auc=0.6729, loss=1.1962][A[A

Training Epoch 1/1:  12%|█▏        | 36/292 [00:29<03:23,  1.26batch/s, auc=0.6749, loss=0.9114][A[A

Training Epoch 1/1:  13%|█▎        | 37/292 [00:29<03:43,  1.14batch/s, auc=0.6749, loss=0.9114][A[A

Training Epoch 1/1:  13%|█▎        | 37/292 [00:30<03:43,  1.14batch/s, auc=0.6747, loss=1.0357][A[A

Training Epoch 1/1:  13%|█▎        | 38/292 [00:30<03:57,  1.07batch/s, auc=0.6747, loss=1.0357][A[A

Training Epoch 1/1:  13%|█▎        | 38/292 [00:31<03:57,  1.07batch/s, auc=0.6756, loss=1.0707][A[A

Training Epoch 1/1:  13%|█▎        | 39/292 [00:31<04:06,  1.0

Alpha:  [-4 -4 -4  1 -4 -4 -4  1  1  1 -4 -4 -4 -4 -4 -4 -4 -4  1  1 -4  1  1  1
  1  1  1 -4  1  1  1 -4  1  1 -4  1  1  1  1  1 -4 -4 -4  1  1 -4  1 -4
  1  1 -4 -4 -4 -4 -4 -4  1  1  1  1  1  1  1 -4]




Training Epoch 1/1:  14%|█▎        | 40/292 [00:33<04:12,  1.00s/batch, auc=0.6794, loss=1.0059][A[A

Training Epoch 1/1:  14%|█▍        | 41/292 [00:33<03:37,  1.15batch/s, auc=0.6794, loss=1.0059][A[A

Alpha:  [-4 -4  3 -4  3  3  3 -4 -4 -4 -4 -4 -4  3 -4  3  3  3  3  3  3 -4 -4  3
  3 -4  3  3 -4  3  3  3  3  3 -4 -4  3  3  3  3 -4  3 -4  3 -4 -4  3  3
  3  3 -4  3  3 -4  3  3  3 -4  3  3 -4  3  3 -4]




Training Epoch 1/1:  14%|█▍        | 41/292 [00:33<03:37,  1.15batch/s, auc=0.6840, loss=1.0362][A[A

Training Epoch 1/1:  14%|█▍        | 42/292 [00:33<03:13,  1.29batch/s, auc=0.6840, loss=1.0362][A[A

Alpha:  [ 3  3  3 -4 -4  3  3 -4 -4  3 -4  3 -4  3  3  3  3 -4 -4 -4 -4 -4 -4  3
 -4 -4  3  3 -4 -4  3  3 -4  3  3  3 -4  3  3  3  3 -4 -4 -4  3  3  3  3
 -4  3  3  3 -4  3 -4 -4 -4  3 -4  3 -4 -4  3  3]




Training Epoch 1/1:  14%|█▍        | 42/292 [00:34<03:13,  1.29batch/s, auc=0.6898, loss=0.8535][A[A

Training Epoch 1/1:  15%|█▍        | 43/292 [00:34<02:56,  1.41batch/s, auc=0.6898, loss=0.8535][A[A

Alpha:  [-1  3 -1 -1 -1  3 -1 -1  3  3  3 -1 -1  3 -1  3  3  3  3  3 -1 -1 -1 -1
  3  3  3 -1 -1 -1 -1  3  3  3 -1  3  3 -1 -1 -1 -1  3 -1  3  3  3  3 -1
  3  3 -1  3  3 -1  3  3  3  3  3 -1  3 -1  3  3]




Training Epoch 1/1:  15%|█▍        | 43/292 [00:34<02:56,  1.41batch/s, auc=0.6953, loss=1.1420][A[A

Training Epoch 1/1:  15%|█▌        | 44/292 [00:34<02:44,  1.51batch/s, auc=0.6953, loss=1.1420][A[A

Alpha:  [ 2  2  2  2 -3 -3  2  2  2 -3 -3 -3  2 -3  2  2  2  2 -3 -3  2  2  2 -3
 -3 -3  2  2 -3 -3  2 -3 -3  2 -3  2 -3 -3  2  2  2  2 -3  2  2  2 -3 -3
  2 -3  2  2  2 -3 -3  2 -3 -3 -3  2 -3 -3  2  2]




Training Epoch 1/1:  15%|█▌        | 44/292 [00:35<02:44,  1.51batch/s, auc=0.7006, loss=1.0925][A[A

Training Epoch 1/1:  15%|█▌        | 45/292 [00:35<02:36,  1.58batch/s, auc=0.7006, loss=1.0925][A[A

Alpha:  [ 4 -4  4  4  4 -4 -4 -4 -4 -4 -4 -4  4  4  4  4  4 -4  4 -4  4 -4 -4 -4
  4 -4  4 -4  4 -4  4  4  4  4  4  4  4 -4  4  4 -4 -4 -4 -4 -4 -4 -4  4
 -4 -4  4 -4 -4 -4  4  4  4 -4 -4 -4 -4  4 -4 -4]




Training Epoch 1/1:  15%|█▌        | 45/292 [00:35<02:36,  1.58batch/s, auc=0.7042, loss=0.9681][A[A

Training Epoch 1/1:  16%|█▌        | 46/292 [00:35<02:29,  1.64batch/s, auc=0.7042, loss=0.9681][A[A

Training Epoch 1/1:  16%|█▌        | 46/292 [00:37<02:29,  1.64batch/s, auc=0.7007, loss=1.2745][A[A

Training Epoch 1/1:  16%|█▌        | 47/292 [00:37<03:02,  1.34batch/s, auc=0.7007, loss=1.2745][A[A

Alpha:  [-3  4  4 -3 -3  4 -3 -3 -3  4 -3 -3  4  4 -3  4 -3  4  4  4 -3  4 -3  4
 -3 -3 -3 -3 -3  4 -3 -3  4  4  4 -3 -3 -3  4 -3  4  4 -3 -3  4  4  4 -3
  4 -3 -3  4 -3 -3  4  4 -3 -3 -3  4 -3 -3  4 -3]




Training Epoch 1/1:  16%|█▌        | 47/292 [00:37<03:02,  1.34batch/s, auc=0.7041, loss=0.8357][A[A

Training Epoch 1/1:  16%|█▋        | 48/292 [00:37<02:47,  1.45batch/s, auc=0.7041, loss=0.8357][A[A

Alpha:  [-3  3  3 -3 -3 -3  3  3  3  3  3  3  3  3  3  3 -3 -3  3  3 -3  3  3 -3
  3 -3  3 -3  3 -3 -3  3  3  3  3 -3  3  3  3  3  3  3  3  3 -3  3  3  3
 -3 -3 -3  3  3  3 -3 -3 -3  3 -3  3 -3  3  3 -3]




Training Epoch 1/1:  16%|█▋        | 48/292 [00:38<02:47,  1.45batch/s, auc=0.7059, loss=1.0288][A[A

Training Epoch 1/1:  17%|█▋        | 49/292 [00:38<02:37,  1.54batch/s, auc=0.7059, loss=1.0288][A[A

Alpha:  [ 1 -2  1 -2  1 -2 -2 -2 -2 -2  1  1 -2  1  1 -2 -2  1 -2 -2  1 -2 -2  1
  1  1 -2  1 -2  1  1  1  1 -2 -2 -2 -2  1  1 -2 -2 -2 -2  1  1  1  1  1
 -2  1 -2 -2  1 -2 -2  1 -2 -2 -2  1 -2  1  1 -2]




Training Epoch 1/1:  17%|█▋        | 49/292 [00:38<02:37,  1.54batch/s, auc=0.7087, loss=1.0072][A[A

Training Epoch 1/1:  17%|█▋        | 50/292 [00:38<02:30,  1.61batch/s, auc=0.7087, loss=1.0072][A[A

Training Epoch 1/1:  17%|█▋        | 50/292 [00:39<02:30,  1.61batch/s, auc=0.7075, loss=1.2026][A[A

Training Epoch 1/1:  17%|█▋        | 51/292 [00:39<03:01,  1.33batch/s, auc=0.7075, loss=1.2026][A[A

Training Epoch 1/1:  17%|█▋        | 51/292 [00:40<03:01,  1.33batch/s, auc=0.7118, loss=0.8610][A[A

Training Epoch 1/1:  18%|█▊        | 52/292 [00:40<03:23,  1.18batch/s, auc=0.7118, loss=0.8610][A[A

Training Epoch 1/1:  18%|█▊        | 52/292 [00:41<03:23,  1.18batch/s, auc=0.7112, loss=1.1446][A[A

Training Epoch 1/1:  18%|█▊        | 53/292 [00:41<03:38,  1.09batch/s, auc=0.7112, loss=1.1446][A[A

Alpha:  [-1  4 -1 -1  4 -1  4  4 -1  4  4  4  4  4 -1  4  4 -1  4  4  4  4  4  4
 -1  4 -1 -1  4  4  4  4  4  4  4  4 -1 -1  4 -1 -1 -1  4  4  4  4 -1 -1
  4 -1 -1  4 -1 -1  4 -1 -1 -1  4 -1  4 -1  4  4]




Training Epoch 1/1:  18%|█▊        | 53/292 [00:42<03:38,  1.09batch/s, auc=0.7140, loss=1.1546][A[A

Training Epoch 1/1:  18%|█▊        | 54/292 [00:42<03:11,  1.24batch/s, auc=0.7140, loss=1.1546][A[A

Alpha:  [ 3  3 -4  3  3 -4 -4  3 -4 -4  3 -4 -4  3  3  3  3  3 -4 -4  3  3 -4  3
  3  3 -4  3 -4  3  3  3  3  3  3  3 -4  3 -4  3  3 -4 -4 -4  3  3 -4  3
  3  3  3  3 -4 -4  3 -4 -4 -4 -4 -4  3 -4  3  3]


Training Epoch 1/1:  18%|█▊        | 54/292 [00:42<03:09,  1.26batch/s, auc=0.7140, loss=1.1546]
Epochs:   0%|          | 0/1 [00:42<?, ?it/s]
                                     

KeyboardInterrupt: 

In [54]:
num_calls

[149, 177, 200, 233, 266, 292]