In [60]:
import xarray as xr
import numpy as np
import torch
import pytorch_lightning as pl
import torch.nn as nn
#import xrft
import matplotlib.pyplot as plt
#import pandas as pd
from tqdm import tqdm
import os
import glob
from pathlib import Path
from typing import Union, List
from sklearn.decomposition import PCA
from sklearn.metrics import mean_squared_error
from sklearn.metrics import r2_score
import torch.nn.functional as F
import cartopy.crs as ccrs
import pickle

In [61]:
from src.utils import load_ssf_acoustic_variables, load_sound_speed_fields
from src.data import TrainingItem, AutoEncoderDatamodule, BaseDatamodule
from src.acoustic_predictor import ConvBlock, AcousticPredictor
from src.autoencoder import AutoEncoder

In [62]:

sound_speed_path = "/DATASET/eNATL/eNATL60_BLB002_sound_speed_regrid_0_1000m.nc"
ecs_path = "/DATASET/eNATL/eNATL60_BLB002_cutoff_freq_regrid_0_1000m.nc"

In [63]:
device = 'cuda:1'

In [64]:
ss_ds = load_ssf_acoustic_variables(sound_speed_path,ecs_path)[0]


In [65]:
ae_path = "/homes/o23gauvr/Documents/thèse/code/FASCINATION/outputs/AE_without_AP/4_4_sigmoid_lr_0.001/2024-04-15_17-37"
model_ae_path = glob.glob(f"{ae_path}/**/*.ckpt",recursive=True)[0]
pickle_ae_path = glob.glob(f"{ae_path}/**/*.pickle",recursive=True)[0]




In [66]:

with open(pickle_ae_path, 'rb') as file:
    time_idx_split = pickle.load(file)


ss_ds_test = ss_ds.sel(time=time_idx_split['test'])



del ss_ds


In [67]:
nan_idx = np.argwhere(np.isnan(ss_ds_test.celerity.data))
nan_idx[:, 0], nan_idx[:, 1] = nan_idx[:, 1].copy(), nan_idx[:, 0].copy()
nan_index_time_lat_lon, nan_counts = np.unique(nan_idx[:, -3:], axis=0, return_counts=True)
nan_index_time_lat_lon = np.column_stack((nan_index_time_lat_lon, nan_counts))
nan_index_time_lat_lon = nan_index_time_lat_lon[np.argsort(nan_index_time_lat_lon[:,-1])]

nan_max = nan_index_time_lat_lon[-1,-1]


percentile = 30

