In [86]:
import numpy as np
import xarray as xr
from matplotlib import pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

class AE(nn.Module):
    def __init__(self, n_comps, n_coeffs):
        super(AE, self).__init__()
        self.n_coeffs = n_coeffs
        self.n_comps = n_comps
        self.D = nn.Linear(1, self.n_comps*404*404, bias=False)
        self.conv1 = nn.Conv2d(self.n_comps, 2*self.n_comps, kernel_size=5, bias=False)
        self.coeffs = nn.Linear(1, 2*self.n_comps*self.n_coeffs, bias=False)

    def forward(self, x):
        base = self.D(x)
        conv1 =  torch.tanh(self.conv1(base.reshape(1, self.n_comps, 404, 404)))
        coeffs = self.coeffs(x).view(2*self.n_comps, self.n_coeffs)
        return torch.einsum('ki,kj->ji', conv1.view(2*self.n_comps,400*400), coeffs)


def nan_mse_loss(output, target):
    loss = torch.mean((output[target == target] - target[target == target])**2)
    return loss


ds = xr.open_dataset("/data/pca_act/000_clean.nc")
ti_nan = (np.count_nonzero(np.isnan(ds.nbart_blue.values), axis=(1,2)))<.66*160000
ds = ds.isel(time=ti_nan)

stack = np.empty((0,400,400))
for fname in ds:
    band = ds[fname].values/1e4
    stack = np.append(stack, band, axis=0)

stack = stack.reshape(stack.shape[0], -1)

ncomps = 12
ncoeffs = stack.shape[0]

input = torch.ones(1, device=device)
tmean = np.nanmean(stack, axis=0)
target = torch.from_numpy(stack-tmean).float().to(device)

net = AE(ncomps, ncoeffs)
net.to(device)
optimizer = optim.AdamW(net.parameters(), lr=0.1)

epochs = 3000
# training loop:
for it in range(epochs):
    output = net(input)

    loss = nan_mse_loss(output, target)# + sparsity

    optimizer.zero_grad()   # zero the gradient buffers
    loss.backward()
    optimizer.step()    # does the update

    prev_loss = loss.item()

    if it % 100 == 0:
        print(it, loss.item(), nan_mse_loss(output, target).item())

cuda:0
0 0.734163224697113 0.734163224697113
100 0.0005438955849967897 0.0005438955849967897
200 0.00016100543143693358 0.00016100543143693358
300 0.00012727727880701423 0.00012727727880701423
400 0.00010232954809907824 0.00010232954809907824
500 8.325299131684005e-05 8.325299131684005e-05
600 7.523925160057843e-05 7.523925160057843e-05
700 6.876718543935567e-05 6.876718543935567e-05
800 6.202853546710685e-05 6.202853546710685e-05
900 5.60833650524728e-05 5.60833650524728e-05
1000 5.081657218397595e-05 5.081657218397595e-05
1100 4.695257302955724e-05 4.695257302955724e-05
1200 4.4398955651558936e-05 4.4398955651558936e-05
1300 4.26465121563524e-05 4.26465121563524e-05
1400 4.1008766856975853e-05 4.1008766856975853e-05
1500 3.925561759388074e-05 3.925561759388074e-05
1600 3.7734080251539126e-05 3.7734080251539126e-05
1700 7.869963155826554e-05 7.869963155826554e-05
1800 6.348744500428438e-05 6.348744500428438e-05
1900 4.4255804823478684e-05 4.4255804823478684e-05
2000 3.359564652782865e

KeyboardInterrupt: 

In [95]:
from torch.autograd import grad

class AE(nn.Module):
    def __init__(self, n_comps, n_coeffs):
        super(AE, self).__init__()
        self.base = torch.nn.Parameter(torch.empty(n_comps, 400*400, requires_grad=True))#.to(device)
        torch.nn.init.xavier_uniform_(self.base)
        self.coeffs = torch.nn.Parameter(torch.empty(n_coeffs, n_comps, requires_grad=True))#.to(device)
        torch.nn.init.xavier_uniform_(self.coeffs)
    def forward(self):
        return torch.mm(self.coeffs, self.base)
    
net = AE(ncomps, ncoeffs)
list(net.parameters())

[Parameter containing:
 tensor([[ 0.0003, -0.0010, -0.0033,  ...,  0.0015,  0.0048, -0.0022],
         [-0.0059,  0.0052,  0.0010,  ...,  0.0032,  0.0010,  0.0013],
         [-0.0056, -0.0002, -0.0011,  ...,  0.0025, -0.0038,  0.0007],
         ...,
         [ 0.0046, -0.0006, -0.0018,  ..., -0.0039,  0.0041,  0.0043],
         [ 0.0039, -0.0019,  0.0059,  ..., -0.0042,  0.0024,  0.0021],
         [-0.0004,  0.0035,  0.0053,  ..., -0.0056, -0.0027, -0.0057]],
        requires_grad=True),
 Parameter containing:
 tensor([[-0.0405,  0.0590,  0.0506,  ...,  0.0668, -0.0500,  0.0060],
         [-0.0047, -0.0524, -0.0117,  ...,  0.0526,  0.0260,  0.0896],
         [-0.0240, -0.0182,  0.0787,  ...,  0.0547,  0.0570, -0.0848],
         ...,
         [-0.0180, -0.0366, -0.0491,  ...,  0.0141,  0.0815,  0.0371],
         [ 0.1068,  0.0304, -0.0873,  ...,  0.1058,  0.0052,  0.0790],
         [ 0.0033,  0.0068,  0.0543,  ...,  0.0299, -0.0566, -0.0940]],
        requires_grad=True)]

