In [1]:
import torch
import torch.nn as nn
import netCDF4 as nc
import numpy as np
import os
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
from datetime import datetime, timedelta, date
import xarray as xr

# Check that MPS is available
if not torch.backends.mps.is_available():
    if not torch.backends.mps.is_built():
        print("MPS not available because the current PyTorch install was not "
              "built with MPS enabled.")
    else:
        print("MPS not available because the current MacOS version is not 12.3+ "
              "and/or you do not have an MPS-enabled device on this machine.")

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("mps")
device

device(type='cuda')

In [3]:
class CompactAutoencoder(nn.Module):
    def __init__(self):
        super(CompactAutoencoder, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv3d(10, 16, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
            nn.ReLU(True),
            #nn.Conv3d(8, 16, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
            nn.ReLU(True),
        )

        # Bottleneck (no further reduction of dimensions)

        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose3d(16, 1, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
            nn.ReLU(True),
            #nn.ConvTranspose3d(8, 1, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
            #nn.Sigmoid()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

# Create autoencoder model and use all available GPUs
if torch.cuda.device_count() > 1:
    print("Using", torch.cuda.device_count(), "GPUs!")
    autoencoder = nn.DataParallel(CompactAutoencoder()).to(device)
else:
    autoencoder = CompactAutoencoder().to(device)

Using 2 GPUs!


In [4]:
class NetCDFDataset(Dataset):
    def __init__(self, root_dir, start_date, end_date, transform=False):
        self.root_dir = root_dir
        self.file_list = self.create_file_list(root_dir, start_date, end_date)
        self.mean = 278.83097 #279.9124
        self.std = 56.02780 #107.1107
        self.transform = transform

    @staticmethod
    def create_file_list(root_dir, start_date, end_date):
        file_list = []
        time_step = timedelta(days=1)
        current_date = start_date

        while current_date <= end_date:
            for hh in ['00', '06', '12', '18']:
                filename = f'{os.path.basename(root_dir)}.{current_date.strftime("%Y%m%d")}{hh}.nc'
                file_list.append(filename)
            current_date += time_step
        
        return file_list

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = os.path.join(self.root_dir, self.file_list[idx])
        
        # Load NetCDF data
        dataset = xr.open_dataset(file_path)
        
        # Get a list of all variable names in the dataset
        variable_names = dataset.data_vars
        
        all_data = np.zeros((len(variable_names), 129, 181, 360), dtype=np.float32)
        
        for idx, var_name in enumerate(variable_names):
            data = dataset.variables[var_name][:].astype(np.float32)
            data = data.fillna(0)
            all_data[idx, :, :, :] = data  # Fill the array with data for each variable

        dataset.close()

        # data = dataset.variables['t2m'][:].astype(np.float32)  # Adjust 'data' to the variable name in your file
        # dataset.close()
        
        # Reshape the data to (1, 50, 721, 1440)
        # data = data.reshape(1, 50, 721, 1440)

        if self.transform:
            data = self.normalize_data(data)  # Normalize the data if transform is True

        return torch.tensor(data)

    def normalize_data(self, data):
        data = (data - self.mean) / self.std
        return data

    def rescale_data(self, data):
        data = (data * self.std) + self.mean
        return data


In [5]:
def check_missing_files(start_date, end_date, gfs_directory, era5_directory):
    time_step = timedelta(days=1)
    current_date = start_date
    total_missing_files = 0

    while current_date <= end_date:
        date_str = current_date.strftime("%Y%m%d")
        for hour_str in ['00', '06', '12', '18']:
            gfs_file_name = f"GFS.{date_str}{hour_str}.nc"
            gfs_file_path = os.path.join(gfs_directory, gfs_file_name)

            era5_file_name = f"ERA5.{date_str}{hour_str}.nc"
            era5_file_path = os.path.join(era5_directory, era5_file_name)

            if not os.path.exists(gfs_file_path):
                print(f"Missing file in GFS directory: {gfs_file_name}")
                total_missing_files += 1

            if not os.path.exists(era5_file_path):
                print(f"Missing file in ERA5 directory: {era5_file_name}")
                total_missing_files += 1

        current_date += time_step

    print(f"Total number of missing files: {total_missing_files}")

In [6]:
def calculate_mean_and_std(root_dir, start_date, end_date):
    time_step = timedelta(days=1)
    current_date = start_date
    total_count = 0
    total_mean = 0.0
    total_var = 0.0

    while current_date <= end_date:
        for hour in ['00', '06', '12', '18']:
            filename = f"GFS.t2m.{current_date.strftime('%Y%m%d')}{hour}.nc"
            file_path = os.path.join(root_dir, filename)

            if os.path.exists(file_path):
                dataset = nc.Dataset(file_path)
                data = dataset.variables['t2m'][:]  # Adjust this to your variable name
                dataset.close()

                current_mean = np.mean(data)
                total_mean = (total_count * total_mean + len(data) * current_mean) / (total_count + len(data))
                total_var = (total_count * total_var + np.sum((data - current_mean) ** 2)) / (total_count + len(data))
                total_count += len(data)

        current_date += time_step

    total_std = np.sqrt(total_var / total_count)
    return total_mean, total_std

In [7]:
# Define your data directories
gfs_root_dir = '../data/GFS'
era5_root_dir = '../data/ERA5'

# Define the start and end date for the dataset
start_date = date(2021, 3, 23)  # Adjust the start date
end_date = date(2023, 3, 23)    # Adjust the end date

check_missing_files(start_date, end_date, gfs_root_dir, era5_root_dir)

Total number of missing files: 0


mean_value, std_value = calculate_mean_and_std(gfs_root_dir, start_date, end_date)
print(f"Mean: {mean_value}, Standard Deviation: {std_value}")

In [8]:
# Create GFS and ERA5 datasets
gfs_dataset = NetCDFDataset(gfs_root_dir, start_date, end_date)
era5_dataset = NetCDFDataset(era5_root_dir, start_date, end_date)

gfs_dataset.file_list

In [9]:
# Create the shuffled indices for both datasets
shuffled_indices = torch.randperm(len(gfs_dataset))

# Apply shuffled indices to both datasets
gfs_dataset.file_list = [gfs_dataset.file_list[i] for i in shuffled_indices]
era5_dataset.file_list = [era5_dataset.file_list[i] for i in shuffled_indices]

era5_dataset.file_list

In [10]:
batch_size = 8
shuffle = False
num_workers = 0
seed = 42
torch.manual_seed(seed)

# Define the loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=0.0001)

# print(gfs_dataset)
gfs_data_loader = DataLoader(gfs_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
era5_data_loader = DataLoader(era5_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)

# Access the dataset from the DataLoader
dataset = gfs_data_loader.dataset

# Retrieve the file list from the dataset
file_list = dataset.file_list  # Access the file list attribute (change 'file_list' to your dataset attribute)
print(file_list)  # Display the file list

In [None]:
# Training loop with a custom progress bar
num_epochs = 50
for epoch in range(num_epochs):
    autoencoder.train()
    total_loss = 0.0

    # Create a custom progress bar for the epoch
    progress_bar = tqdm(enumerate(zip(gfs_data_loader, era5_data_loader)), total=len(gfs_data_loader), desc=f'Epoch [{epoch+1}/{num_epochs}]', dynamic_ncols=True)
    for batch_idx, (gfs_data, era5_data) in progress_bar:
        optimizer.zero_grad()
        outputs = autoencoder(gfs_data.to(device))
        loss = criterion(outputs, era5_data.to(device))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    progress_bar.close()  # Close the custom progress bar

    # Calculate and print the average loss for the epoch
    avg_loss = total_loss / len(gfs_data_loader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}')

    # Save the trained model
    torch.save(autoencoder.module.state_dict() if isinstance(autoencoder, nn.DataParallel) else autoencoder.state_dict(), f'autoencoder_model_epoch_{epoch+1}.pth')