In [None]:
import logging
import sys

import torch


LOGGER = logging.getLogger('detector')
LOGGER.setLevel(logging.INFO)


stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setLevel(logging.INFO)


LOGGER.addHandler(stream_handler)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LOGGER.info(f'device is {DEVICE}')

device is cuda


INFO:detector:device is cuda


In [None]:
import torch.nn as nn
import torch.nn.functional as F



class PreActBlock(nn.Module):
    """Pre-activation version of the BasicBlock."""

    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(PreActBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.ind = None

        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False)
            )

    def forward(self, x):
        out = F.relu(self.bn1(x))
        shortcut = self.shortcut(out) if hasattr(self, "shortcut") else x
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out)))
        if self.ind is not None:
            out += shortcut[:, self.ind, :, :]
        else:
            out += shortcut
        return out

class PreActResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(PreActResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            BLOCK = block(self.in_planes, planes, stride)
            BLOCK.to('cuda')
            layers.append(BLOCK)
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)

        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def PreActResNet18(num_classes=10):
    return PreActResNet(PreActBlock, [2, 2, 2, 2], num_classes=num_classes)

In [None]:
import glob
import os


ROOT_DIR = 'eval_dataset'

def load_model(num_classes, model_path):
    model = PreActResNet18(num_classes)
    model.load_state_dict(torch.load(model_path))
    model = model.to(DEVICE)
    model.eval()

    return model


def load_test(idx: int):
    test_root_dir = os.path.join(ROOT_DIR, str(idx))

    metadata = torch.load(os.path.join(test_root_dir, 'metadata.pt'))

    num_classes = metadata['num_classes']
    ground_truth = metadata['ground_truth']
    images_root_dir = metadata['test_images_folder_address']
    transformation = metadata['transformation']

    model_path = os.path.join(test_root_dir, 'model.pt')

    if images_root_dir[0] == '.':
        images_root_dir = images_root_dir[2:]

    images_root_dir = os.path.join(test_root_dir, images_root_dir)


    model = load_model(num_classes, model_path)

    return model, num_classes, ground_truth, transformation, images_root_dir


In [None]:
from typing import List, Tuple

from torchvision import transforms
from PIL import Image


def transform_images(images_path: List[str], transformation: transforms.Compose):
    transformed_images = []
    for img_path in images_path:
        try:
            image = Image.open(img_path).convert('RGB')
            image = transformation(image)
            transformed_images.append(image)
        except Exception as e:
            LOGGER.error(f"Error loading image {img_path}: {e}")

    if not transformed_images:
        LOGGER.error("No images were loaded. Please check the images_path list.")

    return torch.stack(transformed_images).to(DEVICE)



In [None]:

def extract_normalization_params(transformation: transforms.Compose):
    mean = None
    std = None

    for transform in transformation.transforms:
        if isinstance(transform, transforms.Normalize):
            mean = transform.mean
            std = transform.std
    return mean, std


def get_logits_and_probs(model: PreActResNet, transformed_images: torch.Tensor):
    logits = model(transformed_images)
    probabilities = F.softmax(logits, dim=1)
    return probabilities, logits


In [None]:
from collections import defaultdict

def calculate_margins(probs: List[torch.Tensor],
                      labels: List[int]) -> Tuple[defaultdict[int, List[float]],
                                                           defaultdict[int, List[float]]]:
    accepted_margins = defaultdict(list)
    failed_margins = defaultdict(list)
    for i in range(len(probs)):
        topk = torch.topk(probs[i], k=2, largest=True, sorted=True)
        topk_values = topk.values

        margin = topk_values[0].item() - topk_values[1].item()
        if labels[i] == torch.argmax(probs[i]).item():
            accepted_margins[labels[i]].append(margin)
        else:
            failed_margins[torch.argmax(probs[i]).item()].append(margin)
    return accepted_margins, failed_margins

def find_safe_margin(accepted_margins: defaultdict[int, List[float]],
                     failed_margins: defaultdict[int, List[float]]):
    min_accepted_margins = dict()
    for c in accepted_margins.keys():
        max_failed = float('inf')
        if failed_margins[c]:
            max_failed = max(failed_margins[c])

        min_accepted_margin = max_failed
        for margin in accepted_margins[c]:
            if margin <= min_accepted_margin:
                min_accepted_margin = margin

        min_accepted_margins[c] = min_accepted_margin
    return min_accepted_margins


In [None]:
def project_image(optimized_img: torch.Tensor, mean: float, std: float):
    mean = torch.tensor(mean).view(-1, 1, 1).to(optimized_img.device)
    std = torch.tensor(std).view(-1, 1, 1).to(optimized_img.device)

    min_val = (0.0 - mean) / std
    max_val = (1.0 - mean) / std

    optimized_img = torch.clamp(optimized_img, min=min_val, max=max_val)

    return optimized_img

import torch

def project_image_with_epsilon(
    optimized_img: torch.Tensor,
    reference_img: torch.Tensor,
    mean: float,
    std: float,
    epsilon: float
) -> torch.Tensor:

    # Normalize the reference image
    mean_tensor = torch.tensor(mean).view(-1, 1, 1).to(optimized_img.device)
    std_tensor = torch.tensor(std).view(-1, 1, 1).to(optimized_img.device)
    normalized_ref = (reference_img - mean_tensor) / std_tensor

    # Clamp the optimized image to valid pixel range
    min_val = (0.0 - mean_tensor) / std_tensor
    max_val = (1.0 - mean_tensor) / std_tensor
    optimized_img = torch.clamp(optimized_img, min=min_val, max=max_val)

    # Clamp the optimized image to be within epsilon distance from the reference
    optimized_img = torch.max(optimized_img, normalized_ref - epsilon)
    optimized_img = torch.min(optimized_img, normalized_ref + epsilon)

    return optimized_img


In [None]:

def select_top_images_per_class(probs: List[torch.Tensor],
                                images: List[torch.Tensor],
                                labels: List[int],
                                num_classes: int, top_k=3) -> defaultdict[int, List[torch.Tensor]]:

    selected_images_per_class = defaultdict(list)

    with torch.no_grad():
        for idx, (prob, label) in enumerate(zip(probs, labels)):
            detected_class = torch.argmax(prob).item()
            if detected_class == label:
                confidence = prob[detected_class].item()
                selected_images_per_class[label].append((confidence, images[idx]))

    for c in range(num_classes):
        class_images = selected_images_per_class[c]
        sorted_images = sorted(class_images, key=lambda x: x[0], reverse=True)
        selected_images_per_class[c] = [img for _, img in sorted_images[:top_k]]
    return selected_images_per_class




In [None]:
def generate_random_image(transformation):
    transformation = transforms.Compose([transforms.ToPILImage()] + list(transformation.transforms))
    return transformation(torch.randn(3, *(220, 220)) * 255)


In [None]:
import random

