# 1. Vision Transformer Architecture

In [195]:
import warnings
warnings.filterwarnings('ignore')

import copy
import datetime
import h5py
import keras
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import sys

from cv2 import resize
from datetime import datetime
from gc import collect
from os import cpu_count
from scipy.io import savemat, loadmat
from sklearn.model_selection import train_test_split
from time import sleep
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchinfo import summary
from tqdm import tqdm

sys.path.append(f"{os.getcwd()}/ViT architecture/working ViT")
sys.path.append(f"{os.getcwd()}/scripts/")
from VisionTransformer_working import VisionTransformer as Vit_old

sys.path.append(f"{os.getcwd()}/ViT architecture/Architecture tryouts/DPT/")
from VisionTransformer_working_for_DPT import VisionTransformer2 as Vit_2

In [196]:
random_seed = 1
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
cudnn.benchmark = True

In [197]:
collect()
torch.cuda.empty_cache()
device = torch.device('cpu')
print(f"Running on device: {device}")

Running on device: cpu


### 1. Comparison of model performances on in-distribution data

Load CNN and ViT model.

In [198]:
vit = Vit_2(dspl_size=104, 
              patch_size=8, 
              embed_dim=128,
              depth=12,
              n_heads=8,
              mlp_ratio=4.,
              p=0.,
              attn_p=0.,
              drop_path=0.).float()
path_to_vit_new = '/home/alexrichard/PycharmProjects/UQ_DL-TFM/ViT-TFM/ViT architecture/Architecture tryouts/logs_and_weights/ViT_with_shifted_patch-2023-Jan-18 09:54:37/ViT_with_shifted_patch-2023-Jan-18 09:54:37.pth'
if torch.cuda.is_available():
    vit.load_state_dict(torch.load(path_to_vit_new)['best_model_weights'], strict=True)
else:
    vit.load_state_dict(torch.load(path_to_vit_new, map_location=torch.device('cpu'))['best_model_weights'], strict=True)        

In [199]:
cnn = keras.models.load_model('/home/alexrichard/PycharmProjects/UQ_DL-TFM/mltfm/model.h5')

Load test data and corrupt it with Gaussian noise. This mimicks the training conditions for both models.

In [200]:
dspl_test = np.array(h5py.File('/home/alexrichard/PycharmProjects/UQ_DL-TFM/ViT-TFM/data/Test data/resolution_104/allDisplacements.h5', 'r')['dspl'])
trac_test = np.array(h5py.File('/home/alexrichard/PycharmProjects/UQ_DL-TFM/ViT-TFM/data/Test data/resolution_104/allTractions.h5', 'r')['trac'])
dspl_test = np.concatenate([dspl_test[i] for i in range(dspl_test.shape[0])], axis=0, dtype=np.float32)
trac_test = np.concatenate([trac_test[i] for i in range(trac_test.shape[0])], axis=0, dtype=np.float32)
dspl_test = dspl_test[:100]
Y_test = torch.from_numpy(np.moveaxis(trac_test[:100], 3, 1))

In [201]:
cov = [[1e-04 ** 2, 0], [0, 1e-04 ** 2]]
X_train_noise = np.random.multivariate_normal(np.array([0, 0]), cov, (dspl_test.shape[0], dspl_test.shape[1], dspl_test.shape[2]))
dspl_test_prime = dspl_test + X_train_noise
X_test = torch.from_numpy(np.moveaxis(dspl_test_prime, 3, 1)).float()
X_train_noise = torch.from_numpy(np.moveaxis(X_train_noise, 3, 1)).float()

Export the data to .mat files.

In [26]:
save_files_here = '/home/alexrichard/LRZ Sync+Share/ML in Physics/Repos/Easy-to-use_TFM_package-master/test_data/Artificial patch data/Raw Samples'
xx, yy = np.meshgrid(np.arange(104), np.arange(104), indexing='ij')
pos = np.vstack([xx.reshape(-1), yy.reshape(-1)], dtype=np.double).T
path_to_dir = f'{save_files_here}/Gaussian_noise_1e-04'
os.makedirs(path_to_dir)
for j, sample in enumerate(X_test):
    file_name = f'{path_to_dir}/test_sample_{j + 1}.mat'
    vec_dspl = np.vstack([sample[0].reshape(-1), sample[1].reshape(-1)], dtype=np.double).T
    vec_noise = np.vstack([X_train_noise[j, 0].reshape(-1), X_train_noise[j, 1].reshape(-1)], dtype=np.double).T
    mdict = {'input_data': {'noise': [{'vec': vec_noise, 'pos': pos}, {'vec': vec_noise, 'pos': pos}], 'displacement': [{'vec': vec_dspl, 'pos': pos}, {'vec': vec_dspl, 'pos': pos}]}}
    savemat(file_name, mdict=mdict)

Load predictions of the BFFTC method.

In [202]:
bfftc_predictions = torch.zeros((100, 2, 102, 102))
bfftc_displacements = torch.zeros((100, 2, 102, 102))

directory = "/home/alexrichard/LRZ Sync+Share/ML in Physics/Repos/Easy-to-use_TFM_package-master/test_data/Artificial patch data/Predictions/Gaussian_noise_1e-04"
    
for j, file in enumerate(os.listdir(f'{directory}')):
    filename = os.fsdecode(file)
    if filename.endswith(".mat"):
        bfft_pred = loadmat(f'{directory}/pred_of_sample_{j + 1}')['TFM_results']['traction'][0][0].T.reshape((2, 102, 102), order='F')
        bfft_dspl = loadmat(f'{directory}/pred_of_sample_{j + 1}')['TFM_results']['displacement'][0][0].T.reshape((2, 102, 102), order='F')
        bfftc_predictions[j] = torch.tensor(bfft_pred)
        bfftc_displacements[j] = torch.tensor(bfft_dspl)

Allocate test samples to ground truths as they got mixed up when BFFTC was performed.

In [204]:
bfftc_predictions_trimmed = torch.zeros((100, 2, 98, 98))
Y_test_trimmed = torch.zeros((100, 3, 98, 98))
X_test_trimmed = torch.zeros((100, 2, 98, 98))

for i, sample in enumerate(bfftc_displacements):
    for j, dspl in enumerate(X_test):
        if torch.allclose(dspl[:, 1:103, 1:103].float(), torch.tensor(sample).float(), atol=1e-02, rtol=1):
            print(f'bfftc dspl {i} matches dspl {j}')
            bfftc_predictions_trimmed[i] = torch.tensor(bfftc_predictions[i, :, 3:101, 3:101]).float()
            Y_test_trimmed[i] = Y_test[j, :, 3:101, 3:101].float()
            X_test_trimmed[i] = dspl[:, 3:101, 3:101].float()

bfftc dspl 0 matches dspl 0
bfftc dspl 1 matches dspl 16
bfftc dspl 2 matches dspl 98
bfftc dspl 3 matches dspl 17
bfftc dspl 4 matches dspl 18
bfftc dspl 5 matches dspl 1
bfftc dspl 6 matches dspl 19
bfftc dspl 7 matches dspl 20
bfftc dspl 8 matches dspl 21
bfftc dspl 9 matches dspl 22
bfftc dspl 10 matches dspl 23
bfftc dspl 11 matches dspl 24
bfftc dspl 12 matches dspl 9
bfftc dspl 13 matches dspl 25
bfftc dspl 14 matches dspl 26
bfftc dspl 15 matches dspl 27
bfftc dspl 16 matches dspl 28
bfftc dspl 17 matches dspl 2
bfftc dspl 18 matches dspl 29
bfftc dspl 19 matches dspl 30
bfftc dspl 20 matches dspl 31
bfftc dspl 21 matches dspl 32
bfftc dspl 22 matches dspl 33
bfftc dspl 23 matches dspl 99
bfftc dspl 24 matches dspl 34
bfftc dspl 25 matches dspl 35
bfftc dspl 26 matches dspl 36
bfftc dspl 27 matches dspl 37
bfftc dspl 28 matches dspl 38
bfftc dspl 29 matches dspl 3
bfftc dspl 30 matches dspl 39
bfftc dspl 31 matches dspl 40
bfftc dspl 32 matches dspl 41
bfftc dspl 33 matches dsp

Calculate CNN and ViT prediction on test set.

In [157]:
vit.eval()
vit_predictions = vit(X_test, device=device)
cnn_predictions = cnn.predict(np.moveaxis(np.array(X_test), 1, 3))



Calculate losses.

In [158]:
mse = torch.nn.MSELoss(reduction='none')
vit_mse = torch.mean(mse(vit_predictions[:, :, 3:101, 3:101], Y_test[:, 0:2, 3:101, 3:101]), (1,2,3))
cnn_mse = torch.mean(mse(torch.tensor(np.moveaxis(cnn_predictions, 3, 1)[:, :, 3:101, 3:101]), Y_test[:, 0:2, 3:101, 3:101]), (1,2,3))
bfftc_mse = torch.mean(mse(bfftc_predictions[:, :, 3:101, 3:101], Y_test[:, 0:2, 3:101, 3:101]), (1,2,3))

In [194]:
bfftc_mse

tensor([0.0021, 0.0043, 0.0067, 0.0057, 0.0050, 0.0041, 0.0046, 0.0052, 0.0066,
        0.0041, 0.0043, 0.0039, 0.0065, 0.0066, 0.0059, 0.0038, 0.0055, 0.0044,
        0.0046, 0.0058, 0.0055, 0.0046, 0.0035, 0.0033, 0.0081, 0.0062, 0.0089,
        0.0049, 0.0054, 0.0061, 0.0061, 0.0058, 0.0044, 0.0033, 0.0063, 0.0054,
        0.0078, 0.0048, 0.0056, 0.0044, 0.0070, 0.0049, 0.0028, 0.0045, 0.0047,
        0.0038, 0.0050, 0.0040, 0.0059, 0.0041, 0.0054, 0.0041, 0.0054, 0.0031,
        0.0067, 0.0047, 0.0076, 0.0046, 0.0054, 0.0071, 0.0043, 0.0037, 0.0031,
        0.0068, 0.0046, 0.0045, 0.0022, 0.0071, 0.0057, 0.0044, 0.0057, 0.0048,
        0.0056, 0.0038, 0.0058, 0.0049, 0.0052, 0.0045, 0.0044, 0.0007, 0.0007,
        0.0006, 0.0013, 0.0012, 0.0008, 0.0010, 0.0014, 0.0014, 0.0016, 0.0041,
        0.0050, 0.0055, 0.0051, 0.0051, 0.0058, 0.0041, 0.0060, 0.0046, 0.0056,
        0.0035])

