In [12]:
import torch
import os
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms, datasets
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import DataLoader, Subset
import pandas as pd
import numpy as np
from PIL import *
import PIL.Image
from torch.utils.data import DataLoader, DistributedSampler
from torchvision.models import AlexNet_Weights, resnet50, ResNet50_Weights, resnet18, ResNet18_Weights, resnet101, ResNet101_Weights, VGG19_Weights, vgg19
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.decomposition import PCA
from torch.utils.tensorboard import SummaryWriter


# Best settings for most CNN training
torch.backends.cudnn.benchmark = True     
torch.backends.cudnn.deterministic = False  
torch.backends.cudnn.enabled = True 

resnet_writer = SummaryWriter(log_dir='ResNet-50-full-model/resnet_WB')
sae_writer = SummaryWriter(log_dir='ResNet-50-full-model/sae_WB')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(1111)
np.random.seed(1111)

class SparseAutoEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, sparsity_lambda=0.70, xavier_norm_init=True):
        super(SparseAutoEncoder, self).__init__()
        self.sparsity_lambda = sparsity_lambda
        
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GroupNorm(num_groups=16, num_channels=hidden_dim),
            nn.ReLU(),
        )
        if xavier_norm_init:
            nn.init.xavier_uniform_(self.encoder[0].weight)  # Xavier initialization
            
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, input_dim),
            #nn.ReLU() #nn.Sigmoid()  # Output between 0-1
        )
        if xavier_norm_init:
            nn.init.xavier_uniform_(self.decoder[0].weight)
        

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded
    
    def kl_sparsity_penalty(self, encoded):
        # Penalize the average absolute activation
        rho_hat = torch.mean(torch.abs(encoded), dim=0)  # Average absolute activation per hidden unit
        rho = torch.ones_like(rho_hat) * self.sparsity_target  # Target sparsity value
        epsilon = 1e-8  # Small value to avoid log(0)
        
        # KL-divergence computation for sparsity
        kl_divergence = rho * torch.log(rho / (rho_hat + epsilon)) + (1 - rho) * torch.log((1 - rho) / (1 - rho_hat + epsilon))
        kl_divergence = torch.sum(kl_divergence)  # Sum over all hidden units

        return self.sparsity_lambda * kl_divergence

    # L1-norm sparsity penalty calculation
    def l1_sparsity_penalty(self, encoded):
        # Compute the mean of absolute values of activations
        sparsity_loss = torch.mean(torch.abs(encoded))  # Average absolute activation across all units
        return self.sparsity_lambda * sparsity_loss  # Scale by the sparsity weight
    
    def kl_divergence(self, rho_hat, rho=0.05): # added
        rho_hat = torch.clamp(rho_hat, 1e-8, 1-1e-8)
        return torch.sum(rho * torch.log(rho / rho_hat) + (1-rho)*torch.log((1-rho)/(1-rho_hat)))

    # Loss function combining MSE (reconstruction error) and sparsity penalty
    def loss_function(self, decoded, original, encoded):
        mse_loss = F.mse_loss(decoded, original)  # Mean Squared Error for reconstruction
        sparsity_loss = self.l1_sparsity_penalty(encoded)  # Sparsity penalty for hidden layer activations
        return mse_loss + sparsity_loss  # Total loss is MSE + sparsity penalty