def compute_max_margin(model: PreActResNet,
                       selected_images_per_class: dict[int, List],
                       num_classes: int,
                       projection_mean: float, projection_std: float,
                       max_iterations=1000, lr=0.01, tolerance=1e-5, max_img_per_class=3):
    model.eval()
    max_margins = {}
    all_margins_per_class = defaultdict(list)
    all_triggers_per_class = defaultdict(list)

    for c in range(num_classes):
        margins = []
        triggers = []

        LOGGER.info(f"\nProcessing Class {c}/{num_classes - 1}")
        images_to_optimize = []
        for k in range(num_classes):
            if k == c:
                continue
            images = selected_images_per_class.get(k, [])
            if not images:
                continue
            images_to_optimize.extend(images)
        random.shuffle(images_to_optimize)
        images_to_optimize = images_to_optimize[:max_img_per_class]

        LOGGER.info(f"  Total images to optimize for class {c}: {len(images_to_optimize)}")

        for idx, img in enumerate(images_to_optimize):
            max_margin = -float('inf')
            trigger = None

            optimized_img = img.clone().detach().to(DEVICE)
            optimized_img.requires_grad = True

            if lr != None:
                optimizer = torch.optim.Adam([optimized_img], lr=lr)
            else:
                optimizer = torch.optim.Adam([optimized_img])

            f_old = None

            for iteration in range(max_iterations):
                optimizer.zero_grad()
                probs, logits = get_logits_and_probs(model, optimized_img.unsqueeze(0))
                logits = logits.squeeze(0)
                g_c = logits[c]

                mask = torch.ones(num_classes, dtype=torch.bool).to(DEVICE)
                mask[c] = False
                g_k = torch.max(logits[mask])

                margin = g_c - g_k
                if margin > max_margin:
                    max_margin = margin
                    trigger = optimized_img.detach().clone()

                loss = -margin

                loss.backward()

                optimizer.step()

                with torch.no_grad():
                    optimized_img.copy_(project_image(optimized_img, projection_mean, projection_std))

                f_new = margin.item()
                if f_old is not None:
                    relative_change = abs(f_new - f_old) / (abs(f_old) + 1e-8)
                    if relative_change < tolerance:
                        break
                f_old = f_new
            margins.append(max_margin)
            triggers.append(trigger)


        if margins:
            all_triggers_per_class[c] = triggers.copy()
            all_margins_per_class[c] = margins.copy()
            max_margins[c] = max(margins)
            LOGGER.info(f"  Maximum Margin for class {c}: {max_margins[c]:.4f}")
        else:
            max_margins[c] = 0.0
            LOGGER.info(f"  No margins computed for class {c}.")

    max_margins = [x[1].item() for x in sorted(max_margins.items(), key=lambda x: x[0])]

    return max_margins, all_triggers_per_class, all_margins_per_class



In [None]:
from typing import Dict

from torch.optim import Adam



def compute_max_margin(
    model: PreActResNet,
    selected_images_per_class: Dict[int, List[torch.Tensor]],
    num_classes: int,
    projection_mean: float,
    projection_std: float,
    max_iterations: int = 1000,
    lr: float = 0.01,
    tolerance: float = 1e-5,
    max_img_per_class: int = 3,
    device: str = DEVICE  # Ensure DEVICE is defined
) -> Tuple[List[float], Dict[int, List[torch.Tensor]], Dict[int, List[float]]]:
    model.eval()
    max_margins = {}
    all_margins_per_class = defaultdict(list)
    all_triggers_per_class = defaultdict(list)

    for c in range(num_classes):
        LOGGER.info(f"\nProcessing Class {c}/{num_classes - 1}")
        images_to_optimize = []
        for k in range(num_classes):
            if k == c:
                continue
            images = selected_images_per_class.get(k, [])
            if not images:
                continue
            images_to_optimize.extend(images)
        random.shuffle(images_to_optimize)
        images_to_optimize = images_to_optimize[:max_img_per_class]

        if not images_to_optimize:
            max_margins[c] = 0.0
            LOGGER.info(f"  No images to optimize for class {c}.")
            continue

        LOGGER.info(f"  Total images to optimize for class {c}: {len(images_to_optimize)}")

        # Create a batch of images
        batch_size = len(images_to_optimize)
        optimized_imgs = torch.stack([img.clone().detach() for img in images_to_optimize]).to(device)
        optimized_imgs.requires_grad = True

        optimizer = Adam([optimized_imgs], lr=lr) if lr is not None else Adam([optimized_imgs])

        max_margins_batch = torch.full((batch_size,), -float('inf'), device=device)
        triggers_batch = optimized_imgs.clone().detach()

        f_old = torch.full((batch_size,), float('inf'), device=device)

        for iteration in range(max_iterations):
            optimizer.zero_grad()
            # Forward pass
            probs, logits = get_logits_and_probs(model, optimized_imgs)  # Assume batch processing
            # logits shape: (batch_size, num_classes)
            g_c = logits[:, c]  # Shape: (batch_size,)

            # Create mask to exclude class c
            mask = torch.ones_like(logits, dtype=torch.bool)
            mask[:, c] = False
            g_k, _ = torch.max(logits.masked_fill(~mask, -float('inf')), dim=1)  # Shape: (batch_size,)

            margin = g_c - g_k  # Shape: (batch_size,)

            # Update max margins and triggers
            update_mask = margin > max_margins_batch
            max_margins_batch = torch.maximum(max_margins_batch, margin)
            triggers_batch[update_mask] = optimized_imgs.detach()[update_mask]

            # Loss is negative margin
            loss = -margin.mean()
            loss.backward()
            optimizer.step()

            # Project images
            with torch.no_grad():
                optimized_imgs.copy_(project_image(optimized_imgs, projection_mean, projection_std))

            # Check for convergence
            relative_change = torch.abs(margin - f_old) / (torch.abs(f_old) + 1e-8)
            if torch.all(relative_change < tolerance):
                LOGGER.info(f"  Converged at iteration {iteration} for class {c}.")
                break
            f_old = margin.clone()

        # Store results
        max_margins[c] = max_margins_batch.max().item()
        all_margins_per_class[c] = max_margins_batch.cpu().tolist()
        all_triggers_per_class[c] = [triggers_batch[i].cpu() for i in range(batch_size)]
        LOGGER.info(f"  Maximum Margin for class {c}: {max_margins[c]:.4f}")

    # Sort max_margins by class
    max_margins_sorted = [max_margins.get(c, 0.0) for c in range(num_classes)]

    return max_margins_sorted, all_triggers_per_class, all_margins_per_class

In [None]:
import torch
from torch.optim import Adam
from collections import defaultdict
import random
from typing import List, Dict, Tuple
from tqdm import tqdm  # For progress bars

