# **Adversarial-Training-Bracs**

In [None]:
!git clone https://github.com/hila-chefer/Transformer-Explainability.git

import os
os.chdir(f'./Transformer-Explainability')

In [None]:
!pip install einops
!pip install torchattacks

In [None]:
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch
import numpy as np
import cv2
import random
import tqdm
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torch.utils.data import DataLoader
from torchattacks.attacks.apgd import APGD 
from torchvision.utils import save_image
from PIL import Image
from IPython.display import display
from torchattacks import PGD
from torchvision.utils import save_image
from baselines.ViT.ViT_LRP import deit_base_patch16_224 as vit_LRP
#from baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP
from baselines.ViT.ViT_explanation_generator import LRP
from baselines.ViT.ViT_LRP import VisionTransformer
from baselines.ViT.helpers import load_pretrained

In [None]:
#@title Imagenet class indices to names
%%capture
CLS2IDX = {0: '0_N',
 1: '1_PB',
 2: '2_UDH',
 3: '3_ADH',
 4: '4_FEA',
 5: '5_DCIS',
 6: '6_IC'}

In [None]:
def safe_divide(a, b):
    den = b.clamp(min=1e-9) + b.clamp(max=1e-9)
    den = den + den.eq(0).type(den.type()) * 1e-9
    return a / den * b.ne(0).type(b.type())
    
def forward_hook(self, input, output):
    if type(input[0]) in (list, tuple):
        self.X = []
        for i in input[0]:
            x = i.detach()
            x.requires_grad = True
            self.X.append(x)
    else:
        self.X = input[0].detach()
        self.X.requires_grad = True

    self.Y = output

class RelProp(nn.Module):
    def __init__(self):
        super(RelProp, self).__init__()
        # if not self.training:
        self.register_forward_hook(forward_hook)

    def gradprop(self, Z, X, S):
        C = torch.autograd.grad(Z, X, S, retain_graph=True)
        return C

    def relprop(self, R, alpha):
        return R

class IndexSelect(RelProp):
    def forward(self, inputs, dim, indices):
        self.__setattr__('dim', dim)
        self.__setattr__('indices', indices)

        return torch.index_select(inputs, dim, indices)

    def relprop(self, R, alpha):
        Z = self.forward(self.X, self.dim, indices)
        S = safe_divide(R, Z)
        C = self.gradprop(Z, self.X, S)

        if torch.is_tensor(self.X) == False:
            outputs = []
            outputs.append(self.X[0] * C[0])
            outputs.append(self.X[1] * C[1])
        else:
            outputs = self.X * (C[0])
        return outputs

In [None]:
#Download dataset
!wget -c -r ftp://histoimage.na.icar.cnr.it/BRACS_RoI/previous_versions/Version1_MedIA/Images
#transform downloaded dataset to 224*224 and save it in google Drive
'''
code
'''

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
#unzip saved images into current drive account (I had to use different accounts due to google colab's limitation. Actulally notebook sharing wasn't working neither :)) )
!unzip -u "/content/drive/MyDrive/Images-20221110T104340Z-001.zip" -d "/content/drive/My Drive/Bracs"

In [None]:
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
])

In [None]:
def prepareDatasets(direction, share=1):
    cnt = 0
    dataset = list()

    for folder in os.listdir(direction):
        print(folder)

        if share != 1:
          randomlist = []
          for i in range(0,int ( share * len(os.listdir(direction + f'/{folder}')))):
            n = random.randint(1,len(os.listdir(direction + f'/{folder}')))
            randomlist.append(n)

        cnt = 0
        for image in os.listdir(direction + f'/{folder}'):

            if share != 1:
              if cnt not in randomlist:
                cnt += 1
                continue

            im= Image.open(os.path.join(direction + f'/{folder}',image))
            try:

              img = transform(im)

              label = torch.tensor(int(folder[0]))
              data = {"image": img, "label": label}
              dataset.append(data)
              cnt += 1
            except (OSError):
              print("shit")

    return dataset

