In [None]:
import os
root = os.path.join('..', '..')

In [None]:
import sys
sys.path.append(root)

In [None]:

import dataclasses
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Literal, Optional, Tuple

import matplotlib.pyplot as plt
import patchworklib as pw
import torch
from torch import Tensor

from utils.classifiers import OneHiddenNet
from utils.decision_map import (get_axis_vec, get_decision_map,
                                get_inputs_for_decision_map)
from utils.fig import Axes, Figure
from utils.utils import freeze, gpu

# Global Setting & Variables

In [None]:
Figure.set_tex()
Figure.set_high_dpi()
device = gpu(0)
resolution = 500
limit = 3.4
ylabels = ('Standard', 'Adversarial', 'Noise')

In [None]:
Figure.set_font_scale(1.4)

# Data Utils

In [None]:
@dataclasses.dataclass
class DataUtil:
    in_dim: int
    hidden_dim: int
    n_sample: int
    n_noise_sample: int
    norm: Literal['L0', 'L2', 'Linf']
    mode: Literal['uniform', 'gauss']
    perturbation_constraint: float
    seed: int

    def __post_init__(self) -> None:
        self.d = self._load_data()

    def _load_data(self) -> Dict[str, Any]:
        fname = f'{self.in_dim}_{self.hidden_dim}_{self.n_sample}_{self.n_noise_sample}' + \
                f'_{self.norm}_{self.mode}_{self.perturbation_constraint}_{self.seed}'
        path = os.path.join(root, 'artificial', fname)
        return torch.load(path, map_location='cpu')
    
    def _define_classifier(self) -> OneHiddenNet:
        classifier = OneHiddenNet(self.in_dim, self.hidden_dim)
        classifier.to(device)

        freeze(classifier)
        classifier.eval()
        
        return classifier

    @torch.no_grad()
    def get_decision_maps_and_acc_list(self) -> Tuple[Tensor, Tensor]:
        axis_vec_1, axis_vec_2 = get_axis_vec(self.d['classifier']['linear.weight'])
        inputs = get_inputs_for_decision_map(axis_vec_1, axis_vec_2, resolution, limit)
        inputs = inputs.to(device)

        classifier = self._define_classifier()

        decision_maps = torch.empty(3, resolution, resolution)
        acc_list = torch.empty(3)

        weight_keys = ('classifier', 'adv_classifier', 'noise_classifier')
        acc_keys = ('acc', 'adv_acc_for_natural', 'noise_acc_for_natural')

        for i, (weight_key, acc_key) in enumerate(zip(weight_keys, acc_keys)):
            classifier.load_state_dict(self.d[weight_key])
            decision_maps[i] = get_decision_map(classifier, inputs)
            acc_list[i] = self.d[acc_key]

        return decision_maps, acc_list

# Figure Utils

## Superclass

In [None]:
class Block(ABC):
    suptitle: str
    ylabels: Optional[Tuple[str, str, str]]
    variables: Any

    def __post_init__(self) -> None:
        self.titles = self.variables

    def _embed_decision_map_into_brick(
        self,
        brick: pw.Brick, 
        decision_map: Tensor, 
        acc: float, 
        title: Optional[float] = None, 
        ylabel: Optional[str] = None,
    ) -> None:
        ax = Axes(brick)
        ax.imshow(decision_map, True)
        ax.set_xlabel( f'{int(100*acc)}' + r'\%' )
        if title is not None:
            ax.set_title(f'{title:,}')
        if ylabel is not None:
            ax.set_ylabel(ylabel)

    def _construct_block(
        self,
        decision_map_2dlist: Tensor,
        acc_2dlist: Tensor, 
        suptitle: str,
        titles: Tuple[float, float, float, float], 
        ylabels: Optional[Tuple[str, str, str]],
    ) -> pw.Bricks:
        
        row_bricks_list: List[pw.Bricks] = []
        for i, (decision_map_list, acc_list) in enumerate(zip(decision_map_2dlist, acc_2dlist)):

            col_brick_list: List[pw.Brick] = []
            for j, (decision_map, acc) in enumerate(zip(decision_map_list, acc_list)):

                brick = pw.Brick()
                col_brick_list.append(brick)
                
                title = titles[j] if i == 0 else None
                ylabel = ylabels[i] if j == 0 and ylabels is not None else None

                self._embed_decision_map_into_brick(brick, decision_map, acc.item(), title, ylabel)

            row_bricks = pw.stack(col_brick_list, 0.05, '|')
            row_bricks_list.append(row_bricks)

        bricks: pw.Bricks = pw.stack(row_bricks_list, 0.05, '/')
        bricks.set_suptitle(suptitle)
        return bricks
    
    @abstractmethod
    def _define_artificial_instance(self, var: Any) -> DataUtil:
        pass
    
    def __call__(self) -> pw.Bricks:
        decision_map_block = torch.empty(3, 4, resolution, resolution)
        acc_block = torch.empty(3, 4)

        for i, var in enumerate(self.variables):
            n = self._define_artificial_instance(var)
            decision_maps, acc_list = n.get_decision_maps_and_acc_list()
            decision_map_block[:, i] = decision_maps
            acc_block[:, i] = acc_list

        return self._construct_block(decision_map_block, acc_block, self.suptitle, self.titles, self.ylabels)

