# PyTorch Dataloader for the International Tree-Ring Data Bank (ITRDB)

This data works by wrapping our parsed ITRDB data in a PyTorch Dataset to be used for PyTorch Neural Networks. The parsed data is publicly hosted on an AWS S3 bucket, and is retrieved simply through the Python requests library. The Dataset will also cache the created dataframe, so the API request will only need to be made once per session, enabling you to create multiple Datasets (train, test, and validate) with little to no wait time. For the sake of simplicity, this Dataset will also limit the tree ring widths to between the years 1900-2023, and will then drop any rings that have 0 measurements between that time (row is all NaN between 1900-2023)

Import necessary dependencies

In [82]:
import torch
import requests
import pandas as pd
from io import StringIO
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

The Dataset Class itself, we retrieve the data from S3, cache it, and write a split function that can produce a 70-15-15 split on the data, only saving the set type requested in the constructor's parameters:

In [133]:
class TreeRingDataset(Dataset):

  _cache = None
  _train = None
  _test = None
  _validate = None

  def __init__(self, set_type="train"):

    if TreeRingDataset._cache is None:
      res = requests.get("https://paleo-data.s3.amazonaws.com/data.csv")
      TreeRingDataset._cache = pd.read_csv(StringIO(res.text), sep=",")

    self.df = TreeRingDataset._cache.copy()

    self.df.drop(self.df.columns[list(range(1, 1940))], axis=1, inplace=True)
    self.df.dropna(subset=self.df.columns[1:52], how='any', inplace=True)

    if type(TreeRingDataset._train) == type(None):
      print("Performing a Split")
      TreeRingDataset._train, TreeRingDataset._test, TreeRingDataset._validate = self.__split()

    self.df = TreeRingDataset._train if set_type == "train" else (TreeRingDataset._test if set_type == "test" else TreeRingDataset._validate)


  def __split(self):
    train, test = train_test_split(self.df, train_size=.70)
    validate, test = train_test_split(test, train_size=.5)
    return (train, test, validate)

  def __len__(self):
    return self.df.shape[0]

  def __getitem__(self, index):
    x = torch.tensor(self.df.iloc[index, 1:52])
    latn, lats, lone, lonw = self.df.iloc[index, 87:91]
    label = torch.tensor([(latn+lats)/2, (lone+lonw)/2])

    return x, label

Then this can be taken and retrieved for each train, test, and validate set, which will only need to send 1 API request and thus finish fairly quickly:

In [134]:
train_data = TreeRingDataset(
    set_type="train"
)

test_data = TreeRingDataset(
    set_type="test"
)

validate_data = TreeRingDataset(
    set_type="validate"
)

Performing a Split


Finally, we can wrap it in a Dataloader:

In [135]:
train_dataloader = DataLoader(train_data, batch_size=32)
test_dataloader = DataLoader(test_data, batch_size=32)
validate_dataloader = DataLoader(validate_data, batch_size=32)

And just an example of what the data looks like, we can display the first batch:

In [136]:
train_features, train_labels = next(iter(train_dataloader))
print(train_features)
print(train_labels)

tensor([[2.0700, 1.8200, 1.9400,  ..., 1.5900, 0.8500, 1.3300],
        [0.6320, 1.0950, 1.0200,  ..., 0.5740, 0.6440, 0.5190],
        [0.6200, 0.7100, 0.4600,  ..., 0.4900, 0.5500, 0.5000],
        ...,
        [0.2700, 0.2500, 0.2400,  ..., 0.4400, 0.3200, 0.3200],
        [0.4000, 0.3400, 0.4400,  ..., 0.2300, 0.0900, 0.3700],
        [0.6030, 0.6760, 0.3060,  ..., 0.2580, 0.4550, 0.3760]],
       dtype=torch.float64)
tensor([[  36.6000, -118.7000],
        [  38.3800, -108.0200],
        [  41.8700, -110.8000],
        [  40.1417, -111.3333],
        [  40.0500, -108.3000],
        [  48.6800, -120.6300],
        [  24.7140,  -81.3850],
        [  39.8300, -108.2000],
        [  37.8300, -119.2200],
        [  37.2000, -112.8000],
        [  48.6800, -120.6300],
        [  43.3042, -110.6711],
        [  58.4410, -135.6090],
        [  37.6596, -112.8560],
        [  34.9738,  -77.1201],
        [  44.6000, -110.4000],
        [  44.9140, -109.5730],
        [  34.8370, -119.0480]