Visualize ground truth and predictions for first test sample.

In [180]:
get_ipython().run_line_magic('matplotlib', 'notebook')

fig, axs = plt.subplots(2, 2, figsize=(8, 8))
fig.tight_layout(pad=2, w_pad=2, h_pad=2)

axs[0, 0].quiver(Y_test_trimmed[0, 0, :, :].detach().numpy(), Y_test_trimmed[0, 1, :, :].detach().numpy(), scale=6)
axs[0, 0].set_title('Ground truth', {'fontsize': 11})

axs[0, 1].quiver(vit_predictions[0, 0, 3:101, 3:101].detach().numpy(), vit_predictions[0, 1, 3:101, 3:101].detach().numpy(), scale=1)
axs[0, 1].set_title(f'ViT prediction, loss: {vit_mse[0]:9.6f}', {'fontsize': 11})

axs[1, 0].quiver(cnn_predictions[0, 3:101, 3:101, 0], cnn_predictions[0, 3:101, 3:101, 1], scale=1)
axs[1, 0].set_title(f'CNN prediction, loss: {cnn_mse[0]:9.6f}', {'fontsize': 11})

axs[1, 1].quiver(bfftc_predictions_trimmed[0, 0, 3:101, 3:101].detach().numpy(), bfftc_predictions_trimmed[0, 1, 3:101, 3:101].detach().numpy(), scale=1)
axs[1, 1].set_title(f'BFFTC prediction, loss: {bfftc_mse[0]:9.6f}', {'fontsize': 11})

<IPython.core.display.Javascript object>

Text(0.5, 1.0, 'BFFTC prediction, loss:  0.002130')

Compute metrics.

In [184]:
vit_predictions.shape

torch.Size([100, 2, 104, 104])

In [210]:
from MultiTask import append_predictions_and_targets, cosine_sim, adtma, dtma, dma, snr, dtmb

appended_vit_predictions, appended_vit_targets = append_predictions_and_targets(vit_predictions[:, :, 3:101, 3:101], Y_test[:, :, 3:101, 3:101], device)
dtma_vit = dtma(appended_vit_predictions, appended_vit_targets, device, True)
dtmb_vit = dtmb(vit_predictions[:, :, 3:101, 3:101], Y_test[:, :, 3:101, 3:101], appended_vit_predictions, appended_vit_targets, device=device, per_sample=True)
snr_vit = snr(vit_predictions[:, :, 3:101, 3:101], Y_test[:, :, 3:101, 3:101], appended_vit_predictions, appended_vit_targets, device=device, per_sample=True)
dma_vit = dma(vit_predictions[:, :, 3:101, 3:101], Y_test[:, :, 3:101, 3:101], appended_vit_predictions, appended_vit_targets, device=device, per_sample=True)
adtma_vit = adtma(appended_vit_predictions, appended_vit_targets, device, True)
cosine_sim_vit = cosine_sim(appended_vit_predictions, appended_vit_targets, device, True)

In [211]:
appended_cnn_predictions, appended_cnn_targets = append_predictions_and_targets(torch.tensor(np.moveaxis(cnn_predictions, 3, 1))[:, :, 3:101, 3:101], Y_test[:, :, 3:101, 3:101], device)
dtma_cnn = dtma(appended_cnn_predictions, appended_cnn_targets, device, True)
dtmb_cnn = dtmb(torch.tensor(np.moveaxis(cnn_predictions, 3, 1))[:, :, 3:101, 3:101], Y_test[:, :, 3:101, 3:101], appended_cnn_predictions, appended_cnn_targets, device=device, per_sample=True)
snr_cnn = snr(torch.tensor(np.moveaxis(cnn_predictions, 3, 1))[:, :, 3:101, 3:101], Y_test[:, :, 3:101, 3:101], appended_cnn_predictions, appended_cnn_targets, device=device, per_sample=True)
dma_cnn = dma(torch.tensor(np.moveaxis(cnn_predictions, 3, 1))[:, :, 3:101, 3:101], Y_test[:, :, 3:101, 3:101], appended_cnn_predictions, appended_cnn_targets, device=device, per_sample=True)
adtma_cnn = adtma(appended_cnn_predictions, appended_cnn_targets, device, True)
cosine_sim_cnn = cosine_sim(appended_cnn_predictions, appended_cnn_targets, device, True)

In [212]:
appended_bfftc_predictions, appended_bfftc_targets = append_predictions_and_targets(bfftc_predictions_trimmed, Y_test[:, :, 3:101, 3:101], device)
dtma_bfftc = dtma(appended_bfftc_predictions, appended_bfftc_targets, device, True)
dtmb_bfftc = dtmb(bfftc_predictions_trimmed, Y_test[:, :, 3:101, 3:101], appended_bfftc_predictions, appended_bfftc_targets, device=device, per_sample=True)
snr_bfftc = snr(bfftc_predictions_trimmed, Y_test[:, :, 3:101, 3:101], appended_bfftc_predictions, appended_bfftc_targets, device=device, per_sample=True)
dma_bfftc = dma(bfftc_predictions_trimmed, Y_test[:, :, 3:101, 3:101], appended_bfftc_predictions, appended_bfftc_targets, device=device, per_sample=True)
adtma_bfftc = adtma(appended_bfftc_predictions, appended_bfftc_targets, device, True)
cosine_sim_bfftc = cosine_sim(appended_bfftc_predictions, appended_bfftc_targets, device, True)

In [209]:
'''
dtma_vit = torch.reshape(dtma_vit, (Y_test_trimmed.shape[0], 1))
dtma_cnn = torch.reshape(dtma_cnn, (Y_test_trimmed.shape[0], 1))
dtma_bfftc = torch.reshape(dtma_bfftc, (Y_test_trimmed.shape[0], 1))

cosine_sim_vit = torch.reshape(cosine_sim_vit, (Y_test_trimmed.shape[0], 1))
cosine_sim_cnn = torch.reshape(cosine_sim_cnn, (Y_test_trimmed.shape[0], 1))
cosine_sim_bfftc = torch.reshape(cosine_sim_bfftc, (Y_test_trimmed.shape[0], 1))

dtmb_vit_new = torch.reshape(dtmb_vit, (Y_test_trimmed.shape[0], 1))
dtmb_cnn = torch.reshape(dtmb_cnn, (Y_test_trimmed.shape[0], 1))
dtmb_bfftc = torch.reshape(dtmb_bfftc, (Y_test_trimmed.shape[0], 1))

snr_vit_new = torch.reshape(snr_vit, (Y_test_trimmed.shape[0], 1))
snr_cnn = torch.reshape(snr_cnn, (Y_test_trimmed.shape[0], 1))
snr_bfftc = torch.reshape(snr_bfftc, (Y_test_trimmed.shape[0], 1))

dma_vit_new = torch.reshape(dma_vit, (Y_test_trimmed.shape[0], 1))
dma_cnn = torch.reshape(dma_cnn, (Y_test_trimmed.shape[0], 1))
dma_bfftc = torch.reshape(dma_bfftc, (Y_test_trimmed.shape[0], 1))

adtma_vit_new = torch.reshape(adtma_vit, (Y_test_trimmed.shape[0], 1))
adtma_cnn = torch.reshape(adtma_cnn, (Y_test_trimmed.shape[0], 1))
adtma_bfftc = torch.reshape(adtma_bfftc, (Y_test_trimmed.shape[0], 1))
'''

In [220]:
mses = torch.stack((torch.mean(vit_mse), torch.mean(cnn_mse), torch.mean(bfftc_mse)))
dtmas = torch.stack((torch.mean(dtma_vit), torch.mean(dtma_cnn), torch.mean(dtma_bfftc)))
dtmbs = torch.stack((torch.mean(dtmb_vit), torch.mean(dtmb_cnn), torch.mean(dtmb_bfftc)))
snrs = torch.stack((torch.mean(snr_vit), torch.mean(snr_cnn), torch.mean(snr_bfftc)))
dmas = torch.stack((torch.mean(dma_vit), torch.mean(dma_cnn), torch.mean(dma_bfftc)))
adtmas = torch.stack((torch.mean(adtma_vit), torch.mean(adtma_cnn), torch.mean(adtma_bfftc)))
cos_sims = torch.stack((torch.mean(cosine_sim_vit), torch.mean(cosine_sim_cnn), torch.mean(cosine_sim_bfftc)))

In [253]:
columns = ['MSE', 
           'DTMA', 
           'Cosine sim', 
           'DTMB',
           'SNR',
           'DMA',
           'ADTMA']

metrics = pd.DataFrame(
        np.array(torch.stack((mses, 
                        dtmas,
                        cos_sims,
                        dtmbs, 
                        snrs,
                        dmas,
                        adtmas, 
                       ), 1).detach().numpy()), columns=columns)

float_formatter = "{:.6f}".format

averages = {}
for column in columns:
    averages[column] = f"avg: {float_formatter(round(metrics[column].mean(), ndigits=7))}"
    
averages = pd.DataFrame(averages, index=[0])

In [254]:
def highlight_closes_to_zero(s, props=''):
    return np.where(np.absolute(s) == np.nanmin(np.absolute(s.values)), props, '')

def highlight_max(x, props=''):
    return np.where(x == np.nanmax(x.to_numpy()), props, '')

In [256]:
metrics = metrics.rename(index={0: 'ViT', 1: 'CNN', 2: 'BFFTC'})
slice_ = ['MSE', 'DTMA', 'DTMB', 'DMA', 'ADTMA']
metrics = metrics.style.apply(highlight_closes_to_zero, props='color:white; background-color:purple', subset=slice_, axis=0)
slice_ = ['SNR', 'Cosine sim']
metrics = metrics.apply(highlight_max, props='color:white; background-color:purple', subset=slice_, axis=0)
metrics = metrics.set_table_styles([dict(selector='th', props=[('text-align', 'center')])])
metrics.set_properties(**{'text-align': 'center'})

metrics.set_properties(**{"border": "0.5px solid black"})

