In [None]:
!pip install -q grad-cam==1.4.3
!pip install -q wandb
!pip install -q segmentation_models_pytorch
!pip install -q torchattacks
!pip install -q monai
!pip install -q torchsummary

# from kaggle_datasets import KaggleDatasets

In [None]:
import torch
from torch import nn
import torch.nn.functional as F

from torch.utils.data import DataLoader, random_split
import albumentations as A

import segmentation_models_pytorch as smp
# import torchsummary

import pandas as pd
import numpy as np
import random, shutil, time, os

import sklearn
import cv2
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import albumentations as A

from glob import glob
from tqdm.notebook import tqdm
from sklearn.model_selection import KFold, GroupKFold
from sklearn.metrics import roc_auc_score
# from skimage import color
from IPython import display as ipd

import scipy
import pdb
import gc

import torchattacks
import monai

from pytorch_grad_cam import GradCAM


from torch.cuda import amp

import warnings
warnings.filterwarnings('ignore')

print('done')

In [None]:
CFG = {
    'lr':3e-4,
    'shape':(224, 224),

}
TRAIN = True

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
def seed_everything(seed=44):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def clear_cache():
    torch.cuda.empty_cache()
    gc.collect()

In [None]:
DATA_DIR = '../input/uw-madison-gi-tract-image-segmentation'

# Open the training dataframe and display the initial dataframe
TRAIN_DIR = os.path.join(DATA_DIR, "train")
TRAIN_CSV = os.path.join(DATA_DIR, "train.csv")
train_df = pd.read_csv(TRAIN_CSV)

all_train_images = glob(os.path.join(TRAIN_DIR, "**", "*.png"), recursive=True)


def get_filepath_from_partial_identifier(_ident, file_list):
    return [x for x in file_list if _ident in x][0]

def df_preprocessing(df, globbed_file_list, is_test=False):
    """ The preprocessing steps applied to get column information """
    # 1. Get Case-ID as a column (str and int)
    df["case_id_str"] = df["id"].apply(lambda x: x.split("_", 2)[0])
    df["case_id"] = df["id"].apply(lambda x: int(x.split("_", 2)[0].replace("case", "")))

    # 2. Get Day as a column
    df["day_num_str"] = df["id"].apply(lambda x: x.split("_", 2)[1])
    df["day_num"] = df["id"].apply(lambda x: int(x.split("_", 2)[1].replace("day", "")))

    # 3. Get Slice Identifier as a column
    df["slice_id"] = df["id"].apply(lambda x: x.split("_", 2)[2])

    # 4. Get full file paths for the representative scans
    df["_partial_ident"] = (globbed_file_list[0].rsplit("/", 4)[0]+"/"+ # /kaggle/input/uw-madison-gi-tract-image-segmentation/train/
                           df["case_id_str"]+"/"+ # .../case###/
                           df["case_id_str"]+"_"+df["day_num_str"]+ # .../case###_day##/
                           "/scans/"+df["slice_id"]) # .../slice_####
    _tmp_merge_df = pd.DataFrame({"_partial_ident":[x.rsplit("_",4)[0] for x in globbed_file_list], "f_path":globbed_file_list})
    df = df.merge(_tmp_merge_df, on="_partial_ident").drop(columns=["_partial_ident"])

    # 5. Get slice dimensions from filepath (int in pixels)
    df["slice_h"] = df["f_path"].apply(lambda x: int(x[:-4].rsplit("_",4)[1]))
    df["slice_w"] = df["f_path"].apply(lambda x: int(x[:-4].rsplit("_",4)[2]))

    # 6. Pixel spacing from filepath (float in mm)
    df["px_spacing_h"] = df["f_path"].apply(lambda x: float(x[:-4].rsplit("_",4)[3]))
    df["px_spacing_w"] = df["f_path"].apply(lambda x: float(x[:-4].rsplit("_",4)[4]))

    if not is_test:
        # 7. Merge 3 Rows Into A Single Row (As This/Segmentation-RLE Is The Only Unique Information Across Those Rows)
        l_bowel_df = df[df["class"]=="large_bowel"][["id", "segmentation"]].rename(columns={"segmentation":"lb_seg_rle"})
        s_bowel_df = df[df["class"]=="small_bowel"][["id", "segmentation"]].rename(columns={"segmentation":"sb_seg_rle"})
        stomach_df = df[df["class"]=="stomach"][["id", "segmentation"]].rename(columns={"segmentation":"st_seg_rle"})
        df = df.merge(l_bowel_df, on="id", how="left")
        df = df.merge(s_bowel_df, on="id", how="left")
        df = df.merge(stomach_df, on="id", how="left")
        df = df.drop_duplicates(subset=["id",]).reset_index(drop=True)
        df["lb_seg_flag"] = df["lb_seg_rle"].apply(lambda x: not pd.isna(x))
        df["sb_seg_flag"] = df["sb_seg_rle"].apply(lambda x: not pd.isna(x))
        df["st_seg_flag"] = df["st_seg_rle"].apply(lambda x: not pd.isna(x))
        df["n_segs"] = df["lb_seg_flag"].astype(int)+df["sb_seg_flag"].astype(int)+df["st_seg_flag"].astype(int)

    # 8. Reorder columns to the a new ordering (drops class and segmentation as no longer necessary)
    new_col_order = ["id", "f_path", "n_segs",
                     "lb_seg_rle", "lb_seg_flag",
                     "sb_seg_rle", "sb_seg_flag",
                     "st_seg_rle", "st_seg_flag",
                     "slice_h", "slice_w", "px_spacing_h",
                     "px_spacing_w", "case_id_str", "case_id",
                     "day_num_str", "day_num", "slice_id",]
    if is_test: new_col_order.insert(1, "class")
    new_col_order = [_c for _c in new_col_order if _c in df.columns]
    df = df[new_col_order]

    return df