def compute_max_margin_parallel(
    model: PreActResNet,
    selected_images_per_class: Dict[int, List[torch.Tensor]],
    num_classes: int,
    projection_mean: float,
    projection_std: float,
    max_iterations: int = 1000,
    lr: float = 0.01,
    tolerance: float = 1e-5,
    max_img_per_class: int = 3,
    device: str = DEVICE,  # Ensure DEVICE is defined
    parallel_classes: int = 15  # Number of classes to process in parallel
) -> Tuple[List[float], Dict[int, List[torch.Tensor]], Dict[int, List[float]]]:

    model.eval()
    max_margins = {}
    all_margins_per_class = defaultdict(list)
    all_triggers_per_class = defaultdict(list)

    # Prepare classes in batches for parallel processing
    class_indices = list(range(num_classes))
    class_batches = [
        class_indices[i:i + parallel_classes]
        for i in range(0, len(class_indices), parallel_classes)
    ]

    for batch_num, class_batch in enumerate(class_batches):
        LOGGER.info(f"\nProcessing Batch {batch_num + 1}/{len(class_batches)}: Classes {class_batch}")

        # Gather all images to optimize across the current batch of classes
        batch_images = []
        batch_class_labels = []
        class_to_image_indices = defaultdict(list)  # Mapping from class to image indices in batch_images

        for c in class_batch:
            images_to_optimize = []
            for k in range(num_classes):
                if k == c:
                    continue
                images = selected_images_per_class.get(k, [])
                if not images:
                    continue
                images_to_optimize.extend(images)
            random.shuffle(images_to_optimize)
            images_to_optimize = images_to_optimize[:max_img_per_class]

            if not images_to_optimize:
                max_margins[c] = 0.0
                LOGGER.info(f"  No images to optimize for class {c}.")
                continue

            LOGGER.info(f"  Class {c}: {len(images_to_optimize)} images to optimize.")

            start_idx = len(batch_images)
            batch_images.extend([img.clone().detach() for img in images_to_optimize])
            batch_class_labels.extend([c] * len(images_to_optimize))
            for i in range(len(images_to_optimize)):
                class_to_image_indices[c].append(start_idx + i)

        if not batch_images:
            LOGGER.info("  No images to optimize in this batch.")
            continue

        # Create a batch tensor of images
        optimized_imgs = torch.stack(batch_images).to(device)
        optimized_imgs.requires_grad = True

        # Initialize optimizer for all images in the batch
        optimizer = Adam([optimized_imgs], lr=lr) if lr is not None else Adam([optimized_imgs])

        # Initialize tracking tensors
        batch_size = optimized_imgs.size(0)
        max_margins_batch = torch.full((batch_size,), -float('inf'), device=device)
        triggers_batch = optimized_imgs.clone().detach()
        f_old = torch.full((batch_size,), float('inf'), device=device)

        # Optimization loop
        for iteration in tqdm(range(max_iterations)):
            optimizer.zero_grad()
            # Forward pass
            probs, logits = get_logits_and_probs(model, optimized_imgs)  # Assume batch processing
            # logits shape: (batch_size, num_classes)

            # Gather target class indices for each image
            target_classes = torch.tensor(batch_class_labels, device=device)
            g_c = logits[torch.arange(batch_size), target_classes]  # Shape: (batch_size,)

            # Create mask to exclude target classes
            mask = torch.ones_like(logits, dtype=torch.bool)
            mask[torch.arange(batch_size), target_classes] = False
            # Set excluded logits to -inf for max computation
            g_k, _ = torch.max(logits.masked_fill(~mask, -float('inf')), dim=1)  # Shape: (batch_size,)

            margin = g_c - g_k  # Shape: (batch_size,)

            # Update max margins and triggers
            update_mask = margin > max_margins_batch
            max_margins_batch = torch.maximum(max_margins_batch, margin)
            triggers_batch[update_mask] = optimized_imgs.detach()[update_mask]

            # Compute loss as the negative margin
            loss = -margin.mean()
            loss.backward()
            optimizer.step()

            # Project images
            with torch.no_grad():
                optimized_imgs.copy_(
                   project_image(optimized_imgs, projection_mean, projection_std)
                    )

            # Check for convergence
            relative_change = torch.abs(margin - f_old) / (torch.abs(f_old) + 1e-8)
            if torch.all(relative_change < tolerance):
                LOGGER.info(f"  Converged at iteration {iteration} for batch {batch_num + 1}.")
                break
            f_old = margin.clone()

        # Store results per class
        for c in class_batch:
            image_indices = class_to_image_indices.get(c, [])
            if not image_indices:
                continue
            class_margins = max_margins_batch[image_indices].cpu().tolist()
            class_triggers = triggers_batch[image_indices].detach().cpu()

            all_margins_per_class[c].extend(class_margins)
            all_triggers_per_class[c].extend(class_triggers)

            max_margins[c] = max(class_margins)
            LOGGER.info(f"  Maximum Margin for class {c}: {max_margins[c]:.4f}")

    # Sort max_margins by class
    max_margins_sorted = [max_margins.get(c, 0.0) for c in range(num_classes)]

    return max_margins_sorted, all_triggers_per_class, all_margins_per_class


In [None]:
import torch
from torch.optim import Adam
from collections import defaultdict
import random
from typing import List, Dict, Tuple
from tqdm import tqdm  # Optional: For progress bars