## Input Dimension Block

In [None]:
@dataclasses.dataclass
class InputDimensionBlock(Block):
    in_dims: Tuple[int, int, int, int]
    hidden_dim: int
    n_sample: int
    n_noise_sample: int
    norm: Literal['L0', 'L2', 'Linf']
    mode: Literal['uniform', 'gauss']
    perturbation_constraints: Tuple[float, float, float, float]
    seed: int
    ylabels: Optional[Tuple[str, str, str]] = None
    suptitle: str = r'Input dimension $d$'

    def __post_init__(self) -> None:
        self.variables = [(i, j) for i, j in zip(self.in_dims, self.perturbation_constraints)]
        self.titles = self.in_dims

    def _define_artificial_instance(self, var: Tuple[int, float]) -> DataUtil:
        return DataUtil(var[0], self.hidden_dim, self.n_sample, self.n_noise_sample, 
                        self.norm, self.mode, var[1], self.seed)            

## Natural Sample Block

In [None]:
@dataclasses.dataclass
class NaturalSampleBlock(Block):
    in_dim: int
    hidden_dim: int
    n_samples: Tuple[int, int, int, int]
    n_noise_sample: int
    norm: Literal['L0', 'L2', 'Linf']
    mode: Literal['uniform', 'gauss']
    perturbation_constraint: float
    seed: int
    ylabels: Optional[Tuple[str, str, str]] = None
    suptitle: str = r'Natural sample $N$'

    def __post_init__(self) -> None:
        self.variables = self.n_samples
        super().__post_init__()

    def _define_artificial_instance(self, var: int) -> DataUtil:
        return DataUtil(self.in_dim, self.hidden_dim, var, self.n_noise_sample, 
                        self.norm, self.mode, self.perturbation_constraint, self.seed)            

## Noise Sample Block

In [None]:
@dataclasses.dataclass
class NoiseSampleBlock(Block):
    in_dim: int
    hidden_dim: int
    n_sample: int
    n_noise_samples: Tuple[int, int, int, int]
    norm: Literal['L0', 'L2', 'Linf']
    mode: Literal['uniform', 'gauss']
    perturbation_constraint: float
    seed: int
    ylabels: Optional[Tuple[str, str, str]] = None
    suptitle: str = r'(Adversarial) noise sample $N^{\mathrm{adv}}$'

    def __post_init__(self) -> None:
        self.variables = self.n_noise_samples
        super().__post_init__()

    def _define_artificial_instance(self, var: int) -> DataUtil:
        return DataUtil(self.in_dim, self.hidden_dim, self.n_sample, var, 
                        self.norm, self.mode, self.perturbation_constraint, self.seed)            

## Perturbation Constraint Block

In [None]:
@dataclasses.dataclass
class PerturbationConstraintBlock(Block):
    in_dim: int
    hidden_dim: int
    n_sample: int
    n_noise_samples: int
    norm: Literal['L0', 'L2', 'Linf']
    mode: Literal['uniform', 'gauss']
    perturbation_constraints: Tuple[float, float, float, float]
    seed: int
    ylabels: Optional[Tuple[str, str, str]] = None
    suptitle: Optional[str] = None

    def __post_init__(self) -> None:
        if self.suptitle is None:
            if self.norm in ['L2', 'Linf']:
                self.suptitle = r'Perturbation constraint $\epsilon$'
            else:
                self.suptitle = r'Modified pixel ratio $d_\delta/d$'
        self.variables = self.perturbation_constraints
        super().__post_init__()

    def _define_artificial_instance(self, var: int) -> DataUtil:
        return DataUtil(self.in_dim, self.hidden_dim, self.n_sample, self.n_noise_samples, 
                        self.norm, self.mode, var, self.seed)            

# All Blocks

