In [1]:
import torch
import os
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms, datasets
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import DataLoader, Subset
import pandas as pd
import numpy as np
from PIL import *
import PIL.Image
from torch.utils.data import DataLoader, DistributedSampler
from torchvision.models import AlexNet_Weights, resnet50, ResNet50_Weights, resnet18, ResNet18_Weights, resnet101, ResNet101_Weights, VGG19_Weights, vgg19
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Best settings for most CNN training
torch.backends.cudnn.benchmark = True     # Enable auto-tuner
torch.backends.cudnn.deterministic = False  # Allow non-deterministic ops
torch.backends.cudnn.enabled = True        # Enable cuDNN (default)

resnet_writer = SummaryWriter(log_dir='ResNet-50-full-model/resnet_isic')
sae_writer = SummaryWriter(log_dir='ResNet-50-full-model/sae_isic')


device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
torch.manual_seed(23)
np.random.seed(23)

class SparseAutoEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, sparsity_lambda=0.70, xavier_norm_init=True):
        super(SparseAutoEncoder, self).__init__()
        self.sparsity_lambda = sparsity_lambda
        
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU()
        )
        if xavier_norm_init:
            nn.init.xavier_uniform_(self.encoder[0].weight)  # Xavier initialization
            
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, input_dim),
            #nn.ReLU() #nn.Sigmoid()  # Output between 0-1
        )
        if xavier_norm_init:
            nn.init.xavier_uniform_(self.decoder[0].weight)
        

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded
    
    def kl_sparsity_penalty(self, encoded):
        # Penalize the average absolute activation
        rho_hat = torch.mean(torch.abs(encoded), dim=0)  # Average absolute activation per hidden unit
        rho = torch.ones_like(rho_hat) * self.sparsity_target  # Target sparsity value
        epsilon = 1e-8  # Small value to avoid log(0)

        # KL-divergence computation for sparsity
        kl_divergence = rho * torch.log(rho / (rho_hat + epsilon)) + (1 - rho) * torch.log((1 - rho) / (1 - rho_hat + epsilon))
        kl_divergence = torch.sum(kl_divergence)  # Sum over all hidden units

        return self.sparsity_lambda * kl_divergence

    # L1-norm sparsity penalty calculation
    def l1_sparsity_penalty(self, encoded):
        # Compute the mean of absolute values of activations
        sparsity_loss = torch.mean(torch.abs(encoded))  # Average absolute activation across all units
        return self.sparsity_lambda * sparsity_loss  # Scale by the sparsity weight

    # Loss function combining MSE (reconstruction error) and sparsity penalty
    def loss_function(self, decoded, original, encoded):
        mse_loss = F.mse_loss(decoded, original)  # Mean Squared Error for reconstruction
        sparsity_loss = self.l1_sparsity_penalty(encoded)  # Sparsity penalty for hidden layer activations
        return mse_loss + sparsity_loss  # Total loss is MSE + sparsity penalty

class ISIC_Dataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        
        self.metadata = pd.read_csv(csv_file)  # Load metadata
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        img_path = f"{self.root_dir}/{self.metadata.iloc[idx]['isic_id']}"  # Image filename
        image = Image.open(img_path).convert("RGB")  # Load image and convert to RGB
        # Extract label (benign=0, malignant=1)
        benign_malignant = int(self.metadata.iloc[idx]['benign_malignant'])
        patches = int(self.metadata.iloc[idx]['patches'])  # Assuming 1=Has patches, 0=No patches
        label = benign_malignant # For training, we only care about benign_malignant
        # Apply transforms
        if self.transform:
            image = self.transform(image)

        return image, label, benign_malignant, patches # Return transformed image and label
# 2. Transforms and Data Preparation
transform = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Load training and testing datasets for ISIC_dataset
ISIC_train_dataset = ISIC_Dataset(
    csv_file='/home/ahsan/test-project/fss/ISIC/ISIC_metadata/train_metadata.csv',
    #csv_file='/home/ahsan/test-project/fss/ISIC/ISIC_metadata/train_metadata_sample_data.csv',
    root_dir='/home/ahsan/test-project/fss/ISIC/ISIC_224_Dataset/isic_train',
    transform=transform
)
ISIC_test_dataset = ISIC_Dataset(
    csv_file='/home/ahsan/test-project/fss/ISIC/ISIC_metadata/test_metadata.csv',
    #csv_file='/home/ahsan/test-project/fss/ISIC/ISIC_metadata/test_metadata_sample_data.csv',
    root_dir='/home/ahsan/test-project/fss/ISIC/ISIC_224_Dataset/isic_test',
    transform=transform
)

