In [None]:
import torch
from PIL import Image
import numpy as np
from matplotlib.colors import LinearSegmentedColormap
import timm
from torchvision import transforms
import random
from torch import nn
import os
from matplotlib import pyplot as plt

In [None]:
%pip install captum

In [None]:
from captum.attr import IntegratedGradients
from captum.attr import visualization as viz

### Loading models

In [None]:
class GenericBackbone(nn.Module):
    def __init__(self, backbone_name, pretrained, model_path="", num_classes=-1):
        super().__init__()
        self.backbone = timm.create_model(backbone_name, pretrained=(pretrained and len(model_path) == 0), num_classes=num_classes)

        if model_path is not None and len(model_path) != 0:
            print(f"loading checkpoint from path {model_path}")
            self.load_state_dict(torch.load(model_path))

    def forward(self, x):
        x = self.backbone(x)
        return x

In [None]:
imagenet_model = GenericBackbone("mobilevit_xxs", True)
imagenet_model = imagenet_model.eval()

### Visualisation of the mobilevit_xxs 
trained on the k0 split using weakly supervised learning

In [None]:
# Please add the path of the model
# It can be found in the Parameters of a run in mlflow model.model.model_path
# you can create a model with : "python train.py --config-name=wsl -m +experiment=wsl/mobilevit_s"
# in our case we took the run named "allvideo_wsl_mobilevit_xxs_k0randomcrop_blur_colorjitter_adamw" in test_k0
wsl_path = "mlruns/xp/run_name/checkpoints/backbone_path.pth"
wsl_model = GenericBackbone("mobilevit_xxs", False, model_path=wsl_path)
wsl_model = wsl_model.eval()

In [None]:
class ModelWrapper(nn.Module):
    def __init__(self, model):
        super(ModelWrapper, self).__init__()
        self.model = model

    def forward(self, input_data):
        # obtain the embeddings
        embeddings = self.model(input_data).flatten(1)

        # Calculate an objective/loss using the embeddings
        objective = torch.mean(embeddings, 1)
        return objective

# Creating feature ablation method (Integrated Gradients) using the model wrapper
model_wrapper_wsl = ModelWrapper(wsl_model)
model_wrapper_wsl.eval()
feature_ablation_wsl = IntegratedGradients(model_wrapper_wsl)

model_wrapper_imagenet = ModelWrapper(imagenet_model)
model_wrapper_imagenet.eval()
feature_ablation_imagenet = IntegratedGradients(model_wrapper_imagenet)

In [None]:
transform = transforms.Compose([
#  transforms.Resize((256, 256)),
 transforms.Resize((224, 224)),
#  transforms.CenterCrop(224),
 transforms.ToTensor()
])

transform_normalize = transforms.Normalize(
     mean=[0.485, 0.456, 0.406],
     std=[0.229, 0.224, 0.225]
)

def get_attribution(path, feature_ablation, n_steps=200):

    img = Image.open(path)

    transformed_img = transform(img)

    input = transform_normalize(transformed_img)
    input = input.unsqueeze(0)

    # Perform feature ablation without specifying a target
    attributions = feature_ablation.attribute(input, n_steps=n_steps)
    return attributions, transformed_img

default_cmap = LinearSegmentedColormap.from_list('custom blue', 
                                                 [(0, '#ffffff'),
                                                  (0.25, '#000000'),
                                                  (1, '#000000')], N=224)

def visu_attr(path, feature_ablation, n_steps=20, cmap=False, use_pyplot=True):
    attributions, transformed_img = get_attribution(path, feature_ablation, n_steps)

    return viz.visualize_image_attr_multiple(np.transpose(attributions.squeeze().cpu().detach().numpy(), (1,2,0)),
                                        np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
                                        ["original_image", "heat_map"],
                                        ["all", "positive"],
                                        cmap=default_cmap if cmap else None,
                                        show_colorbar=True,
                                        use_pyplot=use_pyplot)
