In [1]:
import xarray as xr
import numpy as np
import torch
from torchvision import transforms
import os
from p_drought_indices.functions.function_clns import load_config

CONFIG_PATH= "../config.yaml"

config_file = load_config(CONFIG_PATH=CONFIG_PATH)

# Open the NetCDF file with xarray
dataset = xr.open_dataset(os.path.join(config_file['NDVI']['ndvi_path'], 'smoothed_ndvi_1.nc'))

time_end = "2008-12-31"
time_start = "2008-06-01"

dim=64
ds = dataset.sel(time=slice(time_start,time_end)).isel(lat=slice(0,dim), lon=slice(0,dim))

  from .autonotebook import tqdm as notebook_tqdm
  warn(f"Failed to load image Python extension: {e}")


In [2]:
config_directories = [config_file['SPI']['IMERG']['path'], config_file['SPI']['GPCC']['path'], config_file['SPI']['CHIRPS']['path'], config_file['SPI']['ERA5']['path'], config_file['SPI']['MSWEP']['path'] ]
config_dir_precp = [config_file['PRECIP']['IMERG']['path'],config_file['PRECIP']['CHIRPS_05']['path'], config_file['PRECIP']['GPCC']['path'], config_file['PRECIP']['CHIRPS']['path'], config_file['PRECIP']['ERA5']['path'],  config_file['PRECIP']['TAMSTAT']['path'],config_file['PRECIP']['MSWEP']['path']]
        
prod = "CHIRPS"
late = 60
product_dir = [f for f in config_dir_precp if prod in f][0]
list_files = [f for f in os.listdir(product_dir) if (f.endswith(".nc")) and ("merged" in f)]
precp_ds = xr.open_dataset(os.path.join(product_dir, list_files[0]))
variable = [var for var in precp_ds.data_vars if var!= "spatial_ref"][0]

sub_precp= precp_ds.sel(time=slice(time_start,time_end)).isel(lat=slice(0,dim), lon=slice(0,dim))

In [3]:
from p_drought_indices.functions.function_clns import prepare

ds = prepare(ds)
sub_precp = prepare(sub_precp)

veg_repr = ds["ndvi"].rio.reproject_match(sub_precp[variable]).rename({'x':'lon','y':'lat'})

### converting null values to -99
sub_veg = veg_repr.where(veg_repr.notnull(), -99)
sub_precp = sub_precp.assign(null_precp = sub_precp[variable].where(sub_precp[variable].notnull(), -99))

# Read the data as a numpy array
target = sub_veg.transpose("lat","lon","time").values
data = sub_precp["null_precp"].transpose("lat","lon","time").values

target = np.array(target)
data = np.array(data)

In [4]:
n_samples = data.shape[-1]

split = 0.8
train_samples = int(round(split*n_samples, 0))

In [5]:
import numpy as np

def add_channel(data, n_samples):

    # define the desired size of the time steps and number of channels 
    # ##output: (num_samples, num_frames, num_channels, height, width)
    n_timesteps = n_samples
    n_channels = 1

    # determine the number of samples based on the desired number of time steps
    n_samples = data.shape[-1] // n_timesteps

    # reshape the input data into a 4D tensor
    input_data = np.reshape(data, (data.shape[0], data.shape[1], n_timesteps, n_samples))

    # add an extra dimension for the channels
    input_data = np.reshape(input_data, (n_samples, n_timesteps,n_channels, input_data.shape[0], input_data.shape[1]))

    # check the shape of the input data
    print(input_data.shape) # should print (n_samples, n_timesteps, lat, lon, n_channels)
    return input_data


In [6]:
input_data = add_channel(data, n_samples)
target_data = add_channel(target, n_samples)
train_data = input_data[:,:train_samples,:,:]
test_data =  input_data[:,train_samples:,:,:]
train_label = target_data[:,:train_samples,:,:]
test_label =  target_data[:,train_samples:,:,:]

(1, 214, 1, 64, 64)
(1, 214, 1, 64, 64)


In [7]:
import torch
from torch.utils.data import Dataset

batch_size=4

class CustomDataset(Dataset):
    def __init__(self, data, targets):
        self.data = data
        self.targets = targets

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, index):
        x = self.data[index]
        y = self.targets[index]
        return x, y
    
from torch.utils.data import DataLoader

# create a CustomDataset object using the reshaped input data
train_dataset = CustomDataset(train_data, train_label)
test_dataset = CustomDataset(test_data, test_label)