# all_test_images = glob(os.path.join(TEST_DIR, "**", "*.png"), recursive=True)

train_df = df_preprocessing(train_df, all_train_images)

df = pd.read_csv(f'../input/uw-madison-gi-tract-image-segmentation/train.csv')
df['segmentation'] = df.segmentation.fillna('')
df['rle_len'] = df.segmentation.map(len) # length of each rle mask

df2 = df.groupby(['id'])['segmentation'].agg(list).to_frame().reset_index() # rle list of each id
df2 = df2.merge(df.groupby(['id'])['rle_len'].agg(sum).to_frame().reset_index()) # total length of all rles of each id
df = df.drop(columns=['segmentation', 'class', 'rle_len'])
df = df.groupby(['id']).head(1).reset_index(drop=True)
df = df.merge(df2, on=['id'])
df['empty'] = (df.rle_len==0) # empty masks

# 1. Get Case-ID as a column (str and int)
df["case_id_str"] = df["id"].apply(lambda x: x.split("_", 2)[0])
df["case_id"] = df["id"].apply(lambda x: int(x.split("_", 2)[0].replace("case", "")))

# 2. Get Day as a column
df["day_num_str"] = df["id"].apply(lambda x: x.split("_", 2)[1])
df["day_num"] = df["id"].apply(lambda x: int(x.split("_", 2)[1].replace("day", "")))

# 3. Get Slice Identifier as a column
df["slice_id"] = df["id"].apply(lambda x: x.split("_", 2)[2])

# 4. Get full file paths for the representative scans
df["_partial_ident"] = (all_train_images[0].rsplit("/", 4)[0]+"/"+ # /kaggle/input/uw-madison-gi-tract-image-segmentation/train/
                       df["case_id_str"]+"/"+ # .../case###/
                       df["case_id_str"]+"_"+df["day_num_str"]+ # .../case###_day##/
                       "/scans/"+df["slice_id"]) # .../slice_####
_tmp_merge_df = pd.DataFrame({"_partial_ident":[x.rsplit("_",4)[0] for x in all_train_images], "f_path":all_train_images})
df = df.merge(_tmp_merge_df, on="_partial_ident").drop(columns=["_partial_ident"])

# 5. Get slice dimensions from filepath (int in pixels)
df["slice_h"] = df["f_path"].apply(lambda x: int(x[:-4].rsplit("_",4)[1]))
df["slice_w"] = df["f_path"].apply(lambda x: int(x[:-4].rsplit("_",4)[2]))

df.rename(columns={
    'f_path':'path',
    'slice_h':'img_height',
    'slice_w':'img_width',
    'case_id':'case',
    'day_num':'day'
}, inplace=True)

df['slice'] = df['slice_id'].apply(lambda a : int(a.split('_')[1]))

df.drop(columns=['slice_id', 'case_id_str', 'day_num_str'], inplace=True)
if TRAIN:
    fault1 = 'case7_day0'
    fault2 = 'case81_day30'
    df = df[~df['id'].str.contains(fault1) & ~df['id'].str.contains(fault2)].reset_index(drop=True)