In [None]:
def all_blocks(
    in_dim: int,
    in_dims: Tuple[int, int, int, int],
    perturbation_constraints_along_with_in_dims: Tuple[float, float, float, float],
    hidden_dim: int,
    n_sample: int,
    n_samples: Tuple[int, int, int, int],
    n_noise_sample: int,
    n_noise_samples: Tuple[int, int, int, int],
    norm: Literal['L0', 'L2', 'Linf'],
    mode: Literal['uniform', 'gauss'],
    perturbation_constraint: float,
    perturbation_constraints: Tuple[float, float, float, float],
    seed: int,
) -> None:
    in_dim_block = InputDimensionBlock(in_dims, hidden_dim, n_sample, n_noise_sample, 
                                       norm, mode, perturbation_constraints_along_with_in_dims, seed, ylabels)()
    noise_sample_block = NoiseSampleBlock(in_dim, hidden_dim, n_sample, n_noise_samples, 
                                          norm, mode, perturbation_constraint, seed)()
    natural_sample_block = NaturalSampleBlock(in_dim, hidden_dim, n_samples, n_noise_sample, 
                                              norm, mode, perturbation_constraint, seed, ylabels)()
    perturbation_constraint_block = PerturbationConstraintBlock(in_dim, hidden_dim, n_sample, n_noise_sample, 
                                                                norm, mode, perturbation_constraints, seed)()
    
    top = pw.stack([in_dim_block, noise_sample_block], margin=0.2, operator='|')
    bottom = pw.stack([natural_sample_block, perturbation_constraint_block], margin=0.2, operator='|')
    all = pw.stack([top, bottom], margin=0.2, operator='/')

    path = os.path.join(root, 'figs', f'decision_maps_{norm}_{mode}.pdf')
    all.savefig(path, bbox_inches='tight', pad_inches=0.025)

    pw.clear()
    plt.close()

# L0 / Uniform

In [None]:
in_dim = 10000
in_dims = (100, 500, 1000, 10000)
perturbation_constraints_along_with_in_dims = (0.05, 0.05, 0.05, 0.05)
hidden_dim = 1000
n_sample = 1000
n_samples = (1000, 2000, 5000, 10000)
n_noise_sample = 10000
n_noise_samples = (1, 10, 100, 10000)
norm = 'L0'
mode = 'uniform'
perturbation_constraint = 0.05
perturbation_constraints = (0.0001, 0.0004, 0.001, 0.05)
seed = 5

all_blocks(
    in_dim,
    in_dims,
    perturbation_constraints_along_with_in_dims,
    hidden_dim,
    n_sample,
    n_samples,
    n_noise_sample,
    n_noise_samples,
    norm,
    mode,
    perturbation_constraint,
    perturbation_constraints,
    seed,
)

# L0 / Gauss

In [None]:
in_dim = 10000
in_dims = (100, 500, 1000, 10000)
perturbation_constraints_along_with_in_dims = (0.05, 0.05, 0.05, 0.05)
hidden_dim = 1000
n_sample = 1000
n_samples = (1000, 2000, 5000, 10000)
n_noise_sample = 10000
n_noise_samples = (1, 10, 100, 10000)
norm = 'L0'
mode = 'gauss'
perturbation_constraint = 0.05
perturbation_constraints = (0.0001, 0.0004, 0.001, 0.05)
seed = 2

all_blocks(
    in_dim,
    in_dims,
    perturbation_constraints_along_with_in_dims,
    hidden_dim,
    n_sample,
    n_samples,
    n_noise_sample,
    n_noise_samples,
    norm,
    mode,
    perturbation_constraint,
    perturbation_constraints,
    seed,
)

# L2 / Uniform

In [None]:
in_dim = 10000
in_dims = (100, 500, 1000, 10000)
perturbation_constraints_along_with_in_dims = (0.078, 0.17, 0.24, 0.78)
hidden_dim = 1000
n_sample = 1000
n_samples = (1000, 2000, 5000, 10000)
n_noise_sample = 10000
n_noise_samples = (1, 10, 100, 10000)
norm = 'L2'
mode = 'uniform'
perturbation_constraint = 0.78
perturbation_constraints = (0.01, 0.05, 0.1, 0.78)
seed = 5

all_blocks(
    in_dim,
    in_dims,
    perturbation_constraints_along_with_in_dims,
    hidden_dim,
    n_sample,
    n_samples,
    n_noise_sample,
    n_noise_samples,
    norm,
    mode,
    perturbation_constraint,
    perturbation_constraints,
    seed,
)

# L2 / Gauss

