# Imports

In [None]:
%pdb

from pathlib import Path
from itertools import chain
from collections import OrderedDict, defaultdict
from io import StringIO
from math import inf
import json

from PIL import Image

from tqdm.notebook import tqdm

import torch
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch.nn.functional as F
import torchvision.transforms.functional as TF

import cv2
import numpy as np

from brisque import BRISQUE
import kornia

from lib.models import build_model
from lib.attacks import BiasFieldAttack

# Constants

In [None]:
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
ATTACK_RANDOM_SEED = 2020
TEST_RANDOM_SEED = 1997

DATA_ROOT = Path('The dataset directory')

COSOD_DATASET_NAMES = ['Cosal2015', 'iCoseg', 'CoSOD3k', 'CoCA']

DATASET_NAMES = COSOD_DATASET_NAMES # + SOD_DATASET_NAMES

RESIZE_TARGET = "224"

RESIZED_COSOD_DATASET_NAMES = [f'{name}_{RESIZE_TARGET}' for name in COSOD_DATASET_NAMES]

# NOISE_LEVELS = [4, 8, 16, 32, 64]

IMAGE_EXTENSIONS = ['jpg', 'png', 'bmp']
GROUP_SIZE = 5

IMAGE_FOLDER = 'img'
DEFENSE_PREFIX = 'defense'
GT_FOLDER = 'gt'
RESULT_FOLDER = 'result'

BETA_SQUARE = .3
EPSILON = 1e-5

## Attack Configs

In [None]:
def sstd(f):
    return f.flatten(-2).std(-1).mean()

def nsstd(f):
    return f.transpose(0, 1).flatten(1).std(-1).mean()

CRITERIONS = {
    'l1': lambda r, gt: F.l1_loss(r[0], 1 - gt[0]),
    'l2': lambda r, gt: F.mse_loss(r[0], 1 - gt[0]),
    'bce': lambda r, gt: F.binary_cross_entropy(r[0], 1 - gt[0]),
    'l1_all': lambda r, gt: F.l1_loss(r, 1 - gt),
    'l2_all': lambda r, gt: F.mse_loss(r, 1 - gt),
    'bce_all': lambda r, gt: F.binary_cross_entropy(r, 1 - gt),
    'sstd': lambda r: sum(sstd(f) for f in r).div_(len(r)),
    'nsstd': lambda r: sum(nsstd(f) for f in r).div_(len(r)),
}

ATTACK_MODEL_CONFIGS = OrderedDict()
ATTACK_MODEL_CONFIGS['resneta50b_123'] = {
    'name': 'Cls',
    'cls_name': 'resneta50b', # resneta50b, resneta101b, hrnetv2_w48
    'stages': (1, 2, 3),
}
# ATTACK_MODEL_CONFIGS['resneta50b_1234'] = {
#     'name': 'Cls',
#     'cls_name': 'resneta50b', # resneta50b, resneta101b, hrnetv2_w48
#     'stages': (1, 2, 3, 4),
# }
# ATTACK_MODEL_CONFIGS['resneta50b_12345'] = {
#     'name': 'Cls',
#     'cls_name': 'resneta50b', # resneta50b, resneta101b, hrnetv2_w48
#     'stages': (1, 2, 3, 4, 5),
# }
# ATTACK_MODEL_CONFIGS['resneta50b_2'] = {
#     'name': 'Cls',
#     'cls_name': 'resneta50b', # resneta50b, resneta101b, hrnetv2_w48
#     'stages': (2,),
# }
# ATTACK_MODEL_CONFIGS['inceptionv3_123'] = {
#     'name': 'Cls',
#     'cls_name': 'inceptionv3', # resneta50b, resneta101b, hrnetv2_w48
#     'stages': (1, 2, 3),
# }
# ATTACK_MODEL_CONFIGS['gicd'] = {
#     'name': 'GICD',
#     'mode': 'cosal',
#     'detach_cls': False,
#     'weights_path': 'weights/GICD-GINet.pth',
# }
# ATTACK_MODEL_CONFIGS['gicdd'] = {
#     'name': 'GICD',
#     'mode': 'cosal',
#     'detach_cls': True,
#     'weights_path': 'weights/GICD-GINet.pth',
# }
# ATTACK_MODEL_CONFIGS['poolnet'] = {
#     'name': 'PoolNet',
#     'backbone': 'resnet',
#     'joint': True,
#     'mode': 'sal',
#     'weights_path': 'weights/PoolNet-ResNet50-w-edge.pth',
# }
# ATTACK_MODEL_CONFIGS['gcagc_cosal'] = {
#     'name': 'GCAGC',
#     'backbone': 'hrnet',
#     'mode': 'cosal',
#     'weights_path': 'weights/GCAGC-HRNet.pth',
# }

COSOD_ATTACK_CONFIGS = OrderedDict()
SOD_ATTACK_CONFIGS = OrderedDict()

### Resnet50 Black-Box

In [None]:
model = 'resneta50b_123'