def compute_max_margin_parallel_with_epsilon(
    model: PreActResNet,
    selected_images_per_class: Dict[int, List[torch.Tensor]],
    reference_images_per_class: Dict[int, List[torch.Tensor]],
    num_classes: int,
    projection_mean: float,
    projection_std: float,
    epsilon: float,
    max_iterations: int = 1000,
    lr: float = 0.01,
    tolerance: float = 1e-5,
    max_img_per_class: int = 3,
    device: str= DEVICE,  # Ensure DEVICE is defined
    parallel_classes: int = 50  # Number of classes to process in parallel
) -> Tuple[List[float], Dict[int, List[torch.Tensor]], Dict[int, List[float]]]:
    """
    Computes the maximum margin for each class by optimizing multiple classes in parallel on the GPU,
    ensuring that optimized images stay within epsilon distance from their reference images.

    Args:
        model (PreActResNet): The neural network model.
        selected_images_per_class (Dict[int, List[torch.Tensor]]): Dictionary mapping class indices to lists of images to optimize.
        reference_images_per_class (Dict[int, List[torch.Tensor]]): Dictionary mapping class indices to lists of reference images.
        num_classes (int): Total number of classes.
        projection_mean (float): Mean for projection normalization.
        projection_std (float): Standard deviation for projection normalization.
        epsilon (float): Maximum allowable distance from the reference image.
        max_iterations (int, optional): Maximum number of optimization iterations. Defaults to 1000.
        lr (float, optional): Learning rate for the optimizer. Defaults to 0.01.
        tolerance (float, optional): Tolerance for convergence. Defaults to 1e-5.
        max_img_per_class (int, optional): Maximum number of images to optimize per class. Defaults to 3.
        device (torch.device, optional): The device to perform computations on. Defaults to DEVICE.
        parallel_classes (int, optional): Number of classes to process in parallel. Defaults to 8.

    Returns:
        Tuple[List[float], Dict[int, List[torch.Tensor]], Dict[int, List[float]]]:
            - List of maximum margins sorted by class.
            - Dictionary mapping class indices to lists of trigger images.
            - Dictionary mapping class indices to lists of margins.
    """
    model.eval()
    max_margins = {}
    all_margins_per_class = defaultdict(list)
    all_triggers_per_class = defaultdict(list)

    # Prepare classes in batches for parallel processing
    class_indices = list(range(num_classes))
    class_batches = [
        class_indices[i:i + parallel_classes]
        for i in range(0, len(class_indices), parallel_classes)
    ]

    for batch_num, class_batch in enumerate(class_batches):
        LOGGER.info(f"\nProcessing Batch {batch_num + 1}/{len(class_batches)}: Classes {class_batch}")

        # Gather all images and their reference images to optimize across the current batch of classes
        batch_images = []
        batch_reference_images = []
        batch_class_labels = []
        class_to_image_indices = defaultdict(list)  # Mapping from class to image indices in batch_images

        for c in class_batch:
            images_to_optimize = []
            reference_images = []
            for k in range(num_classes):
                if k == c:
                    continue
                images = selected_images_per_class.get(k, [])
                references = reference_images_per_class.get(k, [])
                if not images:
                    continue
                # Ensure the number of reference images matches
                if len(references) < len(images):
                    raise ValueError(f"Not enough reference images for class {k}.")
                images_to_optimize.extend(images)
                reference_images.extend(references[:len(images)])
            random.shuffle(images_to_optimize)
            images_to_optimize = images_to_optimize[:max_img_per_class]
            reference_images = reference_images[:max_img_per_class]

            if not images_to_optimize:
                max_margins[c] = 0.0
                LOGGER.info(f"  No images to optimize for class {c}.")
                continue

            LOGGER.info(f"  Class {c}: {len(images_to_optimize)} images to optimize.")

            start_idx = len(batch_images)
            batch_images.extend([img.clone().detach() for img in images_to_optimize])
            batch_reference_images.extend([ref.clone().detach() for ref in reference_images])
            batch_class_labels.extend([c] * len(images_to_optimize))
            for i in range(len(images_to_optimize)):
                class_to_image_indices[c].append(start_idx + i)

        if not batch_images:
            LOGGER.info("  No images to optimize in this batch.")
            continue

        # Create a batch tensor of images and reference images
        optimized_imgs = torch.stack(batch_images).to(device)
        reference_imgs = torch.stack(batch_reference_images).to(device)
        optimized_imgs.requires_grad = True

        # Initialize optimizer for all images in the batch
        optimizer = Adam([optimized_imgs], lr=lr) if lr is not None else Adam([optimized_imgs])

        # Initialize tracking tensors
        batch_size = optimized_imgs.size(0)
        max_margins_batch = torch.full((batch_size,), -float('inf'), device=device)
        triggers_batch = optimized_imgs.clone().detach()
        f_old = torch.full((batch_size,), float('inf'), device=device)

        # Optimization loop
        for iteration in tqdm(range(max_iterations)):
            optimizer.zero_grad()
            # Forward pass
            probs, logits = get_logits_and_probs(model, optimized_imgs)  # Assume batch processing
            # logits shape: (batch_size, num_classes)

            # Gather target class indices for each image
            target_classes = torch.tensor(batch_class_labels, device=device)
            g_c = logits[torch.arange(batch_size), target_classes]  # Shape: (batch_size,)

            # Create mask to exclude target classes
            mask = torch.ones_like(logits, dtype=torch.bool)
            mask[torch.arange(batch_size), target_classes] = False
            # Set excluded logits to -inf for max computation
            g_k, _ = torch.max(logits.masked_fill(~mask, -float('inf')), dim=1)  # Shape: (batch_size,)

            margin = g_c - g_k  # Shape: (batch_size,)

            # Update max margins and triggers
            update_mask = margin > max_margins_batch
            max_margins_batch = torch.maximum(max_margins_batch, margin)
            triggers_batch[update_mask] = optimized_imgs.detach()[update_mask]

            # Compute loss as the negative margin
            loss = -margin.mean()
            loss.backward()
            optimizer.step()

            # Project images with epsilon constraint
            with torch.no_grad():
                optimized_imgs.copy_(project_image_with_epsilon(
                    optimized_imgs,
                    reference_imgs,
                    projection_mean,
                    projection_std,
                    epsilon
                ))

            # Check for convergence
            relative_change = torch.abs(margin - f_old) / (torch.abs(f_old) + 1e-8)
            if torch.all(relative_change < tolerance):
                #LOGGER.info(f"  Converged at iteration {iteration + 1} for batch {batch_num + 1}.")
                break
            f_old = margin.clone()

            # Optional: Log progress every 100 iterations
            if (iteration + 1) % 100 == 0:
                pass
                #LOGGER.info(f"  Iteration {iteration + 1}/{max_iterations} for batch {batch_num + 1}.")

        # Store results per class
        for c in class_batch:
            image_indices = class_to_image_indices.get(c, [])
            if not image_indices:
                continue
            class_margins = max_margins_batch[image_indices].cpu().tolist()
            class_triggers = triggers_batch[image_indices].detach().cpu()

            all_margins_per_class[c].extend(class_margins)
            all_triggers_per_class[c].extend(class_triggers)

            max_margins[c] = max(class_margins)
            LOGGER.info(f"  Maximum Margin for class {c}: {max_margins[c]:.4f}")

    # Sort max_margins by class
    max_margins_sorted = [max_margins.get(c, 0.0) for c in range(num_classes)]

    return max_margins_sorted, all_triggers_per_class, all_margins_per_class


In [None]:
import warnings

import numpy as np
from scipy import stats


def compute_p_values(gamma_list: List[float],
                     distributions=['gamma', 'norm'], p_value_type='standard') -> dict[str, float]:

    if not gamma_list:
        raise ValueError("gamma_list is empty.")

    gamma_array = np.array(gamma_list)
    r_max = np.max(gamma_array)
    n = len(gamma_array)

    null_data = gamma_array[gamma_array != r_max]
    if len(null_data) == 0:
        raise ValueError("All values in gamma_list are identical.")

    p_values = {}

    for dist_name in distributions:
        try:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")

                if dist_name == 'gamma':
                    a, loc, scale = stats.gamma.fit(null_data, floc=0)
                    fitted_dist = stats.gamma(a, loc=loc, scale=scale)
                elif dist_name == 'norm':
                    mu, sigma = stats.norm.fit(null_data)
                    fitted_dist = stats.norm(loc=mu, scale=sigma)
                elif dist_name == 'expon':
                    loc, scale = stats.expon.fit(null_data)
                    fitted_dist = stats.expon(loc=loc, scale=scale)
                elif dist_name == 'beta':
                    a, b, loc, scale = stats.beta.fit(null_data, floc=0, fscale=1)
                    fitted_dist = stats.beta(a, b, loc=loc, scale=scale)
                elif dist_name == 'lognorm':
                    s, loc, scale = stats.lognorm.fit(null_data, floc=0)
                    fitted_dist = stats.lognorm(s, loc=loc, scale=scale)
                else:
                    LOGGER.error(f"Distribution '{dist_name}' is not supported.")
                    continue

                H0_r_max = fitted_dist.cdf(r_max)

                if p_value_type == 'standard':
                    p_val = 1 - H0_r_max**n
                elif p_value_type == 'user_specified':
                    p_val = H0_r_max**(n-1)
                else:
                    LOGGER.error(f"p_value_type '{p_value_type}' is not recognized. Choose 'standard' or 'user_specified'.")
                    continue

                p_values[dist_name] = p_val

        except Exception as e:
            LOGGER.error(f"An error occurred while fitting distribution '{dist_name}': {e}")
            p_values[dist_name] = None

    return p_values


