In [1]:
import openet.ssebop as model
import ee
import pprint
from IPython.display import Image
import io
import requests
import numpy as np
import zipfile
import requests
import pandas as pd
import pvcz
import torch
import xarray as xr
from datetime import date, datetime, timedelta
import math
import torch.nn as nn
from torch.nn.functional import avg_pool2d, interpolate
import os
from os import path
import torch.nn.functional as F
from torch.autograd import Variable
import rasterio as rio

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
def get_elevation(lon, lat, units='Meters', output='json'):
    
    URL = 'https://nationalmap.gov/epqs/pqs.php?'
    PARAMS = {'x':str(lon),
              'y':str(lat),
              'units':units,
              'output':output}
  
    return requests.get(url = URL, params = PARAMS).json()['USGS_Elevation_Point_Query_Service']['Elevation_Query']['Elevation']

In [4]:
def get_koppen(lon, lat):
    clims = ['DFB', 'BWK', 'CFA', 'CWA', 'DWB', 'DFC', 'DFA', 'BSK', 'CSA', 'BSH']
    URL = 'http://climateapi.scottpinkelman.com/api/v1/location/' + str(lat) + '/' + str(lon)
    return torch.nn.functional.one_hot(torch.tensor(clims.index(requests.get(url = URL).json()["return_values"][0]['koppen_geiger_zone'].upper())), num_classes=len(clims))

In [5]:
# ee.Authenticate()

In [6]:
ee.Initialize()

In [7]:
lc = ee.ImageCollection('LANDSAT/LC08/C02/T1_L2')
gridmet = ee.ImageCollection("IDAHO_EPSCOR/GRIDMET")

In [8]:
i_date = '2020-05-01'

# Final date of interest (exclusive).
f_date = '2020-09-30'

# Selection of appropriate bands and dates for LST.
lc = lc.filterDate(i_date, f_date)

scale = 1000  # scale in meters

In [9]:
# site1 = (-110.8661, 31.8214)
site1 = (-105.0000, 40.6525)
# site1 = (-108.6730, 37.2246)
# site1 = (-102.3020, 39.7312)
site1point = ee.Geometry.Rectangle(site1[0], site1[1], site1[0], site1[1])
# site2point = ee.Geometry.Point(site2[0], site2[1])
# site3point = ee.Geometry.Point(site3[0], site3[1])

In [10]:
lc_r_poi = lc.getRegion(site1point, scale).getInfo()
id_list = [x[0] for x in lc_r_poi[1:]]
id_list

['LC08_033032_20200508',
 'LC08_033032_20200524',
 'LC08_033032_20200609',
 'LC08_033032_20200625',
 'LC08_033032_20200711',
 'LC08_033032_20200727',
 'LC08_033032_20200812',
 'LC08_033032_20200828',
 'LC08_033032_20200913',
 'LC08_033032_20200929',
 'LC08_034032_20200515',
 'LC08_034032_20200531',
 'LC08_034032_20200616',
 'LC08_034032_20200702',
 'LC08_034032_20200718',
 'LC08_034032_20200803',
 'LC08_034032_20200819',
 'LC08_034032_20200904',
 'LC08_034032_20200920']

In [11]:
img_id = 'LC08_033032_20200929'

In [12]:
dats = [x[0].split("_")[-1] for x in lc_r_poi[1:]]
dats = [x[0].split("_")[-1][:4] + "_" + x[0].split("_")[-1][4:6] + "_" + x[0].split("_")[-1][6:] for x in lc_r_poi[1:]]
dats.sort()
dats

['2020_05_08',
 '2020_05_15',
 '2020_05_24',
 '2020_05_31',
 '2020_06_09',
 '2020_06_16',
 '2020_06_25',
 '2020_07_02',
 '2020_07_11',
 '2020_07_18',
 '2020_07_27',
 '2020_08_03',
 '2020_08_12',
 '2020_08_19',
 '2020_08_28',
 '2020_09_04',
 '2020_09_13',
 '2020_09_20',
 '2020_09_29']

In [13]:
ndvi_palette = ['#000000', '#FFFFFF']
# ndvi_palette = ['#EFE7E1', '#003300']
et_palette = [
    'DEC29B', 'E6CDA1', 'EDD9A6', 'F5E4A9', 'FFF4AD', 'C3E683', '6BCC5C', 
    '3BB369', '20998F', '1C8691', '16678A', '114982', '0B2C7A']