# Jadena single
COSOD_ATTACK_CONFIGS[f'{model}_sstd_bf10o_mifgsm_single_neps16_step20_blr1e-1_b5e-1_s1e-2'] = {
    'model': ATTACK_MODEL_CONFIGS[model],
    'attack': {
        'criterion': CRITERIONS['sstd'],
        'step': 20,
        'noise_mode': 'add',
        'bias_mode': 'same',
        'spatial_mode': 'optical_flow',
        'noise_lr': 1. / 255.,
        'bias_lr': 1e-1,
        'spatial_lr': 1e-2,
        'lambda_b': 5e-1,
        'lambda_s': 1e-2,
        'momentum_decay': 1.0,
        'epsilon_n': 16. / 255.,
        'degree': 10,
    },
}

# Jadena augment group
# for co_mode in ('group',):
for co_mode in ('augment', 'group'):
    COSOD_ATTACK_CONFIGS[f'{model}_nsstd_bf10o_mifgsm_{co_mode}_neps16_step20_blr1e-1_b1e-2_s1e-2'] = {
        'model': ATTACK_MODEL_CONFIGS[model],
        'attack': {
            'criterion': CRITERIONS['nsstd'],
            'step': 20,
            'noise_mode': 'add',
            'bias_mode': 'same',
            'spatial_mode': 'optical_flow',
            'noise_lr': 1. / 255.,
            'bias_lr': 1e-1,
            'spatial_lr': 1e-2,
            'lambda_b': 1e-2,
            'lambda_s': 1e-2,
            'momentum_decay': 1.0,
            'epsilon_n': 16. / 255.,
            'degree': 10,
        },
        'co_mode': co_mode,
    }
    

In [None]:
print(*map(str, COSOD_ATTACK_CONFIGS.keys()), sep='\n')
print(*map(str, SOD_ATTACK_CONFIGS.keys()), sep='\n')

# Attack

In [None]:
def augment(im):
    return [im, im.flip(-1), im.flip(-2), im.rot90(dims=(-2, -1)), im.rot90(dims=(-1, -2))]

def attack(dataset_names, attack_configs, force=False):
    for attack_name, attack_config in tqdm(attack_configs.items(), desc='attack'):
        model = build_model(attack_config['model']).eval().to(DEVICE)
        attack = BiasFieldAttack(model, **attack_config['attack'])
        need_gt = attack_config.get('need_gt', False)
        co_mode = attack_config.get('co_mode', 'none')
        random_seed = attack_config.get('random_seed', ATTACK_RANDOM_SEED)
        dataset_keyword = attack_config.get('dataset_keyword', None)

        for dataset_name in tqdm(dataset_names, desc='dataset', leave=False):
            if not (dataset_keyword is None or dataset_keyword in dataset_name):
                continue
            source_root = DATA_ROOT / dataset_name / IMAGE_FOLDER
            target_root = DATA_ROOT / dataset_name / f'{IMAGE_FOLDER}_{attack_name}'
            gt_root = DATA_ROOT / dataset_name / GT_FOLDER

            source_paths = chain.from_iterable(source_root.rglob(f'*.{e}') for e in IMAGE_EXTENSIONS)
            source_paths = sorted(source_paths)

            if co_mode != 'none':
                torch.manual_seed(random_seed) # force the same intra-group image selection

            for source_path in tqdm(source_paths, leave=False):
                relative_path = source_path.relative_to(source_root)
                target_path = target_root.joinpath(relative_path).with_suffix('.png')

                source_paths = [source_path]
                if co_mode == 'group':
                    # consume random whether target_path exists or not
                    group_folder = source_path.parent
                    intragroup_paths = chain.from_iterable(group_folder.glob(f'*.{e}') for e in IMAGE_EXTENSIONS)
                    intragroup_paths = sorted(intragroup_paths)
                    random_index = torch.randperm(len(intragroup_paths))[:GROUP_SIZE - 1].tolist()
                    source_paths += [intragroup_paths[i] for i in random_index]

                if not force and target_path.exists():
                    continue

                ims = map(lambda p: Image.open(p).convert('RGB'), source_paths)
                ims = list(map(TF.to_tensor, ims))

                if co_mode == 'augment':
                    ims = augment(ims[0])

                ims = torch.stack(ims, dim=0).to(DEVICE)

                if need_gt:
                    relative_paths = [p.relative_to(source_root) for p in source_paths]
                    gt_paths = [gt_root.joinpath(p).with_suffix('.png') for p in relative_paths]
                    gts = map(lambda p: Image.open(p).convert('1'), source_paths)
                    gts = map(TF.to_tensor, gts)
                    if co_mode == 'augment':
                        gts = augment(gts[0])
                    gts = torch.stack(list(gts), dim=0).to(DEVICE)
                    pert, ex = attack(ims, gts)
                else:
                    pert, ex = attack(ims)
                pert = pert[0]

                pert = TF.to_pil_image(pert.cpu())
                target_path.parent.mkdir(exist_ok=True, parents=True)
                pert.save(target_path)

attack(
    dataset_names=RESIZED_COSOD_DATASET_NAMES,
    attack_configs=COSOD_ATTACK_CONFIGS,
    force=True
)