I will attempt my filter for the first time. 

In [82]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt


In [83]:
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential( # like the Composition layer you built
            nn.Conv2d(2, 16, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 7)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 7),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 2, 3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [84]:
from viz import amv_analysis as aa
from viz import dataframe_calculators as dfc 
import datetime
import pickle 
import metpy
import numpy as np
import torch
dict_path = '../data/interim/dictionaries/dataframes.pkl'
dataframes_dict = pickle.load(open(dict_path, 'rb'))


start_date=datetime.datetime(2006,7,1,6,0,0,0)
end_date=datetime.datetime(2006,7,1,7,0,0,0)
df0 = aa.df_concatenator(dataframes_dict, start_date, end_date, False, True)

#df=df.dropna()
df=df0.loc[end_date:end_date]

lon = df.pivot('y', 'x', 'lon').values
lat = df.pivot('y', 'x', 'lat').values
u=df.pivot('y', 'x', 'u_scaled_approx').values
v=df.pivot('y', 'x', 'v_scaled_approx').values
u=np.nan_to_num(u)
v=np.nan_to_num(v)

df=df0.loc[start_date:start_date]
u0=df.pivot('y', 'x', 'u_scaled_approx').values
v0=df.pivot('y', 'x', 'v_scaled_approx').values
u0=np.nan_to_num(u0)
v0=np.nan_to_num(v0)

dx, dy = metpy.calc.lat_lon_grid_deltas(lon, lat)
dx=dx.magnitude 
mask = np.isnan(dx)
dx[mask] = np.interp(np.flatnonzero(mask), np.flatnonzero(~mask), dx[~mask])


dy=dy.magnitude 
mask = np.isnan(dy)
dy[mask] = np.interp(np.flatnonzero(mask), np.flatnonzero(~mask), dy[~mask])


mask = dx<1e-15
dx[mask] = np.interp(np.flatnonzero(mask), np.flatnonzero(~mask), dx[~mask])

mask = dy<1e-15
dy[mask] = np.interp(np.flatnonzero(mask), np.flatnonzero(~mask), dy[~mask])

dx=torch.from_numpy(dx)
dy=torch.from_numpy(dy)

hello world
concatenating dataframes for all dates for further analysis:
2006-07-01 06:00:00
2006-07-01 06:00:00
2006-07-01 07:00:00
2006-07-01 07:00:00


In [85]:

wind0=torch.rand([1,2,lon.shape[0],lon.shape[1]])
wind=torch.rand([1,2,lon.shape[0],lon.shape[1]])
wind0[0,0,:,:]=torch.from_numpy(u0)
wind0[0,1,:,:]=torch.from_numpy(v0)
wind[0,0,:,:]=torch.from_numpy(u)
wind[0,1,:,:]=torch.from_numpy(v)
print(wind.shape)
dt=3600

torch.Size([1, 2, 360, 720])


In [86]:
def diff_central_x(dx, f):
    diff=torch.zeros((f.shape[0],f.shape[1]-2))
  
    for i,_ in enumerate(f): 
        x_diff = dx[i,1:] + dx[i,:-1]
        f_diff = f[i,2:] - f[i,:-2]
        diff[i,:]= f_diff/x_diff
    difft=torch.zeros(f.shape)
    difft[:,1:-1]=diff
    return difft


def diff_central_y(dy, f):
    diff=torch.zeros((f.shape[0]-2,f.shape[1]))
  
    for i,_ in enumerate(f.T): 
        y_diff = dy[1:,i] + dy[:-1,i]
        f_diff = f[2:,i] - f[:-2,i]
        diff[:,i]= f_diff/y_diff
    difft=torch.zeros(f.shape)
    difft[1:-1,:]=diff
    return difft

def omega(x,y,u,v):
    omega=diff_central_x(x,v)-diff_central_y(y,u)
    return omega