train_dir = "/content/drive/MyDrive/Bracs/Images/Train"
val_dir = "/content/drive/MyDrive/Bracs/Images/Validation"

train_dataset = prepareDatasets(train_dir)
print(np.shape(train_dataset))
#For example, prepareDatasets(train_dir, share=0.3) picks 0.3 of images

validation_dataset = prepareDatasets(val_dir)
print(np.shape(validation_dataset))


1_PB
2_UDH
5_DCIS
3_ADH
0_N
6_IC
4_FEA
(3163,)
4_FEA
0_N
3_ADH
5_DCIS
1_PB
2_UDH
6_IC
(602,)


In [None]:
data = DataLoader(
    train_dataset,
    batch_size=20,
    shuffle=True,
    num_workers=2,
 )


data_val = DataLoader(
    validation_dataset,
    batch_size=20,
    shuffle=False,
    num_workers=2,
 )

In [None]:
model = vit_LRP(pretrained=True).cuda()

model.head.out_features = 7
model.pool = IndexSelect()

In [None]:
def _cfg(url='', **kwargs):
    return {
        'url': url,
        'num_classes': 7, 'input_size': (3, 224, 224), 'pool_size': None,
        'crop_pct': .9, 'interpolation': 'bicubic',
        'first_conv': 'patch_embed.proj', 'classifier': 'head',
        **kwargs
    }


