In [465]:
import warnings
warnings.filterwarnings('ignore')

import copy
import datetime
import h5py
import keras
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import scienceplots
import torch
import torch.nn as nn
import torchvision
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import sys

from cv2 import resize
from datetime import datetime
from gc import collect
from os import cpu_count
from scipy.io import savemat, loadmat
from sklearn.model_selection import train_test_split
from time import sleep
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchinfo import summary
from tqdm import tqdm

sys.path.append(f"{os.getcwd()}/ViT architecture/working ViT")
sys.path.append(f"{os.getcwd()}/scripts/")
from VisionTransformer_working import VisionTransformer as Vit_old

sys.path.append(f"{os.getcwd()}/ViT architecture/Architecture tryouts/DPT/")
from VisionTransformer_working_for_DPT import VisionTransformer as Vit
from VisionTransformer_working_for_DPT import VisionTransformer3 as Vit3

In [466]:
random_seed = 2
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
cudnn.benchmark = True

In [467]:
collect()
torch.cuda.empty_cache()
torch.set_printoptions(precision=6)
device = torch.device('cpu')
print(f"Running on device: {device}")

Running on device: cpu


## Load models

In [468]:
vit = Vit(dspl_size=104, 
          patch_size=8, 
          embed_dim=128,
          depth=6,
          n_heads=4,
          mlp_ratio=1.0,
          qkv_bias=False,
          p=0.1,
          attn_p=0.1,
          drop_path=0.).float()

path_to_vit = '/home/alexrichard/PycharmProjects/UQ_DL-TFM/ViT-TFM/ViT architecture/Architecture tryouts/DPT/logs_and_weights/ViT-final_2023-Jul-27 10:18:31/ViT-final_2023-Jul-27 10:18:31.pth'

vit.load_state_dict(torch.load(path_to_vit, map_location=torch.device('cpu'))['best_model_weights'], strict=True)        

<All keys matched successfully>

In [469]:
prob_vit = Vit3(dspl_size=104, 
                patch_size=8, 
                embed_dim=128,
                depth=4,
                n_heads=4,
                mlp_ratio=1.0,
                qkv_bias=False,
                p=0.1,
                attn_p=0.1,
                drop_path=0.).float()

path_to_prob_vit = '/home/alexrichard/PycharmProjects/UQ_DL-TFM/ViT-TFM/ViT architecture/Architecture tryouts/DPT/logs_and_weights/ViT-GNLL_final2023-May-17 23:02:08/ViT-GNLL_final2023-May-17 23:02:08.pth'

prob_vit.load_state_dict(torch.load(path_to_prob_vit, map_location=torch.device('cpu'))['best_model_weights'], strict=True)        

<All keys matched successfully>

In [470]:
# cnn = keras.models.load_model('/home/alexrichard/PycharmProjects/UQ_DL-TFM/mltfm/CNN_noisy-2023-Mar-21 18:13:25_checkpoint.h5')
cnn = keras.models.load_model('/home/alexrichard/PycharmProjects/UQ_DL-TFM/mltfm/CNN_noisy_final-2023-May-18 23:46:36_checkpoint.h5')

## Load data

In [7]:
real_cells = []
real_dspl = torch.zeros((14, 2, 104, 104))
directory = '/home/alexrichard/LRZ Sync+Share/ML in Physics/Repos/DL-TFM-main/cells/cells/dspl'

for i, file in enumerate(os.listdir(directory)):
    filename = os.fsdecode(file)
    if filename.endswith(".mat"):
        real_cell = {}
        file = loadmat(directory+'/'+filename)
        real_cell['brdx'] = np.rot90(file['brdx'], 2)
        real_cell['brdy'] = np.rot90(file['brdy'], 2)
        real_cell['dspl'] = torch.tensor((1 / 200.2) * np.transpose(file['dspl']))
        real_cell['vit_pred'] = vit(real_cell['dspl'].unsqueeze(0).float()) * 10670
        real_cell['prob_vit_pred'] = prob_vit(real_cell['dspl'].unsqueeze(0).float()) * 10670
        real_cell['cnn_pred'] = cnn.predict(np.expand_dims(np.moveaxis(np.array(real_cell['dspl']), 0, 2), 0)) * 10670
        real_cells.append(real_cell)

TypeError: forward() missing 2 required positional arguments: 'src_size' and 'tgt_size'