viridis_palette = ['440154', '433982', '30678D', '218F8B', '36B677', '8ED542', 'FDE725']


image_size = 768

In [14]:
date = img_id[-8:-4] + "_" + img_id[-4:-2] + "_" + img_id[-2:]
os.makedirs(date, exist_ok=True)

In [15]:
landsat_img = ee.Image('LANDSAT/LC08/C02/T1_L2/' + img_id)
landsat_region = site1point.buffer(225).bounds()
landsat_region2 = site1point.buffer(3834).bounds()
dats1 = [date]

In [16]:
# image_url = landsat_img.select(['SR_B4', 'SR_B3', 'SR_B2'])\
#     .multiply([0.0000275, 0.0000275, 0.0000275])\
#     .add([-0.2, -0.2, -0.2])\
#     .getThumbURL({'min': 0.0, 'max': 0.3, 
#                   'region': landsat_region, 'dimensions': image_size})
# Image(image_url, embed=True, format='png')

In [17]:
# Build the SSEBop object from the Landsat image
model_obj = model.Image.from_landsat_c2_sr(
    landsat_img, 
    tcorr_source='FANO',
    # et_reference_source='projects/climate-engine/cimis/daily', 
    # et_reference_band='ETr_ASCE',
    et_reference_source='IDAHO_EPSCOR/GRIDMET', 
    et_reference_band='etr', 
    et_reference_factor=0.85,
    et_reference_resample='nearest',
)

In [18]:
# image_url = model_obj.et\
#     .getThumbURL({'min': 0.0, 'max': 8, 'palette': et_palette, 
#                   'region': landsat_region2, 'dimensions': image_size})
# Image(image_url, embed=True, format='png')

In [19]:
url = model_obj.lst.getDownloadUrl({
    'region': landsat_region2,
    'format': 'GEO_TIFF'
})
response = requests.get(url)
with open(date + '/LST_' + img_id + '.tif', 'wb') as lst_fd:
    lst_fd.write(response.content)
lst_fd.close()   
lst_file = rio.open(date + '/LST_' + img_id + '.tif')
lst_file = lst_file.read()
lst_file.shape

(1, 256, 256)

In [20]:
url = model_obj.ndvi.getDownloadUrl({
    'region': landsat_region2,
    'format': 'GEO_TIFF'
})
response = requests.get(url)
with open(date + '/NDVI_' + img_id + '.tif', 'wb') as ndvi_fd:
    ndvi_fd.write(response.content)
ndvi_fd.close()
# ndvi_file = rio.open('NDVI_' + img_id + '.tif')
# ndvi_file = ndvi_file.read()
# ndvi_file.shape

In [21]:
url = model_obj.et.getDownloadUrl({
    'region': landsat_region2,
    'format': 'GEO_TIFF'
})
response = requests.get(url)
with open(date + '/OPENET_' + img_id + '.tif', 'wb') as openet_fd:
    openet_fd.write(response.content)
openet_fd.close()
# openet_file = rio.open('OPENET_' + img_id + '.tif')
# openet_file = openet_file.read()
# openet_file.shape

In [22]:
# image_url = model_obj.ndvi\
#     .getThumbURL({'min': 0.0, 'max': 1, 'palette': ndvi_palette, 
#                   'region': landsat_region, 'dimensions': image_size})
# Image(image_url, embed=True, format='png')

In [23]:
# model_obj.et.reduceRegion(ee.Reducer.count(), landsat_region).getInfo()

In [24]:
# model_obj.et.reduceRegion(ee.Reducer.max(), landsat_region).getInfo()

In [25]:
# model_obj.et.reduceRegion(ee.Reducer.mean(), landsat_region).getInfo()

In [26]:
# model_obj.ndvi.reduceRegion(ee.Reducer.max(), landsat_region).getInfo()

In [27]:
# model_obj.ndvi.reduceRegion(ee.Reducer.mean(), landsat_region).getInfo()

In [28]:
# model_obj.lst.reduceRegion(ee.Reducer.max(), landsat_region).getInfo()

In [29]:
# model_obj.lst.reduceRegion(ee.Reducer.mean(), landsat_region).getInfo()