In [None]:
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model.eval()
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        self.hook_handles = []
        self._register_hooks()

    def _register_hooks(self):
        def forward_hook(module, input, output):
            self.activations = output.detach()

        def backward_hook(module, grad_in, grad_out):
            self.gradients = grad_out[0].detach()

        self.hook_handles.append(
            self.target_layer.register_forward_hook(forward_hook)
        )
        self.hook_handles.append(
            self.target_layer.register_backward_hook(backward_hook)
        )

    def generate_heatmap(self, input_tensor, class_idx=None):
        # Forward pass
        output = self.model(input_tensor)
        if class_idx is None:
            class_idx = output.argmax(dim=1).item()

        # Zero gradients
        self.model.zero_grad()

        # Backward pass
        target = output[0, class_idx]
        target.backward()

        # Compute weights
        pooled_gradients = torch.mean(self.gradients, dim=[0, 2, 3])

        # Weight the activations
        for i in range(self.activations.shape[1]):
            self.activations[:, i, :, :] *= pooled_gradients[i]

        # Compute the heatmap
        heatmap = torch.mean(self.activations, dim=1).squeeze()
        heatmap = F.relu(heatmap)
        heatmap /= torch.max(heatmap)

        return heatmap.cpu().numpy()

    def remove_hooks(self):
        for handle in self.hook_handles:
            handle.remove()

def get_heatmap(model, inp):
    last_child = None
    for child in model.layer2.children():
        last_child = child
    target_layer = list(last_child.children())[3]
    #target_layer = model.layer4[-1].conv3

    grad_cam = GradCAM(model, target_layer)



    input_tensor = inp.unsqueeze(0).to(DEVICE)
    # Generate heatmap
    heatmap = grad_cam.generate_heatmap(input_tensor)
    grad_cam.remove_hooks()
    return heatmap

def plot_heatmap(heatmap):
    plt.imshow(heatmap, cmap='jet', alpha=0.5)
    plt.axis('off')
    plt.show()


In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np

def imshow(tensor, mean, std, title=None, save_dir=None, filename=None):
    plt.figure()
    img = tensor.clone()
    if img.dim() == 4:
        img = img[0]
    num_channels = img.size(0)
    if isinstance(mean, (list, tuple)):
        if len(mean) == 1:
            mean = [mean[0]] * num_channels
        elif len(mean) != num_channels:
            raise ValueError(f"Length of mean ({len(mean)}) does not match number of channels ({num_channels}).")
        mean = torch.tensor(mean).to(DEVICE)
    else:
        mean = torch.tensor([mean] * num_channels).to(DEVICE)
    if isinstance(std, (list, tuple)):
        if len(std) == 1:
            std = [std[0]] * num_channels
        elif len(std) != num_channels:
            raise ValueError(f"Length of std ({len(std)}) does not match number of channels ({num_channels}).")
        std = torch.tensor(std).to(DEVICE)
    else:
        std = torch.tensor([std] * num_channels).to(DEVICE)
    img = img.to(DEVICE)
    for c in range(num_channels):
        img[c] = img[c] * std[c] + mean[c]
    img = torch.clamp(img, 0, 1)
    np_img = img.cpu().numpy()
    if num_channels == 1:
        np_img = np_img.squeeze(0)
        plt.imshow(np_img, cmap='gray')
    else:
        np_img = np.transpose(np_img, (1, 2, 0))
        plt.imshow(np_img)
    if title:
        plt.title(title)
    plt.axis('off')
    if save_dir and filename:
        os.makedirs(save_dir, exist_ok=True)
        save_path = os.path.join(save_dir, f'{filename}.png')
        plt.savefig(save_path, bbox_inches='tight')
        LOGGER.info(f"Saved image plot to {save_path}")
    plt.show()
    LOGGER.info(f"Plotted image: {title}")



In [None]:
from tqdm import tqdm


def select_neurons_and_get_activations_per_each(model, percent_per_layer, inputs: List[torch.Tensor]):
    LOGGER.info("Selecting neurons and collecting activations")
    activation = defaultdict(list)
    hooks = []
    selected_neurons = {}
    layers = []
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
            layers.append(name)
    for name in layers:
        def get_activation(name):
            def hook(model, input, output):
                act = output.detach()
                if act.dim() > 2:
                    act = act.mean(dim=[2,3])
                act = act.squeeze(0).cpu().numpy()
                if act.ndim == 1:
                    activation[name].append(act)
                else:
                    activation[name].append(act.flatten())
            return hook
        hooks.append(model.get_submodule(name).register_forward_hook(get_activation(name)))
    LOGGER.info("Registered hooks")
    model.eval()
    with torch.no_grad():
        for img in tqdm(inputs, desc="Processing inputs"):
            model(img.unsqueeze(0))
    for hook in hooks:
        hook.remove()
    LOGGER.info("Collected activations")
    for name in layers:
        acts = activation[name]
        if len(acts) != len(inputs):
            LOGGER.error(f"Activation length mismatch for layer {name}: expected {len(inputs)}, got {len(acts)}")
            raise ValueError(f"Activation length mismatch for layer {name}")
        act = np.stack(acts, axis=0)
        num_neurons = act.shape[1]
        selected = np.random.choice(num_neurons, max(1, int(num_neurons * percent_per_layer)), replace=False)
        selected_neurons[name] = selected
    LOGGER.info("Selected neurons per layer")
    activations_per_input = []
    for img_idx in tqdm(range(len(inputs)), desc="Aggregating activations"):
        img_activations = []
        for name in layers:
            act = activation[name][img_idx]
            selected = selected_neurons[name]
            img_activations.extend(act[selected])
        activations_per_input.append(img_activations)
    LOGGER.info("Aggregated activations per input")
    return activations_per_input

def compute_neuron_pairwise_correlation(activations: List[List[float]]) -> np.ndarray:
    LOGGER.info("Computing neuron pairwise correlation")
    activation_matrix = np.array(activations)  # Shape: (num_inputs, num_neurons)
    if activation_matrix.ndim != 2:
        LOGGER.error("Activation matrix is not 2D")
        raise ValueError("Activation matrix must be 2D")
    corr_matrix = np.corrcoef(activation_matrix, rowvar=False)  # Correlation between neurons
    LOGGER.info("Computed neuron pairwise correlation")
    return corr_matrix

In [None]:
! wget https://huggingface.co/datasets/abbasfar/backdoor_attack_evaluation_dataset/resolve/main/eval_dataset.zip

--2024-12-15 14:39:27--  https://huggingface.co/datasets/abbasfar/backdoor_attack_evaluation_dataset/resolve/main/eval_dataset.zip
Resolving huggingface.co (huggingface.co)... 3.165.160.11, 3.165.160.61, 3.165.160.12, ...
Connecting to huggingface.co (huggingface.co)|3.165.160.11|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.hf.co/repos/44/d3/44d3f2884b0ce4544d589087f4734c7d7d5713c4209ec3c9f040c4bd2c393c1e/951aa8d3ee66c892f3c2cb660c7af40a74c274ece5d6d71cbe93385f8d7e0783?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27eval_dataset.zip%3B+filename%3D%22eval_dataset.zip%22%3B&response-content-type=application%2Fzip&Expires=1734532767&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTczNDUzMjc2N319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmhmLmNvL3JlcG9zLzQ0L2QzLzQ0ZDNmMjg4NGIwY2U0NTQ0ZDU4OTA4N2Y0NzM0YzdkN2Q1NzEzYzQyMDllYzNjOWYwNDBjNGJkMmMzOTNjMWUvOTUxYWE4ZDNlZTY2Yzg5MmYzYzJjYjY2MG

