In [4]:
import torch
import os
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms, datasets
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import DataLoader, Subset
import pandas as pd
import numpy as np
from PIL import *
import PIL.Image
from torch.utils.data import DataLoader, DistributedSampler
from torchvision.models import AlexNet_Weights, resnet50, ResNet50_Weights, resnet18, ResNet18_Weights, resnet101, ResNet101_Weights, VGG19_Weights, vgg19
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Best settings for most CNN training
torch.backends.cudnn.benchmark = True     # Enable auto-tuner
torch.backends.cudnn.deterministic = False  # Allow non-deterministic ops
torch.backends.cudnn.enabled = True        # Enable cuDNN (default)

resnet_writer = SummaryWriter(log_dir='AlexNet-full-model/AlexNet-WB')
sae_writer = SummaryWriter(log_dir='AlexNet-full-model/AlexNet-WB')


device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
torch.manual_seed(23)
np.random.seed(23)

class SparseAutoEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, sparsity_lambda=0.7, xavier_norm_init=True):
        super(SparseAutoEncoder, self).__init__()
        self.sparsity_lambda = sparsity_lambda
        
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GroupNorm(num_groups=16, num_channels=hidden_dim),
            nn.ReLU()
        )
        if xavier_norm_init:
            nn.init.xavier_uniform_(self.encoder[0].weight)  # Xavier initialization
            
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, input_dim),
            #nn.ReLU() #nn.Sigmoid()  # Output between 0-1
        )
        if xavier_norm_init:
            nn.init.xavier_uniform_(self.decoder[0].weight)
        

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded
    
    def kl_sparsity_penalty(self, encoded):
        # Penalize the average absolute activation
        rho_hat = torch.mean(torch.abs(encoded), dim=0)  # Average absolute activation per hidden unit
        rho = torch.ones_like(rho_hat) * self.sparsity_target  # Target sparsity value
        epsilon = 1e-8  # Small value to avoid log(0)

        # KL-divergence computation for sparsity
        kl_divergence = rho * torch.log(rho / (rho_hat + epsilon)) + (1 - rho) * torch.log((1 - rho) / (1 - rho_hat + epsilon))
        kl_divergence = torch.sum(kl_divergence)  # Sum over all hidden units

        return self.sparsity_lambda * kl_divergence

    # L1-norm sparsity penalty calculation
    def l1_sparsity_penalty(self, encoded):
        # Compute the mean of absolute values of activations
        sparsity_loss = torch.mean(torch.abs(encoded))  # Average absolute activation across all units
        return self.sparsity_lambda * sparsity_loss  # Scale by the sparsity weight

    # Loss function combining MSE (reconstruction error) and sparsity penalty
    def loss_function(self, decoded, original, encoded):
        mse_loss = F.mse_loss(decoded, original)  # Mean Squared Error for reconstruction
        sparsity_loss = self.l1_sparsity_penalty(encoded)  # Sparsity penalty for hidden layer activations
        return mse_loss + sparsity_loss  # Total loss is MSE + sparsity penalty