In [433]:
Bay_FTTC_results = loadmat('/home/alexrichard/LRZ Sync+Share/ML in Physics/Repos/Easy-to-use_TFM_package-master/test_data/real_cells/Samples/Bay-FTTC_results_24-05-23.mat')
TFM_results = Bay_FTTC_results['TFM_results']['traction'][0][0].T.reshape((2, 102, 102), order='F')
TFM_results[:, 0:3, :] = 0
TFM_results[:, :, 0:3] = 0
TFM_results[:, 99:, :] = 0
TFM_results[:, :, 99:] = 0
TFM_results_padded = F.pad(input=torch.tensor(TFM_results).float(), pad=(1, 1, 1, 1), mode='constant', value=0)

In [435]:
plt.rcParams['xtick.labelsize'] = 25
plt.rcParams['ytick.labelsize'] = 25

get_ipython().run_line_magic('matplotlib', 'notebook')
plt.style.use(['science', 'grid', 'muted'])

fig, axs = plt.subplots(2, 1, figsize=(6, 9))
fig.tight_layout(pad=2, w_pad=0, h_pad=3)

i = 0

#C = np.sqrt(real_cells[i]['dspl'][0, :, :].detach().numpy() ** 2 + real_cells[i]['dspl'][1, :, :].detach().numpy() ** 2)
#argmax = np.unravel_index(np.argmax(C, axis=None), C.shape)
#real_cells[i]['dspl'][0, 10, 10] = np.linalg.norm(C[argmax])
#real_cells[i]['dspl'][1, 10, 10] = 0

axs[0].quiver(real_cells[i]['dspl'][0, :, :].detach().numpy(), real_cells[i]['dspl'][1, :, :].detach().numpy(), scale=1/5)
axs[0].set(adjustable='box', aspect='equal')

axs[0].plot(real_cells[i]['brdx'][0], real_cells[i]['brdy'][0], c='grey')
axs[0].set_title('Displacement', pad=10, fontsize=30)
#axs[0].text(80, 110, 'Displacement', fontsize=35)
dx = 5 * 0.51948
dy = 0
axs[0].arrow(x=10, y=10, dx=dx, dy=dy, head_width=1.3, color='black')
axs[0].annotate(u'5 \xb5m',
                fontsize=13,
                xy=(0.1, 0.1), 
                xytext=(11, 8),
                textcoords='offset points')

C = np.sqrt(real_cells[i]['dspl'][0, :, :].detach().numpy() ** 2 + real_cells[i]['dspl'][1, :, :].detach().numpy() ** 2)
im = axs[1].pcolormesh(C, shading='gouraud')
axs[1].plot(real_cells[i]['brdx'][0], real_cells[i]['brdy'][0], c='white')
#axs[1].set_title(r'Displacement', pad=10, fontsize=35)
colorbar = fig.colorbar(im, ax=axs[1], pad=0.02)
colorbar.ax.set_title(u'\xb5m', fontsize=17)
axs[1].set(adjustable='box', aspect='equal')
#cb_ax = fig.add_axes([.91,.124,.04,.754])
#fig.colorbar(im,orientation='vertical',cax=cb_ax)

fig.subplots_adjust(top=0.82)

image_name = 'Real_dspl_vert.jpeg'
#fig.savefig(f'/home/alexrichard/LRZ Sync+Share/ML in Physics/Thesis/Plots/{image_name}', format='jpeg', dpi=1000)

<IPython.core.display.Javascript object>

In [436]:
norm = matplotlib.colors.Normalize(vmin=0.0,vmax=np.max(np.sqrt(TFM_results_padded[0, :, :].detach().numpy() ** 2 + TFM_results_padded[1, :, :].detach().numpy() ** 2)))

In [441]:
plt.rcParams['xtick.labelsize'] = 25
plt.rcParams['ytick.labelsize'] = 25

get_ipython().run_line_magic('matplotlib', 'notebook')
plt.style.use(['science', 'grid', 'muted'])

fig, axs = plt.subplots(2, 1, figsize=(6, 9))
fig.tight_layout(pad=2, w_pad=0, h_pad=3)

i = 0
axs[0].quiver(real_cells[i]['vit_pred'][0, 0, :, :].detach().numpy(), real_cells[i]['vit_pred'][0, 1, :, :].detach().numpy(), scale=10670)
axs[0].set(adjustable='box', aspect='equal')
axs[0].set_title('ViT', pad=10, fontsize=30)
axs[0].plot(real_cells[i]['brdx'][0], real_cells[i]['brdy'][0], c='grey')

