In [1]:
# Imports for data handling:
import pandas as pd
import numpy as np

from torch.utils.data import Dataset

from data import grab_dataset

# Imports for model training:
import torch
from torch import nn

In [2]:
class EarthSystemsDataset(Dataset):
    '''
    pyTorch Dataset to supply a neural network with time series data
    '''

    def __init__(self, data_var_names, timeframe='monthly'):
        self.data_var_names = data_var_names
        raw_datasets = [grab_dataset(var_name, timeframe=timeframe) for var_name in data_var_names]
        self.data = EarthSystemsDataset.trim_data(raw_datasets) 


    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, index):
        return self.data.iloc[index]

    @staticmethod
    def trim_data(all_data):
        '''
        This function trims the time series data so that they start and end on the same date.

        :all_data (list-like of pd.DataFrame): List of DataFrames to trim

        :return: unified DataFrame of all the trimmed data
        '''

        trimmed_data = []
        for df in all_data:
            trimmed_data.append(
                df.set_index(['year', 'month'])
            )

        return trimmed_data[0].join(trimmed_data[1:], how='inner')
            

In [3]:
class EarthSystemsNN(nn.Module):
    def __init__(self, in_size, sequence):
        self.network = sequence

    def forward(self, x):
        return self.network(x)

In [4]:
data_var_names = ['global_temp', 'electricity', 'co2', 'ch4']
d = EarthSystemsDataset(data_var_names)

In [5]:
d.data

Unnamed: 0_level_0,Unnamed: 1_level_0,temp_change,elec_generation,co2_average,ch4_average
year,month,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
1984,Jan,0.31,216.924,344.32,1638.76
1985,Jan,0.22,228.148,345.35,1655.58
1986,Jan,0.26,217.761,347.11,1666.27
1987,Jan,0.32,223.041,348.02,1679.37
1988,Jan,0.57,238.188,350.91,1692.08
...,...,...,...,...,...
2018,Dec,0.91,342.292,409.19,1866.04
2019,Dec,1.09,338.536,411.76,1874.69
2020,Dec,0.80,344.523,414.14,1891.80
2021,Dec,0.86,337.104,416.60,1908.78