In [30]:
# landsat_img.sampleRectangle(landsat_region).getInfo()

In [31]:
# landsat_img.sampleRectangle(landsat_region).getInfo()["properties"].keys()

In [32]:
blue = (np.expand_dims(np.array(landsat_img.sampleRectangle(landsat_region).getInfo()["properties"]["SR_B2"]), axis=0).astype(float)*0.0000275)-0.2

In [33]:
green = (np.expand_dims(np.array(landsat_img.sampleRectangle(landsat_region).getInfo()["properties"]["SR_B3"]), axis=0).astype(float)*0.0000275)-0.2

In [34]:
red = (np.expand_dims(np.array(landsat_img.sampleRectangle(landsat_region).getInfo()["properties"]["SR_B4"]), axis=0).astype(float)*0.0000275)-0.2

In [35]:
nir = (np.expand_dims(np.array(landsat_img.sampleRectangle(landsat_region).getInfo()["properties"]["SR_B5"]), axis=0).astype(float)*0.0000275)-0.2

In [36]:
swir1 = (np.expand_dims(np.array(landsat_img.sampleRectangle(landsat_region).getInfo()["properties"]["SR_B6"]), axis=0).astype(float)*0.0000275)-0.2

In [37]:
lst = (((np.expand_dims(np.array(landsat_img.sampleRectangle(landsat_region).getInfo()["properties"]["ST_B10"]), axis=0).astype(float)*0.00341802) + 149.0)/400.0)

In [38]:
qa = np.expand_dims(np.array(landsat_img.sampleRectangle(landsat_region).getInfo()["properties"]["QA_PIXEL"]), axis=0).astype(float)

In [39]:
lons = np.array([x[0] for x in landsat_img.sampleRectangle(landsat_region).getInfo()["geometry"]["coordinates"][0]])
lons = np.expand_dims(np.tile(np.expand_dims(np.linspace(lons.min(), lons.max(), 16), axis=0), (16, 1)), axis=0)

In [40]:
lats = np.array([x[1] for x in landsat_img.sampleRectangle(landsat_region).getInfo()["geometry"]["coordinates"][0]])
lats = np.expand_dims(np.tile(np.expand_dims(np.linspace(lats.min(), lats.max(), 16), axis=-1), (1, 16)), axis=0)

In [41]:
img_seq = torch.tensor(np.concatenate([blue, green, red, nir, swir1, lst, qa, lons, lats])).unsqueeze(0)

In [42]:
dat, lon, lat, clim = dats1, torch.tensor([site1[0]]), torch.tensor([site1[1]]), get_koppen(site1[0], site1[1])

In [43]:
date_time_obj = datetime.strptime(dats1[0].split("_")[0] + '_' + dats1[0].split("_")[1] + '_' + dats1[0].split("_")[2], '%Y_%m_%d')
day_of_year = date_time_obj.timetuple().tm_yday
day_sin = torch.tensor([np.sin(2 * np.pi * day_of_year/364.0)])
day_cos = torch.tensor([np.cos(2 * np.pi * day_of_year/364.0)])
x_coord = torch.tensor([np.sin(math.pi/2-np.deg2rad(site1[1])) * np.cos(np.deg2rad(site1[0]))])
y_coord = torch.tensor([np.sin(math.pi/2-np.deg2rad(site1[1])) * np.sin(np.deg2rad(site1[0]))])
z_coord = torch.tensor([np.cos(math.pi/2-np.deg2rad(site1[1]))])

In [44]:
# tmin = np.expand_dims(np.tile(np.array(gridmet_img.sampleRectangle(gridmet_region).getInfo()["properties"]['tmmn']), [16,16]), axis=0)

In [45]:
# tmax = np.expand_dims(np.tile(np.array(gridmet_img.sampleRectangle(gridmet_region).getInfo()["properties"]['tmmx']), [16,16]), axis=0)

In [46]:
# srad = np.expand_dims(np.tile(np.array(gridmet_img.sampleRectangle(gridmet_region).getInfo()["properties"]['srad']), [16,16]), axis=0)

In [47]:
# sph = np.expand_dims(np.tile(np.array(gridmet_img.sampleRectangle(gridmet_region).getInfo()["properties"]['sph']), [16,16]), axis=0)