dx = 9
dy = 0
axs[0].arrow(x=10, y=10, dx=dx, dy=dy, head_width=1, color='black')
axs[0].annotate(u'1000 Pa',
                fontsize=13,
                xy=(0.2, 0.2), 
                xytext=(11, 8),
                textcoords='offset points')

C = np.sqrt(real_cells[i]['vit_pred'][0, 0, :, :].detach().numpy() ** 2 + real_cells[i]['vit_pred'][0, 1, :, :].detach().numpy() ** 2)
im = axs[1].pcolormesh(C, shading='gouraud', norm=norm)
axs[1].plot(real_cells[i]['brdx'][0], real_cells[i]['brdy'][0], c='white')
colorbar = fig.colorbar(im, ax=axs[1], pad=0.02, norm=norm)
colorbar.ax.set_title(u'Pa', fontsize=17)
axs[1].set(adjustable='box', aspect='equal')

fig.subplots_adjust(top=0.825)

image_name = 'Real_vit_vert.jpeg'
fig.savefig(f'/home/alexrichard/LRZ Sync+Share/ML in Physics/Thesis/Plots/{image_name}', format='jpeg', dpi=1000)

<IPython.core.display.Javascript object>

In [442]:
plt.rcParams['xtick.labelsize'] = 25
plt.rcParams['ytick.labelsize'] = 25

get_ipython().run_line_magic('matplotlib', 'notebook')
plt.style.use(['science', 'grid', 'muted'])

fig, axs = plt.subplots(2, 1, figsize=(6, 9))
fig.tight_layout(pad=2, w_pad=0, h_pad=3)

i = 0
axs[0].quiver(real_cells[i]['prob_vit_pred'][0, 0, :, :].detach().numpy(), real_cells[i]['prob_vit_pred'][0, 1, :, :].detach().numpy()  , scale=10670)
axs[0].set(adjustable='box', aspect='equal')
axs[0].set_title('Prob-ViT', pad=10, fontsize=30)
axs[0].plot(real_cells[i]['brdx'][0], real_cells[i]['brdy'][0], c='grey')

dx = 9
dy = 0
axs[0].arrow(x=10, y=10, dx=dx, dy=dy, head_width=1, color='black')
axs[0].annotate(u'1000 Pa',
                fontsize=13,
                xy=(0.2, 0.2), 
                xytext=(11, 8),
                textcoords='offset points')

C = np.sqrt(real_cells[i]['prob_vit_pred'][0, 0, :, :].detach().numpy() ** 2 + real_cells[i]['prob_vit_pred'][0, 1, :, :].detach().numpy() ** 2)
im = axs[1].pcolormesh(C, shading='gouraud', norm=norm)
axs[1].plot(real_cells[i]['brdx'][0], real_cells[i]['brdy'][0], c='white')
colorbar = fig.colorbar(im, ax=axs[1], pad=0.02, norm=norm)
colorbar.ax.set_title(u'Pa', fontsize=17)
axs[1].set(adjustable='box', aspect='equal')

fig.subplots_adjust(top=0.825)

image_name = 'Real_prob-vit_vert.jpeg'
fig.savefig(f'/home/alexrichard/LRZ Sync+Share/ML in Physics/Thesis/Plots/{image_name}', format='jpeg', dpi=1000)

<IPython.core.display.Javascript object>

In [443]:
plt.rcParams['xtick.labelsize'] = 25
plt.rcParams['ytick.labelsize'] = 25

get_ipython().run_line_magic('matplotlib', 'notebook')
plt.style.use(['science', 'grid', 'muted'])

fig, axs = plt.subplots(2, 1, figsize=(6, 9))
fig.tight_layout(pad=2, w_pad=0, h_pad=3)

i = 0
axs[0].quiver(real_cells[i]['cnn_pred'][0, :, :, 0], real_cells[i]['cnn_pred'][0, :, :, 1], scale=10670)
axs[0].set(adjustable='box', aspect='equal')
axs[0].set_title('CNN', pad=10, fontsize=30)
axs[0].plot(real_cells[i]['brdx'][0], real_cells[i]['brdy'][0], c='grey')