def advection(x,y,u,v,q):
    adv=diff_central_x(x,(q*u))+diff_central_y(y,(q*v))
    return adv



def loss_physics(wind0,wind,model,dt):
    wind_cnn=wind+model.forward(wind)
    u=wind_cnn[:,0,:,:]
    v=wind_cnn[:,1,:,:]
    ug=wind[:,0,:,:]
    vg=wind[:,1,:,:]
    u=torch.squeeze(u)
    v=torch.squeeze(v)
    
    u0=wind0[:,0,:,:]
    v0=wind0[:,1,:,:]
    u0=torch.squeeze(u0)
    v0=torch.squeeze(v0)
    vor0 = omega(dx,dy,u0,v0)
    vor = omega(dx,dy,u,v)
    
    vor_adv = advection(dx,dy,u0,v0,vor0)
    print(vor_adv.mean())
    reg=torch.sqrt((u-ug)**2+(v-vg)**2)
    func=vor-vor0-(vor_adv)*dt+reg
    return abs(func.mean())

In [87]:
def train(model, num_epochs=5, batch_size=1, learning_rate=1e-3):
    torch.manual_seed(42)
    criterion = nn.MSELoss() # mean square error loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=learning_rate, 
                                 weight_decay=1e-5) # <--

    outputs = []
    
    for epoch in range(num_epochs):
        loss = loss_physics(wind0,wind,model,dt)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        print('Epoch:{}, Loss:{:.4f}'.format(epoch+1, float(loss)))
        #outputs.append((epoch, img, recon),)
    #return outputs

In [88]:
model = Autoencoder()
max_epochs = 10
outputs = train(model, num_epochs=max_epochs)

tensor(5.0250e-12)
Epoch:1, Loss:1.0253
tensor(5.0250e-12)
Epoch:2, Loss:0.9523
tensor(5.0250e-12)
Epoch:3, Loss:0.8325
tensor(5.0250e-12)
Epoch:4, Loss:0.7343
tensor(5.0250e-12)
Epoch:5, Loss:0.6578
tensor(5.0250e-12)
Epoch:6, Loss:0.6263
tensor(5.0250e-12)
Epoch:7, Loss:0.5900
tensor(5.0250e-12)
Epoch:8, Loss:0.5844
tensor(5.0250e-12)
Epoch:9, Loss:0.5819
tensor(5.0250e-12)
Epoch:10, Loss:0.5805


In [89]:
model.forward(wind)

tensor([[[[4.2498e-01, 3.9790e-01, 2.8350e-01,  ..., 4.4009e-01,
           3.7090e-01, 4.5737e-01],
          [2.8784e-01, 3.6262e-01, 2.1448e-01,  ..., 5.6864e-01,
           3.3247e-01, 4.7843e-01],
          [3.3643e-01, 3.0184e-01, 1.8331e-01,  ..., 4.3169e-01,
           3.4845e-01, 4.9109e-01],
          ...,
          [1.7107e-01, 1.4779e-02, 1.1098e-01,  ..., 2.2301e-01,
           4.2695e-01, 3.1059e-01],
          [3.7019e-01, 4.2646e-01, 5.2093e-02,  ..., 5.9807e-01,
           2.0361e-01, 5.5275e-01],
          [4.4709e-01, 1.7522e-01, 1.5941e-01,  ..., 2.5414e-01,
           3.7213e-01, 5.4528e-01]],

         [[4.7283e-01, 4.5406e-01, 4.3002e-01,  ..., 4.7897e-01,
           4.0839e-01, 4.8283e-01],
          [4.0792e-01, 1.2313e-01, 5.4220e-01,  ..., 1.6317e-01,
           4.4263e-01, 3.3837e-01],
          [4.3434e-01, 3.8571e-01, 2.3727e-01,  ..., 3.6027e-01,
           3.8962e-01, 5.0020e-01],
          ...,
          [3.4260e-01, 3.6782e-04, 4.5531e-02,  ..., 3.1958