In [48]:
# etr = np.expand_dims(np.tile(np.array(gridmet_img.sampleRectangle(gridmet_region).getInfo()["properties"]['etr']), [16,16]), axis=0)

In [49]:
elev = torch.tensor([get_elevation(site1[0], site1[1])/8848.0])

In [50]:
class SSEBop(torch.nn.Module):
    def __init__(self, device):
        super().__init__()    
        self.lc = ee.ImageCollection('LANDSAT/LC08/C02/T1_L2')
        self.scale = 1000
        self.device = device
        self.dict = {}
        
    
    def _openet(self, date, lon, lat):
        date_1 = [datetime.strptime(d1, "%Y-%m-%d") for d1 in date]
        date_next_day = [str(d2 + timedelta(days=1))[:10] for d2 in date_1]
        ets = []
        
        for i1 in range(len(date)):
            
            if str(date[i1]) + "_" + str(lon[i1]) + "_" + str(lat[i1]) in self.dict:
                ets.append(self.dict[str(date[i1]) + "_" + str(lon[i1]) + "_" + str(lat[i1])])
            else:
                lc = self.lc.filterDate(date[i1], date_next_day[i1])
                site1point = ee.Geometry.Rectangle(lon[i1].item(), lat[i1].item(), lon[i1].item(), lat[i1].item())
                lc_r_poi = lc.getRegion(site1point, self.scale).getInfo()
                id_list = [x[0] for x in lc_r_poi[1:]]
                id_list.sort()
                landsat_img = ee.Image('LANDSAT/LC08/C02/T1_L2/' + id_list[-1])
                landsat_region = site1point.buffer(225)
                model_obj = model.Image.from_landsat_c2_sr(
                    landsat_img, 
                    tcorr_source='FANO',
                    et_reference_source='IDAHO_EPSCOR/GRIDMET', 
                    et_reference_band='etr', 
                    et_reference_factor=0.85,
                    et_reference_resample='nearest',
                )
                et = model_obj.et.reduceRegion(ee.Reducer.max(), landsat_region).getInfo()["et"]
                if et:
                    ets.append(et)
                    self.dict[str(date[i1]) + "_" + str(lon[i1]) + "_" + str(lat[i1])] = et
                else:
                    ets.append(0.0)
                    self.dict[str(date[i1]) + "_" + str(lon[i1]) + "_" + str(lat[i1])] = 0.0
        
        return ets
    
    def forward(self, date, lat, lon):
        #nir red lst
        # etr, elev, sph, srad, tmin, tmax, lat, doy
        
#         data = data
        date_str = [dstr.replace("_", "-") for dstr in date]
        
        return torch.tensor(self._openet(date_str, lon, lat), device=self.device)  

In [51]:
ssebop_model = SSEBop(device).to(device, dtype=torch.float32)

In [52]:
lons1 = torch.tensor([[site1[0]]])

In [53]:
lats1 = torch.tensor([[site1[1]]])

In [54]:
output_ET = ssebop_model(dats1, lats1.to(device, dtype=torch.float32), lons1.to(device, dtype=torch.float32))

In [55]:
def conv1x1(in_channels, out_channels, stride = 1):
    return nn.Conv2d(in_channels,out_channels,kernel_size = 1,
                    stride =stride, padding=0,bias=False)

def conv3x3(in_channels, out_channels, stride = 1):
    return nn.Conv2d(in_channels,out_channels,kernel_size = 3,
        stride =stride, padding=1,bias=False)

class irnn_layer(nn.Module):
    def __init__(self,in_channels):
        super(irnn_layer,self).__init__()
        self.left_weight = nn.Conv2d(in_channels,in_channels,kernel_size=1,stride=1,groups=in_channels,padding=0)
        self.right_weight = nn.Conv2d(in_channels,in_channels,kernel_size=1,stride=1,groups=in_channels,padding=0)
        self.up_weight = nn.Conv2d(in_channels,in_channels,kernel_size=1,stride=1,groups=in_channels,padding=0)
        self.down_weight = nn.Conv2d(in_channels,in_channels,kernel_size=1,stride=1,groups=in_channels,padding=0)
        
    def forward(self,x):
        _,_,H,W = x.shape
        top_left = x.clone()
        top_right = x.clone()
        top_up = x.clone()
        top_down = x.clone()
        top_left[:,:,:,1:] = F.relu(self.left_weight(x)[:,:,:,:W-1]+x[:,:,:,1:],inplace=False)
        top_right[:,:,:,:-1] = F.relu(self.right_weight(x)[:,:,:,1:]+x[:,:,:,:W-1],inplace=False)
        top_up[:,:,1:,:] = F.relu(self.up_weight(x)[:,:,:H-1,:]+x[:,:,1:,:],inplace=False)
        top_down[:,:,:-1,:] = F.relu(self.down_weight(x)[:,:,1:,:]+x[:,:,:H-1,:],inplace=False)
        return (top_up,top_right,top_down,top_left)