dx = 9
dy = 0
axs[0].arrow(x=10, y=10, dx=dx, dy=dy, head_width=1, color='black')
axs[0].annotate(u'1000 Pa',
                fontsize=13,
                xy=(0.2, 0.2), 
                xytext=(11, 8),
                textcoords='offset points')

C = np.sqrt(real_cells[i]['cnn_pred'][0, :, :, 0] ** 2 + real_cells[i]['cnn_pred'][0, :, :, 1] ** 2)
im = axs[1].pcolormesh(C, shading='gouraud', norm=norm)
axs[1].plot(real_cells[i]['brdx'][0], real_cells[i]['brdy'][0], c='white')
colorbar = fig.colorbar(im, ax=axs[1], pad=0.02, norm=norm)
colorbar.ax.set_title(u'Pa', fontsize=17)
axs[1].set(adjustable='box', aspect='equal')

fig.subplots_adjust(top=0.825)

image_name = 'Real_cnn_vert.jpeg'
fig.savefig(f'/home/alexrichard/LRZ Sync+Share/ML in Physics/Thesis/Plots/{image_name}', format='jpeg', dpi=1000)

<IPython.core.display.Javascript object>

In [444]:
plt.rcParams['xtick.labelsize'] = 25
plt.rcParams['ytick.labelsize'] = 25

get_ipython().run_line_magic('matplotlib', 'notebook')
plt.style.use(['science', 'grid', 'muted'])

fig, axs = plt.subplots(2, 1, figsize=(6, 9))
fig.tight_layout(pad=2, w_pad=0, h_pad=3)

i = 0
axs[0].quiver(TFM_results_padded[0, :, :].detach().numpy().T, TFM_results_padded[1, :, :].detach().numpy().T, scale=10670)
axs[0].set(adjustable='box', aspect='equal')
axs[0].set_title('BFTTC', pad=10, fontsize=30)
axs[0].plot(real_cells[i]['brdx'][0], real_cells[i]['brdy'][0], c='grey')

dx = 9
dy = 0
axs[0].arrow(x=10, y=10, dx=dx, dy=dy, head_width=1, color='black')
axs[0].annotate(u'1000 Pa',
                fontsize=13,
                xy=(0.2, 0.2), 
                xytext=(11, 8),
                textcoords='offset points')

C = np.sqrt(TFM_results_padded[0, :, :].detach().numpy().T ** 2 + TFM_results_padded[1, :, :].detach().numpy().T ** 2)
im = axs[1].pcolormesh(C, shading='gouraud', norm=norm)
axs[1].plot(real_cells[i]['brdx'][0], real_cells[i]['brdy'][0], c='white')
colorbar = fig.colorbar(im, ax=axs[1], pad=0.02, norm=norm)
colorbar.ax.set_title(u'Pa', fontsize=17)
axs[1].set(adjustable='box', aspect='equal')

fig.subplots_adjust(top=0.825)

image_name = 'Real_bfttc_vert.jpeg'
fig.savefig(f'/home/alexrichard/LRZ Sync+Share/ML in Physics/Thesis/Plots/{image_name}', format='jpeg', dpi=1000)

<IPython.core.display.Javascript object>

scale conversion:
1 model pixel = 7 camera pixel, 
104mp = 7 * 104 * 0.275 micrometer = 200.2 micrometer = 0.2002 mm

### Fibroblast

In [471]:
Bay_FTTC_output = loadmat('/home/alexrichard/LRZ Sync+Share/ML in Physics/Thesis/Plots/Real cell/Cell4/Bay-FTTC_results_24-07-23.mat')
Bay_FTTC_output['TFM_settings']

array([[(array([[0.5]]), array([[10000]], dtype=uint16), array([[0.108]]), array([[230.82756012]]), array([[16]], dtype=uint8), array([[0]], dtype=uint8), array([[54]], dtype=uint8), array([[54]], dtype=uint8), array(['Region of interest for noise selected manually'], dtype='<U46'))]],
      dtype=[('poisson', 'O'), ('young', 'O'), ('micrometer_per_pix', 'O'), ('regularization_parameter', 'O'), ('meshsize', 'O'), ('zdepth', 'O'), ('i_max', 'O'), ('j_max', 'O'), ('type_noise', 'O')])

Poisson's ratio: 0.5 <br>
Young's modulus: 10000 Pa <br>
Micrometer per pixel : 0.108 <br>
Number of pixel per dimension: 54 * 16 <br>

