In [5]:
from sklearn.metrics import accuracy_score
from torchvision import models
import torch
import torch.nn as nn
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor, transforms, Normalize, InterpolationMode
from torch.optim import AdamW
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from PIL import Image
Image.MAX_IMAGE_PIXELS = None

device = 'cuda'

class CustomImageFolder(ImageFolder):
    def __init__(self, root, transform=None):
        super().__init__(root, transform=transform)
    
    def __getitem__(self, index):
        sample, target = super().__getitem__(index)
        return sample, target

transform = transforms.Compose([
    transforms.Lambda(lambda img: img.convert('RGBA') if 'P' in img.getbands() else img), # Convert Palette images to RGBA
    transforms.Lambda(lambda img: img.convert('RGB')), # Convert RGBA to RGB
    transforms.Resize(232, interpolation=InterpolationMode.BILINEAR), # Resize to 232
    transforms.CenterCrop(224), # Center crop to 224
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # Normalize
])

def test_model(model, test_loader):
    model.eval()  # Set the model to evaluation mode

    correct_predictions = 0
    total_predictions = 0

    with torch.no_grad():  # Disable gradient computation during evaluation
        for inputs, labels in tqdm(test_loader, desc="Testing"):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total_predictions += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()

    accuracy = (correct_predictions / total_predictions) * 100
    return accuracy

# model_paths = ['./unsafe_models_CB/best_model.pth', './unsafe_models_NSFW/best_model.pth', './unsafe_models_SH/best_model.pth']
# test_paths = ['/workspace/adv_robustness/CSE/datasets/cyberbullying/',
#               '/workspace/adv_robustness/CSE/datasets/nsfw/',
#               '/workspace/adv_robustness/CSE/datasets/self_harm/']

model_paths = ['./unsafe_models_CB/best_model.pth', './unsafe_models_NSFW/best_model_masked.pth', './unsafe_models_SH/best_model_masked.pth']
test_paths = ['/workspace/adv_robustness/CSE/datasets/cyberbullying/',
              '/workspace/adv_robustness/CSE/datasets/nsfw/',
              '/workspace/adv_robustness/CSE/datasets/self_harm/']


for model_path, test_path in zip(model_paths, test_paths):
    # Load the model
    model = models.resnet50()
    model.fc = nn.Linear(model.fc.in_features, 2)
    model.load_state_dict(torch.load(model_path))
    model = model.to(device)

    # Load all data
    dataset = CustomImageFolder(root=test_path, transform=transform)

    # Determine lengths of splits
    total_size = len(dataset)
    train_size = int(total_size * 0.8)
    valid_size = int(total_size * 0.1)
    test_size = total_size - train_size - valid_size

    # Set the seed for reproducibility
    torch.manual_seed(0)

    # Create the data sets
    _, _, test_dataset = random_split(dataset, [train_size, valid_size, test_size])
    
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

    # Test the model
    accuracy = test_model(model, test_loader)
    print(f'Accuracy of model {model_path} on test data: {accuracy}')
    
    
# Accuracy of model ./unsafe_models_CB/best_model.pth on test data: 0.9194461925739459
# Accuracy of model ./unsafe_models_NSFW/best_model.pth on test data: 0.9889080459770115
# Accuracy of model ./unsafe_models_SH/best_model.pth on test data: 0.9757575757575757


Testing: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:27<00:00,  1.12s/it]


Accuracy of model ./unsafe_models_CB/best_model.pth on test data: 91.94461925739459


Testing: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 272/272 [00:55<00:00,  4.88it/s]


Accuracy of model ./unsafe_models_NSFW/best_model_masked.pth on test data: 97.82183908045977


Testing: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:07<00:00,  1.05it/s]

Accuracy of model ./unsafe_models_SH/best_model_masked.pth on test data: 94.14141414141413