ISIC_val_dataset = ISIC_Dataset(
    #csv_file='/home/ahsan/test-project/fss/ISIC/val/metadata_ISIC_test/benign_no_yes_patch_100.csv',
    csv_file='/home/ahsan/test-project/fss/ISIC/ISIC_metadata/val_metadata_all.csv',
    root_dir= '/home/ahsan/test-project/fss/ISIC/ISIC_224_Dataset/isic_224/raw_224',
    transform=transform
)
batch_size= 128
ISIC_train_loader = DataLoader(ISIC_train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
ISIC_test_loader = DataLoader(ISIC_test_dataset, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)
ISIC_val_loader = DataLoader(ISIC_val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
# 3. Define the ResNet Model for Binary Classification

In [None]:
import torch
import os
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms, datasets
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import DataLoader, Subset
import pandas as pd
import numpy as np
from PIL import *
import PIL.Image
from torch.utils.data import DataLoader, DistributedSampler
from torchvision.models import AlexNet_Weights, resnet50, ResNet50_Weights, resnet18, ResNet18_Weights, resnet101, ResNet101_Weights, VGG19_Weights, vgg19
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Best settings for most CNN training
torch.backends.cudnn.benchmark = True     # Enable auto-tuner
torch.backends.cudnn.deterministic = False  # Allow non-deterministic ops
torch.backends.cudnn.enabled = True        # Enable cuDNN (default)

resnet_writer = SummaryWriter(log_dir='ResNet-50-full-model/resnet_isic')
sae_writer = SummaryWriter(log_dir='ResNet-50-full-model/sae_isic')


device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
torch.manual_seed(23)
np.random.seed(23)

class SparseAutoEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, sparsity_lambda=0.70, xavier_norm_init=True):
        super(SparseAutoEncoder, self).__init__()
        self.sparsity_lambda = sparsity_lambda
        
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU()
        )
        if xavier_norm_init:
            nn.init.xavier_uniform_(self.encoder[0].weight)  # Xavier initialization
            
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, input_dim),
            #nn.ReLU() #nn.Sigmoid()  # Output between 0-1
        )
        if xavier_norm_init:
            nn.init.xavier_uniform_(self.decoder[0].weight)
        

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded
    
    def kl_sparsity_penalty(self, encoded):
        # Penalize the average absolute activation
        rho_hat = torch.mean(torch.abs(encoded), dim=0)  # Average absolute activation per hidden unit
        rho = torch.ones_like(rho_hat) * self.sparsity_target  # Target sparsity value
        epsilon = 1e-8  # Small value to avoid log(0)

        # KL-divergence computation for sparsity
        kl_divergence = rho * torch.log(rho / (rho_hat + epsilon)) + (1 - rho) * torch.log((1 - rho) / (1 - rho_hat + epsilon))
        kl_divergence = torch.sum(kl_divergence)  # Sum over all hidden units

        return self.sparsity_lambda * kl_divergence

    # L1-norm sparsity penalty calculation
    def l1_sparsity_penalty(self, encoded):
        # Compute the mean of absolute values of activations
        sparsity_loss = torch.mean(torch.abs(encoded))  # Average absolute activation across all units
        return self.sparsity_lambda * sparsity_loss  # Scale by the sparsity weight

    # Loss function combining MSE (reconstruction error) and sparsity penalty
    def loss_function(self, decoded, original, encoded):
        mse_loss = F.mse_loss(decoded, original)  # Mean Squared Error for reconstruction
        sparsity_loss = self.l1_sparsity_penalty(encoded)  # Sparsity penalty for hidden layer activations
        return mse_loss + sparsity_loss  # Total loss is MSE + sparsity penalty