Unnamed: 0,MSE,DTMA,Cosine sim,DTMB,SNR,DMA,ADTMA
ViT,6e-05,-0.095683,0.943508,0.004416,49.706181,0.06545,0.096325
CNN,7.5e-05,-0.205903,0.914052,0.007989,47.141068,-0.075841,0.206614
BFFTC,0.004674,-0.361674,0.094725,0.142906,5.02633,0.306055,0.641559


### Comparison of models on out-of-distribution data

Load CNN and ViT models.

In [276]:
if torch.cuda.is_available():
    vit_new.load_state_dict(torch.load(path_to_vit_new)['best_model_weights'], strict=True)
else:
    vit_new.load_state_dict(torch.load(path_to_vit_new, map_location=torch.device('cpu'))['best_model_weights'], strict=True)        

In [277]:
vit_2_noises = Vit_2(dspl_size=104, 
              patch_size=8, 
              embed_dim=128,
              depth=12,
              n_heads=8,
              mlp_ratio=4.,
              p=0.,
              attn_p=0.,
              drop_path=0.).float()
path_to_vit_2_noises = '/home/alexrichard/PycharmProjects/UQ_DL-TFM/ViT-TFM/ViT architecture/Architecture tryouts/logs_and_weights/ViT_with_shifted_patch-2023-Jan-18 09:54:37/ViT_with_shifted_patch-2023-Jan-18 09:54:37.pth'
if torch.cuda.is_available():
    vit_2_noises.load_state_dict(torch.load(path_to_vit_2_noises)['best_model_weights'], strict=True)
else:
    vit_2_noises.load_state_dict(torch.load(path_to_vit_2_noises, map_location=torch.device('cpu'))['best_model_weights'], strict=True)        

In [278]:
cnn = keras.models.load_model('/home/alexrichard/PycharmProjects/UQ_DL-TFM/mltfm/models/model_noise_1e-4.h5')

Load test set with clean samples.

In [279]:
dspl_test = np.array(h5py.File('/home/alexrichard/PycharmProjects/UQ_DL-TFM/ViT-TFM/data/Test data/resolution_104/allDisplacements.h5', 'r')['dspl'])
trac_test = np.array(h5py.File('/home/alexrichard/PycharmProjects/UQ_DL-TFM/ViT-TFM/data/Test data/resolution_104/allTractions.h5', 'r')['trac'])
dspl_test = np.concatenate([dspl_test[i] for i in range(dspl_test.shape[0])], axis=0, dtype=np.float32)
trac_test = np.concatenate([trac_test[i] for i in range(trac_test.shape[0])], axis=0, dtype=np.float32)
#dspl_test = np.moveaxis(dspl_test, 3, 1)
dspl_test = dspl_test[:100]
Y_test = torch.from_numpy(np.moveaxis(trac_test[:100], 3, 1))
#X_test = torch.from_numpy(dspl_test).float()

Corrupt test set with ten different levels of Gaussian noise.

In [280]:
std_dspl = np.std(dspl_test, axis=(1,2,3))
test_sets = {}
noise_sets = {}
for i in range(1, 11):
    test_set = np.zeros((dspl_test.shape))
    noise_set = np.zeros((dspl_test.shape))
    for j in range(len(dspl_test)):
        sigma = (i/100) * std_dspl[j]
        cov = [[sigma**2,0],[0,sigma**2]]
        noise = np.random.multivariate_normal(np.array([0,0]), cov, (104, 104))
        test_set[j] = dspl_test[j] + noise
        noise_set[j] = noise
    test_sets[f'{i}'] = np.moveaxis(test_set, 3, 1)
    noise_sets[f'{i}'] = np.moveaxis(noise_set, 3, 1)

Visualize test sample with different noise floors.

In [281]:
get_ipython().run_line_magic('matplotlib', 'notebook')

fig, axs = plt.subplots(2, 3, figsize=(10, 8))
fig.tight_layout(pad=2, w_pad=2, h_pad=2)

axs[0, 0].quiver(test_sets['1'][0, 0, :, :], test_sets['1'][0, 1, :, :], scale=1)
axs[0, 0].set_title('First test sample: 1% noise', {'fontsize': 11})

axs[0, 1].quiver(test_sets['3'][0, 0, :, :], test_sets['3'][0, 1, :, :], scale=1)
axs[0, 1].set_title('First test sample: 3% noise', {'fontsize': 11})

axs[0, 2].quiver(test_sets['5'][0, 0, :, :], test_sets['5'][0, 1, :, :], scale=1)
axs[0, 2].set_title(f'First test sample: 5% noise', {'fontsize': 11})

axs[1, 0].quiver(test_sets['7'][0, 0, :, :], test_sets['7'][0, 1, :, :], scale=1)
axs[1, 0].set_title(f'First test sample: 7% noise)', {'fontsize': 11})

axs[1, 1].quiver(test_sets['9'][0, 0, :, :], test_sets['9'][0, 1, :, :], scale=1)
axs[1, 1].set_title(f'First test sample: 9% noise)', {'fontsize': 11})

axs[1, 2].quiver(test_sets['10'][0, 0, :, :], test_sets['10'][0, 1, :, :], scale=1)
axs[1, 2].set_title(f'First test sample: 10% noise)', {'fontsize': 11})

<IPython.core.display.Javascript object>

Text(0.5, 1.0, 'First test sample: 10% noise)')

For comparison: the test corrupted with the same noise level as during training.

In [271]:
cov = [[1e-04 ** 2, 0], [0, 1e-04 ** 2]]
X_train_noise = np.transpose(np.random.multivariate_normal(np.array([0, 0]), cov, (dspl_test.shape[0], dspl_test.shape[2], dspl_test.shape[3])), (0, 3, 2, 1))
dspl_test_prime = dspl_test + dspl_test

In [17]:
for i in range(1,11):
    print(f'{(i/100) * std_dspl[3]}')

6.046496331691742e-05
0.00012092992663383484
0.00018139488995075226
0.00024185985326766969
0.00030232481658458714
0.0003627897799015045
0.000423254743218422
0.00048371970653533937
0.0005441846698522567
0.0006046496331691743


In [18]:
get_ipython().run_line_magic('matplotlib', 'notebook')

fig, axs = plt.subplots(1, 1, figsize=(4, 4))
fig.tight_layout(pad=2, w_pad=2, h_pad=2)

axs.quiver(dspl_test_prime[0, :, :, 0], dspl_test_prime[0, :, :, 1], scale=1)
axs.set_title('First test sample: 1% noise', {'fontsize': 11})

<IPython.core.display.Javascript object>

Text(0.5, 1.0, 'First test sample: 1% noise')

Predictions by ViT and CNN for test sample.

In [282]:
vit_new.eval()
vit_2_noises.eval()
vit_predictions = {}
vit_2_noises_predictions = {}
cnn_predictions = {}
for num, test_set in test_sets.items():
    if int(num)%2 == 0:
        #vit_predictions[num] = vit_new(torch.tensor(test_set).float(), device=device)
        vit_2_noises_predictions[num] = vit_2_noises(torch.tensor(test_set).float(), device=device)
        cnn_predictions[num] = cnn.predict(np.moveaxis(np.array(test_set), 1, 3))



In [283]:
vit_mse = {}
cnn_mse = {}
noise_levels = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10']
mse = torch.nn.MSELoss(reduction='none')
for noise_level in noise_levels:
    if int(noise_level) % 2 == 0:
        vit_mse[noise_level] = torch.mean(mse(vit_2_noises_predictions[noise_level][:, :, 3:101, 3:101], Y_test[:, 0:2, 3:101, 3:101]), (1,2,3))
        cnn_mse[noise_level] = torch.mean(mse(torch.tensor(np.moveaxis(cnn_predictions[noise_level], 3, 1)[:, :, 3:101, 3:101]), Y_test[:, 0:2, 3:101, 3:101]), (1,2,3))

In [284]:
get_ipython().run_line_magic('matplotlib', 'notebook')

fig, axs = plt.subplots(2, 3, figsize=(10, 8))
fig.tight_layout(pad=2, w_pad=2, h_pad=2)

axs[0, 0].quiver(vit_2_noises_predictions['2'][1, 0, :, :].detach().numpy(), vit_2_noises_predictions['2'][1, 1, :, :].detach().numpy(), scale=1)
axs[0, 0].set_title(f'ViT Pred: 1% noise (loss: {vit_mse["2"][0]:9.6f})', {'fontsize': 11})

axs[0, 1].quiver(vit_2_noises_predictions['4'][1, 0, :, :].detach().numpy(), vit_2_noises_predictions['4'][1, 1, :, :].detach().numpy(), scale=1)
axs[0, 1].set_title(f'ViT Pred: 5% noise (loss: {vit_mse["4"][0]:9.6f})', {'fontsize': 11})

axs[0, 2].quiver(vit_2_noises_predictions['10'][1, 0, :, :].detach().numpy(), vit_2_noises_predictions['10'][1, 1, :, :].detach().numpy(), scale=1)
axs[0, 2].set_title(f'ViT Pred: 10% noise) (loss: {vit_mse["10"][0]:9.6f}', {'fontsize': 11})

axs[1, 0].quiver(cnn_predictions['2'][1, :, :, 0], cnn_predictions['2'][1, :, :, 1], scale=1)
axs[1, 0].set_title(f'CNN Pred: 1% noise (loss: {cnn_mse["2"][0]:9.6f})', {'fontsize': 11})

axs[1, 1].quiver(cnn_predictions['4'][1, :, :, 0], cnn_predictions['4'][1, :, :, 1], scale=1)
axs[1, 1].set_title(f'CNN Pred: 5% noise (loss: {cnn_mse["4"][0]:9.6f})', {'fontsize': 11})

axs[1, 2].quiver(cnn_predictions['10'][1, :, :, 0], cnn_predictions['10'][1, :, :, 1], scale=1)
axs[1, 2].set_title(f'CNN Pred: 10% noise (loss: {cnn_mse["10"][0]:9.6f})', {'fontsize': 11})

<IPython.core.display.Javascript object>

Text(0.5, 1.0, 'CNN Pred: 10% noise (loss:  0.007227)')