In [None]:
in_dim = 10000
in_dims = (100, 500, 1000, 10000)
perturbation_constraints_along_with_in_dims = (0.078, 0.17, 0.24, 0.78)
hidden_dim = 1000
n_sample = 1000
n_samples = (1000, 2000, 5000, 10000)
n_noise_sample = 10000
n_noise_samples = (1, 10, 100, 10000)
norm = 'L2'
mode = 'gauss'
perturbation_constraint = 0.78
perturbation_constraints = (0.01, 0.05, 0.1, 0.78)
seed = 2

all_blocks(
    in_dim,
    in_dims,
    perturbation_constraints_along_with_in_dims,
    hidden_dim,
    n_sample,
    n_samples,
    n_noise_sample,
    n_noise_samples,
    norm,
    mode,
    perturbation_constraint,
    perturbation_constraints,
    seed,
)

# Linf / Uniform

In [None]:
in_dim = 10000
in_dims = (100, 500, 1000, 10000)
perturbation_constraints_along_with_in_dims = (0.03, 0.03, 0.03, 0.03)
hidden_dim = 1000
n_sample = 1000
n_samples = (1000, 2000, 5000, 10000)
n_noise_sample = 10000
n_noise_samples = (1, 10, 100, 10000)
norm = 'Linf'
mode = 'uniform'
perturbation_constraint = 0.03
perturbation_constraints = (0.001, 0.005, 0.01, 0.03)
seed = 5

all_blocks(
    in_dim,
    in_dims,
    perturbation_constraints_along_with_in_dims,
    hidden_dim,
    n_sample,
    n_samples,
    n_noise_sample,
    n_noise_samples,
    norm,
    mode,
    perturbation_constraint,
    perturbation_constraints,
    seed,
)

# Linf / Gauss

In [None]:
in_dim = 10000
in_dims = (100, 500, 1000, 10000)
perturbation_constraints_along_with_in_dims = (0.03, 0.03, 0.03, 0.03)
hidden_dim = 1000
n_sample = 1000
n_samples = (1000, 2000, 5000, 10000)
n_noise_sample = 10000
n_noise_samples = (1, 10, 100, 10000)
norm = 'Linf'
mode = 'gauss'
perturbation_constraint = 0.03
perturbation_constraints = (0.001, 0.005, 0.01, 0.03)
seed = 2

all_blocks(
    in_dim,
    in_dims,
    perturbation_constraints_along_with_in_dims,
    hidden_dim,
    n_sample,
    n_samples,
    n_noise_sample,
    n_noise_samples,
    norm,
    mode,
    perturbation_constraint,
    perturbation_constraints,
    seed,
)

# Two Blocks

In [None]:
Figure.set_font_scale(1.9)

In [None]:
def two_blocks(
    in_dim: int,
    in_dims: Tuple[int, int, int, int],
    perturbation_constraints_along_with_in_dims: Tuple[float, float, float, float],
    hidden_dim: int,
    n_sample: int,
    n_noise_sample: int,
    n_noise_samples: Tuple[int, int, int, int],
    norm: Literal['L2', 'L0'],
    mode: Literal['uniform', 'gauss'],
    perturbation_constraint: float,
    seed: int,
):
    in_dim_block = InputDimensionBlock(in_dims, hidden_dim, n_sample, n_noise_sample, 
                                       norm, mode, perturbation_constraints_along_with_in_dims, seed, ylabels)()
    noise_sample_block = NoiseSampleBlock(in_dim, hidden_dim, n_sample, n_noise_samples, 
                                          norm, mode, perturbation_constraint, seed)()
    
    b = pw.stack([in_dim_block, noise_sample_block], margin=0.2, operator='|')

    path = os.path.join(root, 'figs', f'decision_maps_{norm}_{mode}_two.pdf')
    b.savefig(path, bbox_inches='tight', pad_inches=0.025)

    pw.clear()
    Figure.close()

In [None]:
in_dim = 10000
in_dims = (100, 500, 1000, 10000)
perturbation_constraints_along_with_in_dims = (0.078, 0.17, 0.24, 0.78)
hidden_dim = 1000
n_sample = 1000
n_samples = (1000, 2000, 5000, 10000)
n_noise_sample = 10000
n_noise_samples = (1, 10, 100, 10000)
norm = 'L2'
mode = 'uniform'
perturbation_constraint = 0.78
perturbation_constraints = (0.01, 0.05, 0.1, 0.78)
seed = 5

two_blocks(
    in_dim,
    in_dims,
    perturbation_constraints_along_with_in_dims,
    hidden_dim,
    n_sample,
    n_noise_sample,
    n_noise_samples,
    norm,
    mode,
    perturbation_constraint,
    seed,
)