class ISIC_Dataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        
        self.metadata = pd.read_csv(csv_file)  # Load metadata
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        img_path = f"{self.root_dir}/{self.metadata.iloc[idx]['isic_id']}"  # Image filename
        image = Image.open(img_path).convert("RGB")  # Load image and convert to RGB
        # Extract label (benign=0, malignant=1)
        benign_malignant = int(self.metadata.iloc[idx]['benign_malignant'])
        patches = int(self.metadata.iloc[idx]['patches'])  # Assuming 1=Has patches, 0=No patches
        label = benign_malignant # For training, we only care about benign_malignant
        # Apply transforms
        if self.transform:
            image = self.transform(image)

        return image, label, benign_malignant, patches # Return transformed image and label
# 2. Transforms and Data Preparation
transform = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Load training and testing datasets for ISIC_dataset
ISIC_train_dataset = ISIC_Dataset(
    csv_file='/home/ahsan/test-project/fss/ISIC/ISIC_metadata/train_metadata.csv',
    #csv_file='/home/ahsan/test-project/fss/ISIC/ISIC_metadata/train_metadata_sample_data.csv',
    root_dir='/home/ahsan/test-project/fss/ISIC/ISIC_224_Dataset/isic_train',
    transform=transform
)
ISIC_test_dataset = ISIC_Dataset(
    csv_file='/home/ahsan/test-project/fss/ISIC/ISIC_metadata/test_metadata.csv',
    #csv_file='/home/ahsan/test-project/fss/ISIC/ISIC_metadata/test_metadata_sample_data.csv',
    root_dir='/home/ahsan/test-project/fss/ISIC/ISIC_224_Dataset/isic_test',
    transform=transform
)

