In [70]:
import os
from PIL import Image
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import DataLoader, random_split, Subset
import numpy as np
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
import torch.nn.functional as F
import matplotlib.pyplot as plt
#import seaborn as sns
import torch.optim as optim
from torchvision import datasets, transforms
from sklearn.metrics.pairwise import cosine_similarity
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from torchvision.models import alexnet, AlexNet_Weights, resnet50, ResNet50_Weights, resnet18, ResNet18_Weights, resnet101, ResNet101_Weights, VGG19_Weights, vgg19
from PIL import *
import PIL.Image
import gc
import os
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(23)
np.random.seed(23)

In [101]:
# 1. Define Datasets and Dataloader
class WaterbirdsDataset(torch.utils.data.Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.metadata = pd.read_csv(csv_file)  # Metadata file with bird type and background
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        img_path = f"{self.root_dir}/{self.metadata.iloc[idx]['img_filename']}"
        image = Image.open(img_path).convert("RGB")
        bird_type = self.metadata.iloc[idx]['y']  # Waterbird=1, Landbird=0
        background = self.metadata.iloc[idx]['place']  # Water=1, Land=0
        label = bird_type  # For training, we only care about bird type

        if self.transform:
            image = self.transform(image)
        return image, label, bird_type, background

# 2. Transforms and Data Preparation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load training and testing datasets
train_dataset = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/train_metadata_updated.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB/train_DB/all_birds_train',
    transform=transform
)
test_dataset = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/test_metadata_updated.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB/test_DB/all_birds_test',
    transform=transform
)
test_dataset_LB = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/test_metadata_updated_LB.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB/test_DB/all_birds_test',
    transform=transform
)
test_dataset_LB_no_patch = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/test_metadata_updated_LB_no_patch.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB/test_DB/all_birds_test',
    transform=transform
)
test_dataset_LB_patch = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/test_metadata_updated_LB_patch.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB/test_DB/all_birds_test',
    transform=transform
)

test_dataset_LB_25 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/test_metadata_updated_LB_25.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB/test_DB/all_birds_test',
    transform=transform
)
test_dataset_WB = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/test_metadata_updated_WB.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB/test_DB/all_birds_test',
    transform=transform
)
test_dataset_WB_no_patch = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/test_metadata_updated_WB_no_patch.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB/test_DB/all_birds_test',
    transform=transform
)
test_dataset_WB_patch = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/test_metadata_updated_WB_patch.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB/test_DB/all_birds_test',
    transform=transform
) 
test_dataset_WB_25 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/test_metadata_updated_WB_25.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB/test_DB/all_birds_test',
    transform=transform
)

val_dataset_LB = WaterbirdsDataset(
    #csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_LB.csv',
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_0_1_LB.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_WB = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_WB.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_WB_LB = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_WB+LB.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
# Variational DB
val_dataset_LB_25 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_LB_yes_no_25.csv',
    #csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_0_1_LB.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_LB_no_patch_25 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_LB_no_patch_25.csv',
    #csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_0_1_LB.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_LB_patch_25 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_LB_patch_25.csv',
    #csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_0_1_LB.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)

val_dataset_WB_25 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_WB_yes_no_25.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_WB_no_patch_25 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_WB_no_patch_25.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_WB_patch_25 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_WB_patch_25.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)

val_dataset_LB_50 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_LB_yes_no_50.csv',
    #csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_0_1_LB.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_LB_no_patch_50 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_LB_no_patch_50.csv',
    #csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_0_1_LB.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_LB_patch_50 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_LB_patch_50.csv',
    #csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_0_1_LB.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)

val_dataset_WB_50 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_WB_yes_no_50.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_WB_no_patch_50 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_WB_no_patch_50.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_WB_patch_50 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_WB_patch_50.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)


val_dataset_LB_100 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_LB_yes_no_100.csv',
    #csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_0_1_LB.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_LB_patch_100 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_LB_patch_100.csv',
    #csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_0_1_LB.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_LB_no_patch_100 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_LB_no_patch_100.csv',
    #csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_0_1_LB.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_WB_100 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_WB_yes_no_100.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_WB_patch_100 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_WB_patch_100.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_WB_no_patch_100 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_WB_no_patch_100.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_LB_200 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_LB_yes_no_200.csv',
    #csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_0_1_LB.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_LB_patch_200 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_LB_patch_200.csv',
    #csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_0_1_LB.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_LB_no_patch_200 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_LB_no_patch_200.csv',
    #csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_0_1_LB.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_WB_200 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_WB_yes_no_200.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_WB_patch_200 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_WB_patch_200.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
