# MARU-Net: Multi-Scale Attention Gated Residual U-Net With Contrastive Loss for SAR-Optical Image Matching

In [2]:
import SarOptMatch
import tensorflow as tf
import random
import numpy as np

# Assign seed
seed = 42
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)

### Generate dataset
We first generate the dataset:

In [3]:
# Generate dataset
sar_files, opt_files, offsets = SarOptMatch.dataset.sen1_2(data_path = 'C:/Users/miche/Downloads/SEN1-2', seed = seed, ims_per_folder = 1)

# Process and split dataset
training_data, validation_data, validation_dataRGB = SarOptMatch.dataset.split_data(sar_files, opt_files, offsets, batch_size = 4, seed = seed, masking_strategy = "unet")

Time for sen1_2 is 4.643621444702148
Time for split_data is 1.163482427597046


### Model
We first instantiate the `SAR_opt_Matcher()` class

In [4]:
matcher = SarOptMatch.architectures.SAR_opt_Matcher()
matcher.print_attributes()

<SAR_opt_Matcher> instantiated

---Printing class attributes:---
backbone = None
n_filters = 0
multiscale = None
attention = None
activation = None
model = None




If we want to train a model we specify the behaviour of the model through the *config* dictionary

In [6]:
config = {'model_name' : "marunet_vanilla",
                  'backbone' : 'marunet',
                  'n_filters' : 32,
                  'multiscale' : True,
                  'attention' : True, 
                  'activation' : "elu"
                  }
 
matcher.create_model(**config) 
matcher.train(training_data, validation_data, epochs = 5)

if we want to load an existing model we call the `load_model()` method. We can select an existing model to be loaded. 

In [7]:
matcher.load_model()
matcher.print_attributes()

--Loading

---Printing class attributes:---
backbone = marunet
n_filters = 32
multiscale = True
attention = True
activation = elu
model = <keras.engine.functional.Functional object at 0x000001DAC819C700>
model_name = C:/Users/miche/Dropbox/PhD/Projects/GridEyeS/SAR_optical_matching/Code/weights/marunet_vanilla.h5




### Inference

We use the model to generate the heatmaps and the feature maps

In [8]:
# Generate heatmaps
heatmaps = matcher.predict_heatmap(validation_data)

# Generate feature maps
feature_maps = matcher.calculate_features(validation_data)

--Calculating heatmaps
--Calculating feature maps
Features: psi_opt_o, psi_SAR_o, psi_opt_d,  psi_SAR_d


## Visualization

We can easily visualize all the outputs through a unified GUI (it will open an exteranl window).

In [9]:
SarOptMatch.visualization.visualize_dataset_with_GUI(validation_dataRGB, heatmaps, feature_maps)

![img](https://github.com/MicheleGazzea/Sar_opt_matching/blob/main/imgs/GUI_example.png)