# Visualization 

### Goal
- Interpretability of predictions
    - Why the model is making this prediction

### Discussion
- For simplicity is only the modulator attention map, but the we could also get the gatting maps for further analysis
- We could add gradient shap to add interpretability on the tabular features
- With Gradien Shap on the images we could also say what features on an image are core of each class


### Notes


## Imports

In [1]:
import gradio as gr

import skimage
import cv2
from PIL import Image

import pandas as pd
import numpy as np

from src.model import focalnet_tiny_srf
from src.data import IMAGE_SIZE, DATASET_MEAN, DATASET_STD

import torch
import torch.nn as nn
from torchvision import transforms

from src.data import NuweDataset

### Load model

In [2]:
device = torch.device('cuda') 

In [3]:
class NuweModel(nn.Module):
    
    def __init__(self):
        
        super(NuweModel, self).__init__()
        
        self.image_backbone = focalnet_tiny_srf(pretrained=True)
        self.image_backbone.head = nn.Identity()
        self.image_proj = nn.Linear(768, 30)
        
        self.year = nn.Embedding(16, 5)
        
        self.neighbors_proj = nn.Linear(3, 5)
        
        self.head = nn.Sequential(
            nn.Linear(40,40),
            nn.BatchNorm1d(40),
            nn.ReLU(),
            nn.Linear(40,10),
            nn.BatchNorm1d(10),
            nn.ReLU(),
            nn.Linear(10,3)
        )
        
    def forward(self, image, year, neighbors_ctx):
        
        x_img = self.image_backbone(image)
        x_img = self.image_proj(x_img)
        
        x_year = self.year(year)
        x_neighbors_ctx = self.neighbors_proj(neighbors_ctx)
        
        x = torch.cat([x_img, x_year, x_neighbors_ctx], dim=1)
        
        x = self.head(x)
        
        return x

model = NuweModel()
model = model.to(device)
model.load_state_dict(torch.load('artifacts/model2.pt'))
model.eval()

NuweModel(
  (image_backbone): FocalNet(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
      (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0): BasicLayer(
        dim=96, input_resolution=(56, 56), depth=2
        (blocks): ModuleList(
          (0): FocalNetBlock(
            dim=96, input_resolution=(56, 56), mlp_ratio=4.0
            (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
            (modulation): FocalModulation(
              dim=96
              (f): Linear(in_features=96, out_features=195, bias=True)
              (h): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1))
              (act): GELU(approximate='none')
              (proj): Linear(in_features=96, out_features=96, bias=True)
              (proj_drop): Dropout(p=0.0, inplace=False)
              (focal_layers): ModuleList(
                (0): Sequenti

### Data pipeline

In [4]:
DATA_PATH = 'data/raw/'

train = pd.read_csv(DATA_PATH + 'train.csv')
test = pd.read_csv(DATA_PATH +'test.csv')
test['label'] = -1

df = pd.concat([train,test],axis=0).reset_index(drop=True)

full_dataset = NuweDataset(data = df, directory = DATA_PATH)

labels = ['Plantation', 'Grassland/Shrubland', 'Smallholder Agriculture']

LAYER = -1 # Attention layer to print

### Gradio visualization

In [5]:
def get_blend_map(img, att_map, kernel_size=(5,5), colormap=cv2.COLORMAP_JET, a1=0.7, a2=0.3):
    '''
    Plot attetion over the image
    '''
    img = np.asarray(img)
    att_map = att_map - att_map.min()
    att_map = att_map/att_map.max()
    
    att_map = att_map * 255

    att_map = skimage.transform.resize(att_map, (img.shape[:2]), order = 3).astype(np.uint8)

    im_cloud_blur = cv2.GaussianBlur(att_map, kernel_size, 0)
    im_cloud_clr = cv2.applyColorMap(im_cloud_blur, colormap)

    return (a1*img + a2*im_cloud_clr).astype(np.uint8)


def segment(image_number):
    '''
    Get atteention map and model predictions
    '''
    image_number = int(image_number)
    path  = DATA_PATH + df.example_path[image_number]
    img_d = Image.open(path)
    
    # Input to model
    i_image, i_year, i_neightbours_context, i_label = full_dataset.__getitem__(image_number)
    b_image = i_image.unsqueeze(0).to(device)
    b_year = i_year.unsqueeze(0).to(device)
    b_neightbours_context = i_neightbours_context.unsqueeze(0).to(device)
    
    with torch.no_grad():
        out = model(b_image, b_year, b_neightbours_context)
    y_probas = nn.Softmax(dim=1)(out)
    confidences = {labels[idx]: float(i) for idx, i in enumerate(y_probas.detach().cpu().numpy()[0])}
    print('a')
    modulator = torch.abs((model.image_backbone.layers[LAYER].blocks[-1].modulation.modulator)).mean(1, keepdim=True)
    x = modulator.squeeze(1).permute(1, 2, 0).cpu().detach().contiguous().numpy()

    x = get_blend_map(img_d, x)
    
    return img_d, x, confidences


In [9]:
gr.Interface(fn=segment, inputs=[gr.Number()], examples=[8, 13, 148] , outputs=[gr.Image(type="pil"), gr.Image(type="numpy"), gr.Label(num_top_classes=3)]).launch()

Running on local URL:  http://127.0.0.1:7862

To create a public link, set `share=True` in `launch()`.