In [None]:
'''
cov = [[1e-04 ** 2, 0], [0, 1e-04 ** 2]]
X_test_noise = np.transpose(np.random.multivariate_normal(np.array([0, 0]), cov, (X_test.shape[0], X_test.shape[2], X_test.shape[3])), (0, 3, 2, 1))
X_test_noisy = X_test + X_test_noise
X_test_noisy = X_test_noisy.float()
Y_test = torch.from_numpy(trac_test).float()
'''

Export test sets to mat files.

In [None]:
save_files_here = '/home/alexrichard/LRZ Sync+Share/ML in Physics/Repos/Easy-to-use_TFM_package-master/test_data/Artificial patch data/Raw Samples'
xx, yy = np.meshgrid(np.arange(104), np.arange(104), indexing='ij')
pos = np.vstack([xx.reshape(-1), yy.reshape(-1)], dtype=np.double).T
for i in range(1, 11):
    path_to_dir = f'{save_files_here}/{i}'
    os.makedirs(path_to_dir)
    for j, sample in enumerate(test_sets[f'{i}']):
        file_name = f'{path_to_dir}/test_sample_{j + 1}.mat'
        vec_dspl = np.vstack([sample[0].reshape(-1), sample[1].reshape(-1)], dtype=np.double).T
        vec_noise = np.vstack([noise_sets[f'{i}'][j, 0].reshape(-1), noise_sets[f'{i}'][j, 1].reshape(-1)], dtype=np.double).T
        mdict = {'input_data': {'noise': [{'vec': vec_noise, 'pos': pos}, {'vec': vec_noise, 'pos': pos}], 'displacement': [{'vec': vec_dspl, 'pos': pos}, {'vec': vec_dspl, 'pos': pos}]}}
        savemat(file_name, mdict=mdict)

Load BFFTC predictions for each test set.

In [None]:
bfftc_prediction_sets = {}
bfftc_displacement_sets = {}

directory = "/home/alexrichard/LRZ Sync+Share/ML in Physics/Repos/Easy-to-use_TFM_package-master/test_data/Artificial patch data/Predictions"
    
for i, _ in enumerate(os.listdir(directory), 1):
    bfftc_prediction = np.zeros((Y_test.shape[0], 2, 102, 102))
    bfftc_displacement = np.zeros((Y_test.shape[0], 2, 102, 102))
    for j, file in enumerate(os.listdir(f'{directory}/{i}')):
        filename = os.fsdecode(file)
        if filename.endswith(".mat"):
            bfft_pred = loadmat(f'{directory}/{i}/{filename}')['TFM_results']['traction'][0][0].T.reshape((2, 102, 102), order='F')
            bfft_dspl = loadmat(f'{directory}/{i}/{filename}')['TFM_results']['displacement'][0][0].T.reshape((2, 102, 102), order='F')
            bfftc_prediction[j] = bfft_pred
            bfftc_displacement[j] = bfft_dspl
    bfftc_prediction_sets[f'{i}'] = bfftc_prediction
    bfftc_displacement_sets[f'{i}'] = bfftc_displacement

Find correct samples in test set and remove outer rims.

In [None]:
'''
def allcocate_and_trim(bfftc_prediction_sets, num_processes):
    bfftc_prediction_sets_trimmed = {}
    ground_truth_sets_trimmed = {}
    noisy_X_test_sets_trimmed = {}
    pool = mp.Pool(processes=num_processes)
    for num, pred in bfftc_prediction_sets.items():
        pool.apply_async(one_alloc, args=(num, pred))
    pool.close()
    pool.join()
    
def one_alloc(num, pred):
    bfftc_prediction_set_trimmed = torch.zeros((test_sets[num].shape[0], test_sets[num].shape[1], 98, 98))
    ground_truths_trimmed = torch.zeros((test_sets[num].shape[0], 3, 98, 98))
    X_test_noisy = torch.zeros(test_sets[num].shape)
    for i, sample in tqdm(enumerate(bfftc_displacement_sets[num])):
        for j, dspl in enumerate(torch.tensor(test_sets[num])):
            if torch.allclose(dspl[:, 1:103, 1:103].float(), torch.tensor(sample).float(), atol=1e-02, rtol=1):
                # print(f'Set {num}: sample {i} matches with dspl {j}')
                bfftc_prediction_set_trimmed[i] = torch.tensor(bfftc_prediction_sets[num][i, :, 3:101, 3:101]).float()
                ground_truths_trimmed[i] = Y_test[j, :, 3:101, 3:101].float()
                X_test_noisy[i] = dspl.float()
    bfftc_prediction_sets_trimmed[num] = bfftc_prediction_set_trimmed
    ground_truth_sets_trimmed[num] = ground_truths_trimmed
    noisy_X_test_sets_trimmed[num] = X_test_noisy 
    
'''

In [None]:
bfftc_prediction_sets_trimmed = {}
ground_truth_sets_trimmed = {}
noisy_X_test_sets_trimmed = {}

for num, pred in tqdm(bfftc_prediction_sets.items()):
    bfftc_prediction_set_trimmed = torch.zeros((test_sets[num].shape[0], test_sets[num].shape[1], 98, 98))
    ground_truths_trimmed = torch.zeros((test_sets[num].shape[0], 3, 98, 98))
    X_test_noisy = torch.zeros(test_sets[num].shape)
    for i, sample in enumerate(bfftc_displacement_sets[num]):
        for j, dspl in enumerate(torch.tensor(test_sets[num])):
            if torch.allclose(dspl[:, 1:103, 1:103].float(), torch.tensor(sample).float(), atol=1e-02, rtol=1):
                # print(f'Set {num}: sample {i} matches with dspl {j}')
                bfftc_prediction_set_trimmed[i] = torch.tensor(bfftc_prediction_sets[num][i, :, 3:101, 3:101]).float()
                ground_truths_trimmed[i] = Y_test[j, :, 3:101, 3:101].float()
                X_test_noisy[i] = dspl.float()
    bfftc_prediction_sets_trimmed[num] = bfftc_prediction_set_trimmed
    ground_truth_sets_trimmed[num] = ground_truths_trimmed
    noisy_X_test_sets_trimmed[num] = X_test_noisy

Sanity check.

In [None]:
get_ipython().run_line_magic('matplotlib', 'notebook')

fig, axs = plt.subplots(2, 3, figsize=(10, 8))
fig.tight_layout(pad=2, w_pad=2, h_pad=2)

axs[0, 0].quiver(noisy_X_test_sets_trimmed['1'][0, 0, :, :].detach().numpy(), noisy_X_test_sets_trimmed['1'][0, 1, :, :].detach().numpy(), scale=1)
axs[0, 0].set_title(f'First test sample: 1% noise', {'fontsize': 11})

axs[0, 1].quiver(noisy_X_test_sets_trimmed['5'][0, 0, :, :].detach().numpy(), noisy_X_test_sets_trimmed['5'][0, 1, :, :].detach().numpy(), scale=1)
axs[0, 1].set_title(f'First test sample: 5% noise', {'fontsize': 11})

axs[0, 2].quiver(noisy_X_test_sets_trimmed['10'][0, 0, :, :].detach().numpy(), noisy_X_test_sets_trimmed['10'][0, 1, :, :].detach().numpy(), scale=1)
axs[0, 2].set_title(f'First test sample: 10% noise', {'fontsize': 11})

axs[1, 0].quiver(ground_truth_sets_trimmed['1'][0, 0, :, :].detach().numpy(), ground_truth_sets_trimmed['1'][0, 1, :, :].detach().numpy(), scale=1)
axs[1, 0].set_title(f'CNN Pred: 1% noise (loss: {cnn_mse["1"][0]:9.6f})', {'fontsize': 11})

axs[1, 1].quiver(ground_truth_sets_trimmed['5'][0, 0, :, :].detach().numpy(), ground_truth_sets_trimmed['5'][0, 1, :, :], scale=1)
axs[1, 1].set_title(f'CNN Pred: 5% noise (loss: {cnn_mse["5"][0]:9.6f})', {'fontsize': 11})

axs[1, 2].quiver(ground_truth_sets_trimmed['10'][0, 0, :, :].detach().numpy(), ground_truth_sets_trimmed['10'][0, 1, :, :], scale=1)
axs[1, 2].set_title(f'CNN Pred: 10% noise (loss: {cnn_mse["10"][0]:9.6f})', {'fontsize': 11})

In [None]:
'''
bfftc_predictions_trimmed = torch.zeros((X_test_noisy.shape[0], X_test_noisy.shape[1], 98, 98))
ground_truths_trimmed = torch.zeros((Y_test.shape[0], Y_test.shape[1], 98, 98))
X_test_noisy_ = torch.zeros(X_test_noisy.shape)

for i, sample in enumerate(bfftc_displacements):
    for j, dspl in enumerate(X_test_noisy):
        if torch.allclose(dspl[:, 1:103, 1:103].float(), torch.tensor(sample).float(), atol=1e-02, rtol=1):
            print(f'sample {i} matches with dspl {j}')
            bfftc_predictions_trimmed[i] = torch.tensor(bfftc_predictions[i, :, 3:101, 3:101]).float()
            ground_truths_trimmed[i] = Y_test[j, :, 3:101, 3:101].float()
            X_test_noisy_[i] = dspl.float()
'''

In [None]:
vit_new.eval()
vit_predictions = {}
cnn_predictions = {}
for num, test_set in test_sets.items():
    vit_predictions[num] = vit_new(torch.tensor(test_set).float(), device=device)
    cnn_predictions[num] = cnn.predict(np.moveaxis(np.array(test_set), 1, 3))

In [None]:
vit_metrics = {}
cnn_metrics = {}
bfftc_metrics = {}

In [None]:
def compute_mse_for_noise_levels(vit_predictions, cnn_predictions, bfftc_prediction_sets_trimmed, ground_truth_sets_trimmed, noisy_X_test_sets_trimmed):
    vit_mse = {}
    cnn_mse = {}
    #bfftc_mse = {}
    noise_levels = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10']
    mse = torch.nn.MSELoss(reduction='none')
    for noise_level in noise_levels:
        vit_mse[noise_level] = torch.mean(torch.mean(mse(vit_predictions[noise_level][:, :, 3:101, 3:101], Y_test[:, 0:2, 3:101, 3:101]), (1,2,3)))
        cnn_mse[noise_level] = torch.mean(torch.mean(mse(torch.tensor(np.moveaxis(cnn_predictions[noise_level], 3, 1)[:, :, 3:101, 3:101]), Y_test[:, 0:2, 3:101, 3:101]), (1,2,3)))
        #bfftc_mse[noise_level] = torch.mean(torch.mean(mse(bfftc_prediction_sets_trimmed[noise_level], ground_truth_sets_trimmed[noise_level][:, 0:2]), (1,2,3)))
    
    return vit_mse, cnn_mse