val_dataset_WB_no_patch_200 = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_updated_WB_no_patch_200.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)

batch_size = 64
val_loader_LB_25 = DataLoader(val_dataset_LB_25, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
val_loader_WB_25 = DataLoader(val_dataset_WB_25, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
val_loader_LB_50 = DataLoader(val_dataset_LB_50, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
val_loader_WB_50 = DataLoader(val_dataset_WB_50, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
val_loader_LB_100 = DataLoader(val_dataset_LB_100, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
val_loader_WB_100 = DataLoader(val_dataset_WB_100, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
val_loader_LB_200 = DataLoader(val_dataset_LB_200, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
val_loader_WB_200 = DataLoader(val_dataset_WB_200, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)



train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
test_loader_LB = DataLoader(test_dataset_LB, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
test_loader_LB_25 = DataLoader(test_dataset_LB_25, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
test_loader_WB = DataLoader(test_dataset_WB, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
test_loader_WB_25 = DataLoader(test_dataset_WB_25, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

test_Loader_LB_no_patch = DataLoader(test_dataset_LB_no_patch, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
test_Loader_LB_patch = DataLoader(test_dataset_LB_patch, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
test_Loader_WB_no_patch = DataLoader(test_dataset_WB_no_patch, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
test_Loader_WB_patch = DataLoader(test_dataset_WB_patch, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

val_Loader_LB_no_patch_200 = DataLoader(val_dataset_LB_no_patch_200, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_Loader_LB_patch_200 = DataLoader(val_dataset_LB_patch_200, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_Loader_WB_no_patch_200 = DataLoader(val_dataset_WB_no_patch_200, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_Loader_WB_patch_200 = DataLoader(val_dataset_WB_patch_200, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

val_Loader_LB_no_patch_100 = DataLoader(val_dataset_LB_no_patch_100, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_Loader_LB_patch_100 = DataLoader(val_dataset_LB_patch_100, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_Loader_WB_no_patch_100 = DataLoader(val_dataset_WB_no_patch_100, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_Loader_WB_patch_100 = DataLoader(val_dataset_WB_patch_100, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

val_Loader_LB_no_patch_50 = DataLoader(val_dataset_LB_no_patch_50, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_Loader_LB_patch_50 = DataLoader(val_dataset_LB_patch_50, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_Loader_WB_no_patch_50 = DataLoader(val_dataset_WB_no_patch_50, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_Loader_WB_patch_50 = DataLoader(val_dataset_WB_patch_50, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

val_Loader_LB_no_patch_25 = DataLoader(val_dataset_LB_no_patch_25, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_Loader_LB_patch_25 = DataLoader(val_dataset_LB_patch_25, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_Loader_WB_no_patch_25 = DataLoader(val_dataset_WB_no_patch_25, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_Loader_WB_patch_25 = DataLoader(val_dataset_WB_patch_25, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

val_loader_LB = DataLoader(val_dataset_LB, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
val_loader_WB = DataLoader(val_dataset_WB, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
val_loader_WB_LB = DataLoader(val_dataset_WB_LB, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)


In [102]:
class SparseAutoEncoder_2(nn.Module):
    def __init__(self, input_dim, hidden_dim, sparsity_lambda=0.7, xavier_norm_init=True):
        super(SparseAutoEncoder_2, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.sparsity_lambda = sparsity_lambda
        
        self.encoder = nn.Sequential(
            nn.Linear(self.input_dim, self.hidden_dim), #nn.BatchNorm1d(hidden_dim)
            nn.GroupNorm(num_groups=16, num_channels=hidden_dim),
            nn.ReLU()
        )
        if xavier_norm_init:
            nn.init.xavier_uniform_(self.encoder[0].weight)  # Xavier initialization
            
        self.decoder = nn.Sequential(
            nn.Linear(self.hidden_dim, self.input_dim),
            #nn.ReLU() #nn.Sigmoid()  # Output between 0-1
        )
        if xavier_norm_init:
            nn.init.xavier_uniform_(self.decoder[0].weight)
        

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded
    
    def kl_sparsity_penalty(self, encoded):
        # Penalize the average absolute activation
        rho_hat = torch.mean(torch.abs(encoded), dim=0)  # Average absolute activation per hidden unit
        #rho_hat = 0.1122
        rho = torch.ones_like(rho_hat) * self.sparsity_target  # Target sparsity value
        epsilon = 1e-8  # Small value to avoid log(0)

        # KL-divergence computation for sparsity
        kl_divergence = rho * torch.log(rho / (rho_hat + epsilon)) + (1 - rho) * torch.log((1 - rho) / (1 - rho_hat + epsilon))
        kl_divergence = torch.sum(kl_divergence)  # Sum over all hidden units

        return self.sparsity_lambda * kl_divergence

    # L1-norm sparsity penalty calculation
    def l1_sparsity_penalty(self, encoded):
        # Compute the mean of absolute values of activations
        sparsity_loss = torch.mean(torch.abs(encoded))  # Average absolute activation across all units
        #sparsity_loss = 0.1122  # Average absolute activation across all units
        return self.sparsity_lambda * sparsity_loss  # Scale by the sparsity weight

    # Loss function combining MSE (reconstruction error) and sparsity penalty
    def loss_function(self, decoded, original, encoded):
        mse_loss = F.mse_loss(decoded, original)  # Mean Squared Error for reconstruction
        sparsity_loss = self.l1_sparsity_penalty(encoded)  # Sparsity penalty for hidden layer activations
        return mse_loss + sparsity_loss  # Total loss is MSE + sparsity penalty
# Instantiate the Sparse Auto-encoder with given dimensions
def load_autoencoder_2(sae_path, device):
    input_dim = 2048
    hidden_dim = 8000
    sae_2 = SparseAutoEncoder_2(input_dim, hidden_dim)
    sae_2.load_state_dict(torch.load(sae_path, map_location=device))
    sae_2 = sae_2.to(device)
    # Freeze all parameters of the autoencoder
    for param in sae_2.parameters():
        param.requires_grad = False
    sae_2.eval()
    return sae_2


In [103]:
# Load Models
def get_model_AlexNet(model_path, device):
    num_features = 4096
    num_classes = 2
    model_path = model_path
    device = device
    
    model = alexnet(weights=AlexNet_Weights.DEFAULT) # weights=AlexNet_Weights.DEFAULT) # weights=None
    model.classifier[6] = nn.Linear(model.classifier[6].in_features, 2)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model = model.to(device)
    return model

In [104]:
def get_model_ResNet(model_path, device):
    num_features = 2048
    num_classes = 2
    model_path = model_path
    device = device    
    model = models.resnet50(weights=None)   #weights=ResNet50_Weights.DEFAULT)
    model.fc = nn.Linear(model.fc.in_features, 2)  # Output for 2 classes
    model.load_state_dict(torch.load(model_path, map_location=device))
    model = model.to(device)
    return model

In [105]:
import numpy as np
def activation_correlations(b_val_patch_activations, b_val_no_patch_activations, m_val_patch_activations, m_val_no_patch_activations, act_csv_path):
    # Step 1: Print total number of values in tensors and total number of differences
    # Number of neurons and images
    num_neurons = b_val_patch_activations.shape[1]
    #print(f"Number of val patch neurons: {num_neurons}")
    num_images = b_val_patch_activations.shape[0]
    #print(f"Number of val patch images: {b_val_patch_activations.shape[0]}")

    # Create a binary label vector pp (1 for patch, 0 for no patch)
    pp = np.concatenate([np.ones(b_val_patch_activations.shape[0]), np.ones(m_val_patch_activations.shape[0]),
                         np.zeros(b_val_no_patch_activations.shape[0]), np.zeros(m_val_no_patch_activations.shape[0])])
    
    #print(f"Number of pp labels: {len(pp)}")
    
    b_val_patch_activations = b_val_patch_activations.cpu().numpy()
    b_val_no_patch_activations = b_val_no_patch_activations.cpu().numpy()
    m_val_patch_activations = m_val_patch_activations.cpu().numpy()
    m_val_no_patch_activations = m_val_no_patch_activations.cpu().numpy()
    # array to store correlations
    correlations = np.zeros(num_neurons)  # Shape: (2048,)
    for i in range(num_neurons):
        # Combine activations for neuron i from both datassets p and np
        act_i = np.concatenate([b_val_patch_activations[:, i], m_val_patch_activations[:, i],
                                b_val_no_patch_activations[:, i], m_val_no_patch_activations[:, i]])

        # Compute correlation between pp and act_i
        if np.std(pp) > 0 and np.std(act_i) > 0:
            corr = np.corrcoef(pp, act_i)[0, 1]
        else:
            corr = 0  # Handle constant vectors
        #corr_value = np.abs(corr)
        correlations[i] = corr
        
    correlations = np.nan_to_num(correlations)  # Replace NaN values with 0
    # Create a DataFrame with neuron indices and their correlations
    neuron_data = pd.DataFrame({
        "Neuron_Index": np.arange(num_neurons),
        "Correlation": correlations
    })
    # Sort by correlation in descending order and vsave neurons to csv
    neuron_data.sort_values(by="Correlation", ascending=False, inplace=True)
    # Save the DataFrame to a CSV file
    csv_path = act_csv_path
    neuron_data.to_csv(csv_path, index=False)
    
    return correlations

In [106]:
import numpy as np
def activation_correlations_2(b_val_patch_activations, b_val_no_patch_activations, act_csv_path):
    # Step 1: Print total number of values in tensors and total number of differences
    # Number of neurons and images
    num_neurons = b_val_patch_activations.shape[1]
    #print(f"Number of val patch neurons: {num_neurons}")
    num_images = b_val_patch_activations.shape[0]
    #print(f"Number of val patch images: {b_val_patch_activations.shape[0]}")

    # Create a binary label vector pp (1 for patch, 0 for no patch)
    pp = np.concatenate([np.ones(b_val_patch_activations.shape[0]), np.zeros(b_val_no_patch_activations.shape[0])])
    
    #print(f"Number of pp labels: {len(pp)}")
    b_val_patch_activations = b_val_patch_activations.cpu().numpy()
    b_val_no_patch_activations = b_val_no_patch_activations.cpu().numpy()
    
    # array to store correlations
    correlations = np.zeros(num_neurons)  # Shape: (2048,)
    for i in range(num_neurons):
        # Combine activations for neuron i from both datassets p and np
        act_i = np.concatenate([b_val_patch_activations[:, i], b_val_no_patch_activations[:, i]])
        # Compute correlation between pp and act_i
        if np.std(pp) > 0 and np.std(act_i) > 0:
            corr = np.corrcoef(pp, act_i)[0, 1]
        else:
            corr = 0  # Handle constant vectors
        #corr_value = np.abs(corr)
        correlations[i] = corr
        
    correlations = np.nan_to_num(correlations)  # Replace NaN values with 0
    # Create a DataFrame with neuron indices and their correlations
    neuron_data = pd.DataFrame({
        "Neuron_Index": np.arange(num_neurons),
        "Correlation": correlations
    })
    # Sort by correlation in descending order and vsave neurons to csv
    neuron_data.sort_values(by="Correlation", ascending=False, inplace=True)
    # Save the DataFrame to a CSV file
    csv_path = act_csv_path
    neuron_data.to_csv(csv_path, index=False)
    
    return correlations

In [107]:
# Get activations # 2
def get_activations_2(model, feature_extractor, dataloader):
    labels_spu = []
    labels_all = []
    all_activations = []
    spu_activations = []
    all_class_scores = []
    correlations = []
    # Extract activations
    with torch.no_grad():  # No gradient calculation for inference
        for images, labels, benign_malignant, patches in dataloader:
            images = images.to(device)  # Send to GPU if available
            #images = images.unsqueeze(0)
            labels = labels.to(device)
            benign_malignant = benign_malignant.to(device)
            patches = patches.to(device)
            
            #if val_loader_name == 'ISIC_testLoader_dataset_malignant':    
            for label, patch in zip(benign_malignant, patches):
                if patch.item() == 0:
                    activations = feature_extractor(images)  # Get activations from avgpool
                    activations = activations.view(activations.size(0), -1)  # Flatten avgpool output
                    all_activations.append(activations.cpu())  # Collect activations and move to CPU
                    labels_all.append(label)
                    #class_scores = activations.gather(1, labels.view(-1, 1)).squeeze()
                    #all_class_scores.append(class_scores.cpu())
                elif patch.item() == 1:
                    activations = feature_extractor(images)  # Get activations from avgpool
                    activations = activations.view(activations.size(0), -1)  # Flatten avgpool output
                    spu_activations.append(activations.cpu())  # Collect activations and move to CPU
                    labels_spu.append(label)
            
    #print("all_class_scores", all_class_scores)
    labels_all = torch.tensor(labels_all).to(device)
    labels_spu = torch.tensor(labels_spu).to(device)
    all_activations = np.vstack(all_activations) # No furtehr needed
    all_activations = torch.tensor(all_activations).to(device)
    spu_activations = np.vstack(spu_activations) # No furtehr needed
    spu_activations = torch.tensor(spu_activations).to(device)
    return all_activations, labels_all, spu_activations, labels_spu

In [108]:
# Get activations
def get_activations_AlexNet(model, dataloader, device):
    model.eval()
    activations = []
    def hook_fn(module, input, output):
        activations.append(output.detach().cpu())
    handle = model.classifier[4].register_forward_hook(hook_fn)  # fc2 linear layer (pre-ReLU)
    with torch.no_grad():
        for images, _, _, _ in dataloader:
            _ = model(images.to(device))
    handle.remove()
    act_tensors = torch.cat(activations, dim=0).squeeze() 
    return act_tensors

In [109]:
# Function to load top neurons from CSV based on a percentage
def load_top_neurons_from_csv(csv_path, percentage):
    """
    Load top neurons based on the specified percentage from the saved CSV file.
    """
    neuron_data = pd.read_csv(csv_path)

    # Calculate the number of top neurons to select
    top_count = int(len(neuron_data) * (percentage / 100))

    # Select the top neurons based on their correlation difference
    top_neurons = neuron_data.iloc[:top_count]["Neuron_Index"].values

    # Debugging for 0% muting
    if percentage == 0:
        assert len(top_neurons) == 0, "Top neurons list should be empty for 0% muting."

    #print(f"Loaded top {percentage}% neurons ({top_count} neurons) for muting.")
    return top_neurons

In [110]:
def get_activations_ResNet(model, dataloader, device):
    model.to(device)
    model.eval()
    activations = []
    def hook_fn(module, input, output):
        activations.append(output.detach().cpu())
    handle = model.avgpool.register_forward_hook(hook_fn)
    with torch.no_grad():
        for images, labels, _, _ in dataloader:
            images = images.to(device)
            _ = model(images)
    handle.remove()
    act_tensor = torch.cat(activations, dim=0).squeeze()  # shape: (N, 2048)
    #act_tensor = act_tensor / (act_tensor.norm(dim=1, keepdim=True) + 1e-8)
    # Cleanup to free memory
    del model
    torch.cuda.empty_cache()
    gc.collect()
    return act_tensor
    

In [111]:
# Per-group accuracy
def calculate_group_accuracy(predictions, true_labels):
    return accuracy_score(predictions, true_labels)

def classify_with_RestNet(model, all_activations):  # changed Original
    correct = 0 
    total = 0
    pruned_predictions = []
    #with torch.enable_grad():
    for act  in all_activations:
        activation_tensor = act.to(device)
        output = model.fc(activation_tensor)
        prediction = torch.argmax(torch.nn.functional.softmax(output, dim=0)).item()
        #prediction = torch.argmax(output).item()
        pruned_predictions.append(prediction)
    return pruned_predictions

def classify_with_RestNet_2(model, all_activations):
    correct = 0
    total = 0
    pruned_predictions = []
    with torch.no_grad():
        for act  in all_activations:
            if act.ndim == 3:
                    activation_tensor = act.unsqueeze(0).to(device)  # [1, C, H, W]
            else:
                activation_tensor = act.to(device)
            activation_tensor = act.unsqueeze(0).to(device)
            
            #out = model.layer4[-1].relu(activation_tensor)# apply the skipped ReLU
            #out = model.avgpool(out)                      # global average pooling
            #out = torch.flatten(out, 1)
            output = model.fc(activation_tensor)                        # classification logits
            #preds = torch.argmax(logits, dim=1)
            #output = model.fc(activation_tensor)
            #prediction = torch.argmax(torch.nn.functional.softmax(output, dim=0)).item()
            prediction = torch.argmax(output, dim=1).item()
            pruned_predictions.append(prediction)
    return pruned_predictions

def classify_with_AlexNet(model, all_activations):
    model.eval() 
    pruned_predictions = []
    for activation in all_activations:
        # Convert numpy activation to tensor
        activation_tensor = activation.to(device)
        #activation_tensor = torch.from_numpy(activation).float().to(device)
        relu_output = model.classifier[5](activation_tensor)  # Apply ReLU
        output = model.classifier[6](relu_output)  # Apply fc3
        prediction = torch.argmax(output).item()
        #prediction = torch.argmax(torch.nn.functional.softmax(output, dim=0)).item()
        pruned_predictions.append(prediction)
    return pruned_predictions


In [112]:
# Project activations into sparse space
def project_to_sae(sae, all_activations, device):
    #sae.to(device)
    with torch.no_grad():
        projected = sae.encoder(torch.from_numpy(all_activations).float().to(device))
    return projected

In [None]:
# extra 1, Testing on creating correlation based on both benign and malignant labels
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_path = 'ResNet_WB_best_seed-2.pth'
sae_path = 'ResNet_WB_SAE_seed-2.pth'
sparse_act_csv_path = 'ResNet-on-WB_activations.csv'
Avg_group_acc = []
Avg_worst_group_acc = []
Acc_LB_core = [] 
Acc_LB_spu = [] 
Acc_WB_core = [] 
Acc_WB_spu = [] 
x = [0.3, 0.6, 1.0, 1.5, 2, 5, 10, 40]
# Get model and feature_extractor
model = get_model_ResNet(model_path, device)

LB_val_loader_p =   val_Loader_LB_patch_200
LB_val_loader_np =  val_Loader_LB_no_patch_200
WB_val_loader_p =   val_Loader_WB_patch_200     
WB_val_loader_np =  val_Loader_WB_no_patch_200      

test_LB_loader_p =  test_Loader_LB_patch      
test_LB_loader_np = test_Loader_LB_no_patch 
test_WB_loader_p =  test_Loader_WB_patch
test_WB_loader_np = test_Loader_WB_no_patch

# Get activations 
LB_val_activations_no_patch  = get_activations_ResNet(model, LB_val_loader_np, device)
LB_val_activations_patch  = get_activations_ResNet(model, LB_val_loader_p, device)
WB_val_activations_no_patch  = get_activations_ResNet(model, WB_val_loader_np, device)
WB_val_activations_patch  = get_activations_ResNet(model, WB_val_loader_p, device)

test_LB_no_patch  = get_activations_ResNet(model, test_LB_loader_np, device)
test_LB_patch  = get_activations_ResNet(model, test_LB_loader_p, device)
test_WB_no_patch  = get_activations_ResNet(model, test_WB_loader_np, device)
test_WB_patch  = get_activations_ResNet(model, test_WB_loader_p, device)

print("classify before sparse muting.....")
LB_val_activations_patch_pred = classify_with_RestNet(model, LB_val_activations_patch)
LB_val_activations_no_patch_pred = classify_with_RestNet(model, LB_val_activations_no_patch)
test_LB_patch_pred = classify_with_RestNet(model, test_LB_patch)
test_LB_no_patch_pred = classify_with_RestNet(model, test_LB_no_patch)
test_WB_patch_pred = classify_with_RestNet(model, test_WB_patch)
test_WB_no_patch_pred = classify_with_RestNet(model, test_WB_no_patch)
# Calculate group accuracy
LB_val_activations_patch_acc = calculate_group_accuracy(LB_val_activations_patch_pred, [0] * len(LB_val_activations_patch_pred))
LB_val_activations_no_patch_acc = calculate_group_accuracy(LB_val_activations_no_patch_pred, [0] * len(LB_val_activations_no_patch_pred))
test_LB_patch_acc = calculate_group_accuracy(test_LB_patch_pred, [0] * len(test_LB_patch_pred))
test_LB_no_patch_acc = calculate_group_accuracy(test_LB_no_patch_pred, [0] * len(test_LB_no_patch_pred))
test_WB_patch_acc = calculate_group_accuracy(test_WB_patch_pred, [1] * len(test_WB_patch_pred))
test_WB_no_patch_acc = calculate_group_accuracy(test_WB_no_patch_pred, [1] * len(test_WB_no_patch_pred))
print("Accuracy for Val. Benign_P(Spu): ", LB_val_activations_patch_acc)
print("Accuracy for Va. Benign_NP(Core): ", LB_val_activations_no_patch_acc)

print("Accuracy for LB_NP(Core): ", test_LB_no_patch_acc)
print("Accuracy for LB_P(Spu): ", test_LB_patch_acc)
print("Accuracy for WB_NP(Core): ", test_WB_no_patch_acc)
print("Accuracy for WB_P(Spu): ", test_WB_patch_acc)
print(" End of classification before sparse muting......")

# The sparse space: Project activations
sae = load_autoencoder_2(sae_path, device)
sae.to(device)
LB_val_activations_patch = LB_val_activations_patch.cpu().numpy()
LB_val_activations_no_patch = LB_val_activations_no_patch.cpu().numpy()
WB_val_activations_patch = WB_val_activations_patch.cpu().numpy()
WB_val_activations_no_patch =  WB_val_activations_no_patch.cpu().numpy()
test_LB_patch = test_LB_patch.cpu().numpy()
test_LB_no_patch = test_LB_no_patch.cpu().numpy()
test_WB_patch = test_WB_patch.cpu().numpy()
test_WB_no_patch = test_WB_no_patch.cpu().numpy()

LB_projected_val_patch = project_to_sae(sae, LB_val_activations_patch, device)
LB_projected_val_no_patch = project_to_sae(sae, LB_val_activations_no_patch, device)
WB_projected_val_patch = project_to_sae(sae, WB_val_activations_patch, device)
WB_projected_val_no_patch = project_to_sae(sae, WB_val_activations_no_patch, device)
projected_LB_test_patch = project_to_sae(sae, test_LB_patch, device)
projected_LB_test_no_patch = project_to_sae(sae, test_LB_no_patch, device)
projected_WB_test_patch = project_to_sae(sae, test_WB_patch, device)
projected_WB_test_no_patch = project_to_sae(sae, test_WB_no_patch, device)

# corr. 
correlations = activation_correlations(LB_projected_val_patch, LB_projected_val_no_patch, WB_projected_val_patch, WB_projected_val_no_patch, sparse_act_csv_path)
for percentage in x:
    # Correlaiton based Activations
    top_neurons = load_top_neurons_from_csv(sparse_act_csv_path, percentage=percentage)
    # Muting neurons
    LB_projected_val_patch_muted = LB_projected_val_patch.clone().detach()
    LB_projected_val_no_patch_muted = LB_projected_val_no_patch.clone().detach()
    projected_WB_test_patch_muted = projected_WB_test_patch.clone().detach()
    projected_WB_test_no_patch_muted = projected_WB_test_no_patch.clone().detach()
    projected_LB_test_patch_muted = projected_LB_test_patch.clone().detach()
    projected_LB_test_no_patch_muted = projected_LB_test_no_patch.clone().detach()
    LB_projected_val_patch_muted[:, top_neurons] = 0
    LB_projected_val_no_patch_muted[:, top_neurons] = 0
    projected_WB_test_patch_muted[:, top_neurons] = 0
    projected_WB_test_no_patch_muted[:, top_neurons] = 0
    projected_LB_test_patch_muted[:, top_neurons] = 0
    projected_LB_test_no_patch_muted[:, top_neurons] = 0
    # Decode
    LB_decoded_val_patch = sae.decoder(LB_projected_val_patch_muted).to(device)
    LB_decoded_val_no_patch = sae.decoder(LB_projected_val_no_patch_muted).to(device)
    decoded_WB_test_patch = sae.decoder(projected_WB_test_patch_muted).to(device)
    decoded_WB_test_no_patch = sae.decoder(projected_WB_test_no_patch_muted).to(device)
    decoded_LB_test_patch = sae.decoder(projected_LB_test_patch_muted).to(device)
    decoded_LB_test_no_patch = sae.decoder(projected_LB_test_no_patch_muted).to(device)
    # Classify the same amount for your 
    LB_predictions_val_patch_after = classify_with_RestNet(model, LB_decoded_val_patch)
    LB_predictions_val_no_patch_after = classify_with_RestNet(model, LB_decoded_val_no_patch)
    predictions_test_WB_patch_after = classify_with_RestNet(model, decoded_WB_test_patch)
    predictions_test_WB_no_patch_after = classify_with_RestNet(model, decoded_WB_test_no_patch)
    predictions_test_LB_patch_after = classify_with_RestNet(model, decoded_LB_test_patch)
    predictions_test_LB_no_patch_after = classify_with_RestNet(model, decoded_LB_test_no_patch)
    # Calculate group accuracy
    LB_accuracy_val_patch_after = calculate_group_accuracy(LB_predictions_val_patch_after, [0] * len(LB_predictions_val_patch_after))
    LB_accuracy_val_no_patch_after = calculate_group_accuracy(LB_predictions_val_no_patch_after, [0] * len(LB_predictions_val_no_patch_after))
    accuracy_test_WB_patch_after = calculate_group_accuracy(predictions_test_WB_patch_after, [1] * len(predictions_test_WB_patch_after))
    accuracy_test_WB_no_patch_after = calculate_group_accuracy(predictions_test_WB_no_patch_after, [1] * len(predictions_test_WB_no_patch_after))
    accuracy_test_LB_patch_after = calculate_group_accuracy(predictions_test_LB_patch_after, [0] * len(predictions_test_LB_patch_after))
    accuracy_test_LB_no_patch_after = calculate_group_accuracy(predictions_test_LB_no_patch_after, [0] * len(predictions_test_LB_no_patch_after))
    # Worst and average group accuracies
    AGA = (accuracy_test_WB_patch_after + accuracy_test_WB_no_patch_after + accuracy_test_LB_patch_after + accuracy_test_LB_no_patch_after) / 4
    AWGA = min(accuracy_test_WB_patch_after, accuracy_test_WB_no_patch_after, accuracy_test_LB_patch_after, accuracy_test_LB_no_patch_after)
    Avg_group_acc.append(AGA)
    Avg_worst_group_acc.append(AWGA)
    Acc_LB_core.append(accuracy_test_LB_no_patch_after)
    Acc_LB_spu.append(accuracy_test_LB_patch_after)
    Acc_WB_core.append(accuracy_test_WB_no_patch_after)
    Acc_WB_spu.append(accuracy_test_WB_patch_after)
    # Rounded values
    Avg_group_acc_rv = [round(x * 100, 2) for x in Avg_group_acc]
    Avg_worst_group_acc_rv = [round(x * 100, 2) for x in Avg_worst_group_acc]
    Acc_LB_core_rv = [round(x * 100, 2) for x in Acc_LB_core]
    Acc_LB_spu_rv = [round(x * 100, 2) for x in Acc_LB_spu]
    Acc_WB_core_rv = [round(x * 100, 2) for x in Acc_WB_core]
    Acc_WB_spu_rv = [round(x * 100, 2) for x in Acc_WB_spu]
    print("Muting percentage x = ", percentage)
print("*" * 50)
print("We apply muting percentage: x = ", x)
print(f"Avg_group_acc: {Avg_group_acc_rv}")
print(f"Avg_worst_group_acc: {Avg_worst_group_acc_rv}")  
print(f"Acc_benign_core: {Acc_LB_core_rv}")
print(f"Acc_benign_spu: {Acc_LB_spu_rv}")
print(f"Acc_malignant_core: {Acc_WB_core_rv}")
print(f"Acc_malignant_spu: {Acc_WB_spu_rv}")

    #print("Prediction and Evaluation All Groups:")
    #prediction_and_evaluation(model, val_all_activations_decoded, val_loader, b_val_labels_all)
print("*" * 50)
print("Evaluation Complete!")