class Attention(nn.Module):
    def __init__(self,in_channels):
        super(Attention,self).__init__()
        self.out_channels = int(in_channels/2)
        self.conv1 = nn.Conv2d(in_channels,self.out_channels,kernel_size=3,padding=1,stride=1)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(self.out_channels,self.out_channels,kernel_size=3,padding=1,stride=1)
        self.relu2 = nn.ReLU()
        self.conv3 = nn.Conv2d(self.out_channels,4,kernel_size=1,padding=0,stride=1)
        self.sigmod = nn.Sigmoid()
    
    def forward(self,x):
        out = self.conv1(x)
        out = self.relu1(out)
        out = self.conv2(out)
        out = self.relu2(out)
        out = self.conv3(out)
        out = self.sigmod(out)
        return out

class SAM(nn.Module):
    def __init__(self,in_channels,out_channels,attention=1):
        super(SAM,self).__init__()
        self.out_channels = out_channels
        self.irnn1 = irnn_layer(self.out_channels)
        self.irnn2 = irnn_layer(self.out_channels)
        self.conv_in = conv3x3(in_channels,self.out_channels)
        self.relu1 = nn.ReLU(True)
        
        self.conv1 = nn.Conv2d(self.out_channels,self.out_channels,kernel_size=1,stride=1,padding=0)
        self.conv2 = nn.Conv2d(self.out_channels*4,self.out_channels,kernel_size=1,stride=1,padding=0)
        self.conv3 = nn.Conv2d(self.out_channels*4,self.out_channels,kernel_size=1,stride=1,padding=0)
        self.relu2 = nn.ReLU(True)
        self.attention = attention
        if self.attention:
            self.attention_layer = Attention(in_channels)
        self.conv_out = conv1x1(self.out_channels,1)
        self.sigmod = nn.Sigmoid()
    
    def forward(self,x):
        if self.attention:
            weight = self.attention_layer(x)
        out = self.conv1(x)
        top_up,top_right,top_down,top_left = self.irnn1(out)
        
        # direction attention
        if self.attention:
            top_up.mul(weight[:,0:1,:,:])
            top_right.mul(weight[:,1:2,:,:])
            top_down.mul(weight[:,2:3,:,:])
            top_left.mul(weight[:,3:4,:,:])
        out = torch.cat([top_up,top_right,top_down,top_left],dim=1)
        out = self.conv2(out)
        top_up,top_right,top_down,top_left = self.irnn2(out)
        
        # direction attention
        if self.attention:
            top_up.mul(weight[:,0:1,:,:])
            top_right.mul(weight[:,1:2,:,:])
            top_down.mul(weight[:,2:3,:,:])
            top_left.mul(weight[:,3:4,:,:])
        
        out = torch.cat([top_up,top_right,top_down,top_left],dim=1)
        out = self.conv3(out)
        out = self.relu2(out)
        mask = self.sigmod(self.conv_out(out))
        return mask