In [None]:
import time
from statistics import mean


def analyze(idx):
    model, num_classes, ground_truth, transformation, images_root_dir = load_test(idx)
    LOGGER.info(f'\n\n\nmodel id: {idx} {"is malicous" if ground_truth else "is clean"} with {num_classes} classes\n\n\n')

    t = time.time()

    model = model.to(DEVICE)
    model.eval()

    images_path = glob.glob(os.path.join(images_root_dir, '*.jpg'))
    labels = [int(image_path.split('_')[-1].split('.')[0]) for image_path in images_path]

    projection_mean, projection_std = extract_normalization_params(transformation)
    transformed_images = transform_images(images_path, transformation)

    probs, logits = get_logits_and_probs(model, transformed_images)
    accepted_margins, failed_margins = calculate_margins(logits, labels)
    safe_margins = [x[1] for x in
                    sorted(find_safe_margin(accepted_margins, failed_margins).items(),
                           key=lambda x: x[0])]


    k = 9
    max_iterations = int(1000 * (10 / num_classes))
    max_img_per_class =  max(int(9 * (10 / num_classes)), 1)


    epsilon = 3
    confident_images_per_class = select_top_images_per_class(probs, transformed_images,
                                    labels, num_classes, top_k=9)
    lr = None
    tolerance = 1e-5

    if max_img_per_class <= 2:
      confident_images_per_class = {c: [generate_random_image(transformation) for i in range(k)]
                                   for c in range(num_classes)}
      epsilon = 100
      lr = 0.1
      tolerance = 1e-3

    max_margins, triggers_per_class, all_margins_per_class = compute_max_margin_parallel_with_epsilon(model,
                                      confident_images_per_class, confident_images_per_class,
                                       num_classes, projection_mean, projection_std,
                                       max_iterations=max_iterations,
                                       lr=lr,
                                       tolerance=tolerance,
                                       epsilon = epsilon,
                                       max_img_per_class=max_img_per_class)

    for i in range(len(max_margins)):
      max_margins[i] = max(max_margins[i], 0)
    average_margins = list()

    for c in range(num_classes):
      for i in range(len(all_margins_per_class[c])):
        all_margins_per_class[c][i] = max(all_margins_per_class[c][i], 0)
      average_margins.append(mean(all_margins_per_class[c]))





    LOGGER.info(f'finding adversaries has taken {round(time.time() - t, 2)} seconds')
    for i in range(num_classes):
        target_class_c = max_margins.index(max(max_margins))
        LOGGER.info(f'working on class {target_class_c}')
        p_values_standard = compute_p_values(max_margins, distributions=['gamma', 'norm', 'expon'],
                                             p_value_type='standard')
        avg_p_values_standard = compute_p_values(average_margins, distributions=['gamma', 'norm', 'expon'],
                                             p_value_type='standard')
        safe_p_values_standard = compute_p_values(safe_margins, distributions=['gamma', 'norm', 'expon'],
                                                p_value_type='standard')
        LOGGER.info(f'max margin p values {p_values_standard}')
        LOGGER.info(f'average maximum margins p_values {avg_p_values_standard}')
        LOGGER.info(f'safe margins p values {safe_p_values_standard}')

        if p_values_standard['gamma'] <= 0.05:
            LOGGER.info(f'detected as a malicious model and it {"is malicous" if ground_truth else "is clean"}')
        else:
            LOGGER.info(f'detected as a clean model and it {"is malicous" if ground_truth else "is clean"}')
        break
        LOGGER.info('plotting confident test imags')
        for i in range(len(confident_images_per_class[target_class_c])):
            img = confident_images_per_class[target_class_c][i]
            imshow(img, projection_mean, projection_std, title=f'clean from class {target_class_c} num {i}')
            heatmap = get_heatmap(model, img)
            plot_heatmap(heatmap)

        LOGGER.info('plotting adversary imags')
        for i in range(len(triggers_per_class[target_class_c])):
            trigger = triggers_per_class[target_class_c][i]
            #LOGGER.info(f'adversary max margin is {}')
            imshow(trigger, projection_mean, projection_std, title=f'adversary from class {target_class_c} num {i} with margin {all_margins_per_class[target_class_c][i]}')
            heatmap = get_heatmap(model, trigger)
            plot_heatmap(heatmap)
        break






In [None]:
for i in range(50):
    analyze(i)




model id: 0 is malicous with 10 classes





  metadata = torch.load(os.path.join(test_root_dir, 'metadata.pt'))
  model.load_state_dict(torch.load(model_path))
INFO:detector:


model id: 0 is malicous with 10 classes