default_cfgs = {
    # patch models
    'vit_small_patch16_224': _cfg(
        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
    ),
    'vit_base_patch16_224': _cfg(
        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
        mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
    ),
    'vit_large_patch16_224': _cfg(
        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
        mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
}

def _conv_filter(state_dict, patch_size=16):
    """ convert patch embedding weight from manual patchify + linear proj to conv"""
    out_dict = {}
    for k, v in state_dict.items():
        if 'patch_embed.proj.weight' in k:
            v = v.reshape((v.shape[0], 3, patch_size, patch_size))
        out_dict[k] = v
    return out_dict

def vit_base_patch16_224_7class(pretrained=False, **kwargs):
    model = VisionTransformer(
        patch_size=16, num_classes=7, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, **kwargs)
    model.default_cfg = default_cfgs['vit_base_patch16_224']
    if pretrained:
        load_pretrained(
            model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
    return model

def vit_large_patch16_224_7class(pretrained=False, **kwargs):
    model = VisionTransformer(
        patch_size=16, num_classes=7, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, **kwargs)
    model.default_cfg = default_cfgs['vit_large_patch16_224']
    if pretrained:
        load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
    return model

def deit_base_patch16_224_7class(pretrained=False, **kwargs):
    model = VisionTransformer(
        patch_size=16, num_classes=7, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model

In [None]:
#model = deit_base_patch16_224_7class(pretrained=False).cuda()
#model = vit_base_patch16_224_7class(pretrained=False).cuda()
#model = torch.load('/content/drive/MyDrive/Models_saves/model5.pth')

In [None]:
model

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)

attack = APGD(model, eps=(6/255), norm='L2', steps = 10)

print(attack.eps)

In [None]:
#This cell is for saving noises and is unnecessary 
with tqdm.tqdm(enumerate(data), total=len(data)) as pbar:
    ad_images = torch.Tensor(list())
    cnt = 0
    for i, x in pbar:
      
      cnt+=1
      image = x["image"]
      label = x["label"]
      image = image.to(device)
      label = label.to(device)

      ad_images = attack(image, label)

      result = torch.cat((image, ad_images))
      result = torch.cat((result, ad_images - image))
      save_image(result, fp = f'/content/out{cnt}.png' ,nrow = len(label), scale_each=True, normalize=True)


In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=4e-3, momentum=0.8)

In [None]:
def train_val(mode, adv, losses, epoch):

    #torch.cuda.empty_cache()
    loss_sum = 0
    true = 0
    all = 0
    class0_true = 0
    class0_all = 0
    class1_true = 0
    class1_all = 0
    class2_true = 0
    class2_all = 0
    class3_true = 0
    class3_all = 0
    class4_true = 0
    class4_all = 0
    class5_true =  0
    class5_all = 0
    class6_true = 0
    class6_all = 0

    i = 0
    with tqdm.tqdm(enumerate(data), total=len(data)) as pbar:
        ad_images = torch.Tensor(list())

        for i, x in pbar:
            
            image = x["image"]
            label = x["label"]
            image = image.to(device)
            label = label.to(device)

            model.eval()

            if adv and ((mode == 'val') or (mode == 'train' and epoch > 1)):
              ad_images = attack(image, label)
            
            if mode == 'train':
              model.train()

            if adv and ((mode == 'val') or (mode == 'train' and epoch > 1)):
              p = model(ad_images)
            else:
              p = model(image)

            loss = criterion(p[:, :7], label)
            loss_sum += float(loss)*20

            predictions = p[:, :7].argmax(-1)
            all += float(len(predictions))
            class0_all += float((label == 0).sum())
            class1_all += float((label == 1).sum())
            class2_all += float((label == 2).sum())
            class3_all += float((label == 3).sum())
            class4_all += float((label == 4).sum())
            class5_all += float((label == 5).sum())
            class6_all += float((label == 6).sum())
            true += float((predictions == label).sum())

            class0_pred = (predictions == 0)
            class0_label = (label == 0)
            class0_true += float((torch.logical_and(class0_pred,class0_label)).sum())

            class1_pred = (predictions == 1)
            class1_label = (label == 1)
            class1_true += float((torch.logical_and(class1_pred,class1_label)).sum())

            class2_pred = (predictions == 2)
            class2_label = (label == 2)
            class2_true += float((torch.logical_and(class2_pred,class2_label)).sum())

            class3_pred = (predictions == 3)
            class3_label = (label == 3)
            class3_true += float((torch.logical_and(class3_pred,class3_label)).sum())


            class4_pred = (predictions == 4)
            class4_label = (label == 4)
            class4_true += float((torch.logical_and(class4_pred,class4_label)).sum())

            class5_pred = (predictions == 5)
            class5_label = (label == 5)
            class5_true += float((torch.logical_and(class5_pred,class5_label)).sum())


            class6_pred = (predictions == 6)
            class6_label = (label == 6)
            class6_true += float((torch.logical_and(class6_pred,class6_label)).sum())

            del image
            del label

            pbar.set_description(f'Acc: {true * 100. / all:.2f}%')
            
            if mode == 'train':
              loss.backward()
              optimizer.step()
              optimizer.zero_grad()

        if (class0_all > 0): print(f'Acc class0: {class0_true * 100. / class0_all:.2f}%')
        if (class1_all > 0): print(f'Acc class1: {class1_true * 100. / class1_all:.2f}%')
        if (class2_all > 0): print(f'Acc class2: {class2_true * 100. / class2_all:.2f}%')
        if (class3_all > 0): print(f'Acc class3: {class3_true * 100. / class3_all:.2f}%')
        if (class4_all > 0): print(f'Acc class4: {class4_true * 100. / class4_all:.2f}%')
        if (class5_all > 0): print(f'Acc class5: {class5_true * 100. / class5_all:.2f}%')
        if (class6_all > 0): print(f'Acc class6: {class6_true * 100. / class6_all:.2f}%')

    losses.append(loss_sum)
    if mode == 'train':
      torch.save(model.state_dict(), f'/content/drive/MyDrive/Models_saves/model_state_dict{epoch}.pth')
    
    state = 'train' if mode == 'train' else ('adversarial validation' if adv else 'validation')
    print(f'{state} : Epoch {epoch} \n')

    return losses

In [None]:
num_epochs=15
train_losses = []
val_losses = []
adv_val_losses = []

for epoch in range(num_epochs):
    train_losses = train_val('train', True, train_losses, epoch)
    val_losses = train_val('val', False, val_losses, epoch)
    adv_val_losses = train_val('val', True, adv_val_losses, epoch)

In [None]:
#Plotting losses 
epochs = np.arange(15)

fig, (ax1, ax2, ax3) = plt.subplots(3, figsize=(20, 30))
ax1.set_title("Training loss")
ax1.plot(epochs, train_losses)

ax2.set_title("Normal Validation loss")
ax2.plot(epochs, val_losses)

ax3.set_title("Adversarial Validation loss")
ax3.plot(epochs, adv_val_losses)

In [None]:
#model.load_state_dict(torch.load('/content/drive/MyDrive/Models_saves/model_state_dict2.pth'))
#torch.save(model.state_dict(), '/content/drive/MyDrive/Models_saves/model_state_dict2.pth')

In [None]:
# create heatmap from mask on image
def show_cam_on_image(img, mask):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return cam

attribution_generator = LRP(model)

def generate_visualization(original_image, class_index=None):
    transformer_attribution = attribution_generator.generate_LRP(original_image.unsqueeze(0).cuda(), method="transformer_attribution", index=class_index).detach()
    transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
    transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear')
    transformer_attribution = transformer_attribution.reshape(224, 224).cuda().data.cpu().numpy()
    transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())
    image_transformer_attribution = original_image.permute(1, 2, 0).data.cpu().numpy()
    image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (image_transformer_attribution.max() - image_transformer_attribution.min())
    vis = show_cam_on_image(image_transformer_attribution, transformer_attribution)
    vis =  np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
    return vis

