In [1]:
# Builtin
import os
from pathlib import Path
from PIL import Image

# Third Party
import torch
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader

# Attacks
from attacks.fgsm import fgsm_l2_attack, fgsm_l2_attack_dct
from attacks.pgd import pgd_l2_attack, pgd_l2_attack_dct
from attacks.simba import simba_attack
from attacks.one_pixel import one_pixel_attack

# Models
from models.clip import CLIPImageEmbedder

%load_ext autoreload
%autoreload 2

2025-11-26 19:31:16.901630: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-11-26 19:31:17.388001: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-11-26 19:31:18.733359: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [2]:
class HAM10000PILDataset(Dataset):

    def __init__(self, root_dir, save_dir, ids, extensions=(".jpg", ".jpeg", ".png", ".bmp")):

        self.root_dir = root_dir
        self.extensions = extensions
        self.image_paths = [
            os.path.join(root, fname)
            for root, _, files in os.walk(root_dir)
            for fname in files
            if fname.lower().endswith(extensions) and fname.startswith(ids)
        ]

        if not self.image_paths:
            raise ValueError(f"No images found in {root_dir} with extensions {extensions}")
        
        for idx, path in enumerate(self.image_paths):
            img = Image.open(path).convert("RGB")
            img = img.resize((224, 224), Image.BILINEAR)
            
            new_path = os.path.join(save_dir, os.path.basename(path))
            self.image_paths[idx] = new_path

            img.save(new_path)


    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        img = torch.from_numpy(np.array(img)).permute(2, 0, 1).float()/255.0
        return (img, self.image_paths[idx])

def get_image_dataloader_ham(root_dir, save_dir, ids, batch_size=8, num_workers=4, shuffle=False):

    assert Path(root_dir).is_dir()
    assert Path(save_dir).is_dir()
    
    dataset = HAM10000PILDataset(root_dir, save_dir, ids)

    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=shuffle,
    )

    return loader

In [3]:
class DMFPILDataset(Dataset):

    def __init__(self, root_dir, save_dir, ids, extensions=(".jpg", ".jpeg", ".png", ".bmp")):

        self.root_dir = root_dir
        self.extensions = extensions
        self.image_paths = [
            os.path.join(root, fname)
            for root, _, files in os.walk(root_dir)
            for fname in files
            if fname.lower().endswith(extensions) and (fname.split(".")[0] in ids)
        ]

    
        print(len(self.image_paths))

        if not self.image_paths:
            raise ValueError(f"No images found in {root_dir} with extensions {extensions}")
        
        for idx, path in enumerate(self.image_paths):
            img = Image.open(path).convert("RGB")
            img = img.resize((224, 224), Image.BILINEAR)
            
            new_path = os.path.join(save_dir, os.path.basename(path))
            self.image_paths[idx] = new_path

            img.save(new_path)


    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        img = torch.from_numpy(np.array(img)).permute(2, 0, 1).float()/255.0
        return (img, self.image_paths[idx])

def get_image_dataloader_dmf(root_dir, save_dir, ids, batch_size=8, num_workers=4, shuffle=False):

    assert Path(root_dir).is_dir()
    assert Path(save_dir).is_dir()
    
    dataset = DMFPILDataset(root_dir, save_dir, ids)

    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=shuffle,
    )

    return loader

In [None]:
class DERM7PTPILDataset(Dataset):

    def __init__(self, root_dir, save_dir, test_ids, meta, extensions=(".jpg", ".jpeg", ".png", ".bmp")):

        self.root_dir = root_dir
        self.extensions = extensions
        self.image_paths = [
            os.path.join(os.path.join(root_dir, "images"), meta.iloc[case_num]['derm']) 
            for case_num in test_ids['image_id']
        ]


        if not self.image_paths:
            raise ValueError(f"No images found in {root_dir} with extensions {extensions}")
        
        for idx, path in enumerate(self.image_paths):
            img = Image.open(path).convert("RGB")
            img = img.resize((224, 224), Image.BILINEAR)
            
            new_path = os.path.join(save_dir, os.path.basename(path))
            self.image_paths[idx] = new_path

            img.save(new_path)


    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        img = torch.from_numpy(np.array(img)).permute(2, 0, 1).float()/255.0
        return (img, self.image_paths[idx])

def get_image_dataloader_d7p(root_dir, save_dir, ids, meta, batch_size=8, num_workers=4, shuffle=False):

    assert Path(root_dir).is_dir()
    assert Path(save_dir).is_dir()
    
    dataset = DERM7PTPILDataset(root_dir, save_dir, ids, meta)

    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=shuffle,
    )

    return loader

In [5]:
def save_adv_batch(batch, save_dir="datasets/adv_simba/busi", pth=None):
    """
    Save each tensor in the batch as a .bmp image, keeping numbering across calls.
    Assumes images are in [0,1] range (float) or [0,255] (uint8).
    """

    os.makedirs(save_dir, exist_ok=True)

    for img, p in zip(batch, pth):
        img = img.detach().cpu()

        # Remove batch dimension if present: [1, C, H, W] -> [C, H, W]
        if img.dim() == 4:
            img = img.squeeze(0)

        # If image is [C, H, W], convert to [H, W, C]
        if img.dim() == 3:
            img = img.permute(1, 2, 0)
        
        # Scale to 0â€“255 and convert to uint8 if needed
        if img.dtype != torch.uint8:
            img = (img * 255).clamp(0, 255).byte()

        img_pil = Image.fromarray(img.numpy())
        file_path = os.path.join(save_dir, os.path.basename(p))
        img_pil.save(file_path)

