---
# Predict using pretrained DGMR
---

In [11]:
from datetime import datetime

import numpy as np
import geoviews as gv
import holoviews as hv
import xarray as xr
import hvplot.xarray
import torch
from dgmr import DGMR, Sampler, Generator, Discriminator, LatentConditioningStack, ContextConditioningStack

from src.dataio import getDgmrDataset
from src.util import plot

## Load Sample Dataset

In [12]:
# dt = datetime(2018, 1, 22, 1, 0)
# dt = datetime(2019, 6, 23, 10, 0)
# dt = datetime(2018, 7, 6, 0, 0)
dt = datetime(2020, 10, 8, 3, 0)

In [13]:
inp, out = getDgmrDataset(dt)

In [14]:
inp = inp.fillna(0)
imgs = torch.tensor(inp.to_array().values).transpose(0,1).unsqueeze(0)

## Load DGMR model

In [15]:
model = DGMR.from_pretrained("openclimatefix/dgmr")
sampler = Sampler.from_pretrained("openclimatefix/dgmr-sampler")
discriminator = Discriminator.from_pretrained("openclimatefix/dgmr-discriminator")
latent_stack = LatentConditioningStack.from_pretrained("openclimatefix/dgmr-latent-conditioning-stack")
context_stack = ContextConditioningStack.from_pretrained("openclimatefix/dgmr-context-conditioning-stack")
generator = Generator(conditioning_stack=context_stack, latent_stack=latent_stack, sampler=sampler)

## Predict

In [16]:
model.eval()
with torch.no_grad():
    pred = model(imgs)

pred = pred.squeeze()



In [17]:
center = 734
predict = xr.zeros_like(out)

In [19]:
predict = xr.Dataset({'data': (['time', 'lat', 'lon'], pred.numpy())},
                coords={'time':predict.time.data, 'lat':predict.lat.data, 'lon':predict.lon.data})

## Result

In [20]:
plot(predict, class_num=40, title='Predict') 

In [21]:
plot(out, class_num=40, title='Ground Truth')