I will attempt my filter for the first time. 

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import datasets, transforms

In [2]:
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential( # like the Composition layer you built
            nn.Conv2d(2, 16, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 7)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 7),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 2, 3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [3]:
def diff_central_x(x, f):
    diff=torch.zeros((x.shape[0],x.shape[1]-2))
  
    for i,_ in enumerate(f): 
        x_diff = x[i,2:] - x[i,:-2]
        f_diff = f[i,2:] - f[i,:-2]
        diff[i,:]= f_diff/x_diff
    difft=torch.zeros(x.shape)
    difft[:,1:-1]=diff
    return difft


def diff_central_y(y, f):
    diff=torch.zeros((x.shape[0]-2,x.shape[1]))
  
    for i,_ in enumerate(f.T): 
        y_diff = y[2:,i] - y[:-2,i]
        f_diff = f[2:,i] - f[:-2,i]
        diff[:,i]= f_diff/y_diff
    difft=torch.zeros(x.shape)
    difft[1:-1,:]=diff
    return difft

def omega(x,y,u,v):
    omega=diff_central_x(x,u)-diff_central_y(y,v)
    return omega

def advection(x,y,u,v,q):
    adv=diff_central_x(x,(q*u))+diff_central_y(y,(q*v))
    return adv



def loss_physics(wind0,wind,model,dt):
    wind_cnn=wind+model.forward(wind)
    u=wind_cnn[:,0,:,:]
    v=wind_cnn[:,1,:,:]
    u=torch.squeeze(u)
    v=torch.squeeze(v)
    
    u0=wind0[:,0,:,:]
    v0=wind0[:,1,:,:]
    u0=torch.squeeze(u0)
    v0=torch.squeeze(v0)
    vor0 = omega(x,y,u0,v0)
    vor = omega(x,y,u,v)
    
    vor_adv = advection(x,y,u0,v0,vor0)
    print(vor_adv.mean())
    reg=torch.sqrt((u-u0)**2+(v-v0)**2)
    func=vor-vor0-(vor_adv)*dt+reg
    return abs(func.mean())

In [8]:
x=torch.rand([640*2,640*2])
y=torch.rand([640*2,640*2])
wind0=torch.rand([1,2,640*2,640*2])
wind=torch.rand([1,2,640*2,640*2])
dt=3600

In [5]:
def train(model, num_epochs=5, batch_size=1, learning_rate=1e-3):
    torch.manual_seed(42)
    criterion = nn.MSELoss() # mean square error loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=learning_rate, 
                                 weight_decay=1e-5) # <--

    outputs = []
    
    for epoch in range(num_epochs):
        loss = loss_physics(wind0,wind,model,dt)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        print('Epoch:{}, Loss:{:.4f}'.format(epoch+1, float(loss)))
        #outputs.append((epoch, img, recon),)
    #return outputs

In [9]:
model = Autoencoder()
max_epochs = 10
outputs = train(model, num_epochs=max_epochs)

tensor(nan)
Epoch:1, Loss:nan
tensor(nan)


KeyboardInterrupt: 

In [7]:
model.forward(wind)

tensor([[[[0.4396, 0.3880, 0.4309,  ..., 0.4404, 0.4448, 0.4424],
          [0.3610, 0.3762, 0.4376,  ..., 0.4170, 0.4555, 0.5002],
          [0.3064, 0.4183, 0.5026,  ..., 0.4745, 0.4523, 0.4399],
          ...,
          [0.4437, 0.3390, 0.3876,  ..., 0.4283, 0.4687, 0.5027],
          [0.4106, 0.3891, 0.4320,  ..., 0.4711, 0.4370, 0.4491],
          [0.4473, 0.4509, 0.5227,  ..., 0.4565, 0.4625, 0.4658]],

         [[0.4891, 0.5350, 0.4642,  ..., 0.5096, 0.4822, 0.5032],
          [0.5327, 0.3707, 0.5835,  ..., 0.4489, 0.4928, 0.5172],
          [0.4450, 0.4923, 0.4487,  ..., 0.5069, 0.4886, 0.5116],
          ...,
          [0.5182, 0.4341, 0.5645,  ..., 0.4932, 0.5014, 0.4999],
          [0.4886, 0.5036, 0.4380,  ..., 0.5049, 0.4943, 0.5229],
          [0.5114, 0.4511, 0.5310,  ..., 0.4960, 0.5096, 0.5044]]]],
       grad_fn=<SigmoidBackward>)