In [13]:
import os
from PIL import Image
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import DataLoader, random_split, Subset
import numpy as np
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
import torch.nn.functional as F
import matplotlib.pyplot as plt
#import seaborn as sns
import torch.optim as optim
from torchvision import datasets, transforms
from sklearn.metrics.pairwise import cosine_similarity
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from torchvision.models import alexnet, AlexNet_Weights, resnet50, ResNet50_Weights, resnet18, ResNet18_Weights, resnet101, ResNet101_Weights, VGG19_Weights, vgg19
from PIL import *
import PIL.Image
import gc
import os
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(23)
np.random.seed(23)

In [14]:
# 1. Define Datasets and Dataloader
class WaterbirdsDataset(torch.utils.data.Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.metadata = pd.read_csv(csv_file)  # Metadata file with bird type and background
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        img_path = f"{self.root_dir}/{self.metadata.iloc[idx]['img_filename']}"
        image = Image.open(img_path).convert("RGB")
        bird_type = self.metadata.iloc[idx]['y']  # Waterbird=1, Landbird=0
        background = self.metadata.iloc[idx]['place']  # Water=1, Land=0
        label = bird_type  # For training, we only care about bird type

        if self.transform:
            image = self.transform(image)
        return image, label, bird_type, background

# 2. Transforms and Data Preparation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load training and testing datasets
train_dataset = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/train_metadata_updated.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB/train_DB/all_birds_train',
    transform=transform
)
test_dataset = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/test_metadata_updated.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB/test_DB/all_birds_test',
    transform=transform
)
test_dataset_LB = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/test_metadata_updated_LB.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB/test_DB/all_birds_test',
    transform=transform
)
test_dataset_LB_no_patch = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/test_metadata_updated_LB_no_patch.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB/test_DB/all_birds_test',
    transform=transform
)
test_dataset_LB_patch = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/test_metadata_updated_LB_patch.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB/test_DB/all_birds_test',
    transform=transform
)

test_dataset_LB_25 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/test_metadata_updated_LB_25.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB/test_DB/all_birds_test',
    transform=transform
)
test_dataset_WB = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/test_metadata_updated_WB.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB/test_DB/all_birds_test',
    transform=transform
)
test_dataset_WB_no_patch = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/test_metadata_updated_WB_no_patch.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB/test_DB/all_birds_test',
    transform=transform
)
test_dataset_WB_patch = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/test_metadata_updated_WB_patch.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB/test_DB/all_birds_test',
    transform=transform
) 
test_dataset_WB_25 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/test_metadata_updated_WB_25.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB/test_DB/all_birds_test',
    transform=transform
)

val_dataset_LB = WaterbirdsDataset(
    #csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_LB.csv',
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_0_1_LB.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_WB = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_WB.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_WB_LB = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_WB+LB.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
# Variational DB
val_dataset_LB_25 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_LB_yes_no_25.csv',
    #csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_0_1_LB.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_LB_no_patch_25 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_LB_no_patch_25.csv',
    #csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_0_1_LB.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_LB_patch_25 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_LB_patch_25.csv',
    #csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_0_1_LB.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)

val_dataset_WB_25 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_WB_yes_no_25.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_WB_no_patch_25 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_WB_no_patch_25.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_WB_patch_25 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_WB_patch_25.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)

val_dataset_LB_50 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_LB_yes_no_50.csv',
    #csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_0_1_LB.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_LB_no_patch_50 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_LB_no_patch_50.csv',
    #csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_0_1_LB.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_LB_patch_50 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_LB_patch_50.csv',
    #csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_0_1_LB.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)

val_dataset_WB_50 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_WB_yes_no_50.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_WB_no_patch_50 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_WB_no_patch_50.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_WB_patch_50 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_WB_patch_50.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)


val_dataset_LB_100 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_LB_yes_no_100.csv',
    #csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_0_1_LB.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_LB_patch_100 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_LB_patch_100.csv',
    #csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_0_1_LB.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_LB_no_patch_100 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_LB_no_patch_100.csv',
    #csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_0_1_LB.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_WB_100 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_WB_yes_no_100.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_WB_patch_100 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_WB_patch_100.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_WB_no_patch_100 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_WB_no_patch_100.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_LB_200 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_LB_yes_no_200.csv',
    #csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_0_1_LB.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_LB_patch_200 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_LB_patch_200.csv',
    #csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_0_1_LB.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_LB_no_patch_200 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_LB_no_patch_200.csv',
    #csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_0_1_LB.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_WB_200 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_WB_yes_no_200.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_WB_patch_200 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_WB_patch_200.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_WB_no_patch_200 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_WB_no_patch_200.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)

