In [None]:
import numpy as np
import xarray as xr
from matplotlib import pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

class AE(nn.Module):
    def __init__(self, n_comps, n_coeffs):
        super(AE, self).__init__()
        self.n_coeffs = n_coeffs
        self.n_comps = n_comps
        self.D = nn.Linear(1, self.n_comps*404*404, bias=False)
        self.conv1 = nn.Conv2d(self.n_comps, 2*self.n_comps, kernel_size=5, bias=False)
        self.coeffs = nn.Linear(1, 2*self.n_comps*self.n_coeffs, bias=False)

    def forward(self, x):
        base = self.D(x)
        conv1 =  torch.tanh(self.conv1(base.reshape(1, self.n_comps, 404, 404)))
        coeffs = self.coeffs(x).view(2*self.n_comps, self.n_coeffs)
        return torch.einsum('ki,kj->ji', conv1.view(2*self.n_comps,400*400), coeffs)


def nan_mse_loss(output, target):
    loss = torch.mean((output[target == target] - target[target == target])**2)
    return loss


ds = xr.open_dataset("/data/pca_act/000_clean.nc")
ti_nan = (np.count_nonzero(np.isnan(ds.nbart_blue.values), axis=(1,2)))<.66*160000
ds = ds.isel(time=ti_nan)

stack = np.empty((0,400,400))
for fname in ds:
    band = ds[fname].values/1e4
    stack = np.append(stack, band, axis=0)

stack = stack.reshape(stack.shape[0], -1)

ncomps = 12
ncoeffs = stack.shape[0]

input = torch.ones(1, device=device)
tmean = np.nanmean(stack, axis=0)
target = torch.from_numpy(stack-tmean).float().to(device)

net = AE(ncomps, ncoeffs)
net.to(device)
optimizer = optim.AdamW(net.parameters(), lr=0.1)

epochs = 1500
# training loop:
for it in range(epochs):
    output = net(input)

    loss = nan_mse_loss(output, target)# + sparsity

    optimizer.zero_grad()   # zero the gradient buffers
    loss.backward()
    optimizer.step()    # does the update

    prev_loss = loss.item()

    if it % 100 == 0:
        print(it, loss.item(), nan_mse_loss(output, target).item())

cuda:0
0 0.753046989440918 0.753046989440918
100 0.00030304433312267065 0.00030304433312267065
200 0.00018829345935955644 0.00018829345935955644
300 0.00013131002197042108 0.00013131002197042108
400 9.09754671738483e-05 9.09754671738483e-05
500 7.90017656981945e-05 7.90017656981945e-05
600 7.052533328533173e-05 7.052533328533173e-05
700 6.422973092412576e-05 6.422973092412576e-05
800 5.803992462460883e-05 5.803992462460883e-05
900 5.349063576431945e-05 5.349063576431945e-05
1000 4.939973587170243e-05 4.939973587170243e-05
1100 4.571126191876829e-05 4.571126191876829e-05
1200 4.285776594770141e-05 4.285776594770141e-05
1300 4.0531645936425775e-05 4.0531645936425775e-05
1400 3.898710201610811e-05 3.898710201610811e-05


In [83]:
from torch.autograd import grad

class AE(nn.Module):
    def __init__(self, n_comps, n_coeffs):
        super(AE, self).__init__()
        self.base = torch.nn.init.xavier_uniform_(torch.empty(n_comps, 400*400, requires_grad=True)).to(device)
        self.coeffs = torch.nn.init.xavier_uniform_(torch.empty(n_coeffs, n_comps, requires_grad=True)).to(device)
    def forward(self):
        return torch.mm(self.coeffs, self.base)
    
net = AE(ncomps, ncoeffs)
#net.to(device)
target = torch.from_numpy(stack-tmean).float().to(device)

for i in range(100):
    loss = nan_mse_loss(net(), target)
    g_base, = grad(loss, net.base)
    with torch.no_grad():
        net.base -= 10000*g_base
        
    loss = nan_mse_loss(net(), target)  
    g_coeffs, = grad(loss, net.coeffs)
    with torch.no_grad():
        net.coeffs -= 10000*g_coeffs
        
    if i%10 == 0:
        print(loss.item())

0.006083517801016569
0.006080227438360453
0.005994681268930435
0.004141979850828648
0.0006812543724663556
0.0005685978685505688
0.0005585349281318486
0.0005523429717868567
0.000541805406101048
0.0005210671224631369


In [78]:
for i in range(100):
    output = net()
    loss = nan_mse_loss(output, target)
    g_base, = grad(loss, net.base)
    with torch.no_grad():
        net.base -= 100*g_base
        
    output = net()
    loss = nan_mse_loss(output, target)  
    g_coeffs, = grad(loss, net.coeffs)
    with torch.no_grad():
        net.coeffs -= 100*g_coeffs
        
    if i%10 == 0:
        print(loss.item())

0.0008309215772897005
0.0008298605098389089
0.0008291483391076326
0.0008285042131319642
0.000827873358502984
0.0008272448903881013
0.0008266165386885405
0.0008259877213276923


KeyboardInterrupt: 

In [68]:
class AE2(nn.Module):
    def __init__(self, n_comps, n_coeffs):
        super(AE2, self).__init__()
        self.n_coeffs = n_coeffs
        self.n_comps = n_comps
        self.D = nn.Linear(1, self.n_comps*400*400, bias=False)
        self.coeffs = nn.Linear(1, self.n_comps*self.n_coeffs, bias=False)

    def forward(self, x):
        base = self.D(x)
        coeffs = self.coeffs(x).view(self.n_comps, self.n_coeffs)
        return torch.einsum('ki,kj->ji', base.view(self.n_comps,400*400), coeffs)


def nan_mse_loss(output, target):
    loss = torch.mean((output[target == target] - target[target == target])**2)
    return loss


net = AE2(ncomps, ncoeffs)
net.to(device)
optimizer = optim.AdamW(net.parameters(), lr=0.1)
#optimizer = optim.Adam(net.parameters(), lr=0.1)

target = torch.from_numpy(stack-tmean).float().to(device)

for i in range(1500):
    output = net(input)
    loss = nan_mse_loss(output, target)# + sparsity

    optimizer.zero_grad()   # zero the gradient buffers
    loss.backward()
    optimizer.step()    # does the update

    if i % 100 == 0:
        print(i, loss.item())

0 1.353246808052063
10 0.014761357568204403
20 0.024680428206920624
30 0.011752721853554249
40 0.004274691455066204
50 0.001417412655428052
60 0.0008546735625714064
70 0.0006324619753286242
80 0.0005126555915921926
90 0.0003928365185856819
100 0.0003020345466211438
110 0.0002592696400824934
120 0.00023512639745604247
130 0.00021287819254212081
140 0.00019151439482811838
150 0.00017375645984429866
160 0.00015989817620720714
170 0.00014828499115537852
180 0.00013794457481708378
190 0.00012905357289128006
200 0.0001218087927554734
210 0.00011587271001189947
220 0.00011068936146330088
230 0.0001058629568433389
240 0.00010125763219548389
250 9.692672028904781e-05
260 9.299091470893472e-05
270 8.954280201578513e-05
280 8.660915773361921e-05
290 8.415898628300056e-05
300 8.212956163333729e-05
310 8.044940477702767e-05
320 7.905159873189405e-05
330 7.78788817115128e-05
340 7.688413461437449e-05
350 7.602925325045362e-05
360 7.528356945840642e-05
370 7.46224686736241e-05
380 7.402637129416689e-