In [89]:
from torch.autograd import grad

class AE(nn.Module):
    def __init__(self, n_comps, n_coeffs):
        super(AE, self).__init__()
        self.base = torch.nn.Parameter(torch.empty(n_comps, 400*400, requires_grad=True)).to(device)
        torch.nn.init.xavier_uniform_(self.base)
        self.coeffs = torch.nn.Parameter(torch.empty(n_coeffs, n_comps, requires_grad=True)).to(device)
        torch.nn.init.xavier_uniform_(self.coeffs)
    def forward(self):
        return torch.mm(self.coeffs, self.base)
    
net = AE(ncomps, ncoeffs)
#net.to(device)
target = torch.from_numpy(stack-tmean).float().to(device)

optimizer = optim.AdamW(net.parameters(), lr=0.1)
#optimizer = optim.Adam(net.parameters(), lr=0.1)


for i in range(3000):
    output = net(input)
    loss = nan_mse_loss(output, target)# + sparsity

    optimizer.zero_grad()   # zero the gradient buffers
    loss.backward()
    optimizer.step()    # does the update

    if i % 100 == 0:
        print(i, loss.item())

"""
for i in range(100):
    loss = nan_mse_loss(net(), target)
    g_base, = grad(loss, net.base)
    with torch.no_grad():
        net.base -= 10000*g_base
        
    loss = nan_mse_loss(net(), target)  
    g_coeffs, = grad(loss, net.coeffs)
    with torch.no_grad():
        net.coeffs -= 10000*g_coeffs
        
    if i%10 == 0:
        print(loss.item())
"""

ValueError: optimizer got an empty parameter list

In [78]:
for i in range(100):
    output = net()
    loss = nan_mse_loss(output, target)
    g_base, = grad(loss, net.base)
    with torch.no_grad():
        net.base -= 100*g_base
        
    output = net()
    loss = nan_mse_loss(output, target)  
    g_coeffs, = grad(loss, net.coeffs)
    with torch.no_grad():
        net.coeffs -= 100*g_coeffs
        
    if i%10 == 0:
        print(loss.item())

0.0008309215772897005
0.0008298605098389089
0.0008291483391076326
0.0008285042131319642
0.000827873358502984
0.0008272448903881013
0.0008266165386885405
0.0008259877213276923


KeyboardInterrupt: 

In [87]:
class AE2(nn.Module):
    def __init__(self, n_comps, n_coeffs):
        super(AE2, self).__init__()
        self.n_coeffs = n_coeffs
        self.n_comps = n_comps
        self.D = nn.Linear(1, self.n_comps*400*400, bias=False)
        self.coeffs = nn.Linear(1, self.n_comps*self.n_coeffs, bias=False)

    def forward(self, x):
        base = self.D(x)
        coeffs = self.coeffs(x).view(self.n_comps, self.n_coeffs)
        return torch.einsum('ki,kj->ji', base.view(self.n_comps,400*400), coeffs)


def nan_mse_loss(output, target):
    loss = torch.mean((output[target == target] - target[target == target])**2)
    return loss


net = AE2(ncomps, ncoeffs)
net.to(device)
optimizer = optim.AdamW(net.parameters(), lr=0.1)
#optimizer = optim.Adam(net.parameters(), lr=0.1)

target = torch.from_numpy(stack-tmean).float().to(device)

for i in range(3000):
    output = net(input)
    loss = nan_mse_loss(output, target)# + sparsity

    optimizer.zero_grad()   # zero the gradient buffers
    loss.backward()
    optimizer.step()    # does the update

    if i % 100 == 0:
        print(i, loss.item())

0 1.3293317556381226
100 0.0003069194790441543
200 0.00012027848424622789
300 8.65106558194384e-05
400 7.552202441729605e-05
500 6.944664346519858e-05
600 6.26373803243041e-05
700 5.832554597873241e-05
800 5.610979496850632e-05
900 5.295785376802087e-05
1000 5.0605838623596355e-05
1100 4.8676676669856533e-05
1200 4.73498112114612e-05
1300 4.697476833825931e-05
1400 4.684360101236962e-05
1500 4.676807293435559e-05
1600 4.6717213990632445e-05
1700 4.66806668555364e-05
1800 4.665327651309781e-05
1900 4.66321362182498e-05
2000 4.6615452447440475e-05
2100 4.6602017391705886e-05
2200 4.65909943159204e-05
2300 4.658179750549607e-05
2400 4.657400131691247e-05
2500 4.6567296521971e-05
2600 4.656145756598562e-05
2700 4.65563171019312e-05
2800 4.655174780054949e-05


KeyboardInterrupt: 