In [472]:
dspl = Bay_FTTC_output['TFM_results']['displacement'][0][0].T.reshape((2, 54, 54), order='F')
dspl_padded = F.pad(input=torch.tensor(dspl).float(), pad=(1, 1, 1, 1), mode='constant', value=0)
dspl_padded_big = F.pad(input=torch.tensor(dspl).float(), pad=(25, 25, 25, 25), mode='constant', value=0)

In [473]:
plt.rcParams['xtick.labelsize'] = 25
plt.rcParams['ytick.labelsize'] = 25

get_ipython().run_line_magic('matplotlib', 'notebook')
plt.style.use(['science', 'grid', 'muted'])

fig, axs = plt.subplots(2, 1, figsize=(6, 9))
fig.tight_layout(pad=2, w_pad=0, h_pad=3)

sparse_dspl = dspl_padded * np.resize([[0, 0, 0, 0],[0, 1, 0, 0],[0, 0, 1, 0],[0, 0, 0, 1]], (56,56))
sparse_dspl[0, 5, 5] = 5.787  # 5.787 * 16 * 0.108 [micrometer] = 10 [micrometer]
sparse_dspl[1, 5, 5] = 0
x = np.linspace(0, 56, 56)
y = np.linspace(0, 56, 56)
axs[0].quiver(x, y, sparse_dspl[0, :, :], sparse_dspl[1, :, :], angles='xy', scale_units='xy', scale=1)
axs[0].set(adjustable='box', aspect='equal')
axs[0].set_title('Displacement', pad=10, fontsize=30)
axs[0].annotate(u'10 \xb5m',
                fontsize=12,
                xy=(0.1, 0.1), 
                xytext=(11, 5),
                textcoords='offset points')

norm = matplotlib.colors.Normalize(vmin=0.0, vmax=np.max(np.sqrt(np.array(dspl_padded[0, :, :] ** 2 + dspl_padded[1, :, :] ** 2))))
C = np.sqrt(dspl_padded[0, :, :] ** 2 + dspl_padded[1, :, :] ** 2)
im = axs[1].pcolormesh(C, shading='gouraud', norm=norm)
colorbar = fig.colorbar(im, ax=axs[1], pad=0.02)
colorbar.ax.set_title(u'\xb5m', fontsize=17)
axs[1].set(adjustable='box', aspect='equal')

fig.subplots_adjust(top=0.82)

image_name = 'Real_dspl_vert.jpeg'
#fig.savefig(f'/home/alexrichard/LRZ Sync+Share/ML in Physics/Thesis/Plots/{image_name}', format='jpeg', dpi=1000)

<IPython.core.display.Javascript object>

In [474]:
bfttc_result = np.pad(
    array=Bay_FTTC_output['TFM_results']['traction'][0][0].T.reshape((2, 54, 54), order='F'), 
    pad_width=((0,0),(1,1),(1,1)),
    mode='constant',
    constant_values=0)
vit_pred = vit((1 / (104 * 16 * 0.108)) * dspl_padded_big.unsqueeze(0), src_size=(1, 13, 13, -1), tgt_size=(13, 13), interpolate=False) * 10000
cnn_pred = cnn.predict(np.expand_dims(np.moveaxis(np.array((1 / (104 * 16 * 0.108)) * dspl_padded_big), 0, 2), 0)) * 10000



In [475]:
plt.rcParams['xtick.labelsize'] = 25
plt.rcParams['ytick.labelsize'] = 25

get_ipython().run_line_magic('matplotlib', 'notebook')
plt.style.use(['science', 'grid', 'muted'])

fig, axs = plt.subplots(2, 1, figsize=(6, 9))
fig.tight_layout(pad=2, w_pad=0, h_pad=3)

#sparse_pred = bfttc_result * np.resize([[0, 0, 0, 0],[0, 1, 0, 0],[0, 0, 1, 0],[0, 0, 0, 1]], (56, 56))
#sparse_dspl[0, 5, 5] = 5.787  # 5.787 * 16 * 0.108 [micrometer] = 10 [micrometer]
#sparse_dspl[1, 5, 5] = 0

x = np.linspace(0, 56, 56)
y = np.linspace(0, 56, 56)
axs[0].quiver(x, y, bfttc_result[0, :, :], bfttc_result[1, :, :], angles='xy', scale=10000)
axs[0].set(adjustable='box', aspect='equal')
axs[0].set_title('BFTTC', pad=10, fontsize=30)
axs[0].annotate(u'10 \xb5m',
                fontsize=12,
                xy=(0.1, 0.1), 
                xytext=(11, 5),
                textcoords='offset points')

