## Gulf of Mexico vortex

### Importing packages

In [1]:
import numpy as np

import pyro.contrib.gp as gp
from SBIRR.SDE_solver import solve_sde_RK
from SBIRR.utils import plot_trajectories_2
import matplotlib.pyplot as plt
from SBIRR.Schrodinger import gpIPFP, SchrodingerTrain, gpdrift, nndrift
from SBIRR.GP import spatMultitaskGPModel
from SBIRR.gradient_field_NN import train_nn_gradient, GradFieldMLP2
import torch
import torch.nn as nn
import sys
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import math
import time
import csv

from SBIRR.experiments_helper import get_settings # in ../package/SBIRR/experiments_helper.py

## set up and load data

In [5]:
task_name = "GoMvortex" # or "GoMvortexsmall" 
# load all data
Xs_val = np.load("./data/"+task_name+"_data.npy", allow_pickle = True)
Xs_val = [torch.tensor(X).to(device) for X in Xs_val]
n_snap = len(Xs_val)
Xs = [Xs_val[2 * i] for i in range(int(n_snap/2)+1)]
N_steps = len(Xs)
# take SDE seetings and reference family
sigma, dt, N, dts, oursnndrift, vanillanndrift = get_settings(task_name, N_steps)

In [6]:
torch.manual_seed(0)

<torch._C.Generator at 0x7fd2aac64b30>

### running our method

In [12]:
def gpIPFPmaker(ref_drift = None, N = N, device = device):
        return gpIPFP(
                  ref_drift = ref_drift, 
                  N = N, device = device)
oursSchrodinger_train = SchrodingerTrain(Xs, dts, sigma)

# IRR
start_time = time.time()
_, backnforth = oursSchrodinger_train.iter_drift_fit(gpIPFPmaker,
                                                        oursnndrift,
                                                        ipfpiter = 10, 
                                                        iteration = 10)
end_time = time.time()
elapsed_time_irr = end_time - start_time

### vanilla-SB

In [13]:
# Naive
start_time = time.time()
_, backnforth_vSB = oursSchrodinger_train.iter_drift_fit(gpIPFPmaker,
                                                        vanillanndrift,
                                                        ipfpiter = 10, 
                                                        iteration = 1)
end_time = time.time()
elapsed_time_naive = end_time - start_time

### earth mover distance

In [None]:
from TrajectoryNet.optimal_transport.emd import earth_mover_distance
i = 1
[earth_mover_distance(Xs_val[2 * i + 1], backnforth[:,int((backnforth.shape[1]-1)/(len(Xs_val)-1)) * (2 * i + 1),:-1]), \
earth_mover_distance(Xs_val[2 * i + 1], backnforth_vSB[:,int((backnforth.shape[1]-1)/(len(Xs_val)-1)) * (2 * i + 1),:-1])]

### visualizing

Visualizing ours

In [None]:
for i in range(backnforth.shape[0]): # using this for range to plot the same amount of trajectories as the baseline
    plt.plot(backnforth[i,:,0], backnforth[i,:,1], color="black", alpha = 0.1, lw=.8)
for i in range(Xs_val.shape[0]):
    color = cmap(i/Xs_val.shape[0])
    if (i % 2) != 0:
        plt.scatter(Xs_val[i][:,0], Xs_val[i][:,1], color=color, s=3, marker='x', edgecolors=color, lw=0.5)
    if (i % 2) != 1:
        plt.scatter(Xs_val[i][:,0], Xs_val[i][:,1], color=color, marker='o', edgecolors=color, s=3, lw=0.5)
# Add legend
#cax = fig.add_axes([0.91,0.11,0.01,0.77])
#fig.colorbar(cm.ScalarMappable(cmap=cmap, norm=matplotlib.colors.Normalize(vmin=0, vmax=10)), label="Time", cax=cax)

# Add legend markers
training_data = matplotlib.lines.Line2D([], [], color='black', marker='o', linestyle='None',
                          markersize=3, label='Training data')
val_data = matplotlib.lines.Line2D([], [], color='black', marker='x', linestyle='None',
                          markersize=3, label='Validation data')


In [None]:
for i in range(backnforth_vSB.shape[0]): # using this for range to plot the same amount of trajectories as the baseline
    plt.plot(backnforth_vSB[i,:,0], backnforth_vSB[i,:,1], color="black", alpha = 0.1, lw=.8)
for i in range(Xs_val.shape[0]):
    color = cmap(i/Xs_val.shape[0])
    if (i % 2) != 0:
        plt.scatter(Xs_val[i][:,0], Xs_val[i][:,1], color=color, s=3, marker='x', edgecolors=color, lw=0.5)
    if (i % 2) != 1:
        plt.scatter(Xs_val[i][:,0], Xs_val[i][:,1], color=color, marker='o', edgecolors=color, s=3, lw=0.5)
# Add legend
#cax = fig.add_axes([0.91,0.11,0.01,0.77])
#fig.colorbar(cm.ScalarMappable(cmap=cmap, norm=matplotlib.colors.Normalize(vmin=0, vmax=10)), label="Time", cax=cax)

# Add legend markers
training_data = matplotlib.lines.Line2D([], [], color='black', marker='o', linestyle='None',
                          markersize=3, label='Training data')
val_data = matplotlib.lines.Line2D([], [], color='black', marker='x', linestyle='None',
                          markersize=3, label='Validation data')