In [None]:
vit_metrics['mse'], cnn_metrics['mse'] = compute_mse_for_noise_levels(vit_predictions, cnn_predictions, bfftc_prediction_sets_trimmed, ground_truth_sets_trimmed, noisy_X_test_sets_trimmed)

In [None]:
lists = vit_metrics['mse'].items() # sorted by key, return a list of tuples
x, y_vit = zip(*lists) # unpack a list of pairs into two tuples
x = sorted([int(i) for i in x])
y_vit = [i.item() for i in y_vit]
lists = cnn_metrics['mse'].items() # sorted by key, return a list of tuples
_, y_cnn = zip(*lists) # unpack a list of pairs into two tuples
y_cnn = [i.item() for i in y_cnn]

In [None]:
get_ipython().run_line_magic('matplotlib', 'notebook')
fig, ax = plt.subplots(nrows=1, ncols=1)
plt.plot(x, y_vit, 'o-', color='green', label='Vit')
plt.plot(x, y_cnn, 'o-', color='blue', label='CNN')
plt.legend(loc="upper right")
ax.set_facecolor('gainsboro')
ax.grid(color='white', linestyle='-', linewidth=1)
ax.set_xlabel(r'$\mathit{\tilde\sigma}$ / %')
plt.title('MSE', )

In [None]:
from MultiTask import append_predictions_and_targets, cosine_sim, adtma, dtma, dma, snr, dtmb

appended_vit_predictions = {}
appended_vit_targets = {}
appended_cnn_predictions = {}
appended_cnn_targets = {}
appended_bfftc_predictions = {}
appended_bfftc_targets = {}

noise_levels = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10']
for noise_level in noise_levels:
    appended_vit_predictions[noise_level], appended_vit_targets[noise_level] = append_predictions_and_targets(vit_predictions[noise_level][:, :, 3:101, 3:101], ground_truth_sets_trimmed[noise_level], device)

In [None]:
for noise_level in noise_levels:
    appended_cnn_predictions[noise_level], appended_cnn_targets[noise_level] = append_predictions_and_targets(torch.tensor(np.moveaxis(cnn_predictions[noise_level], 3, 1))[:, :, 3:101, 3:101], ground_truth_sets_trimmed[noise_level], device)

In [None]:
'''
for noise_level in noise_levels:
    appended_bfftc_predictions[noise_level], appended_bfftc_targets[noise_level] = append_predictions_and_targets(bfftc_prediction_sets_trimmed[noise_level], ground_truth_sets_trimmed[noise_level], device)
'''

In [None]:
def compute_dtma_for_noise_levels(appended_vit_predictions, appended_vit_targets, appended_cnn_predictions, appended_cnn_targets, appended_bfftc_predictions, appended_bfftc_targets):
    vit_dtma = {}
    cnn_dtma = {}
    #bfftc_dtma = {}
    noise_levels = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10']
    for noise_level in tqdm(noise_levels):
        vit_dtma[noise_level] = dtma(appended_vit_predictions[noise_level], appended_vit_targets[noise_level], device, False)
        cnn_dtma[noise_level] =  dtma(appended_cnn_predictions[noise_level], appended_cnn_targets[noise_level], device, False)
        #bfftc_dtma[noise_level] =  dtma(appended_bfftc_predictions[noise_level], appended_bfftc_targets[noise_level], device, False)
    
    return vit_dtma, cnn_dtma

In [None]:
vit_metrics['dtma'], cnn_metrics['dtma'] = compute_dtma_for_noise_levels(appended_vit_predictions, appended_vit_targets, appended_cnn_predictions, appended_cnn_targets, appended_bfftc_predictions, appended_bfftc_targets)

In [None]:
lists = vit_metrics['dtma'].items() # sorted by key, return a list of tuples
x1, y_vit = zip(*lists) # unpack a list of pairs into two tuples

In [None]:
x1 = sorted([int(i) for i in x1])
y_vit = [i.item() for i in y_vit]
lists = cnn_metrics['dtma'].items() # sorted by key, return a list of tuples
x2, y_cnn = zip(*lists) # unpack a list of pairs into two tuples
y_cnn = [i.item() for i in y_cnn]
x2 = sorted([int(i) for i in x2])

In [None]:
get_ipython().run_line_magic('matplotlib', 'notebook')
fig, ax = plt.subplots(nrows=1, ncols=1)
plt.plot(x1, y_vit, 'o-', color='green', label='Vit')
plt.plot(x2, y_cnn, 'o-', color='blue', label='CNN')
plt.legend(loc="upper right")
ax.set_facecolor('gainsboro')
ax.grid(color='white', linestyle='-', linewidth=1)
ax.set_xlabel(r'$\mathit{\tilde\sigma}$ / %')
plt.title('DTMA')

In [264]:
get_ipython().run_line_magic('matplotlib', 'notebook')

fig, axs = plt.subplots(2, 3, figsize=(10, 8))
fig.tight_layout(pad=2, w_pad=2, h_pad=2)

axs[0, 0].quiver(vit_predictions['1'][0, 0, :, :].detach().numpy(), vit_predictions['1'][0, 1, :, :].detach().numpy(), scale=1)
axs[0, 0].set_title('Vit prediction: low noise', {'fontsize': 11})

axs[0, 1].quiver(vit_predictions['5'][0, 0, :, :].detach().numpy(), vit_predictions['5'][0, 1, :, :].detach().numpy(), scale=1)
axs[0, 1].set_title('Vit prediction: medium noise', {'fontsize': 11})

axs[0, 2].quiver(vit_predictions['10'][0, 0, :, :].detach().numpy(), vit_predictions['10'][0, 1, :, :].detach().numpy(), scale=1)
axs[0, 2].set_title(f'Vit prediction: high noise', {'fontsize': 11})

axs[1, 0].quiver(cnn_predictions['1'][0, :, :, 0], cnn_predictions['1'][0, :, :, 1], scale=1)
axs[1, 0].set_title(f'CNN prediction: low noise)', {'fontsize': 11})

axs[1, 1].quiver(cnn_predictions['5'][0, :, :, 0], cnn_predictions['5'][0, :, :, 1], scale=1)
axs[1, 1].set_title(f'CNN prediction: medium noise)', {'fontsize': 11})

axs[1, 2].quiver(cnn_predictions['10'][0, :, :, 0], cnn_predictions['10'][0, :, :, 1], scale=1)
axs[1, 2].set_title(f'CNN prediction: high noise)', {'fontsize': 11})

<IPython.core.display.Javascript object>

TypeError: new(): invalid data type 'str'

In [None]:
def compute_dtmb_for_noise_levels(appended_vit_predictions, appended_vit_targets, appended_cnn_predictions, appended_cnn_targets, appended_bfftc_predictions, appended_bfftc_targets):
    vit_dtmb = {}
    cnn_dtmb = {}
    # bfftc_dtmb = {}
    noise_levels = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10']
    for noise_level in tqdm(noise_levels):
        vit_dtmb[noise_level] = dtmb(vit_predictions[noise_level][:, :, 3:101, 3:101], ground_truth_sets_trimmed[noise_level], appended_vit_predictions[noise_level], appended_vit_targets[noise_level], device, False)
        cnn_dtmb[noise_level] =  dtmb(torch.tensor(np.moveaxis(cnn_predictions[noise_level], 3, 1))[:, :, 3:101, 3:101], ground_truth_sets_trimmed[noise_level], appended_cnn_predictions[noise_level], appended_cnn_targets[noise_level], device, False)
        # bfftc_dtmb[noise_level] =  dtmb(appended_bfftc_predictions[noise_level], appended_bfftc_targets[noise_level], device, False)
    
    return vit_dtmb, cnn_dtmb

In [None]:
vit_metrics['dtmb'], cnn_metrics['dtmb'] = compute_dtmb_for_noise_levels(appended_vit_predictions, appended_vit_targets, appended_cnn_predictions, appended_cnn_targets, appended_bfftc_predictions, appended_bfftc_targets)

In [None]:
vit_metrics

In [None]:
lists = vit_metrics['dtmb'].items() # sorted by key, return a list of tuples
x, y_vit = zip(*lists) # unpack a list of pairs into two tuples
x = sorted([int(i) for i in x])
y_vit = [i.item() for i in y_vit]
lists = cnn_metrics['dtmb'].items() # sorted by key, return a list of tuples
_, y_cnn = zip(*lists) # unpack a list of pairs into two tuples
y_cnn = [i.item() for i in y_cnn]

In [None]:
get_ipython().run_line_magic('matplotlib', 'notebook')
fig, ax = plt.subplots(nrows=1, ncols=1)
plt.plot(x, y_vit, 'o-', color='green', label='Vit')
plt.plot(x, y_cnn, 'o-', color='blue', label='CNN')
plt.legend(loc="upper right")
ax.set_facecolor('gainsboro')
ax.grid(color='white', linestyle='-', linewidth=1)
ax.set_xlabel(r'$\mathit{\tilde\sigma}$ / %')
plt.title('DTMB')

In [None]:
def compute_snr_for_noise_levels(appended_vit_predictions, appended_vit_targets, appended_cnn_predictions, appended_cnn_targets, appended_bfftc_predictions, appended_bfftc_targets):
    vit_snr = {}
    cnn_snr = {}
    # bfftc_snr = {}
    noise_levels = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10']
    for noise_level in tqdm(noise_levels):
        vit_snr[noise_level] = snr(vit_predictions[noise_level][:, :, 3:101, 3:101], ground_truth_sets_trimmed[noise_level], appended_vit_predictions[noise_level], appended_vit_targets[noise_level], device, False)
        cnn_snr[noise_level] =  snr(torch.tensor(np.moveaxis(cnn_predictions[noise_level], 3, 1))[:, :, 3:101, 3:101], ground_truth_sets_trimmed[noise_level], appended_cnn_predictions[noise_level], appended_cnn_targets[noise_level], device, False)
        # bfftc_snr[noise_level] =  snr(appended_bfftc_predictions[noise_level], appended_bfftc_targets[noise_level], device, False)
    
    return vit_snr, cnn_snr