df['lb'] = df['segmentation'].map(lambda a: a[0] if a[0] != '' else '')
df['sb'] = df['segmentation'].map(lambda a: a[1] if a[1] != '' else '')
df['st'] = df['segmentation'].map(lambda a: a[2] if a[2] != '' else '') # I know it's stupid..

df['classes'] = df['segmentation'].map(lambda a: [(a[0] != '') + 0, (a[1] != '') + 0, (a[2] != '') + 0 ])
np.random.seed(80)
df = df.sample(frac=1).reset_index(drop=True)
print('done')

In [None]:
# https://www.kaggle.com/paulorzp/rle-functions-run-length-encode-decode
def rle_encode(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels= img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)


def rle_decode(mask_rle, wid, hei):
    shape = (wid, hei)
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape)


def img_read(path):
    img = cv2.imread(path, cv2.IMREAD_ANYDEPTH)
    return img


class Dataset2D(torch.utils.data.Dataset):
    def __init__(self, df_sub, train=True):
        self.train = train

        self.paths = np.array(df_sub['path'])
        self.rles = np.array(df_sub['segmentation'])
        self.classes = np.array(df_sub['classes'])
        self.wid = np.array(df_sub['img_width'])
        self.hei = np.array(df_sub['img_height'])


    def __len__(self):
        return len(self.paths)

    def transform(self, img, mask):
        trans = A.Compose([
#             A.ToFloat(max_value=65535.0), # essential because albu requires 32 bits!!! ONLY THIS can force it work with 16 bits!!

            A.VerticalFlip(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.RandomRotate90(p=0.5),


            A.ShiftScaleRotate(
                scale_limit=0.12,  # 0
                shift_limit=0.02,  # 0.05
                rotate_limit=15,
                border_mode=cv2.BORDER_CONSTANT,
                value=(1,1,1),
                always_apply=True,
                p=1,
            ),

            A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1),

            # A.OneOf([
            #         A.ElasticTransform(
            #             alpha=1,
            #             sigma=25,
            #             always_apply=True,
            #         ),
            #         A.GridDistortion(
            #             always_apply=True,
            #         ),
            #         A.OpticalDistortion(
            #             distort_limit=0.05,
            #             shift_limit=0.05,
            #             always_apply=True,
            #         ),
            #     ], p=1
            # ),
        ])
        return trans(image=img, mask=mask)


    def data_prep_aug(self, img, mask, classes):
        shape = CFG['shape']
        img = (cv2.resize(img, shape, interpolation=cv2.INTER_AREA) / img.max()).astype('float32')
        mask = cv2.resize(mask, shape, interpolation=cv2.INTER_AREA).astype('float32')

        if self.train:
            trans = self.transform(img, mask)
            img = trans['image'].reshape((1, shape[0], shape[1]))
            mask = trans['mask']

#         # normalize
#         img = (img - img.min()) / (img.max() - img.min())
        blank_img = np.zeros((shape[0], shape[1], 3))
        blank_img[:, :, 0] = img
        blank_img[:, :, 1] = img
        blank_img[:, :, 2] = img
        img = blank_img.transpose(2,1,0)

#         plt.imshow(img.reshape(256, 256, 3))
#         plt.pause(1)

#         mask_final = np.zeros((len(classes), shape[0], shape[1]))
#         for i in range(len(classes)):
#             mask_final[i, :, :] = mask[:,:,i]
        mask = mask.transpose(2,1,0)

        return torch.tensor(img, dtype=torch.float16, device=device), torch.tensor(mask, dtype=torch.float16, device=device)

    def __getitem__(self, idx):
        img = img_read(self.paths[idx])

        blank_mask = np.zeros((self.wid[idx],  self.hei[idx], 3))
        blank_mask[:, :, 0] = rle_decode(self.rles[idx][0], self.wid[idx], self.hei[idx])
        blank_mask[:, :, 1] = rle_decode(self.rles[idx][1], self.wid[idx], self.hei[idx])
        blank_mask[:, :, 2] = rle_decode(self.rles[idx][2], self.wid[idx], self.hei[idx])

        # data preprocessing and augmentation
        img, masks = self.data_prep_aug(img, blank_mask, self.classes[idx])

        return img, masks

