In [1]:
import torch.nn.functional as F
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
from tqdm.notebook import tqdm

sys.path.append("..")

from gisalgo import *
from helpers import *

import plotly.express as px
import plotly.graph_objects as go
import plotly.figure_factory as ff
import matplotlib as mpl
import datetime
from math import atan, pi
from scipy.optimize import minimize
import pickle
%matplotlib inline

device = 'cpu'
NAN_VAL = -100
WINDOW_SIZE = (15, 15)
VICINITY_SIZE = (40, 40)
dX = 500
dY = 500
STEP = 10
X0 = 190
Y0 = 940
dT = 22810
EPOCHS = 10

In [2]:
b0, data1 = parse('20060504_072852_NOAA_12.m.pro')
data1 = data1.astype(float)
data1[data1 < 0] = -100

In [3]:
b0, data2 = parse('20060504_125118_NOAA_17.m.pro')
data2 = data2.astype(float)
data2[data2 < 0] = -100

In [4]:
point_coors = generate_points(data1, X0, Y0, dX, dY, STEP, WINDOW_SIZE, NAN_VAL)

In [5]:
data1 = torch.tensor(data1).to(device)
data2 = torch.tensor(data2).to(device)

In [6]:
def swish(x):
    return x*torch.sigmoid(x)

class Net(nn.Module):
    def __init__(self, in_features: int, out_features: int) -> None:
        super().__init__()
        self.fc1 = nn.Linear(in_features, out_features)
        self.fc2 = nn.Linear(in_features, out_features)
        self.fc_out = nn.Linear(out_features*3, 1)
        self.in_features = in_features
        self.out_features = out_features

    def forward(self, x: torch.FloatTensor, y: torch.FloatTensor) -> torch.FloatTensor:
        x = x.reshape(1, -1).float()*torch.ones(y.shape)
        y = y.float()
        x = swish(self.fc1(x))
        y = swish(self.fc2(y))
        xy = torch.cat((x, y, torch.abs(x-y)), 1)
        xy = self.fc_out(xy)
        return xy.view(-1)

In [7]:
metric = Net(31**2, 512)

In [8]:
with open('metric.pickle', 'wb') as f:
    pickle.dump(metric, f)

In [9]:
point_coors = generate_points(data1, X0, Y0, dX, dY, STEP, WINDOW_SIZE, NAN_VAL)

In [10]:
loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(metric.parameters(), lr=1e-3)
EPOCHS = 10

In [11]:
for epoch in tqdm(range(EPOCHS)):
    losses = []
    y_pred = None
    y_true = None
    
    for i, point_coor in enumerate(tqdm(point_coors)):
        # Forward pass
        idx_forward, scores_forward = backforward_find(data1, data2, point_coor, 
                              WINDOW_SIZE, VICINITY_SIZE, metric,
                              device)
        best_forward_point = idx_forward[torch.argmax(scores_forward)]

        # Backward pass
        idx_backward, scores_backward = backforward_find(data2, data1, best_forward_point, 
                              WINDOW_SIZE, VICINITY_SIZE, metric,
                              device)

        try:
            ground_truth = idx_backward.index(point_coor)
            if y_pred is None:
                y_pred = scores_backward
                y_true = F.one_hot(torch.tensor([ground_truth]), len(idx_backward)).view(-1).float()
            else:
                y_pred = torch.cat((y_pred, scores_backward), 0)
                y_true = torch.cat((y_true, F.one_hot(torch.tensor([ground_truth]), len(idx_backward)).view(-1).float()), 0)
        except:
            pass
        if (i+1)%32 == 0 or i+1 == len(point_coors) and not y_pred is None:
            loss = loss_fn(y_pred, y_true)
            loss.backward()
            losses.append(loss.detach().item())
            optimizer.step()
            optimizer.zero_grad()
            y_pred = None
            y_true = None
            print(losses[-1])
    print(sum(losses)/len(losses))

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/2401 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [32]:
y_true.shape

AttributeError: 'NoneType' object has no attribute 'shape'

In [13]:
tmp_point_coors = []
new_coors = []
mask = []
losses = []
velocities = []

for point_coor in tqdm(point_coors):
    try:
        idx, velocity, lss, msk = inference(data1, data2, b0, dT, point_coor, 
                      WINDOW_SIZE, VICINITY_SIZE, metric,
                      device, 'max', 'pix',  coefs=None, bef=-1, sf=.5)
        tmp_point_coors.append(point_coor)
        new_coors.append(idx)
        mask.append(msk)
        losses.append(lss)
        velocities.append(velocity)
    except AssertionError as e:
        continue

  0%|          | 0/2401 [00:00<?, ?it/s]

TypeError: forward() takes 3 positional arguments but 4 were given

In [None]:
metric