In [None]:
vit_metrics['snr'], cnn_metrics['snr'] = compute_snr_for_noise_levels(appended_vit_predictions, appended_vit_targets, appended_cnn_predictions, appended_cnn_targets, appended_bfftc_predictions, appended_bfftc_targets)

In [None]:
lists = vit_metrics['snr'].items() # sorted by key, return a list of tuples
x, y_vit = zip(*lists) # unpack a list of pairs into two tuples
x = sorted([int(i) for i in x])
y_vit = [i.item() for i in y_vit]
lists = cnn_metrics['snr'].items() # sorted by key, return a list of tuples
_, y_cnn = zip(*lists) # unpack a list of pairs into two tuples
y_cnn = [i.item() for i in y_cnn]

In [None]:
get_ipython().run_line_magic('matplotlib', 'notebook')
fig, ax = plt.subplots(nrows=1, ncols=1)
plt.plot(x, y_vit, 'o-', color='green', label='Vit')
plt.plot(x, y_cnn, 'o-', color='blue', label='CNN')
plt.legend(loc="upper right")
ax.set_facecolor('gainsboro')
ax.grid(color='white', linestyle='-', linewidth=1)
ax.set_xlabel(r'$\mathit{\tilde\sigma}$ / %')
plt.title('SNR')

In [None]:
lists = vit_metrics['dma'].items() # sorted by key, return a list of tuples
x, y_vit = zip(*lists) # unpack a list of pairs into two tuples
x = sorted([int(i) for i in x])
y_vit = [i.item() for i in y_vit]
lists = cnn_metrics['dma'].items() # sorted by key, return a list of tuples
_, y_cnn = zip(*lists) # unpack a list of pairs into two tuples
y_cnn = [i.item() for i in y_cnn]

In [None]:
get_ipython().run_line_magic('matplotlib', 'notebook')
fig, ax = plt.subplots(nrows=1, ncols=1)
plt.plot(x, y_vit, 'o-', color='green', label='Vit')
plt.plot(x, y_cnn, 'o-', color='blue', label='CNN')
plt.legend(loc="upper right")
ax.set_facecolor('gainsboro')
ax.grid(color='white', linestyle='-', linewidth=1)
ax.set_xlabel(r'$\mathit{\tilde\sigma}$ / %')
plt.title('DMA')

In [None]:
def compute_dma_for_noise_levels(appended_vit_predictions, appended_vit_targets, appended_cnn_predictions, appended_cnn_targets, appended_bfftc_predictions, appended_bfftc_targets):
    vit_dma = {}
    cnn_dma = {}
    # bfftc_dma = {}
    noise_levels = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10']
    for noise_level in tqdm(noise_levels):
        vit_dma[noise_level] = dma(vit_predictions[noise_level][:, :, 3:101, 3:101], ground_truth_sets_trimmed[noise_level], appended_vit_predictions[noise_level], appended_vit_targets[noise_level], device, False)
        cnn_dma[noise_level] =  dma(torch.tensor(np.moveaxis(cnn_predictions[noise_level], 3, 1))[:, :, 3:101, 3:101], ground_truth_sets_trimmed[noise_level], appended_cnn_predictions[noise_level], appended_cnn_targets[noise_level], device, False)
        # bfftc_dma[noise_level] =  dma(appended_bfftc_predictions[noise_level], appended_bfftc_targets[noise_level], device, False)
    
    return vit_dma, cnn_dma

In [None]:
def compute_adtma_for_noise_levels(appended_vit_predictions, appended_vit_targets, appended_cnn_predictions, appended_cnn_targets, appended_bfftc_predictions, appended_bfftc_targets):
    vit_adtma = {}
    cnn_adtma = {}
    bfftc_adtma = {}
    noise_levels = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10']
    for noise_level in tqdm(noise_levels):
        vit_adtma[noise_level] = adtma(appended_vit_predictions[noise_level], appended_vit_targets[noise_level], device, False)
        cnn_adtma[noise_level] = adtma(appended_cnn_predictions[noise_level], appended_cnn_targets[noise_level], device, False)
        bfftc_adtma[noise_level] = adtma(appended_bfftc_predictions[noise_level], appended_bfftc_targets[noise_level], device, False)
    
    return vit_adtma, cnn_adtma, bfftc_adtma

In [None]:
vit_metrics['adtma'], cnn_metrics['adtma'], bfftc_metrics['adtma'] = compute_adtma_for_noise_levels(appended_vit_predictions, appended_vit_targets, appended_cnn_predictions, appended_cnn_targets, appended_bfftc_predictions, appended_bfftc_targets)

In [None]:
def compute_cosine_sim_for_noise_levels(appended_vit_predictions, appended_vit_targets, appended_cnn_predictions, appended_cnn_targets, appended_bfftc_predictions, appended_bfftc_targets):
    vit_cos_sim = {}
    cnn_cos_sim = {}
    # bfftc_cos_sim = {}
    noise_levels = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10']
    for noise_level in tqdm(noise_levels):
        vit_cos_sim[noise_level] = cosine_sim(appended_vit_predictions[noise_level], appended_vit_targets[noise_level], device, False)
        cnn_cos_sim[noise_level] =  cosine_sim(appended_cnn_predictions[noise_level], appended_cnn_targets[noise_level], device, False)
        #bfftc_cos_sim[noise_level] =  cosine_sim(appended_bfftc_predictions[noise_level], appended_bfftc_targets[noise_level], device, False)
    
    return vit_cos_sim, cnn_cos_sim

In [None]:
vit_metrics['cos_sim'], cnn_metrics['cos_sim']= compute_cosine_sim_for_noise_levels(appended_vit_predictions, appended_vit_targets, appended_cnn_predictions, appended_cnn_targets, appended_bfftc_predictions, appended_bfftc_targets)

In [None]:
get_ipython().run_line_magic('matplotlib', 'notebook')

fig, axs = plt.subplots(2, 8, figsize=(10, 8))
fig.tight_layout(pad=2, w_pad=2, h_pad=2)

axs[0, 0].quiver(X_test_noisy_[0, 0, :, :].detach().numpy(), X_test_noisy_[0, 1, :, :].detach().numpy(), scale=1)
axs[0, 0].set_title('Test sample: displacement', {'fontsize': 11})

axs[0, 1].quiver(ground_truths_trimmed[0, 0, :, :].detach().numpy(), ground_truths_trimmed[0, 1, :, :].detach().numpy(), scale=2)
axs[0, 1].set_title('Test sample: traction', {'fontsize': 11})

axs[0, 2].quiver(X_test_noise[0, 0, :, :], X_test_noise[0, 1, :, :], scale=0.05)
axs[0, 2].set_title(f'Test sample: Gaussian noise floor', {'fontsize': 11})

axs[1, 0].quiver(pred_vit_new[0, 0, :, :].detach().numpy(), pred_vit_new[0, 1, :, :].detach().numpy(), scale=1)
axs[1, 0].set_title(f'ViT pred (loss:{mse_vit_new[0, 0]:9.6f})', {'fontsize': 11})

axs[1, 1].quiver(pred_cnn[0, :, :, 0], pred_cnn[0, :, :, 1], scale=1)
axs[1, 1].set_title(f'CNN pred (loss:{mse_cnn[0, 0]:9.6f})', {'fontsize': 11})

axs[1, 2].quiver(bfftc_predictions_trimmed[0, 0, :, :], bfftc_predictions_trimmed[0, 1, :, :], scale=1)
axs[1, 2].set_title(f'BFFTC pred (loss:{mse_bfftc[0, 0]:9.6f})', {'fontsize': 11})

Compute MSE, DTMA and DDA for the predictions.

In [None]:
vit_new.eval()
pred_vit_new = vit_new(X_test_noisy, device=device)
pred_cnn = cnn.predict(np.moveaxis(np.array(X_test_noisy), 1, 3))

In [None]:
from MultiTask import append_predictions_and_targets, cosine_sim, adtma, dtma, dma, snr, dtmb

mse = torch.nn.MSELoss(reduction='none')
mse_vit_new = torch.mean(mse(pred_vit_new[:, :, 3:101, 3:101], ground_truths_trimmed[:, 0:2]), (1, 2, 3))
mse_cnn = torch.mean(mse(torch.tensor(np.moveaxis(pred_cnn, 3, 1)[:, :, 3:101, 3:101]), ground_truths_trimmed[:, 0:2]), (1, 2, 3))
mse_bfftc = torch.mean(mse(bfftc_predictions_trimmed, ground_truths_trimmed[:, 0:2]), (1, 2, 3))

mse_vit_new = torch.reshape(mse_vit_new, (ground_truths_trimmed.shape[0], 1))
mse_cnn = torch.reshape(mse_cnn, (ground_truths_trimmed.shape[0], 1))
mse_bfftc = torch.reshape(mse_bfftc, (ground_truths_trimmed.shape[0], 1))

In [None]:
appended_vit_predictions, appended_vit_targets = append_predictions_and_targets(pred_vit_new[:, :, 3:101, 3:101], ground_truths_trimmed, device)
dtma_vit_new = dtma(appended_vit_predictions, appended_vit_targets, device, True)
dtmb_vit_new = dtmb(pred_vit_new[:, :, 3:101, 3:101], ground_truths_trimmed, appended_vit_predictions, appended_vit_targets, device=device, per_sample=True)
snr_vit_new = snr(pred_vit_new[:, :, 3:101, 3:101], ground_truths_trimmed, appended_vit_predictions, appended_vit_targets, device=device, per_sample=True)
dma_vit_new = dma(pred_vit_new[:, :, 3:101, 3:101], ground_truths_trimmed, appended_vit_predictions, appended_vit_targets, device=device, per_sample=True)
adtma_vit_new = adtma(appended_vit_predictions, appended_vit_targets, device, True)

In [None]:
cosine_sim_vit_new = cosine_sim(appended_vit_predictions, appended_vit_targets, device, True)