# import ImageOps
def get_attribution_mirror(path, feature_ablation, n_steps=200):

    img = Image.open(path)
    img = img.transpose(Image.FLIP_LEFT_RIGHT)
    transformed_img = transform(img)

    input = transform_normalize(transformed_img)
    input = input.unsqueeze(0)

    # Perform feature ablation without specifying a target
    attributions = feature_ablation.attribute(input, n_steps=n_steps)
    return attributions, transformed_img


def visu_attr_mirror(path, feature_ablation, n_steps=20, cmap=False):
    attributions, transformed_img = get_attribution_mirror(path, feature_ablation, n_steps)

    _ = viz.visualize_image_attr_multiple(np.transpose(attributions.squeeze().cpu().detach().numpy(), (1,2,0)),
                                        np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
                                        ["original_image", "heat_map"],
                                        ["all", "positive"],
                                        cmap=default_cmap if cmap else None,
                                        show_colorbar=True)
    

In [None]:
def visu_attr_wsl(path, n_steps=20, cmap=False, use_pyplot=True):
    return visu_attr(path, feature_ablation_wsl, n_steps, cmap, use_pyplot)

def visu_attr_imagenet(path, n_steps=20, cmap=False, use_pyplot=True):
    return visu_attr(path, feature_ablation_imagenet, n_steps, cmap, use_pyplot=use_pyplot)

## midv holo fake
with glare

In [None]:
print("before training")
visu_attr_imagenet("../../data/midv-holo/crop_ovds/fraud/photo_holo_copy/ID/id06_05_01/img_0042.jpg", 20)
visu_attr_imagenet("../../data/midv-holo/crop_ovds/fraud/photo_holo_copy/ID/id06_05_01/img_0021.jpg", 20)
visu_attr_imagenet("../../data/midv-holo/crop_ovds/fraud/copy_without_holo/ID/id06_05_01/img_0021.jpg", 20)

In [None]:
print("after training")
visu_attr_wsl("../../data/midv-holo/crop_ovds/fraud/photo_holo_copy/ID/id06_05_01/img_0042.jpg", 20)
visu_attr_wsl("../../data/midv-holo/crop_ovds/fraud/photo_holo_copy/ID/id06_05_01/img_0021.jpg", 20)
visu_attr_wsl("../../data/midv-holo/crop_ovds/fraud/copy_without_holo/ID/id06_05_01/img_0021.jpg", 20)

In [None]:
print("wsl")
visu_attr_wsl("../../data/midv-holo/crop_ovds/origins/ID/id10_03_02/img_0038.jpg", 20)
visu_attr_wsl("../../data/midv-holo/crop_ovds/origins/ID/id10_03_02/img_0033.jpg", 20)
visu_attr_wsl("../../data/midv-holo/crop_ovds/origins/ID/id10_03_03/img_0037.jpg", 20)
visu_attr_wsl("../../data/midv-holo/crop_ovds/origins/ID/id10_03_03/img_0011.jpg", 20)

In [None]:
to_export = [
"../../data/midv-holo/crop_ovds/origins/ID/id10_03_02/img_0038.jpg",
"../../data/midv-holo/crop_ovds/origins/ID/id10_03_02/img_0033.jpg",
"../../data/midv-holo/crop_ovds/origins/ID/id10_03_03/img_0037.jpg",
"../../data/midv-holo/crop_ovds/origins/ID/id10_03_03/img_0011.jpg",
"../../data/midv-holo/crop_ovds/fraud/photo_holo_copy/ID/id06_05_01/img_0042.jpg",
"../../data/midv-holo/crop_ovds/fraud/photo_holo_copy/ID/id06_05_01/img_0021.jpg",
"../../data/midv-holo/crop_ovds/fraud/copy_without_holo/ID/id06_05_01/img_0021.jpg"]
for p in to_export:
    f, _ = visu_attr_wsl(p, 20, use_pyplot=False)
    f.savefig(f"samples/train/figure/wsl/wsl_{os.path.basename(p)}")
    f, _ = visu_attr_imagenet(p, 20, use_pyplot=False)
    f.savefig(f"samples/train/figure/imagenet/imagenet_{os.path.basename(p)}")
    