class convolutionalCapsule(nn.Module):
    def __init__(self, in_capsules, out_capsules, in_channels, out_channels, stride=1, padding=2,
                 kernel=5, num_routes=3, nonlinearity='sqaush', batch_norm=False, dynamic_routing='local', cuda=False):
        super(convolutionalCapsule, self).__init__()
        self.num_routes = num_routes
        self.in_channels = in_channels
        self.in_capsules = in_capsules
        self.out_capsules = out_capsules
        self.out_channels = out_channels
        self.nonlinearity = nonlinearity
        self.batch_norm = batch_norm
        self.bn = nn.BatchNorm2d(in_capsules*out_capsules*out_channels)
        self.conv2d = nn.Conv2d(kernel_size=(kernel, kernel), stride=stride, padding=padding,
                                in_channels=in_channels, out_channels=out_channels*out_capsules)
        self.dynamic_routing = dynamic_routing
        self.cuda = cuda
        self.SAM1 = SAM(self.in_channels,self.in_channels,1)

    def forward(self, x):
        batch_size = x.size(0)
        in_width, in_height = x.size(3), x.size(4)
        x = x.view(batch_size*self.in_capsules, self.in_channels, in_width, in_height)
        u_hat = self.conv2d(x) * self.SAM1(x)

        out_width, out_height = u_hat.size(2), u_hat.size(3)

        # batch norm layer
        if self.batch_norm:
            u_hat = u_hat.view(batch_size, self.in_capsules, self.out_capsules * self.out_channels, out_width, out_height)
            u_hat = u_hat.view(batch_size, self.in_capsules * self.out_capsules * self.out_channels, out_width, out_height)
            u_hat = self.bn(u_hat)
            u_hat = u_hat.view(batch_size, self.in_capsules, self.out_capsules*self.out_channels, out_width, out_height)
            u_hat = u_hat.permute(0,1,3,4,2).contiguous()
            u_hat = u_hat.view(batch_size, self.in_capsules, out_width, out_height, self.out_capsules, self.out_channels)

        else:
            u_hat = u_hat.permute(0,2,3,1).contiguous()
            u_hat = u_hat.view(batch_size, self.in_capsules, out_width, out_height, self.out_capsules*self.out_channels)
            u_hat = u_hat.view(batch_size, self.in_capsules, out_width, out_height, self.out_capsules, self.out_channels)


        b_ij = Variable(torch.zeros(1, self.in_capsules, out_width, out_height, self.out_capsules))
        if self.cuda:
            b_ij = b_ij.cuda()
        for iteration in range(self.num_routes):
            c_ij = F.softmax(b_ij, dim=1)
            c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(5)

            s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)


            if (self.nonlinearity == 'relu') and (iteration == self.num_routes - 1):
                v_j = F.relu(s_j)
            elif (self.nonlinearity == 'leakyRelu') and (iteration == self.num_routes - 1):
                v_j = F.leaky_relu(s_j)
            else:
                v_j = self.squash(s_j)

            v_j = v_j.squeeze(1)

            if iteration < self.num_routes - 1:
                temp = u_hat.permute(0, 2, 3, 4, 1, 5)
                temp2 = v_j.unsqueeze(5)
                a_ij = torch.matmul(temp, temp2).squeeze(5) # dot product here
                a_ij = a_ij.permute(0, 4, 1, 2, 3)
                b_ij = b_ij + a_ij.mean(dim=0)

        v_j = v_j.permute(0, 3, 4, 1, 2).contiguous()

        return v_j

    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor

class _DenseLayer(nn.Sequential):

    def __init__(self, num_caps, num_input_features, growth_rate, bn_size, drop_rate, actvec_size):
        super().__init__()
        self.concap1 = convolutionalCapsule(in_capsules=8, out_capsules=8, in_channels=8,
                                  out_channels=8,
                                  stride=1, padding=1, kernel=3, num_routes=3,
                                  nonlinearity='sqaush', batch_norm=True,
                                  dynamic_routing='local', cuda=True)
        
        self.concap2 = convolutionalCapsule(in_capsules=8, out_capsules=8, in_channels=8,
                                  out_channels=8,
                                  stride=1, padding=1, kernel=3, num_routes=3,
                                  nonlinearity='sqaush', batch_norm=True,
                                  dynamic_routing='local', cuda=True)
        

    def forward(self, x):
        new_features = self.concap1(x)
        new_features = self.concap2(new_features)
        return x + new_features

class _DenseBlock(nn.Sequential):

    def __init__(self, num_layers, num_caps, num_input_features, bn_size, growth_rate,
                 drop_rate, actvec_size):
        super().__init__()
        for i in range(num_layers):
            layer = _DenseLayer(num_caps, num_input_features,
                                growth_rate, bn_size, drop_rate, actvec_size)
            self.add_module('denselayer{}'.format(i + 1), layer)