# 1. Define Dataset with Metadata
class WaterbirdsDataset(torch.utils.data.Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.metadata = pd.read_csv(csv_file)  # Metadata file with bird type and background
        self.root_dir = root_dir
        self.transform = transform
    
    def __len__(self):
        return len(self.metadata)
    
    def __getitem__(self, idx):
        img_path = f"{self.root_dir}/{self.metadata.iloc[idx]['img_filename']}"
        image = Image.open(img_path).convert("RGB")
        bird_type = int(self.metadata.iloc[idx]['y'])  # Waterbird=1, Landbird=0
        background = int(self.metadata.iloc[idx]['place'])  # Water=1, Land=0
        label = bird_type  # For training, we only care about bird type
        
        if self.transform:
            image = self.transform(image)
        return image, label, bird_type, background
        
# 2. Transforms and Data Preparation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load training and testing datasets for WaterBirds_Dataset
WB_train_dataset = WaterbirdsDataset(
    csv_file = '/home/ahsan/test-project/fss/split-metadata/output_metadata/train_metadata_updated.csv',
    #csv_file = '/home/ahsan/test-project/fss/split-metadata/output_metadata/train_metadata_updated_samples.csv',
    root_dir= '/home/ahsan/test-project/fss/waterbird_DB/all_images_DB/train_DB/all_birds_train',
    transform=transform
)
WB_test_dataset = WaterbirdsDataset(
    csv_file='/home/ahsan/test-project/fss/split-metadata/output_metadata/test_metadata_updated.csv',
    root_dir='/home/ahsan/test-project/fss/waterbird_DB/all_images_DB/test_DB/all_birds_test',
    transform=transform
)
WB_val_dataset_WB = WaterbirdsDataset(
    csv_file='/home/ahsan/test-project/fss/split-metadata/output_metadata/val_metadata_updated.csv',
    root_dir='/home/ahsan/test-project/fss/waterbird_DB/all_images_DB',
    transform=transform
)
batch_size = 128
WB_train_loader = DataLoader(WB_train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
WB_test_loader = DataLoader(WB_test_dataset, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)
WB_val_loader_WB = DataLoader(WB_val_dataset_WB, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
# 3. Define the AlexNet Model for Binary Classification
model = models.alexnet(weights=AlexNet_Weights.IMAGENET1K_V1) 
# Freeze earlier layers
"""for param in model.features.parameters():
    param.requires_grad = False"""
model.classifier[-1] = nn.Linear(4096, 2)
model = model.to(device)
# 4. Loss and Optimizer
criterion = nn.CrossEntropyLoss()
"""class_counts = [3693, 1102] #
total_samples = sum(class_counts)
weights = [total_samples / count for count in class_counts]
# Convert to tensor
weights_tensor = torch.tensor(weights, dtype=torch.float32).to(device)
criterion = torch.nn.CrossEntropyLoss(weight=weights_tensor, label_smoothing=0.1)
optimizer = optim.SGD([
    {'params': model.features.parameters(), 'lr': 0.0001},
    {'params': model.classifier.parameters(), 'lr': 0.001}
    ], momentum=0.9, weight_decay=5e-3)"""
optimizer = optim.Adam(model.classifier.parameters(), lr=0.001, weight_decay=5e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)
#scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
# 5. Training Loop
def train_model(model, train_loader, val_loader, criterion, optimizer, writer, num_epochs):
    best_val_acc = 0.0
    for epoch in range(num_epochs):
        model.train()
        total_loss, correct, total = 0, 0, 0
        for images, labels, _, _ in train_loader:
            images, labels = images.to(device=device), labels.to(device=device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            # Backward pass and optimization
            loss.backward()
            # Apply gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
            optimizer.step()
            total_loss += loss.item()
            correct += (outputs.argmax(1) == labels).sum().item()
            total += labels.size(0)
        
        train_loss = total_loss / len(train_loader)
        train_acc = correct / total
        # Validation... 
        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        with torch.no_grad():
            for images, labels, _, _ in val_loader:
                images, labels = images.to(device=device), labels.to(device=device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                val_correct += (outputs.argmax(1) == labels).sum().item()
                val_total += labels.size(0)
        val_loss /= len(val_loader)
        val_acc = val_correct / val_total
        scheduler.step(val_loss)
        print(f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}")
        writer.add_scalar("Loss/train", train_loss, epoch)
        writer.add_scalar("Loss/val", val_loss, epoch)
        writer.add_scalar("Accuracy/train", train_acc, epoch)
        writer.add_scalar("Accuracy/val", val_acc, epoch)
        # Save the best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'AlexNet-full-model-without-relu_best_model.pth')
            print(f"Best model saved at epoch {epoch+1} with Val Acc: {val_acc:.4f}")

print("Training AlexNet full-model on WaterBirds dataset:.......................")
print("Size of WB_train_loader : ", len(WB_train_loader)*batch_size)
train_model(model, WB_train_loader, WB_val_loader_WB, criterion, optimizer, resnet_writer, num_epochs=50)
print("Training complete:: ")
torch.save(model.state_dict(), 'AlexNet-full-model-without-relu_100_epochs.pth')
# 6. Testing for Four sub-group Classes
def test_model(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    model_predictions = []
    #with torch.enable_grad():
    for images, lables, _, _  in dataloader:
        images = images.to(device)
        output = model(images)
        prediction = torch.argmax(torch.nn.functional.softmax(output, dim=1)).item() # predictions = torch.argmax(outputs, dim=1) 
        model_predictions.append(prediction)
    
    print("Results:", "*" * 50)
    results = {"Landbird_Land": 0, "Landbird_Water": 0, "Waterbird_Land": 0, "Waterbird_Water": 0}
    counts = {"Landbird_Land": 0, "Landbird_Water": 0, "Waterbird_Land": 0, "Waterbird_Water": 0}
    
    with torch.no_grad():
        for (images, labels, bird_type, background), pred in zip(dataloader, model_predictions):
            if bird_type.item() == 0 and background.item() == 0:
                results["Landbird_Land"] += (pred == 0)
                counts["Landbird_Land"] += 1
            elif bird_type.item() == 0 and background.item() == 1:
                results["Landbird_Water"] += (pred == 0)
                counts["Landbird_Water"] += 1
            elif bird_type.item() == 1 and background.item() == 0:
                results["Waterbird_Land"] += (pred == 1)
                counts["Waterbird_Land"] += 1
            elif bird_type.item() == 1 and background.item() == 1:
                results["Waterbird_Water"] += (pred == 1)
                counts["Waterbird_Water"] += 1

        # Calculate accuracies for each group
        for key in results:
            if counts[key] > 0:
                print(f"Accuracy for {key}: {results[key] / counts[key]:.2f}")
            else:
                print(f"No samples for {key}")
                
# Example usage
print("Testing Model::")
print("size of WB_test_loader : ", len(WB_test_loader)*batch_size)
test_model(model, WB_test_loader, device=device)

# Optional: Save activations (for last layer before fc)
def save_activations(model, dataloader, save_path):
    model.eval()
    activations = []
    def hook_fn(module, input, output):
        activations.append(output.detach().cpu())
    handle = model.classifier[4].register_forward_hook(hook_fn)  # fc2 linear layer (pre-ReLU)
    with torch.no_grad():
        for images, _, _, _ in dataloader:
            model(images.to(device))
    handle.remove()
    act_tensor = torch.cat(activations, dim=0).squeeze()  # shape: (N, 4096)
    np.save(save_path + ".npy", act_tensor.numpy())
    pd.DataFrame(act_tensor.numpy()).to_csv(save_path + ".csv", index=False)
# load best model
model = models.alexnet(pretrained=False)
model.classifier[-1] = nn.Linear(4096, 2)
model.load_state_dict(torch.load("AlexNet-full-model-without-relu_best_model.pth", map_location=device))
# Freeze all layers except `classifier[5]` (ReLU) and `classifier[6]` (fc3)
for name, param in model.named_parameters():
    if name.startswith("classifier.6"):
        param.requires_grad = True
    elif name.startswith("classifier.5"):  # ReLU does not have trainable params
        param.requires_grad = True
    else:
        param.requires_grad = False
# Set the model to evaluation mode
model.to(device)
model.eval()
# Save activations for test set
save_activations(model, WB_train_loader, "AlexNet-full-model-without-relu-activations_100_epochs")
print("WB_train_activations save successfully!")
# SAE Train Loop
def train_sae(model, data, epochs, lr, batch_size, writer):
    model.to(device)
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)

    dataloader = DataLoader(data, batch_size=batch_size, shuffle=True)
    
    for epoch in range(epochs):
        epoch_loss = 0
        for batch in dataloader:
            batch = batch[0].to(device)
            optimizer.zero_grad()
            encoded, decoded = model(batch)
            loss = model.loss_function(decoded, batch, encoded)
            #loss = criterion(decoded, batch) + sparsity_loss(encoded, sparsity_lambda)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
            optimizer.step()
            epoch_loss += loss.item()
        avg_loss = epoch_loss / len(dataloader)
        sae_writer.add_scalar("Loss/train", avg_loss, epoch)
        print(f"Epoch [{epoch+1}/{epochs}], SAE Loss: {avg_loss:.4f}")
        
# Load previously saved activations
activations_np = np.load("AlexNet-full-model-without-relu-activations_100_epochs.npy")
activations = torch.tensor(activations_np, dtype=torch.float32)
# Wrap into a TensorDataset
activation_dataset = TensorDataset(activations)
# Define model
sae = SparseAutoEncoder(input_dim=4096, hidden_dim=8192)
train_sae(sae, activation_dataset, epochs=100, lr=0.001, batch_size=128, writer=sae_writer)
torch.save(sae.state_dict(), 'AlexNet-full-model-without-relu-SAE_100_epochs.pth')
print("Training SAE complete!!")
resnet_writer.close()
sae_writer.close()

Training AlexNet full-model on WaterBirds dataset:.......................
Size of WB_train_loader :  4864




Epoch [1/50] Train Loss: 0.3802, Val Loss: 0.6631, Train Acc: 0.8922, Val Acc: 0.7725
Best model saved at epoch 1 with Val Acc: 0.7725
Epoch [2/50] Train Loss: 0.1978, Val Loss: 0.8478, Train Acc: 0.9297, Val Acc: 0.6862
Epoch [3/50] Train Loss: 0.1782, Val Loss: 0.6739, Train Acc: 0.9391, Val Acc: 0.7234
Epoch [4/50] Train Loss: 0.1760, Val Loss: 0.6856, Train Acc: 0.9391, Val Acc: 0.6862
Epoch [5/50] Train Loss: 0.1614, Val Loss: 0.6523, Train Acc: 0.9431, Val Acc: 0.7198
Epoch [6/50] Train Loss: 0.1659, Val Loss: 0.6383, Train Acc: 0.9387, Val Acc: 0.7437
Epoch [7/50] Train Loss: 0.1566, Val Loss: 0.6795, Train Acc: 0.9416, Val Acc: 0.7090
Epoch [8/50] Train Loss: 0.1538, Val Loss: 0.7082, Train Acc: 0.9437, Val Acc: 0.7210
Epoch [9/50] Train Loss: 0.1577, Val Loss: 0.7755, Train Acc: 0.9426, Val Acc: 0.6838
Epoch [10/50] Train Loss: 0.1495, Val Loss: 0.5547, Train Acc: 0.9481, Val Acc: 0.7365
Epoch [11/50] Train Loss: 0.1511, Val Loss: 0.7437, Train Acc: 0.9477, Val Acc: 0.7162
Epo

  model.load_state_dict(torch.load("AlexNet-full-model-without-relu_best_model.pth", map_location=device))


WB_train_activations save successfully!
Epoch [1/100], SAE Loss: 2.0551
Epoch [2/100], SAE Loss: 1.0325
Epoch [3/100], SAE Loss: 0.9375
Epoch [4/100], SAE Loss: 0.9209
Epoch [5/100], SAE Loss: 0.8642
Epoch [6/100], SAE Loss: 0.8487
Epoch [7/100], SAE Loss: 0.7289
Epoch [8/100], SAE Loss: 0.7259
Epoch [9/100], SAE Loss: 0.8982
Epoch [10/100], SAE Loss: 0.7598
Epoch [11/100], SAE Loss: 0.6499
Epoch [12/100], SAE Loss: 0.6644
Epoch [13/100], SAE Loss: 0.6580
Epoch [14/100], SAE Loss: 0.6451
Epoch [15/100], SAE Loss: 0.5876
Epoch [16/100], SAE Loss: 0.6140
Epoch [17/100], SAE Loss: 0.5283
Epoch [18/100], SAE Loss: 0.5481
Epoch [19/100], SAE Loss: 0.5793
Epoch [20/100], SAE Loss: 0.5427
Epoch [21/100], SAE Loss: 0.5743
Epoch [22/100], SAE Loss: 0.5239
Epoch [23/100], SAE Loss: 0.5411
Epoch [24/100], SAE Loss: 0.6157
Epoch [25/100], SAE Loss: 0.5775
Epoch [26/100], SAE Loss: 0.4749
Epoch [27/100], SAE Loss: 0.5014
Epoch [28/100], SAE Loss: 0.5455
Epoch [29/100], SAE Loss: 0.5110
Epoch [30/10

In [None]:
"""when freeze earlier layers"""
Accuracy for Landbird_Land: 0.97
Accuracy for Landbird_Water: 0.32
Accuracy for Waterbird_Land: 0.29
Accuracy for Waterbird_Water: 0.95

In [6]:
# Save activations (for last layer before fc)
def save_activations(model, dataloader, save_path):
    model.eval()
    activations = []
    def hook_fn(module, input, output):
        activations.append(output.detach().cpu())
    handle = model.classifier[4].register_forward_hook(hook_fn)  # fc2 linear layer (pre-ReLU)
    with torch.no_grad():
        for images, _, _, _ in dataloader:
            model(images.to(device))
    handle.remove()
    act_tensor = torch.cat(activations, dim=0).squeeze()  # shape: (N, 4096)
    np.save(save_path + ".npy", act_tensor.numpy())
    pd.DataFrame(act_tensor.numpy()).to_csv(save_path + ".csv", index=False)
# load best model
model = models.alexnet(weights=None)
model.classifier[6] = nn.Linear(4096, 2)
model.load_state_dict(torch.load("AlexNet-full-model-without-relu_best_model.pth", map_location=device))
# Freeze all layers except `classifier[5]` (ReLU) and `classifier[6]` (fc3)
""" for name, param in model.named_parameters():
    if name.startswith("classifier.6"):
        param.requires_grad = True
    elif name.startswith("classifier.5"):  # ReLU does not have trainable params
        param.requires_grad = True
    else:
        param.requires_grad = False
# Set the model to evaluation mode """
model.eval()
model.to(device)
# Save activations for test set
save_activations(model, WB_train_loader, "AlexNet-full-model-without-relu-activations_100_epochs")
print("WB_train_activations save successfully!")
# SAE Train Loop
def train_sae(model, data, epochs, lr, batch_size, writer):
    model.to(device)
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)
    dataloader = DataLoader(data, batch_size=batch_size, shuffle=True)
    for epoch in range(epochs):
        epoch_loss = 0
        for batch in dataloader:
            batch = batch[0].to(device)
            optimizer.zero_grad()
            encoded, decoded = model(batch)
            loss = model.loss_function(decoded, batch, encoded)
            #loss = criterion(decoded, batch) + sparsity_loss(encoded, sparsity_lambda)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
            optimizer.step()
            epoch_loss += loss.item()
        avg_loss = epoch_loss / len(dataloader)
        sae_writer.add_scalar("Loss/train", avg_loss, epoch)
        print(f"Epoch [{epoch+1}/{epochs}], SAE Loss: {avg_loss:.4f}")
        
# Load previously saved activations
activations_np = np.load("AlexNet-full-model-without-relu-activations_100_epochs.npy")
activations = torch.tensor(activations_np, dtype=torch.float32)
# Wrap into a TensorDataset
activation_dataset = TensorDataset(activations)
# Define model
sae = SparseAutoEncoder(input_dim=4096, hidden_dim=8192)
train_sae(sae, activation_dataset, epochs=100, lr=0.001, batch_size=128, writer=sae_writer)
torch.save(sae.state_dict(), 'AlexNet-full-model-without-relu-SAE_100_epochs.pth')
print("Training SAE complete!!")
resnet_writer.close() 
sae_writer.close()

  model.load_state_dict(torch.load("AlexNet-full-model-without-relu_best_model.pth", map_location=device))


WB_train_activations save successfully!
Epoch [1/100], SAE Loss: 3.3522
Epoch [2/100], SAE Loss: 1.8870
Epoch [3/100], SAE Loss: 1.3149
Epoch [4/100], SAE Loss: 1.0690
Epoch [5/100], SAE Loss: 1.0842
Epoch [6/100], SAE Loss: 0.9398
Epoch [7/100], SAE Loss: 0.9240
Epoch [8/100], SAE Loss: 0.8543
Epoch [9/100], SAE Loss: 0.8721
Epoch [10/100], SAE Loss: 0.8116
Epoch [11/100], SAE Loss: 0.8646
Epoch [12/100], SAE Loss: 0.7938
Epoch [13/100], SAE Loss: 0.7932
Epoch [14/100], SAE Loss: 0.8029
Epoch [15/100], SAE Loss: 0.8503
Epoch [16/100], SAE Loss: 0.7692
Epoch [17/100], SAE Loss: 0.7935
Epoch [18/100], SAE Loss: 0.8370
Epoch [19/100], SAE Loss: 0.7039
Epoch [20/100], SAE Loss: 0.7399
Epoch [21/100], SAE Loss: 0.7452
Epoch [22/100], SAE Loss: 0.7966
Epoch [23/100], SAE Loss: 0.7234
Epoch [24/100], SAE Loss: 0.6854
Epoch [25/100], SAE Loss: 0.7583
Epoch [26/100], SAE Loss: 0.7632
Epoch [27/100], SAE Loss: 0.7741
Epoch [28/100], SAE Loss: 0.6767
Epoch [29/100], SAE Loss: 0.6793
Epoch [30/10

In [7]:
def test_model(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    model_predictions = []
    #with torch.enable_grad():
    for images, lables, _, _  in dataloader:
        images = images.to(device)
        output = model(images)
        prediction = torch.argmax(torch.nn.functional.softmax(output, dim=1)).item() # predictions = torch.argmax(outputs, dim=1) 
        model_predictions.append(prediction)
    
    print("Results:", "*" * 50)
    results = {"Landbird_Land": 0, "Landbird_Water": 0, "Waterbird_Land": 0, "Waterbird_Water": 0}
    counts = {"Landbird_Land": 0, "Landbird_Water": 0, "Waterbird_Land": 0, "Waterbird_Water": 0}
    
    with torch.no_grad():
        for (images, labels, bird_type, background), pred in zip(dataloader, model_predictions):
            if bird_type.item() == 0 and background.item() == 0:
                results["Landbird_Land"] += (pred == 0)
                counts["Landbird_Land"] += 1
            elif bird_type.item() == 0 and background.item() == 1:
                results["Landbird_Water"] += (pred == 0)
                counts["Landbird_Water"] += 1
            elif bird_type.item() == 1 and background.item() == 0:
                results["Waterbird_Land"] += (pred == 1)
                counts["Waterbird_Land"] += 1
            elif bird_type.item() == 1 and background.item() == 1:
                results["Waterbird_Water"] += (pred == 1)
                counts["Waterbird_Water"] += 1

        # Calculate accuracies for each group
        for key in results:
            if counts[key] > 0:
                print(f"Accuracy for {key}: {results[key] / counts[key]:.2f} ({results[key]}/{counts[key]})")
            else:
                print(f"No samples for {key}")
                
# Example usage
print("Testing Model::")
print("size of WB_test_loader : ", len(WB_test_loader))
test_model(model, WB_test_loader, device=device)

Testing Model::
size of WB_test_loader :  5601
Results: **************************************************
Accuracy for Landbird_Land: 0.99 (2229/2255)
Accuracy for Landbird_Water: 0.58 (1188/2062)
Accuracy for Waterbird_Land: 0.24 (151/642)
Accuracy for Waterbird_Water: 0.89 (571/642)


In [4]:
def test_model(model, dataloader, device):
    model.eval()
    
    results = {"Landbird_Land": 0, "Landbird_Water": 0, "Waterbird_Land": 0, "Waterbird_Water": 0}
    counts = {"Landbird_Land": 0, "Landbird_Water": 0, "Waterbird_Land": 0, "Waterbird_Water": 0}
    
    with torch.no_grad():
        for images, labels, bird_type, background in dataloader:
            images = images.to(device)
            outputs = model(images)
            predictions = torch.argmax(outputs, dim=1).cpu()
            bird_type = bird_type.cpu()
            background = background.cpu()

            for i in range(len(predictions)):
                b_type = bird_type[i].item()
                bg = background[i].item()
                pred = predictions[i].item()

                if b_type == 0 and bg == 0:
                    results["Landbird_Land"] += (pred == 0)
                    counts["Landbird_Land"] += 1
                elif b_type == 0 and bg == 1:
                    results["Landbird_Water"] += (pred == 0)
                    counts["Landbird_Water"] += 1
                elif b_type == 1 and bg == 0:
                    results["Waterbird_Land"] += (pred == 1)
                    counts["Waterbird_Land"] += 1
                elif b_type == 1 and bg == 1:
                    results["Waterbird_Water"] += (pred == 1)
                    counts["Waterbird_Water"] += 1

    # Accuracy results
    print("\nResults:", "*" * 50)
    for key in results:
        if counts[key] > 0:
            accuracy = results[key] / counts[key]
            print(f"Accuracy for {key}: {accuracy:.2f} ({results[key]}/{counts[key]})")
        else:
            print(f"No samples for {key}")

# Example usage
print("Testing Model::")
print("size of WB_test_loader : ", len(WB_test_loader))
test_model(model, WB_test_loader, device=device)


Testing Model::
size of WB_test_loader :  5601

Results: **************************************************
Accuracy for Landbird_Land: 0.99 (2229/2255)
Accuracy for Landbird_Water: 0.58 (1188/2062)
Accuracy for Waterbird_Land: 0.24 (151/642)
Accuracy for Waterbird_Water: 0.89 (571/642)