Processing Batch 1/1: Classes [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


INFO:detector:
Processing Batch 1/1: Classes [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


  Class 0: 9 images to optimize.


INFO:detector:  Class 0: 9 images to optimize.


  Class 1: 9 images to optimize.


INFO:detector:  Class 1: 9 images to optimize.


  Class 2: 9 images to optimize.


INFO:detector:  Class 2: 9 images to optimize.


  Class 3: 9 images to optimize.


INFO:detector:  Class 3: 9 images to optimize.


  Class 4: 9 images to optimize.


INFO:detector:  Class 4: 9 images to optimize.


  Class 5: 9 images to optimize.


INFO:detector:  Class 5: 9 images to optimize.


  Class 6: 9 images to optimize.


INFO:detector:  Class 6: 9 images to optimize.


  Class 7: 9 images to optimize.


INFO:detector:  Class 7: 9 images to optimize.


  Class 8: 9 images to optimize.


INFO:detector:  Class 8: 9 images to optimize.


  Class 9: 9 images to optimize.


INFO:detector:  Class 9: 9 images to optimize.
100%|██████████| 1000/1000 [01:07<00:00, 14.71it/s]

  Maximum Margin for class 0: 34.0931



INFO:detector:  Maximum Margin for class 0: 34.0931


  Maximum Margin for class 1: 23.3818


INFO:detector:  Maximum Margin for class 1: 23.3818


  Maximum Margin for class 2: 39.0376


INFO:detector:  Maximum Margin for class 2: 39.0376


  Maximum Margin for class 3: 43.1500


INFO:detector:  Maximum Margin for class 3: 43.1500


  Maximum Margin for class 4: 42.8557


INFO:detector:  Maximum Margin for class 4: 42.8557


  Maximum Margin for class 5: 31.9943


INFO:detector:  Maximum Margin for class 5: 31.9943


  Maximum Margin for class 6: 41.5265


INFO:detector:  Maximum Margin for class 6: 41.5265


  Maximum Margin for class 7: 50.9616


INFO:detector:  Maximum Margin for class 7: 50.9616


  Maximum Margin for class 8: 60.9890


INFO:detector:  Maximum Margin for class 8: 60.9890


  Maximum Margin for class 9: 28.2808


INFO:detector:  Maximum Margin for class 9: 28.2808


finding adversaries has taken 68.24 seconds


INFO:detector:finding adversaries has taken 68.24 seconds


working on class 8


INFO:detector:working on class 8


max margin p values {'gamma': 0.06466686368661678, 'norm': 0.016222460737541522, 'expon': 0.4973051202143991}


INFO:detector:max margin p values {'gamma': 0.06466686368661678, 'norm': 0.016222460737541522, 'expon': 0.4973051202143991}


average maximum margins p_values {'gamma': 0.015278860248490012, 'norm': 0.0013928470569556373, 'expon': 0.352606338027749}


INFO:detector:average maximum margins p_values {'gamma': 0.015278860248490012, 'norm': 0.0013928470569556373, 'expon': 0.352606338027749}


safe margins p values {'gamma': 0.5122724966100904, 'norm': 0.2758160092490499, 'expon': 0.7948359785275755}


INFO:detector:safe margins p values {'gamma': 0.5122724966100904, 'norm': 0.2758160092490499, 'expon': 0.7948359785275755}


detected as a clean model and it is malicous


INFO:detector:detected as a clean model and it is malicous
  metadata = torch.load(os.path.join(test_root_dir, 'metadata.pt'))





model id: 1 is malicous with 10 classes





  model.load_state_dict(torch.load(model_path))
INFO:detector:


model id: 1 is malicous with 10 classes






Processing Batch 1/1: Classes [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


INFO:detector:
Processing Batch 1/1: Classes [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


  Class 0: 9 images to optimize.


INFO:detector:  Class 0: 9 images to optimize.


  Class 1: 9 images to optimize.


INFO:detector:  Class 1: 9 images to optimize.


  Class 2: 9 images to optimize.


INFO:detector:  Class 2: 9 images to optimize.


  Class 3: 9 images to optimize.


INFO:detector:  Class 3: 9 images to optimize.


  Class 4: 9 images to optimize.


INFO:detector:  Class 4: 9 images to optimize.


  Class 5: 9 images to optimize.


INFO:detector:  Class 5: 9 images to optimize.


  Class 6: 9 images to optimize.


INFO:detector:  Class 6: 9 images to optimize.


  Class 7: 9 images to optimize.


INFO:detector:  Class 7: 9 images to optimize.


  Class 8: 9 images to optimize.


INFO:detector:  Class 8: 9 images to optimize.


  Class 9: 9 images to optimize.


INFO:detector:  Class 9: 9 images to optimize.
100%|██████████| 1000/1000 [01:07<00:00, 14.71it/s]

  Maximum Margin for class 0: 37.8012



INFO:detector:  Maximum Margin for class 0: 37.8012


  Maximum Margin for class 1: 21.7733


INFO:detector:  Maximum Margin for class 1: 21.7733


  Maximum Margin for class 2: 59.8119


INFO:detector:  Maximum Margin for class 2: 59.8119


  Maximum Margin for class 3: 59.6252


INFO:detector:  Maximum Margin for class 3: 59.6252


  Maximum Margin for class 4: 42.5500


INFO:detector:  Maximum Margin for class 4: 42.5500


  Maximum Margin for class 5: 29.4279


INFO:detector:  Maximum Margin for class 5: 29.4279


  Maximum Margin for class 6: 39.5594


INFO:detector:  Maximum Margin for class 6: 39.5594


  Maximum Margin for class 7: 48.5680


INFO:detector:  Maximum Margin for class 7: 48.5680


  Maximum Margin for class 8: 108.1978


INFO:detector:  Maximum Margin for class 8: 108.1978


  Maximum Margin for class 9: 32.6390


INFO:detector:  Maximum Margin for class 9: 32.6390


finding adversaries has taken 68.23 seconds


INFO:detector:finding adversaries has taken 68.23 seconds


working on class 8


INFO:detector:working on class 8


max margin p values {'gamma': 0.0005219856478486662, 'norm': 2.2378205144502061e-07, 'expon': 0.11353914971768908}


INFO:detector:max margin p values {'gamma': 0.0005219856478486662, 'norm': 2.2378205144502061e-07, 'expon': 0.11353914971768908}


average maximum margins p_values {'gamma': 4.2711373798964836e-05, 'norm': 1.571387464593954e-10, 'expon': 0.12961855491606888}


INFO:detector:average maximum margins p_values {'gamma': 4.2711373798964836e-05, 'norm': 1.571387464593954e-10, 'expon': 0.12961855491606888}


safe margins p values {'gamma': 0.6636802870468987, 'norm': 0.5351735453435438, 'expon': 0.8164729401067057}


INFO:detector:safe margins p values {'gamma': 0.6636802870468987, 'norm': 0.5351735453435438, 'expon': 0.8164729401067057}


detected as a malicious model and it is malicous


INFO:detector:detected as a malicious model and it is malicous
  metadata = torch.load(os.path.join(test_root_dir, 'metadata.pt'))





model id: 2 is clean with 10 classes





  model.load_state_dict(torch.load(model_path))
INFO:detector:


model id: 2 is clean with 10 classes






Processing Batch 1/1: Classes [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


INFO:detector:
Processing Batch 1/1: Classes [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


  Class 0: 9 images to optimize.


INFO:detector:  Class 0: 9 images to optimize.


  Class 1: 9 images to optimize.


INFO:detector:  Class 1: 9 images to optimize.


  Class 2: 9 images to optimize.


INFO:detector:  Class 2: 9 images to optimize.


  Class 3: 9 images to optimize.


INFO:detector:  Class 3: 9 images to optimize.


  Class 4: 9 images to optimize.


INFO:detector:  Class 4: 9 images to optimize.


  Class 5: 9 images to optimize.


INFO:detector:  Class 5: 9 images to optimize.


  Class 6: 9 images to optimize.


INFO:detector:  Class 6: 9 images to optimize.


  Class 7: 9 images to optimize.


INFO:detector:  Class 7: 9 images to optimize.


  Class 8: 9 images to optimize.


INFO:detector:  Class 8: 9 images to optimize.


  Class 9: 9 images to optimize.


INFO:detector:  Class 9: 9 images to optimize.
100%|██████████| 1000/1000 [01:08<00:00, 14.65it/s]

  Maximum Margin for class 0: 26.0299



INFO:detector:  Maximum Margin for class 0: 26.0299


  Maximum Margin for class 1: 34.7713


INFO:detector:  Maximum Margin for class 1: 34.7713


  Maximum Margin for class 2: 19.7186


INFO:detector:  Maximum Margin for class 2: 19.7186


  Maximum Margin for class 3: 27.0187


INFO:detector:  Maximum Margin for class 3: 27.0187


  Maximum Margin for class 4: 24.0070


INFO:detector:  Maximum Margin for class 4: 24.0070


  Maximum Margin for class 5: 78.4388


INFO:detector:  Maximum Margin for class 5: 78.4388


  Maximum Margin for class 6: 21.7980


INFO:detector:  Maximum Margin for class 6: 21.7980


  Maximum Margin for class 7: 32.4394


INFO:detector:  Maximum Margin for class 7: 32.4394


  Maximum Margin for class 8: 38.1910


INFO:detector:  Maximum Margin for class 8: 38.1910


  Maximum Margin for class 9: 29.2207


INFO:detector:  Maximum Margin for class 9: 29.2207


finding adversaries has taken 68.49 seconds


INFO:detector:finding adversaries has taken 68.49 seconds


working on class 5


INFO:detector:working on class 5


max margin p values {'gamma': 5.141304049161022e-09, 'norm': 0.0, 'expon': 0.009275590900068575}


INFO:detector:max margin p values {'gamma': 5.141304049161022e-09, 'norm': 0.0, 'expon': 0.009275590900068575}


average maximum margins p_values {'gamma': 9.395163330339074e-05, 'norm': 3.928375336137435e-08, 'expon': 0.04507010512425069}


INFO:detector:average maximum margins p_values {'gamma': 9.395163330339074e-05, 'norm': 3.928375336137435e-08, 'expon': 0.04507010512425069}


safe margins p values {'gamma': 0.32574485658904206, 'norm': 0.06399903348242941, 'expon': 0.14458735802543055}


INFO:detector:safe margins p values {'gamma': 0.32574485658904206, 'norm': 0.06399903348242941, 'expon': 0.14458735802543055}


detected as a malicious model and it is clean


INFO:detector:detected as a malicious model and it is clean
  metadata = torch.load(os.path.join(test_root_dir, 'metadata.pt'))
  model.load_state_dict(torch.load(model_path))





model id: 3 is clean with 10 classes





INFO:detector:


model id: 3 is clean with 10 classes






Processing Batch 1/1: Classes [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


INFO:detector:
Processing Batch 1/1: Classes [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


  Class 0: 9 images to optimize.


INFO:detector:  Class 0: 9 images to optimize.


  Class 1: 9 images to optimize.


INFO:detector:  Class 1: 9 images to optimize.


  Class 2: 9 images to optimize.


INFO:detector:  Class 2: 9 images to optimize.


  Class 3: 9 images to optimize.


INFO:detector:  Class 3: 9 images to optimize.


  Class 4: 9 images to optimize.


INFO:detector:  Class 4: 9 images to optimize.


  Class 5: 9 images to optimize.


INFO:detector:  Class 5: 9 images to optimize.


  Class 6: 9 images to optimize.


INFO:detector:  Class 6: 9 images to optimize.


  Class 7: 9 images to optimize.


INFO:detector:  Class 7: 9 images to optimize.


  Class 8: 9 images to optimize.


INFO:detector:  Class 8: 9 images to optimize.


  Class 9: 9 images to optimize.


INFO:detector:  Class 9: 9 images to optimize.
 67%|██████▋   | 670/1000 [00:45<00:22, 14.82it/s]


KeyboardInterrupt: 

In [None]:
! python tester.py

device is cuda
  metadata = torch.load(os.path.join(test_root_dir, 'metadata.pt'))
  model.load_state_dict(torch.load(model_path))
eval_dataset/0/test_dataset

Processing Batch 1/1: Classes [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
  Class 0: 9 images to optimize.
  Class 1: 9 images to optimize.
  Class 2: 9 images to optimize.
  Class 3: 9 images to optimize.
  Class 4: 9 images to optimize.
  Class 5: 9 images to optimize.
  Class 6: 9 images to optimize.
  Class 7: 9 images to optimize.
  Class 8: 9 images to optimize.
  Class 9: 9 images to optimize.
Traceback (most recent call last):
  File "/content/tester.py", line 49, in <module>
    pred.append(backdoor_model_detector(model, num_classes, images_root_dir, transformation))
  File "/content/main.py", line 73, in wrapper
    result = func(*args, **kwargs)
  File "/content/main.py", line 447, in backdoor_model_detector
    max_margins, triggers_per_class, all_margins_per_class = compute_max_margin_parallel_with_epsilon(model,
  File "/conten

In [None]:
! ls -lrth

total 1.5G
-rw-r--r-- 1 root root 1.5G Nov  1 05:13 eval_dataset.zip
drwxr-xr-x 1 root root 4.0K Dec 12 14:22 sample_data
drwxr-xr-x 3 root root 4.0K Dec 15 14:18 41
drwxr-xr-x 3 root root 4.0K Dec 15 14:18 13
drwxr-xr-x 3 root root 4.0K Dec 15 14:18 44
drwxr-xr-x 3 root root 4.0K Dec 15 14:18 43
drwxr-xr-x 3 root root 4.0K Dec 15 14:18 19
drwxr-xr-x 3 root root 4.0K Dec 15 14:18 40
drwxr-xr-x 3 root root 4.0K Dec 15 14:18 14
drwxr-xr-x 3 root root 4.0K Dec 15 14:18 17
drwxr-xr-x 3 root root 4.0K Dec 15 14:18 32
drwxr-xr-x 3 root root 4.0K Dec 15 14:18 26


In [None]:
for idx in range(50):
  model, num_classes, ground_truth, transformation, images_root_dir = load_test(idx)
  if num_classes != 10:
    print(idx, num_classes)

  metadata = torch.load(os.path.join(test_root_dir, 'metadata.pt'))
  model.load_state_dict(torch.load(model_path))


In [None]:
! python tester.py

device is cuda
  metadata = torch.load(os.path.join(test_root_dir, 'metadata.pt'))
  model.load_state_dict(torch.load(model_path))
eval_dataset/0/test_dataset

Processing Batch 1/1: Classes [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
  Class 0: 9 images to optimize.
  Class 1: 9 images to optimize.
  Class 2: 9 images to optimize.
  Class 3: 9 images to optimize.
  Class 4: 9 images to optimize.
  Class 5: 9 images to optimize.
  Class 6: 9 images to optimize.
  Class 7: 9 images to optimize.
  Class 8: 9 images to optimize.
  Class 9: 9 images to optimize.
  Maximum Margin for class 0: 31.3799
  Maximum Margin for class 1: 23.2311
  Maximum Margin for class 2: 32.5545
  Maximum Margin for class 3: 41.6083
  Maximum Margin for class 4: 47.4671
  Maximum Margin for class 5: 35.0290
  Maximum Margin for class 6: 41.7243
  Maximum Margin for class 7: 50.1137
  Maximum Margin for class 8: 62.0549
  Maximum Margin for class 9: 30.8738
finding adversaries has taken 68.97 seconds
working on class 8
max ma

In [None]:
! git clone https://github.com/kuangliu/pytorch-cifar.git

Cloning into 'pytorch-cifar'...
remote: Enumerating objects: 382, done.[K
remote: Total 382 (delta 0), reused 0 (delta 0), pack-reused 382 (from 1)[K
Receiving objects: 100% (382/382), 81.31 KiB | 3.13 MiB/s, done.
Resolving deltas: 100% (198/198), done.


In [None]:
! ls

pytorch-cifar  sample_data


In [None]:
! cd pytorch-cifar ; python main.py

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified
==> Building model..

Epoch: 0
Saving..

Epoch: 1