In [None]:
appended_cnn_predictions, appended_cnn_targets = append_predictions_and_targets(torch.tensor(np.moveaxis(pred_cnn, 3, 1))[:, :, 3:101, 3:101], ground_truths_trimmed, device)
dtma_cnn = dtma(appended_cnn_predictions, appended_cnn_targets, device, True)
dtmb_cnn = dtmb(torch.tensor(np.moveaxis(pred_cnn, 3, 1))[:, :, 3:101, 3:101], ground_truths_trimmed, appended_cnn_predictions, appended_cnn_targets, device=device, per_sample=True)
snr_cnn = snr(torch.tensor(np.moveaxis(pred_cnn, 3, 1))[:, :, 3:101, 3:101], ground_truths_trimmed, appended_cnn_predictions, appended_cnn_targets, device=device, per_sample=True)
dma_cnn = dma(torch.tensor(np.moveaxis(pred_cnn, 3, 1))[:, :, 3:101, 3:101], ground_truths_trimmed, appended_cnn_predictions, appended_cnn_targets, device=device, per_sample=True)
adtma_cnn = adtma(appended_cnn_predictions, appended_cnn_targets, device, True)

In [None]:
cosine_cnn = cosine_sim(appended_cnn_predictions, appended_cnn_targets, device, True)

In [None]:
appended_bfftc_predictions, appended_bfftc_targets = append_predictions_and_targets(bfftc_predictions_trimmed, ground_truths_trimmed, device)
dtma_bfftc = dtma(appended_bfftc_predictions, appended_bfftc_targets, device, True)
dda_bfftc = dda(appended_bfftc_predictions, appended_bfftc_targets, device, True)
dtmb_bfftc = dtmb(bfftc_predictions_trimmed, ground_truths_trimmed, appended_bfftc_predictions, appended_bfftc_targets, device=device, per_sample=True)
snr_bfftc = snr(bfftc_predictions_trimmed, ground_truths_trimmed, appended_bfftc_predictions, appended_bfftc_targets, device=device, per_sample=True)
dma_bfftc = dma(bfftc_predictions_trimmed, ground_truths_trimmed, appended_bfftc_predictions, appended_bfftc_targets, device=device, per_sample=True)
adtma_bfftc = adtma(appended_bfftc_predictions, appended_bfftc_targets, device, True)

In [None]:
cosine_bfftc = cosine_sim(appended_bfftc_predictions, appended_bfftc_targets, device, True)

In [None]:
dtma_vit_new = torch.reshape(dtma_vit_new, (ground_truths_trimmed.shape[0], 1))
dtma_cnn = torch.reshape(dtma_cnn, (ground_truths_trimmed.shape[0], 1))
dtma_bfftc = torch.reshape(dtma_bfftc, (ground_truths_trimmed.shape[0], 1))

cosine_sim_vit_new = torch.reshape(cosine_sim_vit_new, (ground_truths_trimmed.shape[0], 1))
cosine_cnn = torch.reshape(cosine_cnn, (ground_truths_trimmed.shape[0], 1))
cosine_bfftc = torch.reshape(cosine_bfftc, (ground_truths_trimmed.shape[0], 1))

dtmb_vit_new = torch.reshape(dtmb_vit_new, (ground_truths_trimmed.shape[0], 1))
dtmb_cnn = torch.reshape(dtmb_cnn, (ground_truths_trimmed.shape[0], 1))
dtmb_bfftc = torch.reshape(dtmb_bfftc, (ground_truths_trimmed.shape[0], 1))

snr_vit_new = torch.reshape(snr_vit_new, (ground_truths_trimmed.shape[0], 1))
snr_cnn = torch.reshape(snr_cnn, (ground_truths_trimmed.shape[0], 1))
snr_bfftc = torch.reshape(snr_bfftc, (ground_truths_trimmed.shape[0], 1))

dma_vit_new = torch.reshape(dma_vit_new, (ground_truths_trimmed.shape[0], 1))
dma_cnn = torch.reshape(dma_cnn, (ground_truths_trimmed.shape[0], 1))
dma_bfftc = torch.reshape(dma_bfftc, (ground_truths_trimmed.shape[0], 1))

adtma_vit_new = torch.reshape(adtma_vit_new, (ground_truths_trimmed.shape[0], 1))
adtma_cnn = torch.reshape(adtma_cnn, (ground_truths_trimmed.shape[0], 1))
adtma_bfftc = torch.reshape(adtma_bfftc, (ground_truths_trimmed.shape[0], 1))

In [None]:
def highlight_closes_to_zero(s, props=''):
    return np.where(np.absolute(s) == np.nanmin(np.absolute(s.values)), props, '')

In [None]:
def highlight_max(x, props=''):
    return np.where(x == np.nanmax(x.to_numpy()), props, '')

In [None]:
columns = ['MSE ViT', 
           'MSE CNN',
           'MSE BFFTC',
           'DTMA ViT', 
           'DTMA CNN',
           'DTMA BFFTC',
           'Cosine sim ViT', 
           'Cosine sim CNN',
           'Cosine sim BFFTC',
           'DTMB ViT',
           'DTMB CNN',
           'DTMB BFFTC',
           'SNR ViT',
           'SNR CNN',
           'SNR BFFTC',
           'DMA ViT',
           'DMA CNN',
           'DMA BFFTC',
           'ADTMA ViT',
           'ADTMA CNN',
           'ADTMA BFFTC']

metrics = pd.DataFrame(
    np.array(torch.cat((mse_vit_new, 
                        mse_cnn,
                        mse_bfftc,
                        dtma_vit_new, 
                        dtma_cnn,
                        dtma_bfftc,
                        cosine_sim_vit_new, 
                        cosine_cnn,
                        cosine_bfftc,
                        dtmb_vit_new,
                        dtmb_cnn,
                        dtmb_bfftc,
                        snr_vit_new,
                        snr_cnn,
                        snr_bfftc,
                        dma_vit_new,
                        dma_cnn,
                        dma_bfftc,
                        adtma_vit_new,
                        adtma_cnn,
                        adtma_bfftc
                       ), 1).detach().numpy()), columns=columns)

float_formatter = "{:.6f}".format

averages = {}
for column in columns:
    averages[column] = f"avg: {float_formatter(round(metrics[column].mean(), ndigits=7))}"
    
averages = pd.DataFrame(averages, index=[0])

In [None]:
metrics
metrics.index += 1 
metrics.loc['mean'] = metrics.mean()
metrics = metrics.tail(32)
slice_ = ['MSE ViT', 'MSE CNN', 'MSE BFFTC']
metrics = metrics.style.apply(highlight_closes_to_zero, props='color:white; background-color:purple', subset=slice_, axis=1)
slice_ = ['DTMA ViT', 'DTMA CNN', 'DTMA BFFTC']
metrics = metrics.apply(highlight_closes_to_zero, props='color:white; background-color:purple', subset=slice_, axis=1)
slice_ = ['Cosine sim ViT', 'Cosine sim CNN','Cosine sim BFFTC']
metrics = metrics.apply(highlight_max, props='color:white; background-color:purple', subset=slice_, axis=1)
slice_ = ['DTMB ViT', 'DTMB CNN','DTMB BFFTC']
metrics = metrics.apply(highlight_closes_to_zero, props='color:white; background-color:purple', subset=slice_, axis=1)
slice_ = ['SNR ViT', 'SNR CNN', 'SNR BFFTC']
metrics = metrics.apply(highlight_max, props='color:white; background-color:purple', subset=slice_, axis=1)
slice_ = ['DMA ViT', 'DMA CNN', 'DMA BFFTC']
metrics = metrics.apply(highlight_closes_to_zero, props='color:white; background-color:purple', subset=slice_, axis=1)
slice_ = ['ADTMA ViT', 'ADTMA CNN', 'ADTMA BFFTC']
metrics = metrics.apply(highlight_closes_to_zero, props='color:white; background-color:purple', subset=slice_, axis=1)
metrics = metrics.set_table_styles([dict(selector='th', props=[('text-align', 'center')])])
metrics.set_properties(**{'text-align': 'center'})

metrics.set_properties(**{"border": "0.5px solid black"})

Plot example cell.

In [None]:
torch.set_printoptions(sci_mode=False)
torch.set_printoptions(precision=5)

In [None]:
get_ipython().run_line_magic('matplotlib', 'notebook')

fig, axs = plt.subplots(2, 3, figsize=(10, 8))
fig.tight_layout(pad=2, w_pad=2, h_pad=2)

axs[0, 0].quiver(X_test_noisy_[0, 0, :, :].detach().numpy(), X_test_noisy_[0, 1, :, :].detach().numpy(), scale=1)
axs[0, 0].set_title('Test sample: displacement', {'fontsize': 11})

axs[0, 1].quiver(ground_truths_trimmed[0, 0, :, :].detach().numpy(), ground_truths_trimmed[0, 1, :, :].detach().numpy(), scale=2)
axs[0, 1].set_title('Test sample: traction', {'fontsize': 11})

axs[0, 2].quiver(X_test_noise[0, 0, :, :], X_test_noise[0, 1, :, :], scale=0.05)
axs[0, 2].set_title(f'Test sample: Gaussian noise floor', {'fontsize': 11})

axs[1, 0].quiver(pred_vit_new[0, 0, :, :].detach().numpy(), pred_vit_new[0, 1, :, :].detach().numpy(), scale=1)
axs[1, 0].set_title(f'ViT pred (loss:{mse_vit_new[0, 0]:9.6f})', {'fontsize': 11})

axs[1, 1].quiver(pred_cnn[0, :, :, 0], pred_cnn[0, :, :, 1], scale=1)
axs[1, 1].set_title(f'CNN pred (loss:{mse_cnn[0, 0]:9.6f})', {'fontsize': 11})

axs[1, 2].quiver(bfftc_predictions_trimmed[0, 0, :, :], bfftc_predictions_trimmed[0, 1, :, :], scale=1)
axs[1, 2].set_title(f'BFFTC pred (loss:{mse_bfftc[0, 0]:9.6f})', {'fontsize': 11})

Implementation of DTMB

In [None]:
Y_test.shape

In [None]:
appended_multi_targets[0, 2, :, :, :]

Visualize the first test sample, its ground truth and the models' predictions.

In [None]:
get_ipython().run_line_magic('matplotlib', 'notebook')

fig, axs = plt.subplots(2, 2, figsize=(9, 9))
fig.tight_layout(pad=3, w_pad=3, h_pad=3)

