# Linear Regression Model

In [None]:
import os
import torch
import requests
import numpy as np
import pandas as pd
import xarray as xr
import xbatcher
import regionmask
from torch.utils.data import DataLoader

## Data preparation

Use this configuration file to define the training dataset. 

In [None]:
train_config = {
    "data_source_url": "zip:///::https://huggingface.co/datasets/jacobbieker/project-resilience/resolve/main/merged_aggregated_dataset.zarr.zip",
    "data_source_path": "../data/gcb/raw/zarr/merged_aggregated_dataset.zarr",
    "start_year": 2000,
    "end_year": 2020,
    "test_years": [2021, 2022],
    "target": "ELUC_diff",
    "batch_size": 1,
    "n_timepoint_in_each_sample": 10,
    "n_timepoint_in_each_batch": 5,
    "n_latlon_in_each_sample": 128,
    "n_latlon_in_each_batch": 1,
    "input_overlap": 1,
#    "countries": regionmask.defined_regions.natural_earth_v5_0_0.countries_110.names # select all
#    "countries": ["United States of America", "France", "Canada", "United Kingdom"]  # list of countries
    "countries": ["United States of America"] 
}

If ELUC dataset is already in your disk, just update the config file (if necessary).
If not, here are some steps that will save you some time: 

- [Download the dataset as a Zip file](https://huggingface.co/datasets/jacobbieker/project-resilience/resolve/main/merged_aggregated_dataset.zarr.zip)
- Extract the zip file to following folder: `MVP` > `data` > `gcb` > `raw` > `zarr`

In [None]:
class ELUC_Dataset(torch.utils.data.IterableDataset):
    def __init__(self, conf):
        super(ELUC_Dataset).__init__()
        self.path = conf["data_source_path"]
        self.countries = conf["countries"]
        self.target = conf["target"]
        # check if eluc dataset is already on disk
        if os.path.exists(conf["data_source_path"]):
            self.ds = xr.open_dataset(conf["data_source_path"], engine='zarr', chunks={})
        else: 
            self.ds = xr.open_dataset(conf["data_source_url"], engine='zarr', chunks={})
            self.ds.to_zarr(conf["data_source_path"], consolidated=True, compute=True)
        
        self.var_names = list(self.ds.data_vars.keys())
        
        self.ds = self.ds.stack(latlon=('lat', 'lon'))
        country_mask = regionmask.defined_regions.natural_earth_v5_0_0.countries_110.mask(self.ds)
        self.ds = self.ds.assign_coords({"country":country_mask})
        df_countries = regionmask.defined_regions.natural_earth_v5_0_0.countries_110.to_dataframe()
        country_numbers = list(df_countries.loc[df_countries.names.isin(conf["countries"])].index.values)
        self.ds = self.ds.where((self.ds.time >= conf["start_year"]) & 
                                (self.ds.time <= conf["end_year"]) & 
                                (self.ds.country.isin(country_numbers)), 
                                drop=True)
        self.ds = xr.concat([self.ds[var] for var in self.ds.data_vars], dim='variable')
        self.ds = self.ds.assign_coords(variable=self.var_names)
        self.ds = self.ds.transpose("time", "latlon", "variable")
        self.bgen = xbatcher.BatchGenerator(
            ds=self.ds,
            input_dims={"time": conf["n_timepoint_in_each_sample"]+1, 
                        "latlon": conf["n_latlon_in_each_sample"],
                        "variable": len(self.var_names)},
#            batch_dims={"time": conf["n_timepoint_in_each_batch"], 
#                        "latlon": conf["n_latlon_in_each_batch"]},
#            batch_dims={"time": conf["n_timepoint_in_each_batch"]},
#            concat_input_dims=False,
            input_overlap={"time": conf["input_overlap"]})
        display(self.ds)
        
    def __iter__(self):
#        firstb = 0    
        for batch in self.bgen:
            np_batch = np.squeeze(batch.fillna(0.0).values, axis=None)
            np_batch = np.nan_to_num(np_batch, copy=False, nan=0.0, posinf=0.0, neginf=0.0)
            pt_batch = torch.from_numpy(np_batch)
            if int(torch.count_nonzero(pt_batch))==0:  # ignore batches w zeros only
                continue
#            if firstb == 0:  # only for debugging
#                display(batch)
#            firstb = 1
            pt_features = pt_batch[:-1,:,:].permute((1, 0, 2)) # all years except the last (0 : year-1)
            pt_target = pt_batch[-1:,:,self.var_names.index(self.target)].permute((1, 0)) # select target variable (year)          
            yield pt_features, pt_target

Use ELUC_Dataset to iterate over data samples, it will provide a stream of data reading from file.

More information available:

- [Pytorch IterableDataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset)
- [xbatcher: Batch Generation from Xarray Datasets](https://xbatcher.readthedocs.io/en/latest/)

Load `merged_aggregated_dataset.zarr` and iterate ELUC dataset in batches.  

In [None]:
# ELUC dataset
training_data = ELUC_Dataset(train_config)
train_dataloader = DataLoader(training_data)

## Model training

Trained a linear regression model to forecast the values of ELUC change in the next year.
The training period is between the year 2000 until the year 2020.

In [None]:
# Linear regression model
class LinearRegressionModel(torch.nn.Module):
    def __init__(self):
        super(LinearRegressionModel, self).__init__()
        self.linear = torch.nn.Linear(280,1) 
    def forward(self, x):
        x = self.linear(x)
        return x

In [None]:
crit = torch.nn.MSELoss()
model_lin = LinearRegressionModel()
optim_lin = torch.optim.Adam(model_lin.parameters(), lr=0.001)

np.random.seed(12)

epochs = 2
for epoch in range(epochs):
    print(f"epoch: {epoch+1}")
    for i, (ts, target) in enumerate(train_dataloader, 1):
        ts = ts.reshape(128, 28 * 10)
        target = target.reshape(128,1)
        optim_lin.zero_grad() # zero the parameter gradients
        outputs_lin = model_lin(ts)    # forward 
        loss_lin = crit(outputs_lin, target)
        loss_lin.backward()                     # backward
        optim_lin.step()                        # optimize
        if i == 1: 
            print(f"batch {i} | loss: {loss_lin.item()} ")
        elif i % 10 == 0:
            print(f"batch {i} | loss: {loss_lin.item()} ")

## Model testing

Check `ELUC_diff` forecasts for the years 2021 and 2022.

In [None]:
test_config = train_config.copy()

for year in test_config["test_years"]:
    test_config["end_year"] = year
    test_config["start_year"] = test_config["end_year"] - test_config["n_timepoint_in_each_sample"]
    test_data = ELUC_Dataset(test_config)
    test_dataloader = DataLoader(test_data, batch_size=test_config["batch_size"])
    test_loss = []
    for i, (ts, target) in enumerate(test_dataloader, 1):
        ts = ts.reshape(128, 28 * 10)
        target = target.reshape(128,1)
        outputs_test = model_lin(ts)    # forward 
        loss = crit(outputs_test, target)
        test_loss.append(loss.item())
    print(f"Average loss for {year}: {np.mean(test_loss)}")

## Features importance

In [None]:
# Generate features names: from year_9 until year
shifted_vars = ["_" + str(s) for s in range(10)][::-1]
shifted_vars[-1] = ""

variable_names = []
for shift in shifted_vars:
    variable_names.extend([var + shift for var in test_data.var_names])

In [None]:
# Features importance dataframe
features_importance = pd.DataFrame({"Features":variable_names, 
                                    "Importance": model_lin.linear.weight.data.tolist()[0]},
                                   index=variable_names)

In [None]:
# Plot centile 0.01 and 0.99
c95 = features_importance.Importance.quantile(q=0.95)
c05 = features_importance.Importance.quantile(q=0.05)
features_importance.loc[(features_importance.Importance<=c05) | 
                        (features_importance.Importance>=c95)
                       ].sort_values("Importance").plot(kind='bar')