In [1]:
import xarray as xr
import numpy as np

import os
import torch
import math
import tqdm

from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms

In [2]:
class ClimateData(Dataset):
    def __init__(self, temperature_path, co2_path, grid_size=(5, 5)):
        self.temperature = xr.open_dataset(temperature_path)
        self.co2 = xr.open_dataset(co2_path)
        self.grid_size = grid_size
        self.lat_bins = np.arange(-90, 90, self.grid_size[0])
        self.lon_bins = np.arange(-180, 180, self.grid_size[1])
        
        # Reindex time for both datasets to be zero-based
        self.temperature = self.temperature.assign_coords(time=np.arange(len(self.temperature.time)))
        self.co2 = self.co2.assign_coords(Times=np.arange(len(self.co2.Times)))

        # Adjust the max time index to the smallest dataset size for safe indexing
        self.max_time_index = min(len(self.temperature.time), len(self.co2.Times))

    def __len__(self):
        return self.max_time_index * (len(self.lat_bins) * len(self.lon_bins))

    def __getitem__(self, idx):
        time_index = idx // (len(self.lat_bins) * len(self.lon_bins))
        month_index = time_index % 12  # Month index from 0 to 11

        spatial_index = idx % (len(self.lat_bins) * len(self.lon_bins))
        lat_index = spatial_index // len(self.lon_bins)
        lon_index = spatial_index % len(self.lon_bins)

        lat_start = self.lat_bins[lat_index]
        lon_start = self.lon_bins[lon_index]
        lat_slice = slice(lat_start, lat_start + self.grid_size[0])
        lon_slice = slice(lon_start, lon_start + self.grid_size[1])

        temp_data = self.temperature.isel(time=time_index, latitude=lat_slice, longitude=lon_slice)
        if np.any(np.isnan(temp_data['temperature'].values)):
            return None  # Skip null data

        # Handle missing CO2 data by using the last available data
        if time_index < len(self.co2.Times):
            co2_data = self.co2.isel(Times=time_index, LatDim=lat_slice, LonDim=lon_slice)
        else:
            co2_data = self.co2.isel(Times=-1, LatDim=lat_slice, LonDim=lon_slice)

        clim_data = self.temperature['climatology'].isel(month_number=month_index + 1, latitude=lat_slice, longitude=lon_slice)

        # Cyclic representation of month for seasonal effect
        month_sin = np.sin(2 * np.pi * month_index / 12)
        month_cos = np.cos(2 * np.pi * month_index / 12)

        # Combine temperature, climatology, latitude info, land mask, CO2, and cyclical month into a single state dictionary
        state = {
            'temperature': torch.tensor(temp_data['temperature'].values, dtype=torch.float),
            'climatology': torch.tensor(clim_data.values, dtype=torch.float),
            'latitude_info': torch.tensor(lat_start + self.grid_size[0] / 2, dtype=torch.float),
            'land_mask': torch.tensor(temp_data['land_mask'].values, dtype=torch.float),
            'co2': torch.tensor(co2_data['value'].values, dtype=torch.float),
            'month_sin': torch.tensor(month_sin, dtype=torch.float),
            'month_cos': torch.tensor(month_cos, dtype=torch.float)
        }
        return state


In [3]:
class ClimateLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, output_dim):
        super(ClimateLSTM, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        out, _ = self.lstm(x)
        out = self.fc(out[:, -1, :])  # Only take the last time step
        return out


In [4]:
import os
import tqdm


def train_model(model, data_loader, criterion, optimizer, num_epochs, checkpoint_path="checkpoints"):
    model.train()
    os.makedirs(checkpoint_path, exist_ok=True)  # Create checkpoint directory if it doesn't exist

    for epoch in range(num_epochs):
        for batch in tqdm.tqdm(data_loader):
            if batch is None:
                continue  # Skip batches with null data

            # Assuming each batch dictionary contains the necessary tensors already moved to the correct device
            features = torch.cat(
                [
                    batch["temperature"],
                    batch["climatology"],
                    batch["latitude_info"].unsqueeze(-1),  # Ensure dimensions are correct for concatenation
                    batch["land_mask"],
                    batch["co2"],
                    batch["month_sin"].unsqueeze(-1),
                    batch["month_cos"].unsqueeze(-1),
                ],
                dim=-1,
            )

            labels = batch["temperature"]  # Assuming temperature is the label to predict; adjust as necessary

            optimizer.zero_grad()
            outputs = model(features)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

        # Save checkpoint after each epoch
        checkpoint_filename = os.path.join(
            checkpoint_path, f"model_epoch_{epoch+1}.pth"
        )
        torch.save(model.state_dict(), checkpoint_filename)
        print(f"Checkpoint saved to {checkpoint_filename}")

    # Save final model
    final_model_path = os.path.join(checkpoint_path, "final_model.pth")
    torch.save(model.state_dict(), final_model_path)
    print(f"Final model saved to {final_model_path}")

In [5]:
temperature_path = "/Users/jinho/Desktop/climatePrediction/data/raw/globalTemperature/Land_and_Ocean_LatLong1.nc"
co2_path = "/Users/jinho/Desktop/climatePrediction/data/raw/globalGhgEmissions/CO2_1deg_month_1850-2013.nc"

In [6]:
# Define the dataset and DataLoader
dataset = ClimateData(temperature_path, co2_path, grid_size=(5, 5))
data_loader = DataLoader(dataset, batch_size=10, shuffle=True)

# Define the model
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")
input_dim = 7  # Update based on number of features used: temp, climatology, lat, land_mask, co2, month_sin, month_cos
hidden_dim = 64
num_layers = 2
output_dim = 1

model = ClimateLSTM(input_dim, hidden_dim, num_layers, output_dim).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
print('check')

# Start training
train_model(model, data_loader, criterion, optimizer, num_epochs=10)

Using mps device
check


  0%|          | 0/510106 [00:00<?, ?it/s]

: 