# ASOS
## Imports

In [None]:
import os

src_dir = os.path.expanduser(os.environ.get('SRC_DIR', '/home/martin/Projekte/OGC_Testbed-18/software/testbed18-wilderness-workflow'))
os.chdir(src_dir)

import random
import matplotlib.pyplot as plt

from tlib import tlearn, ttorch, tutils
from projects.asos import config, utils

%load_ext autoreload
%autoreload 2

In [None]:
fi = utils.load_file_infos()

## Setup ASOS

In [None]:
# setup asos performer
dims = utils.load_trainer().model.unet.conv_out.out_channels  # number of unet activation maps
use_hypercube = True  # hypercube method (True) or expectation maximization (False)


output_folder = tutils.files.join_paths(utils.load_trainer().log_dir, 'asos')

if use_hypercube:
    if dims == 1:
        asos = tlearn.interpret.asos.ASOSPerformer1d(ax_range=(-1, 1), output_folder=output_folder)
    elif dims == 2:
        asos = tlearn.interpret.asos.ASOSPerformer2d(ax_range=(-1, 1), output_folder=output_folder)
    elif dims == 3:
        asos = tlearn.interpret.asos.ASOSPerformer3d(ax_range=(-1, 1), output_folder=output_folder)

else:
    if dims == 1:
        asos = tlearn.interpret.asos.ASOSPerformerEM1d(ax_range=(-1, 1), output_folder=output_folder)
    elif dims == 2:
        asos = tlearn.interpret.asos.ASOSPerformerEM2d(ax_range=(-1, 1), output_folder=output_folder)
    elif dims == 3:
        asos = tlearn.interpret.asos.ASOSPerformerEM3d(ax_range=(-1, 1), output_folder=output_folder)

asos.save()  # save asos with pickle

In [None]:
# Check content
asos.output_folder

In [None]:
# Check problem with folder
trainer = utils.load_trainer()

print('used:', trainer.datamodule.folder)
# /home/timo/data/anthroprotect/tiles/s2
print('should be used as defined in config:', config.data_folder_tiles)
# -> when the trainer is loaded from checkpoint_5.pt it's datamodule.folder attribute is set to the one used 
#    in training which might be a different one used in later analysis

## Vectorization

In [None]:
# get unet maps
files = fi.df[(fi.df['datasplit'] == 'train') & (fi.df['correct'])].index.to_list()
print(len(files))

# get only a random fraction of unet maps
if config.dataset == 'anthroprotect':
    frac_unet_maps = 0.05
    #frac_unet_maps = 0.25
elif config.dataset == 'places':
    frac_unet_maps = 0.5

random_indices = random.sample(range(0, len(files)), int(len(files) * frac_unet_maps))
files = [files[index] for index in random_indices]
print(len(files))

# Prepend path (see cell above for problem with folder)
files = [os.path.join(config.data_folder_tiles, file) for file in files]

unet_maps = utils.predict_activation_maps(*files)

# vectorize
random_frac = 1/1000
asos.vectorize(maps=unet_maps, map_ids=files, frame_size=10, random_frac=random_frac)
#asos.vectorize(maps=unet_maps, map_ids=files, frame_size=10, random_frac=random_frac, random_seed)
asos.save()  # save asos with pickle

del unet_maps

In [None]:
%matplotlib inline
if asos.dims in [1, 2]:
    asos.plot_chspace()
    plt.show()

In [None]:
%matplotlib widget
if asos.dims == 3:
    asos.plot_chspace(colors='rgb')  # colors=None to not color vectors in rgb
    plt.show()

## Groups

In [None]:
# define groups
if use_hypercube:
    
    if config.dataset in ['anthroprotect', 'places']:
        edge_length = 2/20  # 2/20 2/10 2/2
        consider_factor = 2
    
    asos.fit_groups(edge_length=edge_length, consider_factor=consider_factor)
else:
    asos.fit_groups(n_groups=3)

asos.save()  # save asos with pickle

In [None]:
%matplotlib inline
if asos.dims in [1, 2]:
    asos.plot_chspace(colors='groups')
    plt.show()

In [None]:
%matplotlib widget
if asos.dims == 3:
    asos.plot_chspace(colors='groups')
    plt.show()

## Sensitivities

In [None]:
# we cannot predict all unet-maps as follows at this point, because this would cause a memory overflow for the many training data:
# files = fi[(fi['dataset'] == 'train') & (fi['correct'])].index.to_list()
# unet_maps = utils.predict(*files)
# instead we define an object, that behaves like a list using __getitem__:

class UNetMaps:
    def __init__(self):
        trainer = utils.load_trainer()
        self.dataset = trainer.datamodule.train_dataset
        
        # Fix file paths
        trainer.datamodule.train_dataset.files = [os.path.join(config.data_folder_tiles, os.path.basename(file)) for file in trainer.datamodule.train_dataset.files]
        #print(trainer.datamodule.train_dataset.files[0])
        
        self.dataset.eval()
        self.unet = ttorch.modules.wrapper.AutoMoveData(trainer.model.unet)
    
    def __getitem__(self, index):
        x = self.dataset[index]['x']
        unet_map = self.unet(x.unsqueeze(0)).detach().cpu()[0]
        return unet_map
    
    def __len__(self):
        return len(self.dataset)

unet_maps = UNetMaps()

print(unet_maps.dataset.files[0])
# Subset of maps
unet_maps.dataset.files = unet_maps.dataset.files[0:100]

In [None]:
unet_maps[0]

In [None]:
import numpy as np
plt.imshow(((np.array(unet_maps[0])+1)/2).transpose(1, 2, 0))

In [None]:
%matplotlib inline

# get model
model = utils.load_trainer().model.classify_unet_map

# Note: this step takes  ~6 hours
# fit sensitivities
#asos.fit_sensitivities(maps=unet_maps, model=model, fill_value=0, move_data_to_gpu=True)
asos.fit_sensitivities(maps=unet_maps, model=model, fill_value=0, move_data_to_gpu=True)
asos.save()  # save asos with pickle

In [None]:
# adapt valid deviations

if config.dataset == 'anthroprotect':
    min_n_occluded_pixels = 10
    q = 0.02
elif config.dataset == 'places':
    min_n_occluded_pixels = 10
    q = 0.001
    
# only those deviations are taken for further calculations that were calculated from at least min_n_occluded_pixels when occluding a map
asos.adapt_valid_deviations(min_n_occluded_pixels=min_n_occluded_pixels)

asos.set_vlim(q=q)
asos.save()  # save asos with pickle

In [None]:
%matplotlib inline
asos.plot_histograms()

In [None]:
%matplotlib inline
if asos.dims in [1, 2]:
    asos.plot_chspace(colors='sensitivities')
    plt.show()

In [None]:
%matplotlib widget
if asos.dims == 3:
    asos.plot_chspace(colors='sensitivities')
    plt.show()

In [None]:
%matplotlib inline
# plot sample
index = 99
sensitivity_map = asos.predict_sensitivities(unet_maps[index].unsqueeze(0))[0]
asos.plot_sensitivity_map(sensitivity_map)
plt.show()