In [None]:
def imshow(img, return_only=False, pause=False, show_axis=True):
    if isinstance(img, np.ndarray):
        if len(img.shape) == 4:
            img = img[0]
        if img.shape[0] == 3:
            img = img.transpose(2,1,0)

    if isinstance(img, torch.Tensor):
        if len(img.shape) == 4:
            img = img.cpu().detach()[0].numpy().transpose(2,1,0)
        elif len(img.shape) == 3:
            if img.shape[0] == 3:
                img = img.cpu().detach().numpy().transpose(2,1,0)
        elif len(img.shape) == 2:
            img = img.cpu().detach().numpy()

    if return_only:
        return img
    else:
#         plt.figure(figsize=(5,5))
        plt.subplots()
        plt.imshow(img)
        if pause:
            plt.pause(1)
        if not show_axis:
            plt.axis('off')

In [None]:
idx = 62

def read(idx):
    img = cv2.imread(df['path'][idx], cv2.IMREAD_ANYDEPTH)
    shape = CFG['shape']
    img = (cv2.resize(img, shape, interpolation=cv2.INTER_AREA) / img.max()).astype('float32')

    blank_mask = np.zeros((df.img_width[idx],  df.img_height[idx], 3))
    blank_mask[:, :, 0] = rle_decode(df.segmentation[idx][0], df.img_width[idx], df.img_height[idx])
    blank_mask[:, :, 1] = rle_decode(df.segmentation[idx][1], df.img_width[idx], df.img_height[idx])
    blank_mask[:, :, 2] = rle_decode(df.segmentation[idx][2], df.img_width[idx], df.img_height[idx])
    mask = blank_mask
    mask = cv2.resize(mask, shape, interpolation=cv2.INTER_AREA).astype('float32').transpose(2,1,0).reshape(1, 3, shape[0], shape[1])

    blank_img = np.zeros((shape[0], shape[1], 3))
    blank_img[:, :, 0] = img
    blank_img[:, :, 1] = img
    blank_img[:, :, 2] = img
    img = blank_img.transpose(2,1,0).reshape((1, 3, shape[0], shape[1]))

    return img, mask


def predict(idx, model, to_numpy=True, log=True):
    img, mask = read(idx)

    model.eval()

    mask = torch.tensor(mask, device=device,dtype=torch.float)
    img = torch.tensor(img, device=device,dtype=torch.float)
    pred = torch.sigmoid(model(img))

    pred[pred < 0.5] = 0
    pred[pred > 0.5] = 1
    pred[pred == 0.5] = 1

    if log:
        dl = monai.losses.DiceLoss()(mask, pred)
        print((1-dl.cpu().detach().numpy()))


    if to_numpy:
        pred = pred.cpu().detach().numpy()
        img = img.cpu().detach().numpy()
        mask = mask.cpu().detach().numpy()

    return img, mask, pred

In [None]:
def FGSM_attack(model, img, mask, eps=0.007, loss=monai.losses.DiceFocalLoss(sigmoid=True)):

    if isinstance(img, np.ndarray):
        img = torch.tensor(img, device=device, dtype=torch.float)

    if isinstance(mask, np.ndarray):
        mask = torch.tensor(mask, device=device, dtype=torch.float)

    img.requires_grad = True
    output = (model(img))

    loss = loss(mask, output)

    model.zero_grad()
    loss.backward()

    data_grad = img.grad.data

    sign_data_grad = data_grad.sign()
    perturbed_image = (1-eps) * img + eps * sign_data_grad

    return torch.clamp(perturbed_image, 0, 1)

In [None]:
# VGG13
activation = None
model_vgg = smp.Unet(
    encoder_weights=None,
    encoder_name='vgg13',
    decoder_use_batchnorm=True,
    activation=activation,
    in_channels=3,
    classes=3,
)
model_vgg = model_vgg.to(device)
model_vgg.load_state_dict(torch.load('../input/unetvgg/unet_vgg13_12.15epochs_lr3e4.pt'))
print(1)

# UNet ResNeXt101
activation = None
model_res = smp.Unet(
    encoder_weights=None,
    encoder_name='resnext101_32x8d',
    decoder_use_batchnorm=True,
    activation=activation,
    in_channels=3,
    classes=3,
)
model_res = model_res.to(device)
model_res.load_state_dict(torch.load('../input/unet-resnext101-focaldice-1215epochs-lr3e4pt/unet_resnext101_focaldice_12.15epochs_lr3e4.pt'))
print(2)