class _Transition(nn.Sequential):

    def __init__(self, num_caps, in_vect, out_vect):
        super().__init__()
        self.skip = convolutionalCapsule(in_capsules=8, out_capsules=8,
                                     in_channels=8, out_channels=8,
                                  stride=1, kernel=1, padding=0, num_routes=3,
                                  nonlinearity='sqaush', batch_norm=True,
                                  dynamic_routing='local', cuda=True)
        
        self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2)
        
    def forward(self, x):
        out = self.skip(x)
        batch_size, num_caps = out.size(0), out.size(1)
        out = out.view(out.shape[0] * out.shape[1], out.shape[2], out.shape[3], out.shape[4])
        out = self.avgpool(out)
        out = out.view(batch_size, num_caps, out.shape[1], int(out.shape[2]), int(out.shape[3]))
        return out

class DenseNetModel(nn.Module):
    """Densenet-BC model class
    Args:
        growth_rate (int) - how many filters to add each layer (k in paper)
        block_config (list of 4 ints) - how many layers in each pooling block
        num_init_features (int) - the number of filters to learn in the first convolution layer
        bn_size (int) - multiplicative factor for number of bottle neck layers
          (i.e. bn_size * k features in the bottleneck layer)
        drop_rate (float) - dropout rate after each dense layer
        num_classes (int) - number of classification classes
    """

    def __init__(self,
                 n_input_channels=5,
                 conv1_t_size=7,
                 conv1_t_stride=1,
                 no_max_pool=False,
                 growth_rate=32,
                 block_config=(6, 12, 24, 16),
                 num_init_features=64,
                 bn_size=32,
                 drop_rate=0,
                 num_classes=1,
                 actvec_size=8):

        super().__init__()
        
        self.num_init_features = num_init_features
        self.actvec_size = actvec_size
        self.num_caps = int(num_init_features/actvec_size)

        # First convolution
        self.input_features = [('conv1',
                          nn.Conv2d(n_input_channels,
                                    num_init_features,
                                    kernel_size=4,
                                    stride=2,
                                    padding=1,
                                    bias=False)),
                         ('norm1', nn.BatchNorm2d(num_init_features)),
                         ('relu1', nn.ReLU(inplace=True))]
        if not no_max_pool:
            self.input_features.append(
                ('pool1', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)))
        self.input_features = nn.Sequential(OrderedDict(self.input_features))
        self.features = nn.Sequential()
        # Each denseblock
        num_actvec = actvec_size
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(num_layers=num_layers,
                                num_caps=self.num_caps,
                                num_input_features=num_actvec,
                                bn_size=bn_size,
                                growth_rate=growth_rate,
                                drop_rate=drop_rate,
                                actvec_size=actvec_size)
            self.features.add_module('denseblock{}'.format(i + 1), block)
            num_actvec = num_actvec + num_layers * growth_rate
            if i != len(block_config) - 1:
                trans = _Transition(num_caps = self.num_caps, 
                                    in_vect=actvec_size, 
                                    out_vect=actvec_size)
                self.features.add_module('transition{}'.format(i + 1), trans)
                num_actvec = num_actvec // 2

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
                
        self.metadata_network = torch.nn.Sequential(
            torch.nn.Linear(17, 64),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(64, 128)
        )

        # Linear layer
        self.classifier = nn.Linear(64 + 128, 1)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight,
                                        mode='fan_out',
                                        nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    def forward(self, x, metadata):
        input_features = self.input_features(x)
        input_features = input_features.view(input_features.shape[0], int(self.num_init_features/self.actvec_size), self.actvec_size, input_features.shape[-2], input_features.shape[-1])

        
        
        out = self.features(input_features)
        y = self.metadata_network(metadata)
        out = self.classifier(torch.cat((out.view(out.size(0), -1), y), dim=1))
        return out
    
#     def forward(self, x, metadata):
#         out = self.convolutions(x)
#         out = out.view(x.size(0), -1)
#         out_metadata = self.metadata_network(metadata)
#         out = metadata[:, 0] + self.fc(torch.cat((out, out_metadata), dim=1))

# #         out = ((torch.sigmoid(self.fc(torch.cat((out, out_metadata), dim=1))) * (self.range2 - self.range1)) + self.range1)
#         return out