ISIC_val_dataset = ISIC_Dataset(
    #csv_file='/home/ahsan/test-project/fss/ISIC/val/metadata_ISIC_test/benign_no_yes_patch_100.csv',
    csv_file='/home/ahsan/test-project/fss/ISIC/ISIC_metadata/val_metadata_all.csv',
    root_dir= '/home/ahsan/test-project/fss/ISIC/ISIC_224_Dataset/isic_224/raw_224',
    transform=transform
)
batch_size= 128
ISIC_train_loader = DataLoader(ISIC_train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
ISIC_test_loader = DataLoader(ISIC_test_dataset, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)
ISIC_val_loader = DataLoader(ISIC_val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
# 3. Define the ResNet Model for Binary Classification
model = models.resnet50(weights=ResNet50_Weights.DEFAULT)
model.fc = nn.Linear(model.fc.in_features, 2)  # Output for 2 classes
model = model.to(device)
# 4. Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4, nesterov=True)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2)

# 5. Training Loop
def train_model(model, train_loader, val_loader, criterion, optimizer, writer, num_epochs):
    best_val_acc = 0.0
    for epoch in range(num_epochs):
        model.train()
        total_loss, correct, total = 0, 0, 0
        for images, labels, _, _ in train_loader:
            images, labels = images.to(device=device), labels.to(device=device)
            optimizer.zero_grad()
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            # Backward pass and optimization
            loss.backward()
            # Apply gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
            optimizer.step()
            total_loss += loss.item()
            correct += (outputs.argmax(1) == labels).sum().item()
            total += labels.size(0)
        
        train_loss = total_loss / len(train_loader)
        train_acc = correct / total
        # Validation... 
        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        with torch.no_grad():
            for images, labels, _, _ in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                val_correct += (outputs.argmax(1) == labels).sum().item()
                val_total += labels.size(0)
        val_loss /= len(val_loader)
        val_acc = val_correct / val_total
        scheduler.step(val_loss)
        print(f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}")
        writer.add_scalar("Loss/train", train_loss, epoch)
        writer.add_scalar("Loss/val", val_loss, epoch)
        writer.add_scalar("Accuracy/train", train_acc, epoch)
        writer.add_scalar("Accuracy/val", val_acc, epoch)
        # Save the best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'ResNet-50-full-model-without-relu_best_model.pth')
            print(f"Best model saved at epoch {epoch+1} with Val Acc: {val_acc:.4f}")

print("Training Model::")
print("Size of ISIC_train_loader : ", len(ISIC_train_loader)*batch_size)
train_model(model, ISIC_train_loader, ISIC_val_loader, criterion, optimizer, resnet_writer, num_epochs=50)
print("Training complete:: ")
torch.save(model.state_dict(), 'ResNet-50-full-model-without-relu_100_epochs.pth')
# 6. Testing for Four Classes
def test_model(model, test_loader, device):
    model.eval()
    results = {
        "Benign_NoPatches": 0, "Benign_Patches": 0, 
        "Malignant_NoPatches": 0, "Malignant_Patches": 0
    }
    counts = {
        "Benign_NoPatches": 0, "Benign_Patches": 0, 
        "Malignant_NoPatches": 0, "Malignant_Patches": 0
    }
    with torch.no_grad():
        for images, lables, benign_malignant, patches in test_loader:
            images = images.to(device)
            benign_malignant = benign_malignant.to(device).long()
            patches = patches.to(device).long()

            outputs = model(images)
            predictions = torch.argmax(outputs, dim=1)  # Correct per-sample predictions

            for label, patch, pred in zip(benign_malignant, patches, predictions):
                if label == 0 and patch == 0:
                    results["Benign_NoPatches"] += (pred == 0).item()
                    counts["Benign_NoPatches"] += 1
                elif label == 0 and patch == 1:
                    results["Benign_Patches"] += (pred == 0).item()
                    counts["Benign_Patches"] += 1
                elif label == 1 and patch == 0:
                    results["Malignant_NoPatches"] += (pred == 1).item()
                    counts["Malignant_NoPatches"] += 1
                elif label == 1 and patch == 1:
                    results["Malignant_Patches"] += (pred == 1).item()
                    counts["Malignant_Patches"] += 1
            
    # Print subgroup-wise accuracy
    for key in results:
        if counts[key] > 0:
            accuracy = results[key] / counts[key]
            print(f"Accuracy for {key}: {accuracy:.4f}")
        else:
            print(f"No samples for {key}")
# Example usage
print("Testing Model::")
print("size of ISIC_test_loader : ", len(ISIC_test_loader)*batch_size)
test_model(model, ISIC_test_loader, device=device)

# Optional: Save activations (for last layer before fc)
def save_activations(model, dataloader, save_path):
    model.eval()
    activations = []
    def hook_fn(module, input, output):
        activations.append(output.detach().cpu())
    handle = model.avgpool.register_forward_hook(hook_fn)
    with torch.no_grad():
        for images, _, _, _ in dataloader:
            model(images.to(device))
    handle.remove()
    act_tensor = torch.cat(activations, dim=0).squeeze()  # shape: (N, 2048)
    np.save(save_path + ".npy", act_tensor.numpy())
    pd.DataFrame(act_tensor.numpy()).to_csv(save_path + ".csv", index=False)

# Load best model
model.load_state_dict(torch.load("/home/ahsan/test-project/fss/split-metadata/testing-resNet-model/ResNet-results-without-relu/ResNet-50-full-model/ResNet-on_ISIC/ResNet-ISIC-full-training-testing/ResNet-50-full-model-without-relu_best_model.pth", map_location=device))
model.eval()
# Save activations for test set
save_activations(model, ISIC_train_loader, "ResNet-50-full-model-without-relu-activations_100_epochs")
print("isic_train_activations save successfully!")
# SAE Train Loop
def train_sae(model, data, epochs, lr, batch_size, writer):
    model.to(device)
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)

    dataloader = DataLoader(data, batch_size=batch_size, shuffle=True)
    
    for epoch in range(epochs):
        epoch_loss = 0
        for batch in dataloader:
            batch = torch.stack(batch).to(device)
            optimizer.zero_grad()
            encoded, decoded = model(batch)
            loss = model.loss_function(decoded, batch, encoded)
            #loss = criterion(decoded, batch) + sparsity_loss(encoded, sparsity_lambda)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
            optimizer.step()
            epoch_loss += loss.item()
        avg_loss = epoch_loss / len(dataloader)
        sae_writer.add_scalar("Loss/train", avg_loss, epoch)
        print(f"Epoch [{epoch+1}/{epochs}], SAE Loss: {avg_loss:.4f}")
        

# Load previously saved activations
activations_np = np.load("ResNet-50-full-model-without-relu-activations_100_epochs.npy")
activations = torch.tensor(activations_np, dtype=torch.float32)
# Wrap into a TensorDataset
activation_dataset = TensorDataset(activations)

sae = SparseAutoEncoder(input_dim=2048, hidden_dim=8000)
train_sae(sae, activation_dataset, epochs=100, lr=0.001, batch_size=128, writer=sae_writer)
torch.save(sae.state_dict(), 'ResNet-50-full-model-without-relu-SAE_100_epochs.pth')
print("Training SAE complete!!")
resnet_writer.close()
sae_writer.close()

In [2]:
model = models.resnet50(weights=None)  # Don't load pretrained weights  # Binary classification
model.fc = nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load('ResNet_ISIC_seed-1.pth', map_location=device))  # or 'cuda' if using GPU
model.eval()
model = model.to(device)

  model.load_state_dict(torch.load('ResNet_ISIC_seed-1.pth', map_location=device))  # or 'cuda' if using GPU