In [6]:
# GLOBAL VARIABLES

TEST_IDS_PATH = "test_sets/test_ids"
TEST_IMAGES_PATH = "test_sets/test_images"

HAM_DATASET = "datasets/HAM"
DERM7PT_DATASET = "datasets/DERM7PT"
DMF_DATASET = "datasets/DMF"

EPSILON = 0.05
TEMPERATURE = 0.5

FGSM_L2 = "results/{dataset}/FGSM"
FGSM_DCT_L2 = "results/{dataset}/FGSM_DCT"
FGSM_L2_TEST = "results/{dataset}/FGSM_TEST"

PGD_L2 = "results/{dataset}/PGD"
PGD_DCT_L2 = "results/{dataset}/PGD_DCT"

ONE_PIXEL_L2 = "results/{dataset}/ONE_PIXEL"
SIMBA_L2 = "results/{dataset}/SIMBA"

In [7]:
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

model = CLIPImageEmbedder(device= 'cuda' if torch.cuda.is_available() else 'cpu')


In [8]:
# HAM Dataset
ham_test_ids = pd.read_csv(os.path.join(TEST_IDS_PATH, "ham_ids.csv"))
ham_test_ids = tuple(list(element.item() for element in ham_test_ids.to_numpy()))
ham_loader = get_image_dataloader_ham(HAM_DATASET, os.path.join(TEST_IMAGES_PATH, "ham"), ham_test_ids, batch_size=1)


# DERM7PT Dataset
derm7pt_ids = pd.read_csv(os.path.join(TEST_IDS_PATH, "d7p_ids.csv"))
derm7pt_meta = pd.read_csv(os.path.join(DERM7PT_DATASET, "meta/meta.csv"))
derm7pt_loader = get_image_dataloader_d7p(DERM7PT_DATASET, os.path.join(TEST_IMAGES_PATH, "derm7pt"), ids=derm7pt_ids, meta=derm7pt_meta, batch_size=1)

# DMF Dataset
dmf_test_ids = pd.read_csv(os.path.join(TEST_IDS_PATH, "dmf_ids.csv"))
dmf_test_ids = tuple(list(element.item() for element in dmf_test_ids.to_numpy()))
dmf_loader = get_image_dataloader_dmf(DMF_DATASET, os.path.join(TEST_IMAGES_PATH, "dmf"), ids=dmf_test_ids, batch_size=1)

243


In [None]:
dmf_loader.dataset

In [9]:
for batch in derm7pt_loader:

    pass

    # result, pth = simba_attack(batch, model, epsilon=EPSILON, step_size=EPSILON)
    # save_adv_batch(result, SIMBA_L2.format(dataset="derm7pt"), pth=pth)

    # result, pth = pgd_l2_attack(model, batch, EPSILON, EPSILON/4, 200)
    # save_adv_batch(result, PGD_L2.format(dataset="derm7pt"), pth=pth)

    # result, pth = pgd_l2_attack_dct(model, batch, EPSILON, EPSILON/4, 200)
    # save_adv_batch(result, PGD_DCT_L2.format(dataset="derm7pt"), pth=pth)

    # result, pth = fgsm_l2_attack(model, x=batch, epsilon=EPSILON)
    # save_adv_batch(result, FGSM_L2.format(dataset="derm7pt"), pth=pth)

    # result, pth = fgsm_l2_attack_dct(model, x=batch, epsilon=EPSILON)
    # save_adv_batch(result, FGSM_DCT_L2.format(dataset="derm7pt"), pth=pth)


In [None]:
for batch in dmf_loader:


    batch_exist = False

    for pth in batch[1]:
        new_path = os.path.join((SIMBA_L2).format(dataset="dmf"), os.path.basename(pth))
        if os.path.exists(new_path):
            batch_exist = True

    if batch_exist:
        print("skipped")
        continue

    # result, pth = simba_attack(batch, model, epsilon=EPSILON, step_size=EPSILON)
    # save_adv_batch(result, SIMBA_L2.format(dataset="dmf"), pth=pth)

    # result, pth = pgd_l2_attack(model, batch, EPSILON, EPSILON/4, 200)
    # save_adv_batch(result, PGD_L2.format(dataset="dmf"), pth=pth)

    # result, pth = pgd_l2_attack_dct(model, batch, EPSILON, EPSILON/4, 200)
    # save_adv_batch(result, PGD_DCT_L2.format(dataset="dmf"), pth=pth)

    # result, pth = fgsm_l2_attack(model, x=batch, epsilon=EPSILON)
    # save_adv_batch(result, FGSM_L2.format(dataset="dmf"), pth=pth)

    # result, pth = fgsm_l2_attack_dct(model, x=batch, epsilon=EPSILON)
    # save_adv_batch(result, FGSM_DCT_L2.format(dataset="dmf"), pth=pth)

skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped


In [None]:
for batch in ham_loader:

    batch_exist = False

    for pth in batch[1]:
        new_path = os.path.join((SIMBA_L2).format(dataset="ham"), os.path.basename(pth))
        if os.path.exists(new_path):
            batch_exist = True

    if batch_exist:
        print("skipped")
        continue

    result, pth = simba_attack(batch, model, epsilon=EPSILON, step_size=EPSILON)
    save_adv_batch(result, SIMBA_L2.format(dataset="ham"), pth=pth)


skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
skipped