def print_top_classes(predictions, **kwargs):    
    # Print Top-5 predictions
    prob = torch.softmax(predictions, dim=1)
    class_indices = predictions.data.topk(5, dim=1)[1][0].tolist()
    max_str_len = 0
    class_names = []
    for cls_idx in class_indices:
        class_names.append(CLS2IDX[cls_idx])
        if len(CLS2IDX[cls_idx]) > max_str_len:
            max_str_len = len(CLS2IDX[cls_idx])
    
    print('Top 5 classes:')
    for cls_idx in class_indices:
        output_string = '\t{} : {}'.format(cls_idx, CLS2IDX[cls_idx])
        output_string += ' ' * (max_str_len - len(CLS2IDX[cls_idx])) + '\t\t'
        output_string += 'value = {:.3f}\t prob = {:.1f}%'.format(predictions[0, cls_idx], 100 * prob[0, cls_idx])
        print(output_string)

In [None]:
def drawHeatmap(direction, index):
        
    for image in os.listdir(direction):

        image= Image.open(os.path.join(direction,image))

        transformed_image = transform(image)

        fig, axs = plt.subplots(1, 3)
        axs[0].imshow(image)
        axs[0].axis('off')

        output = model(transformed_image.unsqueeze(0).cuda())

        # generate visualization for the predicted class
        predicted = generate_visualization(transformed_image)

        # generate visualization for expected(index) class
        expected = generate_visualization(transformed_image, class_index=index)

        axs[1].imshow(predicted)
        axs[1].axis('off')
        axs[2].imshow(expected)
        axs[2].axis('off')

In [None]:
val_dir0 = "/content/drive/MyDrive/Bracs/Images/Validation/0_N"
val_dir1 = "/content/drive/MyDrive/Bracs/Images/Validation/1_PB"
val_dir2 = "/content/drive/MyDrive/Bracs/Images/Validation/2_UDH"
val_dir3 = "/content/drive/MyDrive/Bracs/Images/Validation/3_ADH"
val_dir4 = "/content/drive/MyDrive/Bracs/Images/Validation/4_FEA"
val_dir5 = "/content/drive/MyDrive/Bracs/Images/Validation/5_DCIS"
val_dir6 = "/content/drive/MyDrive/Bracs/Images/Validation/6_IC"

In [None]:
drawHeatmap(val_dir0, 0)

In [None]:
drawHeatmap(val_dir1, 1)

In [None]:
drawHeatmap(val_dir2, 2)

In [None]:
drawHeatmap(val_dir3, 3)

In [None]:
drawHeatmap(val_dir4, 4)

In [None]:
drawHeatmap(val_dir5, 5)

In [None]:
drawHeatmap(val_dir6, 6)