In [1]:
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=2


# Datamodule

In [None]:
import sys
sys.path.append('..')
from src.data_notebook import *
from src.utils import *
from src.models import *
import matplotlib.pyplot as plt
import torch
import itertools

lon_min = -66 
lon_max = -54
lat_min = 32
lat_max = 44

datadir="/dmidata/users/maxb/NATL_dataset"
datamodule = BaseDataModule(input_da=load_altimetry_data(datadir+"/natl_gf_w_5nadirs_swot.nc"),
                            domains={'train': {'time': slice('2013-02-24', '2013-09-30',)},
                                     'val': {'time': slice('2012-12-15', '2013-02-24',)},
                                     'test': {'time': slice('2012-10-01', '2012-12-20',)}},
                            xrds_kw={'patch_dims': {'time': 29, 'lat': 240, 'lon': 240},
                                     'strides': {'time': 1, 'lat': 240, 'lon': 240},
                                     'domain_limits':  dict(lon=slice(lon_min, lon_max,),
                                                            lat=slice(lat_min, lat_max,))},
                            dl_kw={'batch_size': 2, 'num_workers': 1})
datamodule.setup()

# Solver

In [None]:
device = 'cuda'

# state only
solver = GradSolver(n_step=15,
                    lr_grad=1e-3,
                    prior_cost = BilinAEPriorCost(dim_in=29,dim_hidden=64,
                                                    bilin_quad=False,downsamp=2),
                    obs_cost = BaseObsCost(),
                    grad_mod = ConvLstmGradModel(dim_in=29,dim_hidden=96)).to(device)

# Training

In [5]:
from IPython.display import clear_output
from torch.optim import Adam
from tqdm.autonotebook import tqdm

rec_weight = get_last_time_wei(patch_dims =  datamodule.xrds_kw['patch_dims'],
                                 crop= {'time': 0, 'lat': 50, 'lon': 50}, offset=1)
rec_weight = torch.from_numpy(rec_weight).to(device)

optim_weight = get_linear_time_wei(patch_dims = {'time': datamodule.xrds_kw['patch_dims']['time'],
                                                'lat': datamodule.xrds_kw['patch_dims']['lat'],
                                                'lon': datamodule.xrds_kw['patch_dims']['lon']},
                                 crop= {'time': 0, 'lat': 50, 'lon': 50}, offset=1)
optim_weight = torch.from_numpy(optim_weight).to(device)

def step(batch):
    if batch.tgt.isfinite().float().mean() < 0.05:
        return None, None

    loss, out = base_step(batch)
    
    grad_loss = weighted_mse(kfilts.sobel(out) - kfilts.sobel(batch.tgt), optim_weight)
    prior_cost = solver.prior_cost(solver.init_state(batch, out))
    training_loss = 50*loss  + 10000 * grad_loss + 10 * prior_cost
    print(50*loss, 10000 * grad_loss, 10 * prior_cost)
    
    return training_loss, out

def base_step(batch):

    out = solver(batch=batch).to(device)
    # mse loss
    loss = weighted_mse(out - batch.tgt, optim_weight)
    return loss, out
    
def weighted_mse(err, weight):
    err_w = err * weight[None, ...]
    non_zeros = (torch.ones_like(err) * weight[None, ...]) == 0.0
    err_num = err.isfinite() & ~non_zeros
    if err_num.sum() == 0:
        return torch.scalar_tensor(1000.0, device=err_num.device).requires_grad_()
    loss = F.mse_loss(err_w[err_num], torch.zeros_like(err_w[err_num]))
    return loss

In [None]:
#@title Training (double click to expand or collapse)

from IPython.display import clear_output
from torch.optim import Adam
from tqdm.autonotebook import tqdm

n_epochs =  300#@param {'type':'integer'}
lr = 1e-3 #@param {'type':'number'}
    
optimizer = Adam(
        [
            {"params": solver.parameters(), "lr": lr},
        ])
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[50, 100, 150, 200, 250], gamma=0.5)

tqdm_epoch = tqdm(range(n_epochs))
solver.training = True
for epoch in tqdm_epoch:
  for batch in data_loader:
    batch = batch._replace(input=(batch.input).nan_to_num().to(device))
    loss, out = step(batch)
    optimizer.zero_grad() 
    if loss is not None:
        loss.backward()
    optimizer.step()
  # Update the checkpoint after each epoch of training.
  torch.save(score_model.state_dict(), 'ckpt/ckpt_4dvarnet_base.pth')
  scheduler.step()    

# Test of the model

In [None]:
# test
from IPython.display import clear_output
ckpt = torch.load('ckpt/ckpt_4dvarnet_base.pth', map_location=device)
solver.load_state_dict(ckpt)

k = 1
test_batch = next(itertools.islice(datamodule.test_dataloader(), k, None))

xr.Dataset(data_vars={'ssh':(('time','lat','lon'),test_batch.input[0,:,:,:].detach().cpu())},
           coords={'time':np.arange(29),
                   'lon':np.arange(-66, -54, 0.1),
                   'lat':np.arange(32, 44, 0.1)}).ssh.plot(col='time',col_wrap=10)
xr.Dataset(data_vars={'ssh':(('time','lat','lon'),test_batch.tgt[0,:,:,:].detach().cpu())},
           coords={'time':np.arange(29),
                   'lon':np.arange(-66, -54, 0.1),
                   'lat':np.arange(32, 44, 0.1)}).ssh.plot(col='time',col_wrap=10)

test_batch = test_batch._replace(input=test_batch.input.nan_to_num().to(device))
out = solver(test_batch)

xr.Dataset(data_vars={'ssh':(('time','lat','lon'),out[0,:,:,:].detach().cpu())},
           coords={'time':np.arange(29),
                   'lon':np.arange(-66, -54, 0.1),
                   'lat':np.arange(32, 44, 0.1)}).ssh.plot(col='time',col_wrap=10)