In [1]:
import torch.nn.functional as F
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
from tqdm.notebook import tqdm
from helpers import *
from matplotlib.cm import ScalarMappable
from joblib import Parallel, delayed, parallel_backend
import multiprocessing
import plotly.express as px
import plotly.graph_objects as go
import plotly.figure_factory as ff
import matplotlib as mpl
import datetime
from math import atan, pi

%matplotlib inline

device = 'cpu'
NAN_VAL = -100
WINDOW_SIZE = (31, 31)
VICINITY_SIZE = (80, 80)
dX = 500
dY = 500
STEP = 10
X0 = 190
Y0 = 940
dT = 22810
DISTANCE_QUANTILE = 0.65

In [2]:
b0, data1 = parse('20060504_072852_NOAA_12.m.pro')
data1 = data1.astype(float)
data1[data1 < 0] = -100

In [3]:
b0, data2 = parse('20060504_125118_NOAA_17.m.pro')
data2 = data2.astype(float)
data2[data2 < 0] = -100

In [4]:
point_coors = generate_points(data1, X0, Y0, dX, dY, STEP, WINDOW_SIZE, NAN_VAL)

In [5]:
data1 = torch.tensor(data1).to(device)
data2 = torch.tensor(data2).to(device)

In [None]:
tmp_point_coors = []
new_coors = []
mask = []

for point_coor in tqdm(point_coors):
    try:
        idx, velocity, mask = inference(data1, data2, b0, dT, point_coor, 
                  WINDOW_SIZE, VICINITY_SIZE, ssim,
                  device, 'max', 'pix',  None, None, .5)
        tmp_point_coors.append(point_coor)
        new_coors.append()
    except:
        continue

In [None]:
gv_idx = (filtered_scores > .5)*(loss < np.quantile(loss, DISTANCE_QUANTILE))

In [None]:
distances = np.array([calculate_distance(b0, point_coors[i], new_coors[i]) for i in range(len(new_coors))])
velocities = distances/dT
max_velocity = np.max(velocities[gv_idx])
quant_velocity = np.quantile(velocities[gv_idx], .95)
min_velocity = np.min(velocities[gv_idx])
mean_velocity = np.mean(velocities[gv_idx])
median_velocity = np.median(velocities[gv_idx])

In [None]:
annotations = []
for i in range(len(new_coors)):
    
    if not gv_idx[i]:
        continue 
    
    normalized_velocity = (velocities[i])/max(velocities)
    if normalized_velocity < 0.5:
        c = colorFader(px.colors.sequential.Inferno[0], px.colors.sequential.Inferno[4], normalized_velocity/.5)
    else:
        try:
            c = colorFader(px.colors.sequential.Inferno[4], px.colors.sequential.Inferno[-1], (normalized_velocity-0.5)/0.5)
        except:
            print((normalized_velocity-0.5)/0.5)
            print(normalized_velocity)
    
    annotations.append(dict(
        x= new_coors[i, 1],
        y= new_coors[i, 0],
        xref="x", yref="y",
        text="",
        showarrow=True,
        axref = "x", ayref='y',
        ax= point_coors[i, 1],
        ay= point_coors[i, 0],
        arrowhead=1,
        arrowwidth=1,
        arrowcolor=c,
        hovertext=velocities[i]
    ))

In [None]:
fig = px.imshow(data1, width=800, height=800, color_continuous_scale='gray', zmin=500, zmax=719,
        title=f'Max velocity: {max_velocity:.4f} | Mean velocity: {mean_velocity:.4f} | Median velocity: {median_velocity:.4f}')

fig.update_layout(coloraxis_showscale=False)
fig.update_xaxes(showticklabels=False)
fig.update_yaxes(showticklabels=False)

fig.update_layout(annotations=annotations)
colorbar_trace  = go.Scatter(x=[None],
                             y=[None],
                             mode='markers',
                             marker=dict(
                                 colorscale='inferno', 
                                 showscale=True,
                                 cmin=1,
                                 cmax=0,
                                 colorbar=dict(thickness=30, tickvals=[0, 0.5, 1], outlinewidth=1)
                             ),
                             hoverinfo='none'
                            )
fig.add_trace(colorbar_trace)
fig.show()

In [101]:
import pickle

with open('10px_best_res_coors.pickle', 'rb') as f:
    new_coors = pickle.load(f)
with open('10px_best_point_coors.pickle', 'rb') as f:
    point_coors = pickle.load(f)
with open('10px_best_scores.pickle', 'rb') as f:
    scores = pickle.load(f)

In [37]:
import pickle

with open('10px_best_res_coors.pickle', 'wb') as f:
    pickle.dump(new_coors, f)
with open('10px_best_point_coors.pickle', 'wb') as f:
    pickle.dump(point_coors, f)
with open('10px_best_scores.pickle', 'wb') as f:
    pickle.dump(scores, f)
with open('10px_best_loss.pickle', 'wb') as f:
    pickle.dump(distance_score, f)