# create a DataLoader object that uses the dataset
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [8]:
for batch_idx, (inputs, targets) in enumerate(train_dataloader):
    inputs = inputs.float()
    targets = targets.float()
    print(inputs.shape, targets.shape, inputs.max(), inputs.min())


for batch_idx, (inputs, targets) in enumerate(test_dataloader):
    inputs = inputs.float()
    targets = targets.float()
    print(inputs.shape, targets.shape, inputs.max(), inputs.min())

torch.Size([1, 171, 1, 64, 64]) torch.Size([1, 171, 1, 64, 64]) tensor(24.5388) tensor(-99.)
torch.Size([1, 43, 1, 64, 64]) torch.Size([1, 43, 1, 64, 64]) tensor(51.9347) tensor(-99.)


In [9]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch
import torch.nn as nn

class ConvLSTMBlock(nn.Module):
    def __init__(self, in_channels, num_features, kernel_size=3, padding=1, stride=1):
        super().__init__()
        self.num_features = num_features
        self.conv = self._make_layer(in_channels+num_features, num_features*4,
                                       kernel_size, padding, stride)

    def _make_layer(self, in_channels, out_channels, kernel_size, padding, stride):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels,
                      kernel_size=kernel_size, padding=padding, stride=stride, bias=False),
            nn.BatchNorm2d(out_channels))

    def forward(self, inputs):
        '''
        :param inputs: (B, S, C, H, W)
        :param hidden_state: (hx: (B, S, C, H, W), cx: (B, S, C, H, W))
        :return:
        '''
        outputs = []
        B, S, C, H, W = inputs.shape
        hx = torch.zeros(B, self.num_features, H, W).to(inputs.device)
        cx = torch.zeros(B, self.num_features, H, W).to(inputs.device)
        for t in range(S):
            combined = torch.cat([inputs[:, t], # (B, C, H, W)
                                  hx], dim=1)
            gates = self.conv(combined)
            ingate, forgetgate, cellgate, outgate = torch.split(gates, self.num_features, dim=1)
            ingate = torch.sigmoid(ingate)
            forgetgate = torch.sigmoid(forgetgate)
            outgate = torch.sigmoid(outgate)

            cy = (forgetgate * cx) + (ingate * cellgate)
            hy = outgate * torch.tanh(cy)
            outputs.append(hy)
            hx = hy
            cx = cy

        return torch.stack(outputs).permute(1, 0, 2, 3, 4).contiguous() # (S, B, C, H, W) -> (B, S, C, H, W)

class Encoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layers = []
        for idx, params in enumerate(config.encoder):
            setattr(self, params[0]+'_'+str(idx), self._make_layer(*params))
            self.layers.append(params[0]+'_'+str(idx))

    def _make_layer(self, type, activation, in_ch, out_ch, kernel_size, padding, stride):
        layers = []
        if type == 'conv':
            layers.append(nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding=padding, stride=stride, bias=False))
            layers.append(nn.BatchNorm2d(out_ch))
            if activation == 'leaky': layers.append(nn.LeakyReLU(inplace=True))
            elif activation == 'relu': layers.append(nn.ReLU(inplace=True))
        elif type == 'convlstm':
            layers.append(ConvLSTMBlock(in_ch, out_ch, kernel_size=kernel_size, padding=padding, stride=stride))
        return nn.Sequential(*layers)

    def forward(self, x):
        '''
        :param x: (B, S, C, H, W)
        :return:
        '''
        outputs = [x]
        for layer in self.layers:
            if 'conv_' in layer:
                B, S, C, H, W = x.shape
                x = x.view(B*S, C, H, W)
            x = getattr(self, layer)(x)
            if 'conv_' in layer: x = x.view(B, S, x.shape[1], x.shape[2], x.shape[3])
            if 'convlstm' in layer: outputs.append(x)
        return outputs

class Decoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layers = []
        for idx, params in enumerate(config.decoder):
            setattr(self, params[0]+'_'+str(idx), self._make_layer(*params))
            self.layers.append(params[0]+'_'+str(idx))

    def _make_layer(self, type, activation, in_ch, out_ch, kernel_size, padding, stride):
        layers = []
        if type == 'conv':
            layers.append(nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding=padding, stride=stride, bias=False))
            layers.append(nn.BatchNorm2d(out_ch))
            if activation == 'leaky': layers.append(nn.LeakyReLU(inplace=True))
            elif activation == 'relu': layers.append(nn.ReLU(inplace=True))
            elif activation == 'sigmoid': layers.append(nn.Sigmoid())
        elif type == 'convlstm':
            layers.append(ConvLSTMBlock(in_ch, out_ch, kernel_size=kernel_size, padding=padding, stride=stride))
        elif type == 'deconv':
            layers.append(nn.ConvTranspose2d(in_ch, out_ch, kernel_size=kernel_size, padding=padding, stride=stride, bias=False))
            layers.append(nn.BatchNorm2d(out_ch))
            if activation == 'leaky': layers.append(nn.LeakyReLU(inplace=True))
            elif activation == 'relu': layers.append(nn.ReLU(inplace=True))
        return nn.Sequential(*layers)

    def forward(self, encoder_outputs):
        '''
        :param x: (B, S, C, H, W)
        :return:
        '''
        idx = len(encoder_outputs)-1
        for layer in self.layers:
            if 'conv_' in layer or 'deconv_' in layer:
                x = encoder_outputs[idx]
                B, S, C, H, W = x.shape
                x = x.view(B*S, C, H, W)
                x = getattr(self, layer)(x)
                x = x.view(B, S, x.shape[1], x.shape[2], x.shape[3])
            elif 'convlstm' in layer:
                idx -= 1
                x = torch.cat([encoder_outputs[idx], x], dim=2)
                x = getattr(self, layer)(x)
                encoder_outputs[idx] = x
        return x

class ConvLSTM(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.encoder = Encoder(config)
        self.decoder = Decoder(config)

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [10]:
def train(config, logger, epoch, model, train_loader, criterion, optimizer):
    model.train()
    epoch_records = {'loss': []}
    num_batchs = len(train_loader)
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        print(inputs.max())
        inputs = inputs.float().to(config.device)
        targets = targets.float().to(config.device)
        outputs = model(inputs)
        losses = criterion(outputs, targets)
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        epoch_records['loss'].append(losses.item())
        if batch_idx and batch_idx % config.display == 0:
            logger.info('EP:{:03d}\tBI:{:05d}/{:05d}\tLoss:{:.6f}({:.6f})'.format(epoch, batch_idx, num_batchs,
                                                                                epoch_records['loss'][-1], np.mean(epoch_records['loss'])))
    return epoch_records

def valid(config, logger, epoch, model, valid_loader, criterion):
    model.eval()
    epoch_records = {'loss': []}
    num_batchs = len(valid_loader)
    for batch_idx, (inputs, targets) in enumerate(valid_loader):
        with torch.no_grad():
            inputs = inputs.float().to(config.device)
            targets = targets.float().to(config.device)
            outputs = model(inputs)
            losses = criterion(outputs, targets)
            epoch_records['loss'].append(losses.item())
            if batch_idx and batch_idx % config.display == 0:
                logger.info('[V] EP:{:03d}\tBI:{:05d}/{:05d}\tLoss:{:.6f}({:.6f})'.format(epoch, batch_idx, num_batchs,
                                                                                    epoch_records['loss'][-1], np.mean(epoch_records['loss'])))
    return epoch_records

import logging
import os
import time

def build_logging(config):
    logging.basicConfig(level=logging.DEBUG,
                        format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
                        datefmt='%m-%d %H:%M',
                        filename=os.path.join(config.log_dir, time.strftime("%Y%d%m_%H%M") + '.log'),
                        filemode='w')
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s')
    console.setFormatter(formatter)
    logging.getLogger('').addHandler(console)
    return logging

In [13]:
from p_drought_indices.configs.config_3x3_16_3x3_32_3x3_64 import config
from torch.nn import MSELoss
import matplotlib.pyplot as plt

name = '3x3_16_3x3_32_3x3_64'

logger = build_logging(config)
model = ConvLSTM(config).to(config.device)
#criterion = CrossEntropyLoss().to(config.device)
#criterion = torch.nn.MSELoss().to(config.device)
criterion = MSELoss().to(config.device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
train_records, valid_records, test_records = [], [], []
for epoch in range(config.epochs):
    epoch_records = train(config, logger, epoch, model, train_dataloader, criterion, optimizer)
    train_records.append(np.mean(epoch_records['loss']))
    epoch_records = valid(config, logger, epoch, model, test_dataloader, criterion)
    valid_records.append(np.mean(epoch_records['loss']))
    plt.plot(range(epoch + 1), train_records, label='train')
    plt.plot(range(epoch + 1), valid_records, label='valid')
    plt.legend()
    plt.savefig(os.path.join(config.output_dir, '{}.png'.format(name)))
    plt.close()

tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(24.5388)
tensor(2