In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext autoreload
%pylab inline
import xarray as xr
from tqdm.autonotebook import tqdm
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback

# Set a bigger default plot size
mpl.rcParams['figure.figsize'] = (10, 8)
mpl.rcParams['font.size'] = 16

from hydrogen_pg.dataloaders.taylor_example_dataloader import Conv2dDataset
from hydrogen_pg.dataloaders.taylor_example_dataloader import Conv2dDataModule
from hydrogen_pg.models.taylor_example_model import RMM_NN_2D_B1
from hydrogen_pg.utils.callbacks import MetricsCallback

In [None]:
#pfmeta_file = '/hydrodata/PFCLM/Taylor/Simulations/1990/Taylor_1990.out.pfmetadata'

In [None]:
pfmeta_file = '/home/ab6361/small_CONUS1_2003_fake.out.pfmetadata'
in_vars = ['precipitation', 'temperature', 'saturation']
out_vars = ['saturation']

# Surface saturation
z_strategy = -1
patch_sizes = {'x': 50, 'y': 50}

dataset = Conv2dDataset(
    pfmeta_file, 
    in_vars=in_vars, 
    out_vars=out_vars, 
    z_strategy=z_strategy,
    patch_sizes=patch_sizes,
)

In [None]:
sample_inds = np.random.randint(0, len(dataset), size=9)
to_np = lambda x: x.squeeze().detach().numpy()

fig, axes = plt.subplots(3, 3, sharex=True, sharey=True, figsize=(12, 12))
axes = axes.flatten()
for idx, ax in tqdm(zip(sample_inds, axes)):
    x, y = dataset[idx]
    ax.imshow(to_np(y), vmin=0, vmax=1)
    ax.axis('off')
plt.tight_layout()

In [None]:
datamodule = Conv2dDataModule(pfmeta_file, 
                              in_vars=in_vars, 
                              out_vars=out_vars, 
                              z_strategy=z_strategy)

In [None]:
model = RMM_NN_2D_B1(grid_size=datamodule.shape,
                     in_vars=in_vars,
                     out_vars=out_vars)
model.configure_optimizers()
model.configure_loss()

## Validate dataloader and model can operate together

Perhaps this goes in `utils` as `validate(model, dataloader)`

In [None]:
data_shape  = datamodule.shape
data_in_features, data_out_features = datamodule.feature_names
model_shape = model.shape
model_in_features, model_out_features = model.feature_names

assert data_shape == model_shape
assert data_in_features == model_in_features
assert data_out_features == model_out_features

In [None]:
metrics = MetricsCallback()
trainer = pl.Trainer(max_epochs=10, gpus=1, callbacks=[metrics])
trainer.fit(model, datamodule)

In [None]:
plt.plot(metrics.metrics['train_loss'], label='Train loss')
plt.plot(metrics.metrics['val_loss'], label='Validation loss')
plt.legend()
plt.xlabel('Epoch #')
plt.ylabel('MSE Loss')
plt.ylim([0.00, 0.02])

In [None]:
val_x, val_y = next(iter(dataloader.val_dataloader()))
val_x = val_x[0:1]
val_y = val_y[0:1].squeeze().cpu().detach().numpy()

val_yhat = model(val_x)
val_yhat = val_yhat.squeeze().cpu().detach().numpy()
err = val_y - val_yhat

In [None]:
fig, axes = plt.subplots(1, 5, figsize=(22, 7),
                         gridspec_kw={'width_ratios': [1, 1, 0.1, 1, 0.1], 'height_ratios': [1]})

sm = axes[0].imshow(val_y, vmin=0, vmax=1)
axes[1].imshow(val_yhat, vmin=0, vmax=1)
plt.colorbar(sm, cax=axes[2])
sm = axes[3].imshow(err, cmap='coolwarm_r')
plt.colorbar(sm, cax=axes[-1])

for ax in axes[[0,1,3]]:
    ax.axis('off')
axes[0].set_title('True Saturation')
axes[1].set_title('Predicted Saturation')
axes[3].set_title('Error')