axs[0].annotate(u'10 \xb5m',
                fontsize=12,
                xy=(0.1, 0.1), 
                xytext=(11, 5),
                textcoords='offset points')

norm = matplotlib.colors.Normalize(vmin=0.0, vmax=np.max(np.sqrt(np.array(bfttc_result[0, :, :] ** 2 + bfttc_result[1, :, :] ** 2))))
C = np.sqrt(bfttc_result[0, :, :] ** 2 + bfttc_result[1, :, :] ** 2)
im = axs[1].pcolormesh(C, shading='gouraud', norm=norm)
colorbar = fig.colorbar(im, ax=axs[1], pad=0.02, norm=norm)
colorbar.ax.set_title(u'Pa', fontsize=17)
axs[1].set(adjustable='box', aspect='equal')

fig.subplots_adjust(top=0.825)

image_name = 'Real_bfttc_vert.jpeg'
#fig.savefig(f'/home/alexrichard/LRZ Sync+Share/ML in Physics/Thesis/Plots/{image_name}', format='jpeg', dpi=1000)

<IPython.core.display.Javascript object>

In [477]:
torch.max(torch.sqrt((vit_pred[0, 0, :, :] ** 2 + vit_pred[0, 1, :, :] ** 2)))

tensor(4452.343262, grad_fn=<MaxBackward1>)

In [478]:
np.max(np.sqrt(bfttc_result[0, :, :] ** 2 + bfttc_result[1, :, :] ** 2))

2470.570093563491

In [479]:
np.max(np.sqrt(cnn_pred[0, :, :, 0] ** 2 + cnn_pred[0, :, :, 1] ** 2))

11679.527

In [482]:
plt.rcParams['xtick.labelsize'] = 25
plt.rcParams['ytick.labelsize'] = 25

get_ipython().run_line_magic('matplotlib', 'notebook')
plt.style.use(['science', 'grid', 'muted'])

fig, axs = plt.subplots(2, 1, figsize=(6, 9))
fig.tight_layout(pad=2, w_pad=0, h_pad=3)

sparse_pred = vit_pred
#sparse_pred = vit_pred * torch.tensor(np.resize([[0, 0, 0, 0],[0, 1, 0, 0],[0, 0, 1, 0],[0, 0, 0, 1]], (104, 104)))
#sparse_dspl[0, 5, 5] = 5.787  # 5.787 * 16 * 0.108 [micrometer] = 10 [micrometer]
#sparse_dspl[1, 5, 5] = 0
x = np.linspace(0, 104, 104)
y = np.linspace(0, 104, 104)
axs[0].quiver(x, y, sparse_pred[0, 0, :, :].detach(), sparse_pred[0, 1, :, :].detach(), scale=3000, units='xy')
axs[0].set(adjustable='box', aspect='equal')
axs[0].set_title('ViT', pad=10, fontsize=30)

#dx = 9
#dy = 0
#axs[0].arrow(x=10, y=10, dx=dx, dy=dy, head_width=1, color='black')
#axs[0].annotate(u'1000 Pa',
#                fontsize=13,
#                xy=(0.2, 0.2), 
#                xytext=(11, 8),
#                textcoords='offset points')

norm = matplotlib.colors.Normalize(vmin=0.0, vmax=np.max(np.sqrt(np.array(sparse_pred[0, 0, :, :].detach()) ** 2 + np.array(sparse_pred[0, 1, :, :].detach()) ** 2)))
C = np.sqrt(vit_pred[0, 0, :, :].detach().numpy() ** 2 + vit_pred[0, 1, :, :].detach().numpy() ** 2)
im = axs[1].pcolormesh(C, shading='gouraud', norm=norm)
colorbar = fig.colorbar(im, ax=axs[1], pad=0.02, norm=norm)
colorbar.ax.set_title(u'Pa', fontsize=17)
axs[1].set(adjustable='box', aspect='equal')

fig.subplots_adjust(top=0.825)

#image_name = 'Real_vit_vert.jpeg'
#fig.savefig(f'/home/alexrichard/LRZ Sync+Share/ML in Physics/Thesis/Plots/{image_name}', format='jpeg', dpi=1000)

<IPython.core.display.Javascript object>