# UNet EFFB7
activation = None
model_eff = smp.Unet(
    encoder_weights=None,
    encoder_name='efficientnet-b7',
    decoder_use_batchnorm=True,
    activation=activation,
    in_channels=3,
    classes=3,
)
model_eff = model_eff.to(device)
model_eff.load_state_dict(torch.load('../input/new-unet-effb7/unet_effb7_NEW11.15.pt'))
print(3)

# OLD UNet2p EFFB7
activation = None
model_2p = smp.UnetPlusPlus(
    encoder_weights=None,
    encoder_name='efficientnet-b7',
    decoder_use_batchnorm=True,
    activation=activation,
    in_channels=3,
    classes=3,
)
model_2p = model_2p.to(device)
model_2p.load_state_dict(torch.load('../input/unet2p-effb7-focaldice-1315epochs-lr3e4pt/unet2p_effb7_focaldice_13.15epochs_lr3e4.pt'))
print('finished')

In [None]:
model = model_2p

In [None]:
total_p = 0
for a in model.parameters():
    if a.requires_grad:
#         print('asd')
#     total_p += len(a.flatten())
        total_p += a.numel()
print(total_p)

In [None]:
img, mask, pred = predict(62, model=model_2p)

imshow(img * 0.6 + pred * 0.4, show_axis=False)
# imshow(, show_axis=False)
imshow(pred, show_axis=False)

In [None]:
mask = torch.tensor(mask, device=device,dtype=torch.float)
img = torch.tensor(img, device=device,dtype=torch.float)
pred = torch.tensor(pred, device=device,dtype=torch.float)

In [None]:
def gradcam(c, img, mask, model):
    if isinstance(img, np.ndarray):
        img = torch.tensor(img, device=device,dtype=torch.float)

    if isinstance(mask, np.ndarray):
        mask = torch.tensor(mask, device=device,dtype=torch.float)

    if len(mask.shape) == 4:
        mask = mask[0]

#     if isinstance(pred, np.ndarray):
#         pred = torch.tensor(pred, device=device,dtype=torch.float)

    # list(model.decoder.blocks[4].children())
    class SemanticSegmentationTarget:
        def __init__(self, category, mask):
            self.category = category
            self.mask = mask

        def __call__(self, model_output):
            return (model_output[self.category, :, : ] * self.mask).sum()

    targets = [SemanticSegmentationTarget(c, mask)]

    # list(list(model.decoder.blocks.children())[0].children())[-2]
    choices = [
#         [list(model.decoder.blocks.children())[0]]
        [list(model.decoder.blocks.children())[-1]]
#         list(list(model_2p.decoder.blocks.children())[-1])
    ]

    cam = GradCAM(model=model,
                 target_layers=choices[0],
                 use_cuda=torch.cuda.is_available())
    grayscale_cam = cam(input_tensor=img, targets=targets)

    cam_3c = np.uint8(np.array([
        grayscale_cam[0].T,
        grayscale_cam[0].T,
        grayscale_cam[0].T
    ]).transpose(1,2,0) * 255)

    jet_heatmap = cv2.cvtColor(cv2.applyColorMap((cam_3c), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB) / 255

    return jet_heatmap

In [None]:
imshow(0.5 * gradcam(0, poisoned_img, poisoned_pred, model_eff) + 0.5 * img[0].transpose(2,1,0), show_axis=False)

In [None]:
def show_diff_v1(idx, model):
    img, mask, pred = predict(idx, model=model)

    img = img[0].transpose(2,1,0)
    mask = mask[0].transpose(2,1,0)
    pred = pred[0].transpose(2,1,0)


    # False Positive
    r1 = np.sum(pred - mask, axis=(2))

    r1[r1 > 0.5] = 1
    r1[r1 == 0.5] = 1
    r1[r1 < 0.5] = 0

    img[:,:,2] += r1
    img[:,:,2][img[:,:,2] > 1] = 1

    # False Negative
    r2 = np.sum(mask - pred, axis=(2))

    r2[r2 > 0.5] = 1
    r2[r2 == 0.5] = 1
    r2[r2 < 0.5] = 0

    img[:,:,0] += r2
    img[:,:,0][img[:,:,0] > 1] = 1

    t = np.sum(mask * pred, axis=(2))
    img[:,:,1] += t
    img[:,:,1][img[:,:,1] > 1] = 1

    return r1, r2, img

# Below is the version 2 of show_diff
def show_diff(img, mask, pred):

    # False Positive
    r1 = np.sum(pred - mask, axis=(2))

    r1[r1 > 0.5] = 1
    r1[r1 == 0.5] = 1
    r1[r1 < 0.5] = 0

    img[:,:,2] += r1
    img[:,:,2][img[:,:,2] > 1] = 1

    # False Negative
    r2 = np.sum(mask - pred, axis=(2))

    r2[r2 > 0.5] = 1
    r2[r2 == 0.5] = 1
    r2[r2 < 0.5] = 0

    img[:,:,0] += r2
    img[:,:,0][img[:,:,0] > 1] = 1

    t = np.sum(mask * pred, axis=(2))
    img[:,:,1] += t
    img[:,:,1][img[:,:,1] > 1] = 1

    return r1, r2, img

In [None]:
# Test the one above
imshow(show_diff(img[0].transpose(2,1,0), pred[0].transpose(2,1,0), poisoned_pred[0].cpu().detach().numpy().transpose(2,1,0))[2], show_axis=False)

In [None]:
model = model_2p
img, mask, pred = predict(86, model=model)

imshow(img * 0.6 + pred * 0.4, show_axis=False)
# imshow(, show_axis=False)
imshow(pred, show_axis=False)

poisoned_img = FGSM_attack(model=model, img=img, mask=pred, eps=0.009, loss=monai.losses.FocalLoss(), show_pert=True)
poisoned_pred = torch.sigmoid(model_2p(poisoned_img))

poisoned_pred[poisoned_pred > 0.5] = 1
poisoned_pred[poisoned_pred == 0.5] = 1
poisoned_pred[poisoned_pred < 0.5] = 0

imshow(poisoned_img, show_axis=False)
imshow(poisoned_pred, show_axis=False)
print((1 - monai.losses.DiceLoss()(torch.tensor(mask, device=device), poisoned_pred).cpu().detach().numpy()))

In [None]:
gkf = GroupKFold(n_splits=10)
for fold, (train_ind, val_ind) in enumerate(gkf.split(df, df['empty'], groups=df['case'])):
    train_ds = Dataset2D(df.iloc[train_ind], train=True)
    train_ds_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)

    val_ds = Dataset2D(df.iloc[val_ind], train=False)
    val_ds_loader = torch.utils.data.DataLoader(val_ds, batch_size=8)
    break