axs[0, 0].quiver(bfftc_displacements[0, 0, :, :], bfftc_displacements[0, 1, :, :], scale=1)
axs[0, 0].set_title('Displacement field returned by etutfm', {'fontsize': 11})

axs[0, 1].quiver(X_test_noisy[0, 0, :, :].detach().numpy(), X_test_noisy[0, 1, :, :].detach().numpy(), scale=1)
axs[0, 1].set_title(f'Displacement field of X_test_noisy', {'fontsize': 11})

axs[1, 0].quiver(bfftc_predictions_trimmed[0, 0, :, :], bfftc_predictions_trimmed[0, 1, :, :], scale=50)
axs[1, 0].set_title(f'TFM prediction (loss: None)', {'fontsize': 11})

axs[1, 1].quiver(pred_vit_new[0, 0, :, :].detach().numpy(), pred_vit_new[0, 1, :, :].detach().numpy(), scale=1)
axs[1, 1].set_title(f'ViT prediction (loss: {torch.round(mse_vit_new[0, 0], decimals=5)})', {'fontsize': 11})

ym = 0,1 kPa

In [None]:
path_to_sample_1 = '/home/alexrichard/LRZ Sync+Share/ML in Physics/Repos/Easy-to-use_TFM_package-master/test_data/Bay-FTTC_results_01-02-23.mat'

In [None]:
bfft_pred = loadmat(path_to_sample_1)['TFM_results']['traction'][0][0].T.reshape((2, 102, 102), order='F')
bfft_dspl = loadmat(path_to_sample_1)['TFM_results']['displacement'][0][0].T.reshape((2, 102, 102), order='F')

In [None]:
bfft_pred.shape

In [None]:
get_ipython().run_line_magic('matplotlib', 'notebook')

fig, axs = plt.subplots(1, 2, figsize=(9, 5))
fig.tight_layout(pad=3, w_pad=3, h_pad=3)

axs[0].quiver(bfft_pred[0, :, :], bfft_pred[1, :, :], scale=0.01)
axs[0].set_title('bfft_pred', {'fontsize': 11})

axs[1].quiver(bfft_dspl[0, :, :], bfft_dspl[1, :, :], scale=1)
axs[1].set_title(f'bfft_dspl', {'fontsize': 11})

ym = 0,01 kPa

In [None]:
path_to_sample_1 = '/home/alexrichard/LRZ Sync+Share/ML in Physics/Repos/Easy-to-use_TFM_package-master/test_data/Bay-FTTC_results_01-02-23.mat'
bfft_pred = loadmat(path_to_sample_1)['TFM_results']['traction'][0][0].T.reshape((2, 102, 102), order='F')
bfft_dspl = loadmat(path_to_sample_1)['TFM_results']['displacement'][0][0].T.reshape((2, 102, 102), order='F')

In [None]:
get_ipython().run_line_magic('matplotlib', 'notebook')

fig, axs = plt.subplots(1, 2, figsize=(9, 5))
fig.tight_layout(pad=3, w_pad=3, h_pad=3)

axs[0].quiver(bfft_pred[0, :, :], bfft_pred[1, :, :], scale=1)
axs[0].set_title('bfft_pred', {'fontsize': 11})

axs[1].quiver(bfft_dspl[0, :, :], bfft_dspl[1, :, :], scale=1)
axs[1].set_title(f'bfft_dspl', {'fontsize': 11})

### Visualization of attention maps

In [None]:
def compute_attention_map(attn_scores):
    """
    Compute the attention rollout tensors for each layer in a ViT architecture. We assume a batch size of 1.
    
    Parameters
    __________
    attn_scores : List
        List of raw attention tensors for each encoder block.
        
    Returns
    _______
    joint_attentions : List
        List of attention rollouts for each layer.
    
    grid_size : int
        Number of patches per dimension of original (quadratic) input.
    """
    # Stack raw attention tensors in a matrix, get rid of singleton (batch) dimension and average over all attention heads per layer.
    attn_mat = torch.stack(attn_scores) # (depth, n_samples, n_heads, n_patches, n_patches)
    attn_mat = attn_mat.squeeze(1) # (depth, n_heads, n_patches, n_patches)
    attn_mat = torch.mean(attn_mat, dim=1) # (depth, n_heads, n_patches, n_patches)
    
    # Account for skip connections in the architecture
    residual_attn = torch.eye(attn_mat.size(1))
    aug_attn_mat = attn_mat + residual_attn
    aug_attn_mat = aug_attn_mat / aug_attn_mat.sum(dim=-1).unsqueeze(-1)
    
    # Recursively compute attention rollouts 
    joint_attentions = torch.zeros(aug_attn_mat.size())
    joint_attentions[0] = aug_attn_mat[0].double()
    for n in range(1, aug_attn_mat.size(0)):
        joint_attentions[n] = aug_attn_mat[n].double() @ joint_attentions[n-1].double()
    
    grid_size = int(np.sqrt(aug_attn_mat.size(-1)))
    
    return joint_attentions, grid_size

We extract the attention weights of each encoder block for the first test sample and compute the attention rollouts.

In [None]:
_, attn_scores = vit((X_test[0, :, :, :][np.newaxis, ...]), return_attention=True)
joint_attentions, grid_size = compute_attention_map(attn_scores)

Visualize the test sample and select one patch. We will then track the attention rollout of this patch among the network.

In [None]:
from matplotlib.patches import Rectangle

fig, axs = plt.subplots(1, 2, figsize=(9, 5))
fig.tight_layout(pad=3, w_pad=3, h_pad=3)

axs[0].quiver(X_test[0, 0, :, :], X_test[0 ,1 ,: , :], scale=1)
axs[0].set_title('Input as strain map', {'fontsize': 11})

C = np.sqrt(X_test[0, 0, :, :] **2 + X_test[0, 1, :, :] ** 2)
im = axs[1].pcolormesh(C, cmap='jet', shading='gouraud')
axs[1].pcolormesh(C, cmap='jet', shading='gouraud')
axs[1].set_title('Input as heatmap', {'fontsize': 11})


for row in range(0, 13):
    for column in range(0, 13):
        if row == 3 and column == 8:
            axs[0].add_patch(Rectangle(xy=(column*8, row*8), width=8, height=8, linewidth=1, color='red', fill=True))
        else:
            axs[0].add_patch(Rectangle(xy=(column*8, row*8), width=8, height=8, linewidth=1, color='red', fill=False))
    
    
axs[1].add_patch(Rectangle(xy=(8*8, 3*8), width=8, height=8, linewidth=2, color='red', fill=False))

Plot rolled out attention map of chosen patch at each layer of the encoder.

In [None]:
fig, axs = plt.subplots(4, 3, figsize=(9, 9))
fig.tight_layout(pad=1, w_pad=1, h_pad=1)
ind = 0
for i in range(0, 4):
    for j in range(0, 3):
        mask = joint_attentions[ind][47, 0:].reshape(grid_size, grid_size).detach().numpy()
        mask = resize(mask / mask.max(), (104, 104))
        mask = mask[np.newaxis, ...]
        result = torch.tensor(mask) * X_test[0,:,:,:]

        C = np.sqrt(result[0,:,:] **2 + result[1,:,:] ** 2)
        im = axs[i, j].pcolormesh(C, cmap='jet', shading='gouraud')
        axs[i, j].pcolormesh(C, cmap='jet', shading='gouraud')
        axs[i, j].set_title(f'Attention rollout of encoder block {ind}', {'fontsize': 8})
        ind += 1

Next, another test sample with a more "cell-like" geometry is analysed.

In [None]:
os.chdir('..')
sys.path.append(f"{os.getcwd()}/DL_TFM/scripts/")
from data_preparation import matFiles_to_npArray

In [None]:
X_test_ = matFiles_to_npArray('comparables/generic')[1]
Y_test_ = matFiles_to_npArray('comparables/generic')[0]
X_test_ = np.moveaxis(np.array(X_test_['dspl'])[np.newaxis, ...], 3, 1)
Y_test_ = np.moveaxis(np.array(Y_test_['trac'])[np.newaxis, ...], 3, 1)

In [None]:
_, attn_scores = vit(torch.tensor(X_test_).double(), True)
joint_attentions, grid_size = compute_attention_map(attn_scores)

In [None]:
from matplotlib.patches import Rectangle

fig, axs = plt.subplots(1,2, figsize=(9, 5))
fig.tight_layout(pad=3, w_pad=3, h_pad=3)

axs[0].quiver(X_test_[0,0,:,:], X_test_[0,1,:,:], scale=20)
axs[0].set_title('Input (strain vectors)', {'fontsize': 11})

C = np.sqrt(X_test_[0,0,:,:] **2 + X_test_[0,1,:,:] ** 2)
im = axs[1].pcolormesh(C, cmap='jet', shading='gouraud')
axs[1].pcolormesh(C, cmap='jet', shading='gouraud')
axs[1].set_title('Input (heatmap)', {'fontsize': 11})


for row in range(0, 13):
    for column in range(0, 13):
        if row == 7 and column == 9:
            axs[0].add_patch(Rectangle(xy=(column*8, row*8), width=8, height=8, linewidth=1, color='red', fill=True))
        else:
            axs[0].add_patch(Rectangle(xy=(column*8, row*8), width=8, height=8, linewidth=1, color='red', fill=False))
    
    
axs[1].add_patch(Rectangle(xy=(9*8, 7*8), width=8, height=8, linewidth=1, color='red', fill=False))

In [None]:
fig, axs = plt.subplots(4, 3, figsize=(9, 9))
fig.tight_layout(pad=2, w_pad=2, h_pad=2)
ind = 0
for i in range(0, 4):
    for j in range(0, 3):
        mask = joint_attentions[ind][100, 0:].reshape(grid_size, grid_size).detach().numpy()
        mask = resize(mask / mask.max(), (104, 104))
        mask = mask[np.newaxis, ...]
        result = torch.tensor(mask) * X_test_[0,:,:,:]

        C = np.sqrt(result[0,:,:] **2 + result[1,:,:] ** 2)
        im = axs[i, j].pcolormesh(C, cmap='jet', shading='gouraud')
        axs[i, j].pcolormesh(C, cmap='jet', shading='gouraud')
        axs[i, j].set_title(f'Attention rollout of encoder block {ind}', {'fontsize': 10})
        ind += 1