# 1. Define Dataset with Metadata
class WaterbirdsDataset(torch.utils.data.Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.metadata = pd.read_csv(csv_file)  # Metadata file with bird type and background
        self.root_dir = root_dir
        self.transform = transform
    
    def __len__(self):
        return len(self.metadata)
    
    def __getitem__(self, idx):
        img_path = f"{self.root_dir}/{self.metadata.iloc[idx]['img_filename']}"
        image = Image.open(img_path).convert("RGB")
        bird_type = self.metadata.iloc[idx]['y']  # Waterbird=1, Landbird=0
        background = self.metadata.iloc[idx]['place']  # Water=1, Land=0
        label = bird_type  # For training, we only care about bird type
        
        if self.transform:
            image = self.transform(image)
        return image, label, bird_type, background
        
# 2. Transforms and Data Preparation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load training and testing datasets for WaterBirds_Dataset
WB_train_dataset = WaterbirdsDataset(
    csv_file = '/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/train_metadata_updated.csv',
    #csv_file = '/home/ahsan/test-project/fss/split-metadata/output_metadata/train_metadata_updated_samples.csv',
    root_dir= '/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB/train_DB/all_birds_train',
    transform=transform
)
WB_test_dataset = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/test_metadata_updated.csv',
    #csv_file='/home/ahsan/test-project/fss/split-metadata/output_metadata/test_metadata_updated_samples.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB/test_DB/all_birds_test',
    transform=transform
)
WB_val_dataset_WB = WaterbirdsDataset(
    csv_file='/run/determined/workdir/SCLearning_WB/split-metadata/output_metadata/val_metadata_WB+LB.csv',
    root_dir='/run/determined/workdir/SCLearning_WB/WB_DB/all_images_DB',
    transform=transform
)
batch_size = 128
WB_train_loader = DataLoader(WB_train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
WB_test_loader = DataLoader(WB_test_dataset, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)
WB_val_loader_WB = DataLoader(WB_val_dataset_WB, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

# 3. Define the ResNet Model for Binary Classification
num_classes = 2
model = models.resnet50(weights=ResNet50_Weights.DEFAULT)
model.fc = nn.Linear(model.fc.in_features, num_classes)  # Output for 2 classes
model = model.to(device)
# 4. Loss and Optimizer
criterion = nn.CrossEntropyLoss()
#optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-3)
optimizer = torch.optim.SGD(model.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4, nesterov=True)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2)

# 5. Training Loop
def train_model(model, train_loader, val_loader, criterion, optimizer, writer, num_epochs):
    best_val_acc = 0.0
    for epoch in range(num_epochs):
        model.train()
        total_loss, correct, total = 0, 0, 0
        for images, labels, _, _ in train_loader:
            images, labels = images.to(device=device), labels.to(device=device)
            optimizer.zero_grad()
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            # Backward pass and optimization
            loss.backward()
            # Apply gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
            optimizer.step()
            total_loss += loss.item()
            correct += (outputs.argmax(1) == labels).sum().item()
            total += labels.size(0)
        
        train_loss = total_loss / len(train_loader)
        train_acc = correct / total
        # Validation... 
        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        with torch.no_grad():
            for images, labels, _, _ in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                val_correct += (outputs.argmax(1) == labels).sum().item()
                val_total += labels.size(0)
        val_loss /= len(val_loader)
        val_acc = val_correct / val_total
        #scheduler.step(val_loss)
        print(f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}")
        writer.add_scalar("Loss/train", train_loss, epoch)
        writer.add_scalar("Loss/val", val_loss, epoch)
        writer.add_scalar("Accuracy/train", train_acc, epoch)
        writer.add_scalar("Accuracy/val", val_acc, epoch)
        # Save the best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'ResNet_WB_best.pth')
            print(f"Best model saved at epoch {epoch+1} with Val Acc: {val_acc:.4f}")

print("Training Model::")
print("Size of WB_train_loader : ", len(WB_train_loader)*batch_size)
train_model(model, WB_train_loader, WB_val_loader_WB, criterion, optimizer, resnet_writer, num_epochs=200)
print("Training complete:: ")
torch.save(model.state_dict(), 'ResNet_WB.pth')
# 6. Testing for Four Classes
def test_model(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    model_predictions = []
    #with torch.enable_grad():
    for images, lables, bird_type, background  in dataloader:
        images = images.to(device)
        output = model(images)
        prediction = torch.argmax(torch.nn.functional.softmax(output, dim=1)).item() # predictions = torch.argmax(outputs, dim=1) 
        model_predictions.append(prediction)
    
    print("Results:", "*" * 50)
    results = {"Landbird_Land": 0, "Landbird_Water": 0, "Waterbird_Land": 0, "Waterbird_Water": 0}
    counts = {"Landbird_Land": 0, "Landbird_Water": 0, "Waterbird_Land": 0, "Waterbird_Water": 0}
    
    with torch.no_grad():
        for (images, labels, bird_type, background), pred in zip(dataloader, model_predictions):
            if bird_type.item() == 0 and background.item() == 0:
                results["Landbird_Land"] += (pred == 0)
                counts["Landbird_Land"] += 1
            elif bird_type.item() == 0 and background.item() == 1:
                results["Landbird_Water"] += (pred == 0)
                counts["Landbird_Water"] += 1
            elif bird_type.item() == 1 and background.item() == 0:
                results["Waterbird_Land"] += (pred == 1)
                counts["Waterbird_Land"] += 1
            elif bird_type.item() == 1 and background.item() == 1:
                results["Waterbird_Water"] += (pred == 1)
                counts["Waterbird_Water"] += 1

        # Calculate accuracies for each group
        for key in results:
            if counts[key] > 0:
                print(f"Accuracy for {key}: {results[key] / counts[key]:.2f}")
            else:
                print(f"No samples for {key}")
                
# Example usage
print("Testing Model::")
print("size of WB_test_loader : ", len(WB_test_loader))
test_model(model, WB_test_loader, device=device)

def save_activations(model, dataloader, save_path):
    model.eval()
    activations = []
    def hook_fn(module, input, output):
        activations.append(output.detach().cpu())
    handle = model.avgpool.register_forward_hook(hook_fn)
    with torch.no_grad():
        for images, _, _, _ in dataloader:
            model(images.to(device))
    handle.remove()    
    act_tensor = torch.cat(activations, dim=0).squeeze()  # shape: (N, 2048)
    np.save(save_path + ".npy", act_tensor.numpy())
    pd.DataFrame(act_tensor.numpy()).to_csv(save_path + ".csv", index=False)

# Load best model
model = models.resnet50(weights=None)
model.fc = nn.Linear(model.fc.in_features, num_classes) 
model.load_state_dict(torch.load("ResNet_WB.pth", map_location=device))
model.eval()
model.to(device)
# Save activations for test set
save_activations(model, WB_train_loader, "ResNet_WB_act")
print("WB_train_activations save successfully!")
# SAE Train Loop
def train_sae(model, data, epochs, lr, batch_size, writer):
    model.train()
    model.to(device)
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)

    dataloader = DataLoader(data, batch_size=batch_size, shuffle=True)
    
    for epoch in range(epochs):
        epoch_loss = 0
        for batch in dataloader:
            #batch = torch.stack(batch).to(device)
            batch = batch[0].to(device)
            optimizer.zero_grad()
            encoded, decoded = model(batch)
            loss = model.loss_function(decoded, batch, encoded)
            #loss = criterion(decoded, batch) + sparsity_loss(encoded, sparsity_lambda)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
            optimizer.step()
            epoch_loss += loss.item()
        avg_loss = epoch_loss / len(dataloader)
        sae_writer.add_scalar("Loss/train", avg_loss, epoch)
        print(f"Epoch [{epoch+1}/{epochs}], SAE Loss: {avg_loss:.4f}")
        
# Load previously saved activations
activations_np = np.load("ResNet_WB_act.npy")
activations = torch.tensor(activations_np, dtype=torch.float32)
# Wrap into a TensorDataset
activation_dataset = TensorDataset(activations)
# Define model
sae = SparseAutoEncoder(input_dim=2048, hidden_dim=8000)
train_sae(sae, activation_dataset, epochs=300, lr=0.001, batch_size=32, writer=sae_writer)
torch.save(sae.state_dict(), 'ResNet_WB_SAE.pth')
print("Training SAE complete!!")
resnet_writer.close()
sae_writer.close()

Training Model::
Size of WB_train_loader :  4864
Epoch [1/200] Train Loss: 0.1534, Val Loss: 0.4413, Train Acc: 0.9347, Val Acc: 0.8180


RuntimeError: [enforce fail at inline_container.cc:595] . unexpected pos 2760192 vs 2760088