In [None]:
img, mask, pred = predict(86, model=model_2p)

poisoned_img = FGSM_attack(model=model_2p, img=img, mask=pred, eps=0.009, loss=torch.monai.)
poisoned_pred = torch.sigmoid(model_2p(poisoned_img))
imshow(pred)
imshow(poisoned_img)
imshow(poisoned_pred)

print((1 - monai.losses.DiceLoss()(torch.tensor(mask, device=device), poisoned_pred).cpu().detach().numpy()))

In [None]:
poisoned_val = 0
poisoned_val_counts = 0
model.eval()

for a in tqdm(val_ind):
#     img, mask, pred = predict(a, model=model_vgg, to_numpy=False, log=False)
    img, mask = read(a)
    if mask.max() != 0:
#         mask = torch.tensor(mask, device=device, dtype=torch.float)

#         pred = torch.sigmoid(model(torch.tensor(img, device=device, dtype=torch.float)))
#         pred[pred > 0.5] = 1
#         pred[pred == 0.5] = 1
#         pred[pred < 0.5] = 0


#         dl = monai.losses.DiceLoss()(mask, pred)
#         poisoned_val += (1-dl.cpu().detach().numpy())
#         poisoned_val_counts += 1

#         -- --- --- , loss=torch.nn.BCEWithLogitsLoss()
        poisoned_img = FGSM_attack(model=model, img=img, mask=mask, eps=0.009)
        poisoned_img[poisoned_img > 1] = 1

        pred = torch.sigmoid(model(poisoned_img))
        mask = torch.tensor(mask, device=device, dtype=torch.float, )

        pred[pred > 0.5] = 1
        pred[pred == 0.5] = 1
        pred[pred < 0.5] = 0

        dl = monai.losses.DiceLoss()(mask, pred)
        poisoned_val += (1-dl.cpu().detach().numpy())
        poisoned_val_counts += 1
#     else:
#         pred = torch.sigmoid(model(torch.tensor(img, device=device, dtype=torch.float)))
#         pred[pred > 0.5] = 1
#         pred[pred == 0.5] = 1
#         pred[pred < 0.5] = 0
#         if torch.sum(pred) != 0:
#             print(a)
# #             imshow(pred)