# Générer les indices à rechercher
indices_to_search = [nan_index_time_lat_lon[np.searchsorted(nan_index_time_lat_lon[:,-1], 
                                                            i * nan_max // percentile)] 
                     for i in range(1, percentile)]

# Empiler les indices
nan_index_lat_lon_repartition = np.stack([nan_index_time_lat_lon[0]] +
                                         indices_to_search +
                                         [nan_index_time_lat_lon[-1]], axis=0)

In [68]:
del nan_idx
del nan_index_time_lat_lon
del nan_counts

In [69]:
ss_ds_test = ss_ds_test.dropna(dim='lat')


coords_test = ss_ds_test.coords
print(coords_test)

Coordinates:
  * lon      (lon) float64 -65.95 -65.9 -65.85 -65.8 ... -54.1 -54.05 -54.0
  * lat      (lat) float64 32.6 32.65 32.7 32.75 32.8 ... 41.2 41.25 41.3 41.35
  * z        (z) float64 0.4805 1.559 2.794 4.187 ... 968.4 985.3 1.002e+03
  * time     (time) datetime64[ns] 2009-11-14T12:00:00 ... 2010-06-01T12:00:00


In [70]:
coords_test['time']

In [71]:
#depth_std = torch.tensor(ss_arr_test.std(dim=1))
depth_std = torch.tensor(ss_ds_test.celerity.std(dim='z').values)
flatten_max_std_idx = torch.topk(depth_std.nan_to_num(0).flatten(), k = 30).indices
max_ss_std_idx = torch.stack(torch.unravel_index(flatten_max_std_idx, depth_std.shape),dim=1)
flatten_min_std_idx = torch.topk(depth_std.nan_to_num(1e5).flatten(), k = 30, largest = False).indices
min_ss_std_idx = torch.stack(torch.unravel_index(flatten_min_std_idx, depth_std.shape),dim=1)

In [72]:
del depth_std 
del flatten_max_std_idx
del flatten_min_std_idx

In [73]:


depth_std_150m = torch.tensor(ss_ds_test.where(ss_ds_test.z<=150, drop=True).celerity.std(dim='z').values)

flatten_max_std_150_idx = torch.topk(depth_std_150m.nan_to_num(0).flatten(), k = 30).indices
max_ss_std_150_idx = torch.stack(torch.unravel_index(flatten_max_std_150_idx, depth_std_150m.shape),dim=1)
flatten_min_std_150_idx = torch.topk(depth_std_150m.nan_to_num(1e5).flatten(), k = 30, largest = False).indices
min_ss_std_150_idx = torch.stack(torch.unravel_index(flatten_min_std_150_idx, depth_std_150m.shape),dim=1)

In [74]:

del depth_std_150m
del flatten_max_std_150_idx
del flatten_min_std_150_idx

In [75]:
time_lon_idx = np.unravel_index(ss_ds_test.std(dim=('z','lat')).argmax().celerity.data, (len(coords_test['time']),len(coords_test['lon'])))
time_lat_idx = np.unravel_index(ss_ds_test.std(dim=('z','lon')).argmax().celerity.data, (len(coords_test['time']),len(coords_test['lat'])))

max_std_lat=np.column_stack((np.full_like(np.arange(240), time_lon_idx[0]), np.arange(240), np.full_like(np.arange(240), time_lon_idx[1])))
max_std_lon=np.column_stack((np.full_like(np.arange(240), time_lat_idx[0]), np.arange(240), np.full_like(np.arange(240), time_lat_idx[1])))

  time_lon_idx = np.unravel_index(ss_ds_test.std(dim=('z','lat')).argmax().celerity.data, (len(coords_test['time']),len(coords_test['lon'])))
  time_lat_idx = np.unravel_index(ss_ds_test.std(dim=('z','lon')).argmax().celerity.data, (len(coords_test['time']),len(coords_test['lat'])))


In [76]:
del ss_ds_test

In [77]:
n = 1000
sizes = [len(coords_test['time']), len(coords_test['lat']), len(coords_test['lon'])]
random_t_lat_lon = np.array([np.random.randint(0, size, size=(n)) for size in sizes]).T

In [78]:
acc_ds = load_ssf_acoustic_variables(sound_speed_path,ecs_path)[1]
acc_ds_test = acc_ds.sel(time=coords_test['time'].data, lat = coords_test['lat'].data)

del acc_ds

In [79]:
sorted_ecs_idx = np.argsort(acc_ds_test.ecs.values,axis=None)  
max_ecs_idx = np.stack(np.unravel_index(sorted_ecs_idx[-30:], acc_ds_test.ecs.shape),axis=1)
min_ecs_idx = np.stack(np.unravel_index(sorted_ecs_idx[:30], acc_ds_test.ecs.shape),axis=1)


In [80]:
del acc_ds_test
del sorted_ecs_idx

In [81]:
profile_idx_dict = dict(
    min_std=min_ss_std_idx,
    max_std=max_ss_std_idx,
    min_std_150=min_ss_std_150_idx,
    max_std_150=max_ss_std_150_idx,
    min_ecs=min_ecs_idx,
    max_ecs=max_ecs_idx,
    nan_profile_idx=nan_index_lat_lon_repartition[:-1,:-1],
    max_std_lat=max_std_lat,
    max_std_lon=max_std_lon,
    random_1000=random_t_lat_lon   
    
)

profile_idx_dict

{'min_std': tensor([[ 18, 100,  83],
         [ 18, 100,  82],
         [ 18, 100,  84],
         [ 18, 101,  83],
         [ 35, 100,  85],
         [ 18,  99,  83],
         [ 35,  99,  85],
         [ 35, 100,  86],
         [ 18, 101,  84],
         [ 35,  99,  84],
         [ 18,  99,  82],
         [ 35,  99,  86],
         [ 18, 101,  82],
         [ 35, 100,  84],
         [ 18,  99,  84],
         [ 18, 100,  81],
         [ 35, 100,  87],
         [ 35,  99,  83],
         [ 18, 100,  85],
         [ 18,  99,  81],
         [ 18, 101,  85],
         [ 18, 101,  81],
         [ 35, 100,  83],
         [ 35,  99,  87],
         [ 18,  98,  83],
         [ 18, 102,  83],
         [ 35, 101,  85],
         [ 35,  98,  84],
         [ 18, 102,  84],
         [ 35, 101,  86]]),
 'max_std': tensor([[ 22, 107,   0],
         [ 22, 107,   1],
         [ 22, 108,   1],
         [ 22, 106,   0],
         [ 22, 110,   4],
         [ 22, 108,   2],
         [ 22, 109,   3],
         [ 22,

In [82]:
with open("/homes/o23gauvr/Documents/thèse/code/FASCINATION/pickle/profiles_of_interest_idx.pkl", "wb") as file:
    pickle.dump(profile_idx_dict, file)

In [83]:
with open("/homes/o23gauvr/Documents/thèse/code/FASCINATION/pickle/profiles_of_interest_idx.pkl", "rb") as file:
    pickle.load(file)