In [None]:
import xarray as xr
import numpy as np

import os
import torch
import tqdm
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms

In [None]:
class TemperatureData():
    # historical January 1850 - December 2013
    # projection January 2014 - December 2150
    def __init__(self, historical_path, future_path):
        self.historical = xr.open_dataset(historical_path)
        self.projection = xr.open_dataset(future_path)

    def __len__(self):
        return self.historical.time.size + self.projection.time.size

    def __getitem__(self, idx):
        
        
        if idx['time'] < 0 or idx['time'] >= self.historical.time.size + self.projection.time.size:
            raise IndexError("Time out of range")
        
        #time format conversion is needed
        if idx['time'] < self.historical.time.size:
            deviate_temp = self.historical['temperature'].isel(time=idx['time'], latitude=idx['latitude'], longitude=idx['longitude'])
        else:
            deviate_temp = self.projection['temperature'].isel(time=idx['time'] - self.historical.time.size, latitude=idx['latitude'], longitude=idx['longitude'])
        average_temp = self.historical['climatology'].isel(time=idx['time'] % 12 + 1, latitude=idx['latitude'], longitude=idx['longitude'])
        land_mask = self.historical['land_mask'].isel(latitude=idx['latitude'], longitude=idx['longitude'])
        
        state = dict()
        state['deviate_temp'] = deviate_temp
        state['average_temp'] = average_temp
        state['land_mask'] = land_mask
        
        return state

class GHGData():
    # historical January 1850 - December 2013
    # projection January 2014 - December 2150
    def __init__(self, historical_path, ssp_path):
        self.historical = xr.open_dataset(historical_path)
        self.projection = xr.open_dataset(ssp_path)

    def __getitem__(self, idx):
        if idx['time'] < 0 or idx['time'] >= self.historical.time.size + self.projection.time.size:
            raise IndexError("Time out of range")
        
        #time format conversion is needed
        if idx['time'] < self.historical.time.size:
            ppm = self.historical['value'].isel(Times=idx['time'], LatDim=idx['latitude'], LonDim=idx['longitude'])
        elif idx['time'] < self.historical.time.size + 12: # use 2013 value for 2014 (2014 is missing)
            ppm = self.historical['value'].isel(Times=self.historical.time.size - 12, LatDim=idx['latitude'], LonDim=idx['longitude'])
        else:
            ppm = self.projection.isel(time=idx['time'] - self.historical.time.size + 12, latitude=idx['latitude'], longitude=idx['longitude'])
        
        state = dict()
        state['co2_ppm'] = ppm
        
        return state

class GlobalClimateData(Dataset):
    def __init__(self, temperature_path, ghg_path):
        self.temperatureData = TemperatureData(temperature_path['historical'], temperature_path['projection'])
        self.ghgData = GHGData(ghg_path['historical'], ghg_path['projection'])
        
    def __len__(self):
        return len(self.temperatureData)
    
    def __getitem__(self, idx):
        temperature = self.temperatureData[idx]
        ghg = self.ghgData[idx]
        return temperature | ghg

In [None]:
import torch
import torch.nn as nn

class ClimateLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, output_dim):
        super(ClimateLSTM, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        # Initialize hidden state with zeros
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).to(x.device)
        # Initialize cell state
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).to(x.device)
        
        # We need to detach as we are making a new forward pass of all the batches,
        # otherwise the computational graph will become too big to fit into memory
        out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))
        out = self.fc(out[:, -1, :])  # we just need the last time step
        return out


In [None]:
# Example training loop
def train_model(model, data_loader, criterion, optimizer, num_epochs):
    for epoch in tqdm(range(num_epochs)):
        for inputs, labels in data_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')

# Set device
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

# Instantiate the model
input_dim = 2  # Example: temperature and GHG concentration
hidden_dim = 64
num_layers = 2
output_dim = 1  # Example: predicting temperature

model = ClimateLSTM(input_dim, hidden_dim, num_layers, output_dim).to(device)


In [None]:
temp_path = {
    'historical': '../../data/raw/globalTemperature/Land_and_Ocean_LatLong1.nc',
    'projection': new_file_path
}

ghg_path = {
    'historical': '../../data/raw/globalGhgEmissions/CO2_1deg_month_1850-2013.nc',
    'projection': '../../data/raw/globalGhgEmissions/CO2_SSP119_2015_2150.nc'
}

In [None]:
def predict_future(model, initial_data, steps):
    model.eval()
    predictions = []
    input_data = initial_data

    for _ in range(steps):
        with torch.no_grad():
            input_tensor = torch.tensor(input_data[-1]).unsqueeze(0).to(device)  # Assuming last sequence
            prediction = model(input_tensor)
            predictions.append(prediction.cpu().numpy())
            # Update your dataset here (pseudo code)
            # dataset.append(prediction)
            # input_data = update_input_data(input_data, prediction)

    return predictions