def generate_model(model_depth, **kwargs):
    assert model_depth in [121, 169, 201, 264]

    if model_depth == 121:
        model = DenseNetModel(num_init_features=64,
                         growth_rate=4,
                         block_config=(6, 12, 24, 16),
                         **kwargs)
    elif model_depth == 169:
        model = DenseNetModel(num_init_features=64,
                         growth_rate=32,
                         block_config=(6, 12, 32, 32),
                         **kwargs)
    elif model_depth == 201:
        model = DenseNetModel(num_init_features=64,
                         growth_rate=32,
                         block_config=(6, 12, 48, 32),
                         **kwargs)
    elif model_depth == 264:
        model = DenseNetModel(num_init_features=64,
                         growth_rate=32,
                         block_config=(6, 12, 64, 48),
                         **kwargs)

    return model

class Model(nn.Module):

    def __init__(self):
        super().__init__()

        self.SubModel = generate_model(121)
        
#     seq_len, batch, input_size
    def forward(self, x, y):
        out = self.SubModel(x, y)
        return out.flatten()  

In [56]:
quench_model = torch.load("../Checkpoints/Quench_50.pt").to(device, dtype=torch.float32)
quench_model.eval()
output = quench_model(interpolate(img_seq[:, 0:5].to(device, dtype=torch.float32), size=32), torch.cat((output_ET.to(device, dtype=torch.float32), clim.to(device, dtype=torch.float32), day_sin.to(device, dtype=torch.float32), day_cos.to(device, dtype=torch.float32), x_coord.to(device, dtype=torch.float32), y_coord.to(device, dtype=torch.float32), z_coord.to(device, dtype=torch.float32), elev.to(device, dtype=torch.float32))).unsqueeze(0))
print(output)
print(output_ET)
print(output + output_ET)
print(model_obj.lst.reduceRegion(ee.Reducer.max(), landsat_region).getInfo())
print(model_obj.ndvi.reduceRegion(ee.Reducer.max(), landsat_region).getInfo())

tensor([-1.5094], device='cuda:0', grad_fn=<ReshapeAliasBackward0>)
tensor([5.9784], device='cuda:0')
tensor([4.4689], device='cuda:0', grad_fn=<AddBackward0>)
{'lst': 308.84028728}
{'ndvi': 0.8939027184253256}


In [57]:
openet_data = rio.open(date + '/OPENET_' + img_id + '.tif')
openet_values = openet_data.read(1)
openet_values[openet_values!=0.0] = (openet_values[openet_values!=0.0] + output.item())


openet_profile = openet_data.profile
with rio.open(date + '/ET_' + img_id + '.tif', 'w', **openet_profile) as et_dst:
    et_dst.write(openet_values, 1)


In [58]:
# quench_model = torch.load("../Checkpoints/Quench_250.pt").to(device, dtype=torch.float32)
# quench_model.eval()
# output = quench_model(interpolate(img_seq[:, 0:5].to(device, dtype=torch.float32), size=32), torch.cat((output_ET.to(device, dtype=torch.float32), clim.to(device, dtype=torch.float32), day_sin.to(device, dtype=torch.float32), day_cos.to(device, dtype=torch.float32), x_coord.to(device, dtype=torch.float32), y_coord.to(device, dtype=torch.float32), z_coord.to(device, dtype=torch.float32), elev.to(device, dtype=torch.float32))).unsqueeze(0))
# print(output)
# print(output_ET)
# print(model_obj.lst.reduceRegion(ee.Reducer.max(), landsat_region).getInfo())
# print(model_obj.ndvi.reduceRegion(ee.Reducer.max(), landsat_region).getInfo())

In [59]:
# quench_model = torch.load("../Checkpoints/Quench_500.pt").to(device, dtype=torch.float32)
# quench_model.eval()
# output = quench_model(interpolate(img_seq[:, 0:5].to(device, dtype=torch.float32), size=32), torch.cat((output_ET.to(device, dtype=torch.float32), clim.to(device, dtype=torch.float32), day_sin.to(device, dtype=torch.float32), day_cos.to(device, dtype=torch.float32), x_coord.to(device, dtype=torch.float32), y_coord.to(device, dtype=torch.float32), z_coord.to(device, dtype=torch.float32), elev.to(device, dtype=torch.float32))).unsqueeze(0))
# print(output)
# print(output_ET)
# print(model_obj.lst.reduceRegion(ee.Reducer.max(), landsat_region).getInfo())
# print(model_obj.ndvi.reduceRegion(ee.Reducer.max(), landsat_region).getInfo())