## Random selection of Origins sample in train and test set
some images are available in `notebooks/visualisation/samples/train` and `notebooks/visualisation/samples/test`

In [None]:
with open("../../data/splits_kfold_s0/k0/fraud/copy_without_holo/trainval/train_train.txt") as f:
    train_video = f.read().splitlines(False)

with open("../../data/splits_kfold_s0/k0/origins/test.txt") as f:
    test_video = f.read().splitlines(False)

In [None]:
import glob
frames = []
for v in random.sample(train_video, 10):
    p_v = os.path.dirname(v)
    image_p = os.path.join("../../data/midv-holo/crop_ovds/origins", p_v)
    p_glob = os.path.join(image_p, "*.jpg")
    fs = glob.glob(p_glob)
    for f in random.sample(fs, 1):
        frames.append(os.path.join(image_p, f))
        

In [None]:
for i in range(10):
    visu_attr_wsl(frames[i], 20)

In [None]:
for i in range(10):
    f, _ = visu_attr_wsl(frames[i], 20, use_pyplot=False)
    f.savefig(f"samples/train/wsl_{i}.jpg")

## In testset

In [None]:
frames_test = []
for v in random.sample(train_video, 10):
    p_v = os.path.dirname(v)
    image_p = os.path.join("../../data/midv-holo/crop_ovds/origins", p_v)
    p_glob = os.path.join(image_p, "*.jpg")
    fs = glob.glob(p_glob)
    for f in random.sample(fs, 1):
        frames_test.append(os.path.join(image_p, f))

In [None]:
for i in range(10):
    visu_attr_wsl(frames_test[i], 20)

In [None]:
for i in range(10):
    f, _ = visu_attr_wsl(frames_test[i], 20, use_pyplot=False)
    f.savefig(f"samples/test/wsl_{i}.jpg")

### MIDV 2020

In [None]:
visu_attr_wsl("../../data/midv-2020/clips/crop_ovd/alb_id/01/000001.jpg", 20)
visu_attr_wsl("../../data/midv-2020/clips/crop_ovd/aze_passport/01/000001.jpg", 20)
visu_attr_wsl("../../data/midv-2020/clips/crop_ovd/fin_id/01/000001.jpg", 20)

## Video of attribution maps

In [None]:
%mkdir activation_map
%mkdir activation_map_fake

In [None]:
import glob
from tqdm import tqdm
n_steps = 20
base_path = "../../data/midv-holo/crop_ovds/origins/ID/id03_04_02/"
paths = glob.glob(f"{base_path}*.jpg")
# print(paths)
paths = list(sorted(paths, key=lambda p: int(p[p.rfind("img_")+5:p.rfind(".")])))

for i, path in tqdm(enumerate(paths), total=len(paths)):
    attributions, transformed_img = get_attribution(path, feature_ablation_wsl, n_steps)

    fig, a = viz.visualize_image_attr_multiple(np.transpose(attributions.squeeze().cpu().detach().numpy(), (1,2,0)),
                                            np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
                                            ["original_image", "heat_map"],
                                            ["all", "positive"],
                                            # cmap=default_cmap,
                                            show_colorbar=True, use_pyplot=False)
    fig.savefig(f"activation_map/{i}.jpg")

In [None]:
base_path = "../../data/midv-holo/crop_ovds/fraud/photo_holo_copy/ID/id03_04_01/"
paths = glob.glob(f"{base_path}*.jpg")
paths = list(sorted(paths, key=lambda p: int(p[p.rfind("img_")+5:p.rfind(".")])))
os
for i, path in tqdm(enumerate(paths), total=len(paths)):
    attributions, transformed_img = get_attribution(path, feature_ablation_wsl, n_steps)

    fig, a = viz.visualize_image_attr_multiple(np.transpose(attributions.squeeze().cpu().detach().numpy(), (1,2,0)),
                                            np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
                                            ["original_image", "heat_map"],
                                            ["all", "positive"],
                                            # cmap=default_cmap,
                                            show_colorbar=True, use_pyplot=False)
    fig.savefig(f"activation_map_fake/{i}.jpg")