In [4]:
import os
import pickle
import matplotlib.pyplot as plt
import numpy as np
import torch


In [6]:
from utils.setup_cifar import CIFAR, ResNet18_model
import random
from torch.utils.data import DataLoader, Dataset, TensorDataset
from typing import Optional, Tuple, Union

class CustomDataset(Dataset):
    def __init__(self, data: torch.Tensor, label: torch.Tensor) -> None:
        self.data = data
        self.label = label

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        return (self.data[idx], self.label[idx])

    def __len__(
        self,
    ) -> int:
        return len(self.label)


class ImageProcessor:
    def __init__(self, dataset, labels, original_shape=(28, 28, 1), device='cpu'):
        self.dataset = torch.tensor(dataset, dtype=torch.float32, device=device) if isinstance(dataset, np.ndarray) else dataset
        self.original_shape = original_shape
        self.device = device
        self.original_dataset = self.dataset.clone()  
        self.original_labels = torch.tensor(labels,  dtype=torch.float32, device=device) if isinstance(labels, np.ndarray) else labels

    def flatten(self):
        self.dataset = self.dataset.reshape(self.dataset.size(0), -1)
        return self

    def reshape(self, shape):
        self.dataset = self.dataset.view(shape)
        return self
    
    def restore_shape(self):
        self.dataset = self.dataset.view(-1, *self.original_shape)
        return self

    def to_tensor(self, input=None):
        if input is None:
            self.dataset = torch.tensor(self.dataset, dtype=torch.float32, device=self.device)
            self.original_labels = torch.tensor(self.original_labels, dtype=torch.float32, device=self.device)
            return self
        else:
            return torch.tensor(input, dtype=torch.float32, device=self.device)

    @staticmethod
    def to_numpy(input):
        return input.cpu().numpy()

    @staticmethod
    def apply_arctanh(input):
        print("Applying arctanh transformation.")
        return torch.arctanh(1.9999999 * input)

    @staticmethod
    def apply_tanh(input):
        print("Applying tanh transformation.")
        return torch.tanh(0.5 * input)

    def apply_tanh_arctanh_noise(self, noise):
        if isinstance(noise, np.ndarray):
            noise = torch.tensor(noise, dtype=torch.float32, device=self.device)
        
        noise = noise.unsqueeze(0).expand(self.original_dataset.size(0), -1, -1, -1)
        return 0.5 * torch.tanh(torch.arctanh(1.999999 * self.original_dataset) + noise)

    def compute_l2_distortion(self, noise):
        noisy_transformed = self.apply_tanh_arctanh_noise(noise)
        distortion = torch.norm(noisy_transformed - self.original_dataset, p=2, dim=(1, 2, 3))
        return distortion.tolist()

    def create_data_loader(self, batch_size=32, shuffle=True):
        self.data_set = CustomDataset(self.dataset, self.original_labels)
        data_loader = DataLoader(self.data_set, batch_size=batch_size, shuffle=shuffle)
        return data_loader

    def transform(self, flatten=False, reshape_shape=None, apply_arctanh=False, apply_tanh=False, to_numpy=False, to_tensor=False):
        output = self.dataset
        if reshape_shape:
            output = self.reshape(reshape_shape).dataset
        if flatten:
            output = self.flatten().dataset
        if apply_arctanh:
            output = self.apply_arctanh(output)
        if apply_tanh:
            output = self.apply_tanh(output)
        if to_tensor:
            output = self.to_tensor(output)
        if to_numpy:
            output = self.to_numpy(output)
        
        return output


def generate_cifer_data(data, model, target_label, num_sample, device, random_seed=None, batch_size = 64):

    if random_seed is not None:
        np.random.seed(random_seed)
        random.seed(random_seed)
        torch.manual_seed(random_seed)
    data_tensor = data.test_data 
    model_input = (( data_tensor  + 0.5) - data.mean)/data.std
    
    pred_labels = []
    
    with torch.no_grad():
        for i in range(0, len(model_input), batch_size):
            batch_input = model_input[i:i + batch_size]
            pred_logits = model(batch_input)
            batch_pred_labels = torch.argmax(pred_logits, dim=1)
            pred_labels.append(batch_pred_labels)

    pred_labels = torch.cat(pred_labels, dim=0)

    true_labels = data.test_labels  

    print('the pred label (for first 10 samples)= ', pred_labels[0:10])
    print('the true label (for first 10 samples)= ', true_labels[0:10])

    correct_mask = (pred_labels == true_labels)
    
    if target_label == 'All':
        correct_data_indices = torch.nonzero(correct_mask).squeeze().tolist()
        total_target_data = data_tensor.shape[0]
    else:
        target_mask = (true_labels == target_label)
        correct_mask = correct_mask & target_mask
        correct_data_indices = torch.nonzero(correct_mask).squeeze().tolist()
        total_target_data = target_mask.sum().item()

    if len(correct_data_indices) < num_sample:
        selected_indices = correct_data_indices
    else:
        selected_indices = random.sample(correct_data_indices, num_sample)

    selected_indices = torch.tensor(selected_indices, device=device)
    selected_data = data_tensor[selected_indices]
    selected_labels = true_labels[selected_indices]


    selected_labels_one_hot = torch.nn.functional.one_hot(selected_labels, num_classes=len(data.label_name))

    print(
        f"Generated CIFAR dataset with target label = {target_label},  total correct predictions = {len(correct_data_indices)}/{total_target_data}, prediction accuracy = {len(correct_data_indices) / total_target_data:.4f}."
    )
    return selected_data, selected_labels_one_hot, selected_indices

# fix random seed
random_seed = 42
torch.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)
np.random.seed(random_seed)
random.seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

attack_target_label = 'All'
device = 'mps'
data_all, pred_model =  CIFAR( device ), ResNet18_model()
pred_model.load_state_dict(torch.load('./models/cifar10-resnet18.pth', map_location=torch.device(  device ),  weights_only=True ))  #, map_location=torch.device( cfg.device ) , weights_only=True))
pred_model.to(device)

origImgs, origLabels, origImgID = generate_cifer_data(data_all, pred_model, target_label = attack_target_label, num_sample = 100000, device = device)
img_shape = (3, 32, 32)
processor = ImageProcessor(origImgs, origLabels, original_shape = img_shape, device =  device )

flattened_data = processor.transform( flatten=True, apply_arctanh=True)

labels = processor.original_labels
print(f"the flattened data with shape {flattened_data.shape} and the dtype {flattened_data.dtype}" )


Data shape: (50000, 32, 32, 3)
Labels length: 50000
the pred label (for first 10 samples)=  tensor([6, 9, 9, 4, 1, 1, 2, 7, 8, 3], device='mps:0')
the true label (for first 10 samples)=  tensor([6, 9, 9, 4, 1, 1, 2, 7, 8, 3], device='mps:0')
Generated CIFAR dataset with target label = All,  total correct predictions = 45628/50000, prediction accuracy = 0.9126.
Applying arctanh transformation.
the flattened data with shape torch.Size([45628, 3072]) and the dtype torch.float32