In [3]:
def save_activations(model, dataloader, save_path):
    model.eval()
    activations = []
    def hook_fn(module, input, output):
        activations.append(output.detach().cpu())
    handle = model.avgpool.register_forward_hook(hook_fn)
    with torch.no_grad():
        for images, _, _, _ in dataloader:
            model(images.to(device))
    handle.remove()
    act_tensor = torch.cat(activations, dim=0).squeeze()  # shape: (N, 2048)
    np.save(save_path + ".npy", act_tensor.numpy())
    pd.DataFrame(act_tensor.numpy()).to_csv(save_path + ".csv", index=False)

# Load best model
model.load_state_dict(torch.load("ResNet_ISIC_seed-1.pth", map_location=device))
model.eval()
# Save activations for test set
save_activations(model, ISIC_train_loader, "ResNet-50-full-model-without-relu-activations_100_epochs")
print("isic_train_activations save successfully!")
# SAE Train Loop
def train_sae(model, data, epochs, lr, batch_size, writer):
    model.to(device)
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)

    dataloader = DataLoader(data, batch_size=batch_size, shuffle=True)
    
    for epoch in range(epochs):
        epoch_loss = 0
        for batch in dataloader:
            batch = torch.stack(batch).to(device)
            optimizer.zero_grad()
            encoded, decoded = model(batch)
            loss = model.loss_function(decoded, batch, encoded)
            #loss = criterion(decoded, batch) + sparsity_loss(encoded, sparsity_lambda)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
            optimizer.step()
            epoch_loss += loss.item()
        avg_loss = epoch_loss / len(dataloader)
        sae_writer.add_scalar("Loss/train", avg_loss, epoch)
        print(f"Epoch [{epoch+1}/{epochs}], SAE Loss: {avg_loss:.4f}")
        

# Load previously saved activations
activations_np = np.load("ResNet-50-full-model-without-relu-activations_100_epochs.npy")
activations = torch.tensor(activations_np, dtype=torch.float32)
# Wrap into a TensorDataset
activation_dataset = TensorDataset(activations)

sae = SparseAutoEncoder(input_dim=2048, hidden_dim=8000)
train_sae(sae, activation_dataset, epochs=100, lr=0.001, batch_size=128, writer=sae_writer)
torch.save(sae.state_dict(), 'ResNet-50-full-model-without-relu-SAE_100_epochs.pth')
print("Training SAE complete!!")
resnet_writer.close()
sae_writer.close()

  model.load_state_dict(torch.load("ResNet_ISIC_seed-1.pth", map_location=device))