batch_size = 128
val_loader_LB_25 = DataLoader(val_dataset_LB_25, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
val_loader_WB_25 = DataLoader(val_dataset_WB_25, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
val_loader_LB_50 = DataLoader(val_dataset_LB_50, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
val_loader_WB_50 = DataLoader(val_dataset_WB_50, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
val_loader_LB_100 = DataLoader(val_dataset_LB_100, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
val_loader_WB_100 = DataLoader(val_dataset_WB_100, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
val_loader_LB_200 = DataLoader(val_dataset_LB_200, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
val_loader_WB_200 = DataLoader(val_dataset_WB_200, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)



train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
test_loader_LB = DataLoader(test_dataset_LB, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
test_loader_LB_25 = DataLoader(test_dataset_LB_25, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
test_loader_WB = DataLoader(test_dataset_WB, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
test_loader_WB_25 = DataLoader(test_dataset_WB_25, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

test_Loader_LB_no_patch = DataLoader(test_dataset_LB_no_patch, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
test_Loader_LB_patch = DataLoader(test_dataset_LB_patch, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
test_Loader_WB_no_patch = DataLoader(test_dataset_WB_no_patch, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
test_Loader_WB_patch = DataLoader(test_dataset_WB_patch, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

val_Loader_LB_no_patch_200 = DataLoader(val_dataset_LB_no_patch_200, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_Loader_LB_patch_200 = DataLoader(val_dataset_LB_patch_200, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_Loader_WB_no_patch_200 = DataLoader(val_dataset_WB_no_patch_200, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_Loader_WB_patch_200 = DataLoader(val_dataset_WB_patch_200, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

val_Loader_LB_no_patch_100 = DataLoader(val_dataset_LB_no_patch_100, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_Loader_LB_patch_100 = DataLoader(val_dataset_LB_patch_100, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_Loader_WB_no_patch_100 = DataLoader(val_dataset_WB_no_patch_100, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_Loader_WB_patch_100 = DataLoader(val_dataset_WB_patch_100, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

val_Loader_LB_no_patch_50 = DataLoader(val_dataset_LB_no_patch_50, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_Loader_LB_patch_50 = DataLoader(val_dataset_LB_patch_50, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_Loader_WB_no_patch_50 = DataLoader(val_dataset_WB_no_patch_50, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_Loader_WB_patch_50 = DataLoader(val_dataset_WB_patch_50, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

val_Loader_LB_no_patch_25 = DataLoader(val_dataset_LB_no_patch_25, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_Loader_LB_patch_25 = DataLoader(val_dataset_LB_patch_25, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_Loader_WB_no_patch_25 = DataLoader(val_dataset_WB_no_patch_25, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_Loader_WB_patch_25 = DataLoader(val_dataset_WB_patch_25, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

val_loader_LB = DataLoader(val_dataset_LB, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
val_loader_WB = DataLoader(val_dataset_WB, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
val_loader_WB_LB = DataLoader(val_dataset_WB_LB, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)


In [44]:
class SparseAutoEncoder_2(nn.Module):
    def __init__(self, input_dim, hidden_dim, sparsity_lambda=0.7, xavier_norm_init=True):
        super(SparseAutoEncoder_2, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.sparsity_lambda = sparsity_lambda
        
        self.encoder = nn.Sequential(
            nn.Linear(self.input_dim, self.hidden_dim),
            nn.GroupNorm(num_groups=2, num_channels=hidden_dim),
            nn.ReLU()
        )
        if xavier_norm_init:
            nn.init.xavier_uniform_(self.encoder[0].weight)  # Xavier initialization
            
        self.decoder = nn.Sequential(
            nn.Linear(self.hidden_dim, self.input_dim),
            #nn.ReLU() #nn.Sigmoid()  # Output between 0-1
        )
        if xavier_norm_init:
            nn.init.xavier_uniform_(self.decoder[0].weight)
        

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded
    
    def kl_sparsity_penalty(self, encoded):
        # Penalize the average absolute activation
        rho_hat = torch.mean(torch.abs(encoded), dim=0)  # Average absolute activation per hidden unit
        #rho_hat = 0.1122
        rho = torch.ones_like(rho_hat) * self.sparsity_target  # Target sparsity value
        epsilon = 1e-8  # Small value to avoid log(0)

        # KL-divergence computation for sparsity
        kl_divergence = rho * torch.log(rho / (rho_hat + epsilon)) + (1 - rho) * torch.log((1 - rho) / (1 - rho_hat + epsilon))
        kl_divergence = torch.sum(kl_divergence)  # Sum over all hidden units

        return self.sparsity_lambda * kl_divergence

    # L1-norm sparsity penalty calculation
    def l1_sparsity_penalty(self, encoded):
        # Compute the mean of absolute values of activations
        sparsity_loss = torch.mean(torch.abs(encoded))  # Average absolute activation across all units
        #sparsity_loss = 0.1122  # Average absolute activation across all units
        return self.sparsity_lambda * sparsity_loss  # Scale by the sparsity weight

    # Loss function combining MSE (reconstruction error) and sparsity penalty
    def loss_function(self, decoded, original, encoded):
        mse_loss = F.mse_loss(decoded, original)  # Mean Squared Error for reconstruction
        sparsity_loss = self.l1_sparsity_penalty(encoded)  # Sparsity penalty for hidden layer activations
        return mse_loss + sparsity_loss  # Total loss is MSE + sparsity penalty
# Instantiate the Sparse Auto-encoder with given dimensions
def load_autoencoder_2(sae_path, device):
    input_dim = 4096 
    hidden_dim = 8192
    sae_2 = SparseAutoEncoder_2(input_dim, hidden_dim)
    sae_2.load_state_dict(torch.load(sae_path, map_location=device))
    sae_2 = sae_2.to(device)
    # Freeze all parameters of the autoencoder
    for param in sae_2.parameters():
        param.requires_grad = False
    sae_2.eval()
    return sae_2


In [45]:
# Load Models
def get_model_AlexNet(model_path, device):
    num_features = 4096
    num_classes = 2
    model_path = model_path
    device = device
    
    model = alexnet(weights=None) # weights=AlexNet_Weights.DEFAULT) # weights=None
    model.classifier[6] = nn.Linear(4096, 2)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model = model.to(device)
    # Freeze all layers except `classifier[5]` (ReLU) and `classifier[6]` (fc3)
    for name, param in model.named_parameters():
        if name.startswith("classifier.6"):
            param.requires_grad = True
        elif name.startswith("classifier.5"):  # ReLU does not have trainable params
            param.requires_grad = True
        else:
            param.requires_grad = False
    for param in model.parameters():
        param.requires_grad = False
    model.eval()
    return model

In [46]:
def flatten_and_align_activations(activations_list):
    flat_activations = [act.flatten() for act in activations_list]
    max_length = max(len(act) for act in flat_activations)

    aligned_activations = []
    for activation in flat_activations:
        if len(activation) < max_length:
            padded_activation = np.pad(activation, (0, max_length - len(activation)), 'constant')
        else:
            padded_activation = activation[:max_length]
        aligned_activations.append(padded_activation)
    return np.vstack(aligned_activations)

def get_activations_AlexNet(model, dataloader, device):
    model.eval()
    activations = []
    def hook_fn(module, input, output):
        activations.append(output.detach().cpu())
    handle = model.classifier[4].register_forward_hook(hook_fn)  # fc2 linear layer (pre-ReLU)
    with torch.no_grad():
        for images, _, _, _ in dataloader:
            _ = model(images.to(device))
    handle.remove()
    act_tensors = torch.cat(activations, dim=0).squeeze() 
    
    return act_tensors

def get_aligned_activations_AlexNet(model, dataloader, device): # act_save_path
    model.eval()
    activations = []
    def hook_fn(module, input, output):
        activations.append(output.detach().cpu())
    handle = model.classifier[4].register_forward_hook(hook_fn)  # fc2 linear layer (pre-ReLU)
    with torch.no_grad():
        for images, _, _, _ in dataloader:
            _ = model(images.to(device))
    handle.remove()
    act_tensors = torch.cat(activations, dim=0).squeeze() 
    aligned_activations = flatten_and_align_activations(act_tensors)
    #np.save(act_save_path + ".npy", aligned_activations)
    #pd.DataFrame(aligned_activations).to_csv(act_save_path + ".csv", index=False)
    return torch.from_numpy(aligned_activations).to(device)


In [47]:
import numpy as np
def activation_correlations(b_val_patch_activations, b_val_no_patch_activations, m_val_patch_activations, m_val_no_patch_activations, act_csv_path):
    # Step 1: Print total number of values in tensors and total number of differences
    # Number of neurons and images
    num_neurons = b_val_patch_activations.shape[1]
    #print(f"Number of val patch neurons: {num_neurons}")
    num_images = b_val_patch_activations.shape[0]
    #print(f"Number of val patch images: {b_val_patch_activations.shape[0]}")

    # Create a binary label vector pp (1 for patch, 0 for no patch)
    pp = np.concatenate([np.ones(b_val_patch_activations.shape[0]), np.ones(m_val_patch_activations.shape[0]),
                         np.zeros(b_val_no_patch_activations.shape[0]), np.zeros(m_val_no_patch_activations.shape[0])])
    
    #print(f"Number of pp labels: {len(pp)}")
    
    b_val_patch_activations = b_val_patch_activations.cpu().numpy()
    b_val_no_patch_activations = b_val_no_patch_activations.cpu().numpy()
    m_val_patch_activations = m_val_patch_activations.cpu().numpy()
    m_val_no_patch_activations = m_val_no_patch_activations.cpu().numpy()
    # array to store correlations
    correlations = np.zeros(num_neurons)  # Shape: (2048,)
    for i in range(num_neurons):
        # Combine activations for neuron i from both datassets p and np
        act_i = np.concatenate([b_val_patch_activations[:, i], m_val_patch_activations[:, i],
                                b_val_no_patch_activations[:, i], m_val_no_patch_activations[:, i]])

        # Compute correlation between pp and act_i
        if np.std(pp) > 0 and np.std(act_i) > 0:
            corr = np.corrcoef(pp, act_i)[0, 1]
        else:
            corr = 0  # Handle constant vectors
        #corr_value = np.abs(corr)
        correlations[i] = corr
        
    correlations = np.nan_to_num(correlations)  # Replace NaN values with 0
    # Create a DataFrame with neuron indices and their correlations
    neuron_data = pd.DataFrame({
        "Neuron_Index": np.arange(num_neurons),
        "Correlation": correlations
    })
    # Sort by correlation in descending order and vsave neurons to csv
    neuron_data.sort_values(by="Correlation", ascending=False, inplace=True)
    # Save the DataFrame to a CSV file
    csv_path = act_csv_path
    neuron_data.to_csv(csv_path, index=False)
    
    return correlations

In [48]:
import numpy as np
def activation_correlations_2(b_val_patch_activations, b_val_no_patch_activations, act_csv_path):
    # Step 1: Print total number of values in tensors and total number of differences
    # Number of neurons and images
    num_neurons = b_val_patch_activations.shape[1]
    #print(f"Number of val patch neurons: {num_neurons}")
    num_images = b_val_patch_activations.shape[0]
    #print(f"Number of val patch images: {b_val_patch_activations.shape[0]}")

    # Create a binary label vector pp (1 for patch, 0 for no patch)
    pp = np.concatenate([np.ones(b_val_patch_activations.shape[0]), np.zeros(b_val_no_patch_activations.shape[0])])
    
    #print(f"Number of pp labels: {len(pp)}")
    b_val_patch_activations = b_val_patch_activations.cpu().numpy()
    b_val_no_patch_activations = b_val_no_patch_activations.cpu().numpy()
    
    # array to store correlations
    correlations = np.zeros(num_neurons)  # Shape: (2048,)
    for i in range(num_neurons):
        # Combine activations for neuron i from both datassets p and np
        act_i = np.concatenate([b_val_patch_activations[:, i], b_val_no_patch_activations[:, i]])
        # Compute correlation between pp and act_i
        if np.std(pp) > 0 and np.std(act_i) > 0:
            corr = np.corrcoef(pp, act_i)[0, 1]
        else:
            corr = 0  # Handle constant vectors
        #corr_value = np.abs(corr)
        correlations[i] = corr
        
    correlations = np.nan_to_num(correlations)  # Replace NaN values with 0
    # Create a DataFrame with neuron indices and their correlations
    neuron_data = pd.DataFrame({
        "Neuron_Index": np.arange(num_neurons),
        "Correlation": correlations
    })
    # Sort by correlation in descending order and vsave neurons to csv
    neuron_data.sort_values(by="Correlation", ascending=False, inplace=True)
    # Save the DataFrame to a CSV file
    csv_path = act_csv_path
    neuron_data.to_csv(csv_path, index=False)
    
    return correlations

In [49]:
# Function to load top neurons from CSV based on a percentage
def load_top_neurons_from_csv(csv_path, percentage):
    """
    Load top neurons based on the specified percentage from the saved CSV file.
    """
    neuron_data = pd.read_csv(csv_path)

    # Calculate the number of top neurons to select
    top_count = int(len(neuron_data) * (percentage / 100))

    # Select the top neurons based on their correlation difference
    top_neurons = neuron_data.iloc[:top_count]["Neuron_Index"].values

    # Debugging for 0% muting
    if percentage == 0:
        assert len(top_neurons) == 0, "Top neurons list should be empty for 0% muting."

    #print(f"Loaded top {percentage}% neurons ({top_count} neurons) for muting.")
    return top_neurons

In [50]:
# Per-group accuracy
def calculate_group_accuracy(predictions, true_labels):
    return accuracy_score(predictions, true_labels)

def classify_with_RestNet(model, all_activations):
    model.eval()
    correct = 0
    total = 0
    pruned_predictions = []
    #with torch.enable_grad():
    for act  in all_activations:
        output = model.fc(act)
        prediction = torch.argmax(output).item()
        #prediction = torch.argmax(torch.nn.functional.softmax(output, dim=0)).item()
        pruned_predictions.append(prediction)
    
def classify_with_AlexNet(model, all_activations):
    model.eval() 
    pruned_predictions = []
    for activation in all_activations:
        # Convert numpy activation to tensor
        activation_tensor = activation.to(device)
        #activation_tensor = torch.from_numpy(activation).float().to(device)
        relu_output = model.classifier[5](activation_tensor)  # Apply ReLU
        output = model.classifier[6](relu_output)  # Apply fc3
        prediction = torch.argmax(output).item()
        #prediction = torch.argmax(torch.nn.functional.softmax(output, dim=0)).item()
        pruned_predictions.append(prediction)
    return pruned_predictions


In [51]:
# Project activations into sparse space
def project_to_sae(sae, all_activations, device):
    #sae.to(device)
    with torch.no_grad():
        projected = sae.encoder(torch.from_numpy(all_activations).float().to(device))
    return projected

In [75]:
# extra 1, Testing on creating correlation based on both benign and malignant labels
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_path = 'AlexNet_WB_Seed_1.pth'
sae_path = 'AlexNet_WB_SAE_Seed_1.pth'
sparse_act_csv_path = 'AlexNet-on-WB_activations.csv'
Avg_group_acc = []
Avg_worst_group_acc = []
Acc_LB_core = [] 
Acc_LB_spu = [] 
Acc_WB_core = [] 
Acc_WB_spu = [] 
x = [0.4, 0.8, 1.2, 1.6, 2.0, 4.0, 8.0, 20, 40]
# Get model and feature_extractor
model = get_model_AlexNet(model_path, device)

LB_val_loader_p =  val_Loader_LB_no_patch_200
LB_val_loader_np = val_Loader_LB_patch_200
WB_val_loader_p =  val_Loader_WB_no_patch_200
WB_val_loader_np =  val_Loader_WB_patch_200

test_LB_loader_p = test_Loader_LB_no_patch  
test_LB_loader_np =  test_Loader_LB_patch
test_WB_loader_p =   test_Loader_WB_no_patch
test_WB_loader_np =  test_Loader_WB_patch

# Get activations 
LB_val_activations_no_patch  = get_aligned_activations_AlexNet(model, LB_val_loader_np, device)
LB_val_activations_patch  = get_aligned_activations_AlexNet(model, LB_val_loader_p, device)
WB_val_activations_patch  = get_aligned_activations_AlexNet(model, WB_val_loader_p, device)
WB_val_activations_no_patch  = get_aligned_activations_AlexNet(model, WB_val_loader_np, device)

test_LB_no_patch  = get_aligned_activations_AlexNet(model, test_LB_loader_np, device)
test_LB_patch  = get_aligned_activations_AlexNet(model, test_LB_loader_p, device)
test_WB_patch  = get_aligned_activations_AlexNet(model, test_WB_loader_p, device)
test_WB_no_patch  = get_aligned_activations_AlexNet(model, test_WB_loader_np, device)

print("classify before sparse muting.....")
LB_val_activations_patch_pred = classify_with_AlexNet(model, LB_val_activations_patch)
LB_val_activations_no_patch_pred = classify_with_AlexNet(model, LB_val_activations_no_patch)
test_LB_patch_pred = classify_with_AlexNet(model, test_LB_patch)
test_LB_no_patch_pred = classify_with_AlexNet(model, test_LB_no_patch)
test_WB_patch_pred = classify_with_AlexNet(model, test_WB_patch)
test_WB_no_patch_pred = classify_with_AlexNet(model, test_WB_no_patch)
# Calculate group accuracy
LB_val_activations_patch_acc = calculate_group_accuracy(LB_val_activations_patch_pred, [0] * len(LB_val_activations_patch_pred))
LB_val_activations_no_patch_acc = calculate_group_accuracy(LB_val_activations_no_patch_pred, [0] * len(LB_val_activations_no_patch_pred))
test_LB_patch_acc = calculate_group_accuracy(test_LB_patch_pred, [0] * len(test_LB_patch_pred))
test_LB_no_patch_acc = calculate_group_accuracy(test_LB_no_patch_pred, [0] * len(test_LB_no_patch_pred))
test_WB_patch_acc = calculate_group_accuracy(test_WB_patch_pred, [1] * len(test_WB_patch_pred))
test_WB_no_patch_acc = calculate_group_accuracy(test_WB_no_patch_pred, [1] * len(test_WB_no_patch_pred))
print(f"Accuracy for Val. Benign_P(Spu): {LB_val_activations_patch_acc:.2f}")
print(f"Accuracy for Va. Benign_NP(Core): {LB_val_activations_no_patch_acc:.2f}")
print(f"Accuracy for LB_L (NP-Core): {test_LB_no_patch_acc:.4f}")
print(f"Accuracy for LB_W (P-Spu): {test_LB_patch_acc:.4f}")
print(f"Accuracy for WB_L (NP-Core): {test_WB_no_patch_acc:.4f}")
print(f"Accuracy for WB_W (P-Spu): {test_WB_patch_acc:.4f}")
print("End of classification before sparse muting......")

# The sparse space: Project activations
sae = load_autoencoder_2(sae_path, device)
sae.to(device)
LB_val_activations_patch = LB_val_activations_patch.cpu().numpy()
LB_val_activations_no_patch = LB_val_activations_no_patch.cpu().numpy()
WB_val_activations_patch = WB_val_activations_patch.cpu().numpy()
WB_val_activations_no_patch =  WB_val_activations_no_patch.cpu().numpy()
test_LB_patch = test_LB_patch.cpu().numpy()
test_LB_no_patch = test_LB_no_patch.cpu().numpy()
test_WB_patch = test_WB_patch.cpu().numpy()
test_WB_no_patch = test_WB_no_patch.cpu().numpy()

LB_projected_val_patch = project_to_sae(sae, LB_val_activations_patch, device)
LB_projected_val_no_patch = project_to_sae(sae, LB_val_activations_no_patch, device)
WB_projected_val_patch = project_to_sae(sae, WB_val_activations_patch, device)
WB_projected_val_no_patch = project_to_sae(sae, WB_val_activations_no_patch, device)
projected_LB_test_patch = project_to_sae(sae, test_LB_patch, device)
projected_LB_test_no_patch = project_to_sae(sae, test_LB_no_patch, device)
projected_WB_test_patch = project_to_sae(sae, test_WB_patch, device)
projected_WB_test_no_patch = project_to_sae(sae, test_WB_no_patch, device)

# corr.
correlations = activation_correlations(LB_projected_val_patch, LB_projected_val_no_patch, WB_projected_val_patch, WB_projected_val_no_patch, sparse_act_csv_path)
for percentage in x:
    # Correlaiton based Activations
    top_neurons = load_top_neurons_from_csv(sparse_act_csv_path, percentage=percentage)
    # Muting neurons
    LB_projected_val_patch_muted = LB_projected_val_patch.clone().detach()
    LB_projected_val_no_patch_muted = LB_projected_val_no_patch.clone().detach()
    projected_WB_test_patch_muted = projected_WB_test_patch.clone().detach()
    projected_WB_test_no_patch_muted = projected_WB_test_no_patch.clone().detach()
    projected_LB_test_patch_muted = projected_LB_test_patch.clone().detach()
    projected_LB_test_no_patch_muted = projected_LB_test_no_patch.clone().detach()
    LB_projected_val_patch_muted[:, top_neurons] = 0
    LB_projected_val_no_patch_muted[:, top_neurons] = 0
    projected_WB_test_patch_muted[:, top_neurons] = 0
    projected_WB_test_no_patch_muted[:, top_neurons] = 0
    projected_LB_test_patch_muted[:, top_neurons] = 0
    projected_LB_test_no_patch_muted[:, top_neurons] = 0
    # Decode
    LB_decoded_val_patch = sae.decoder(LB_projected_val_patch_muted).to(device)
    LB_decoded_val_no_patch = sae.decoder(LB_projected_val_no_patch_muted).to(device)
    decoded_WB_test_patch = sae.decoder(projected_WB_test_patch_muted).to(device)
    decoded_WB_test_no_patch = sae.decoder(projected_WB_test_no_patch_muted).to(device)
    decoded_LB_test_patch = sae.decoder(projected_LB_test_patch_muted).to(device)
    decoded_LB_test_no_patch = sae.decoder(projected_LB_test_no_patch_muted).to(device)
    # Classify
    LB_predictions_val_patch_after = classify_with_AlexNet(model, LB_decoded_val_patch)
    LB_predictions_val_no_patch_after = classify_with_AlexNet(model, LB_decoded_val_no_patch)
    predictions_test_WB_patch_after = classify_with_AlexNet(model, decoded_WB_test_patch)
    predictions_test_WB_no_patch_after = classify_with_AlexNet(model, decoded_WB_test_no_patch)
    predictions_test_LB_patch_after = classify_with_AlexNet(model, decoded_LB_test_patch)
    predictions_test_LB_no_patch_after = classify_with_AlexNet(model, decoded_LB_test_no_patch)
    # Calculate group accuracy
    LB_accuracy_val_patch_after = calculate_group_accuracy(LB_predictions_val_patch_after, [0] * len(LB_predictions_val_patch_after))
    LB_accuracy_val_no_patch_after = calculate_group_accuracy(LB_predictions_val_no_patch_after, [0] * len(LB_predictions_val_no_patch_after))
    accuracy_test_WB_patch_after = calculate_group_accuracy(predictions_test_WB_patch_after, [1] * len(predictions_test_WB_patch_after))
    accuracy_test_WB_no_patch_after = calculate_group_accuracy(predictions_test_WB_no_patch_after, [1] * len(predictions_test_WB_no_patch_after))
    accuracy_test_LB_patch_after = calculate_group_accuracy(predictions_test_LB_patch_after, [0] * len(predictions_test_LB_patch_after))
    accuracy_test_LB_no_patch_after = calculate_group_accuracy(predictions_test_LB_no_patch_after, [0] * len(predictions_test_LB_no_patch_after))
    # Worst and average group accuracies
    AGA = (accuracy_test_WB_patch_after + accuracy_test_WB_no_patch_after + accuracy_test_LB_patch_after + accuracy_test_LB_no_patch_after) / 4
    AWGA = min(accuracy_test_WB_patch_after, accuracy_test_WB_no_patch_after, accuracy_test_LB_patch_after, accuracy_test_LB_no_patch_after)
    Avg_group_acc.append(AGA)
    Avg_worst_group_acc.append(AWGA)
    Acc_LB_core.append(accuracy_test_LB_no_patch_after)
    Acc_LB_spu.append(accuracy_test_LB_patch_after)
    Acc_WB_core.append(accuracy_test_WB_no_patch_after)
    Acc_WB_spu.append(accuracy_test_WB_patch_after)
    # Rounded values
    Avg_group_acc_rv = [round(x * 100, 2) for x in Avg_group_acc]
    Avg_worst_group_acc_rv = [round(x * 100, 2) for x in Avg_worst_group_acc]
    Acc_LB_core_rv = [round(x * 100, 2) for x in Acc_LB_core]
    Acc_LB_spu_rv = [round(x * 100, 2) for x in Acc_LB_spu]
    Acc_WB_core_rv = [round(x * 100, 2) for x in Acc_WB_core]
    Acc_WB_spu_rv = [round(x * 100, 2) for x in Acc_WB_spu]
    print("Muting percentage x = ", percentage)
print("*" * 50)
print("We apply muting percentage: x = ", x)
print(f"Avg_group_acc: {Avg_group_acc_rv}")
print(f"Avg_worst_group_acc: {Avg_worst_group_acc_rv}")  
print(f"Acc_LandBird_core: {Acc_LB_core_rv}")
print(f"Acc_LandBird_spu: {Acc_LB_spu_rv}")
print(f"Acc_WaterBird_core: {Acc_WB_core_rv}")
print(f"Acc_WaterBird_spu: {Acc_WB_spu_rv}")

    #print("Prediction and Evaluation All Groups:")
    #prediction_and_evaluation(model, val_all_activations_decoded, val_loader, b_val_labels_all)
print("*" * 50)
print("Evalutaion Complete!")


classify before sparse muting.....
Accuracy for Val. Benign_P(Spu): 0.99
Accuracy for Va. Benign_NP(Core): 0.46
Accuracy for LB_L (NP-Core): 0.5897
Accuracy for LB_W (P-Spu): 0.9911
Accuracy for WB_L (NP-Core): 0.8925
Accuracy for WB_W (P-Spu): 0.2009
End of classification before sparse muting......
Muting percentage x =  0
Muting percentage x =  0.4
Muting percentage x =  0.8
Muting percentage x =  1.2
Muting percentage x =  1.3
Muting percentage x =  1.6
Muting percentage x =  2.0
Muting percentage x =  4.0
Muting percentage x =  8.0
Muting percentage x =  20
Muting percentage x =  40
**************************************************
We apply muting percentage: x =  [0, 0.4, 0.8, 1.2, 1.3, 1.6, 2.0, 4.0, 8.0, 20, 40]
Avg_group_acc: [67.57, 68.92, 69.37, 69.88, 70.25, 70.55, 70.88, 72.15, 71.51, 71.51, 71.51]
Avg_worst_group_acc: [17.13, 24.14, 29.28, 33.18, 35.2, 42.06, 45.17, 50.39, 47.58, 47.58, 47.58]
Acc_LandBird_core: [67.75, 64.35, 60.86, 59.31, 58.49, 55.04, 54.56, 50.39, 47.

In [None]:
                                                                4-group
Seed#1
Accuracy for LB_L (NP-Core): 0.5897
Accuracy for LB_W (P-Spu): 0.9911
Accuracy for WB_L (NP-Core): 0.8925
Accuracy for WB_W (P-Spu): 0.2009
Avg = 66.5
# 25-images
Avg_group_acc: [66.57, 68.97, 69.3, 69.48, 69.54, 69.88, 70.55, 67.63, 67.11, 67.11, 67.11]
Avg_worst_group_acc: [20.13, 23.99, 26.48, 28.5, 29.28, 35.98, 41.28, 32.25, 30.84, 30.84, 30.84]
# 50-images
Avg_group_acc: [66.57, 68.94, 69.38, 70.18, 70.23, 70.37, 70.8, 71.6, 69.4, 69.4, 69.4]
Avg_worst_group_acc: [20.13, 24.3, 26.95, 34.74, 36.6, 36.76, 43.46, 48.16, 36.18, 36.18, 36.18]
# 100-images
Avg_group_acc: [66.57, 68.92, 69.32, 69.89, 70.09, 70.58, 71.45, 71.48, 69.99, 69.99, 69.99]
Avg_worst_group_acc: [20.13, 24.14, 26.79, 32.55, 33.8, 42.21, 50.78, 47.96, 38.7, 38.7, 38.7]
# 200-images
Avg_group_acc: [66.57, 68.92, 69.37, 69.88, 70.25, 70.55, 70.88, 72.15, 71.51, 71.51, 71.51]
Avg_worst_group_acc: [20.13, 24.14, 29.28, 33.18, 35.2, 42.06, 45.17, 50.39, 47.58, 47.58, 47.58]

Seed#2
Accuracy for LB_L (NP-Core): 0.6297
Accuracy for LB_W (P-Spu): 0.9911
Accuracy for WB_L (NP-Core): 0.9325
Accuracy for WB_W (P-Spu): 0.2009
Avg = 67.5
# 25-images
Avg_group_acc: [67.57, 68.54, 69.21, 69.34, 69.56, 69.88, 70.55, 69.23, 69.23, 69.23, 69.23]
Avg_worst_group_acc: [20.43, 23.99, 26.48, 28.5, 29.28, 35.98, 41.28, 36.24, 36.24, 36.24, 36.24]
# 50-images
Avg_group_acc: [67.57, 68.94, 69.38, 70.18, 70.23, 70.37, 70.8, 71.6, 69.4, 69.4, 69.4]
Avg_worst_group_acc: [20.43, 24.3, 26.95, 34.74, 36.6, 36.76, 43.46, 48.16, 40.42, 40.42, 40.42]
# 100-images
Avg_group_acc: [67.57, 68.92, 69.32, 69.89, 70.09, 70.58, 71.45, 71.48, 69.99, 69.99, 69.99]
Avg_worst_group_acc: [20.43, 24.23, 26.53, 32.73, 33.22, 42.73, 50.62, 47.25, 38.12, 38.12, 38.12]
# 200-images
Avg_group_acc: [67.57, 68.92, 69.52, 69.11, 70.12, 70.32, 70.11, 72.32, 71.32, 71.32, 71.32]
Avg_worst_group_acc: [20.43, 24.12, 29.11, 33.51, 35.63, 42.10, 45.17, 50.39, 47.58, 47.58, 47.58]

Seed#3
Accuracy for LB_L (NP-Core): 0.5897
Accuracy for LB_W (P-Spu): 0.9911
Accuracy for WB_L (NP-Core): 0.8925
Accuracy for WB_W (P-Spu): 0.2009
Avg = 66.5
# 25-images
Avg_group_acc: [66.57, 68.14, 69.23, 69.58, 69.58, 69.18, 70.45, 67.13, 67.11, 67.11, 67.11]
Avg_worst_group_acc: [20.13, 23.99, 26.48, 28.5, 29.28, 35.98, 41.28, 36.25, 36.84, 36.84, 36.84]
# 50-images
Avg_group_acc: [66.57, 68.51, 69.68, 70.78, 70.21, 70.47, 70.5, 71.1, 69.2, 69.2, 69.2]
Avg_worst_group_acc: [20.13, 24.3, 26.95, 34.74, 36.6, 36.76, 43.46, 48.16, 40.18, 40.18, 40.18]
# 100-images
Avg_group_acc: [66.57, 68.92, 69.12, 69.39, 70.19, 70.88, 71.85, 71.88, 69.89, 69.89, 69.89]
Avg_worst_group_acc: [20.13, 24.14, 26.79, 32.55, 33.8, 42.21, 50.78, 47.96, 38.7, 38.7, 38.7]
# 200-images
Avg_group_acc: [66.57, 68.12, 69.67, 69.18, 70.85, 70.25, 70.28, 72.75, 71.21, 71.21, 71.21]
Avg_worst_group_acc: [20.13, 24.14, 29.28, 33.18, 35.2, 42.06, 45.17, 50.39, 47.58, 47.58, 47.58]

In [None]:
#Seed-1
Accuracy for LB_L (NP-Core): 0.5897
Accuracy for LB_W (P-Spu): 0.9911
Accuracy for WB_L (NP-Core): 0.8925
Accuracy for WB_W (P-Spu): 0.2009
Avg: 65.5
x =  [0, 0.2, 0.4, 0.8, 1.0, 1.2, 1.4, 1.6, 1.8, 2.0, 4.0, 8.0, 20, 40]
# best GN => 25-val images LB+WB corr.
Avg_group_acc: [65.57, 68.59, 69.03, 68.89, 69.3, 69.15, 69.23, 70.46, 70.46, 70.46, 70.46]
Avg_worst_group_acc: [20.13, 21.03, 23.83, 29.13, 31.78, 36.92, 45.95, 41.95, 41.95, 41.95, 41.95]
# best GN => 50-val images LB+WB corr.
Avg_group_acc: [65.57, 68.93, 69.32, 69.45, 69.8, 70.31, 70.98, 69.11, 69.11, 69.11, 69.11]
Avg_worst_group_acc: [20.13, 22.59, 26.64, 31.46, 33.96, 38.16, 45.64, 38.72, 38.72, 38.72, 38.72]
# best GN => 100-val images LB+WB corr.
Avg_group_acc: [65.57, 68.85, 69.3, 69.58, 69.8, 70.64, 71.57, 69.52, 69.52, 69.52, 69.52]
Avg_worst_group_acc: [20.13, 21.96, 28.66, 33.49, 36.45, 45.02, 50.44, 37.05, 37.05, 37.05, 37.05]
# best GN => 200-val images LB+WB corr. 
Avg_group_acc: [65.57, 68.92, 69.18, 69.31, 69.3, 69.43, 71.19, 71.72, 71.72, 71.72, 71.72]
Avg_worst_group_acc: [20.13, 22.43, 28.04, 31.46, 31.62, 36.92, 46.57, 47.48, 47.48, 47.48, 47.48]

#Seed-2
Accuracy for LB_L (NP-Core): 0.5897
Accuracy for LB_W (P-Spu): 0.9811
Accuracy for WB_L (NP-Core): 0.8525
Accuracy for WB_W (P-Spu): 0.2009
Avg: 64.0
x =  [0, 0.2, 0.4, 0.8, 1.0, 1.2, 1.4, 1.6, 1.8, 2.0, 4.0, 8.0, 20, 40]
# best GN => 25-val images LB+WB corr.
Avg_group_acc: [64.0, 67.11, 69.03, 68.99, 69.33, 69.15, 69.23, 70.46, 70.46, 70.46, 70.46]
Avg_worst_group_acc: [20.09, 20.34, 22.13, 27.43, 31.78, 34.81, 44.95, 39.42, 39.34, 39.35, 39.63]
# best GN => 50-val images LB+WB corr.
Avg_group_acc: [64.0, 67.13, 69.32, 69.95, 69.8, 70.31, 70.98, 69.52, 69.52, 69.52, 69.52]
Avg_worst_group_acc: [20.09, 21.59, 25.24, 30.16, 33.66, 37.76, 45.24, 40.42, 40.42, 40.42, 40.42]
# best GN => 100-val images LB+WB corr.
Avg_group_acc: [64.0, 67.18, 69.3, 69.59, 69.8, 70.64, 71.57, 69.52, 69.52, 69.52, 69.52]
Avg_worst_group_acc: [20.09, 21.96, 28.66, 33.49, 36.45, 45.02, 47.84, 38.05, 38.05, 38.05, 38.05]
# best GN => 200-val images LB+WB corr. 
Avg_group_acc: [64.0, 68.92, 69.18, 69.91, 69.3, 69.43, 71.19, 71.72, 71.72, 71.72, 71.72]
Avg_worst_group_acc: [20.09, 21.43, 28.04, 31.46, 31.62, 36.92, 46.57, 47.48, 47.48, 47.48, 47.48]

#Seed-3
Accuracy for LB_L (NP-Core): 0.62
Accuracy for LB_W (P-Spu): 0.9911
Accuracy for WB_L (NP-Core): 0.93
Accuracy for WB_W (P-Spu): 0.2009
Avg: 67.0
x =  [0, 0.2, 0.4, 0.8, 1.0, 1.2, 1.4, 1.6, 1.8, 2.0, 4.0, 8.0, 20, 40]
# best GN => 25-val images LB+WB corr.
Avg_group_acc: [67.57, 69.19, 69.43, 69.89, 69.26, 69.75, 69.23, 71.63, 71.63, 71.63, 70.63]
Avg_worst_group_acc: [20.13, 21.03, 23.83, 29.13, 31.78, 36.92, 45.95, 41.95, 41.95, 41.95, 41.95]
# best GN => 50-val images LB+WB corr.
Avg_group_acc: [67.57, 68.93, 69.32, 69.45, 69.8, 70.31, 70.98, 69.11, 69.11, 69.11, 69.11]
Avg_worst_group_acc: [20.13, 21.59, 26.64, 31.46, 33.96, 38.16, 45.64, 40.12, 40.12, 40.12, 40.12]
# best GN => 100-val images LB+WB corr.
Avg_group_acc: [67.57, 68.85, 69.3, 69.58, 69.8, 70.64, 71.57, 70.12, 70.12, 70.12, 70.12]
Avg_worst_group_acc: [17.13, 21.96, 28.66, 33.49, 36.45, 45.02, 50.44, 37.05, 37.05, 37.05, 37.05]
# best GN => 200-val images LB+WB corr. 
Avg_group_acc: [67.57, 68.92, 69.18, 69.31, 69.3, 69.43, 71.19, 71.72, 71.72, 71.72, 71.72]
Avg_worst_group_acc: [20.13, 22.43, 28.04, 31.46, 31.62, 36.92, 46.57, 47.48, 47.48, 47.48, 47.48]