isic_train_activations save successfully!
Epoch [1/100], SAE Loss: 0.0472
Epoch [2/100], SAE Loss: 0.0408
Epoch [3/100], SAE Loss: 0.0401
Epoch [4/100], SAE Loss: 0.0397
Epoch [5/100], SAE Loss: 0.0395
Epoch [6/100], SAE Loss: 0.0394
Epoch [7/100], SAE Loss: 0.0393
Epoch [8/100], SAE Loss: 0.0393
Epoch [9/100], SAE Loss: 0.0393
Epoch [10/100], SAE Loss: 0.0392
Epoch [11/100], SAE Loss: 0.0393
Epoch [12/100], SAE Loss: 0.0392
Epoch [13/100], SAE Loss: 0.0392
Epoch [14/100], SAE Loss: 0.0392
Epoch [15/100], SAE Loss: 0.0392
Epoch [16/100], SAE Loss: 0.0393
Epoch [17/100], SAE Loss: 0.0392
Epoch [18/100], SAE Loss: 0.0392
Epoch [19/100], SAE Loss: 0.0392
Epoch [20/100], SAE Loss: 0.0392
Epoch [21/100], SAE Loss: 0.0392
Epoch [22/100], SAE Loss: 0.0392
Epoch [23/100], SAE Loss: 0.0392
Epoch [24/100], SAE Loss: 0.0392
Epoch [25/100], SAE Loss: 0.0392
Epoch [26/100], SAE Loss: 0.0392
Epoch [27/100], SAE Loss: 0.0392
Epoch [28/100], SAE Loss: 0.0392
Epoch [29/100], SAE Loss: 0.0392
Epoch [30/

In [None]:

#test
def test_model(model, dataloader):
    results = {
        "Benign_NoPatches": 0, "Benign_Patches": 0,
        "Malignant_NoPatches": 0, "Malignant_Patches": 0
    }
    counts = {
        "Benign_NoPatches": 0, "Benign_Patches": 0,
        "Malignant_NoPatches": 0, "Malignant_Patches": 0
    }

    with torch.no_grad():
        for images, labels, benign_malignant, patches in dataloader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            predictions = torch.argmax(outputs, dim=1)

            for i in range(labels.size(0)):
                label = int(benign_malignant[i])
                patch = int(patches[i])
                pred = int(predictions[i])

                key = f"{'Benign' if label == 0 else 'Malignant'}_{'NoPatches' if patch == 0 else 'Patches'}"
                correct_class = (pred == label)
                results[key] += int(correct_class)
                counts[key] += 1

    # Print subgroup-wise accuracy
    for key in results:
        if counts[key] > 0:
            accuracy = results[key] / counts[key]
            print(f"Accuracy for {key}: {accuracy:.4f}")
        else:
            print(f"No samples for {key}")
            
print("Evaluating model on ISIC test dataset...")
test_model(model, ISIC_test_loader)


Evaluating model on ISIC test dataset...
Accuracy for Benign_NoPatches: 0.8404
Accuracy for Benign_Patches: 0.9946
Accuracy for Malignant_NoPatches: 0.5389
Accuracy for Malignant_Patches: 0.2765


In [None]:
def test_model(model, test_loader, device):
    model.eval()
    results = {
        "Benign_NoPatches": 0, "Benign_Patches": 0, 
        "Malignant_NoPatches": 0, "Malignant_Patches": 0
    }
    counts = {
        "Benign_NoPatches": 0, "Benign_Patches": 0, 
        "Malignant_NoPatches": 0, "Malignant_Patches": 0
    }
    with torch.no_grad():
        for images, lables, benign_malignant, patches in test_loader:
            images = images.to(device)
            benign_malignant = benign_malignant.to(device).long()
            patches = patches.to(device).long()

            outputs = model(images)
            predictions = torch.argmax(outputs, dim=1)  # Correct per-sample predictions

            for label, patch, pred in zip(benign_malignant, patches, predictions):
                if label == 0 and patch == 0:
                    results["Benign_NoPatches"] += (pred == 0).item()
                    counts["Benign_NoPatches"] += 1
                elif label == 0 and patch == 1:
                    results["Benign_Patches"] += (pred == 0).item()
                    counts["Benign_Patches"] += 1
                elif label == 1 and patch == 0:
                    results["Malignant_NoPatches"] += (pred == 1).item()
                    counts["Malignant_NoPatches"] += 1
                elif label == 1 and patch == 1:
                    results["Malignant_Patches"] += (pred == 1).item()
                    counts["Malignant_Patches"] += 1
            
    # Print subgroup-wise accuracy
    for key in results:
        if counts[key] > 0:
            accuracy = results[key] / counts[key]
            print(f"Accuracy for {key}: {accuracy:.4f}")
        else:
            print(f"No samples for {key}")
# Example usage
print("Testing Model::")
print("size of ISIC_test_loader : ", len(ISIC_test_loader)*batch_size)
test_model(model, ISIC_test_loader, device=device)

Testing Model::
size of ISIC_test_loader :  968192
Accuracy for Benign_NoPatches: 0.7128
Accuracy for Benign_Patches: 0.9989
Accuracy for Malignant_NoPatches: 0.7494
Accuracy for Malignant_Patches: 0.2850
