In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.amp import autocast, GradScaler
from torch.utils.data import DataLoader, TensorDataset
import optuna
import netron
import xarray as xr

if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"CUDA Device: {torch.cuda.get_device_name(device)}")
else:
    print("CUDA is not available")


In [None]:
from datetime import datetime
# Function to log messages with timestamps
def log(msg):
    print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] {msg}")

In [None]:
import xarray as xr
import numpy as np

log("Starting data loading...")

# Paths to your NetCDF files
paths = "/mnt/Elements/training_data/era5_model_levels_*.nc"

# Open and combine all files
log("Opening and combining NetCDF files...")
ds = xr.open_mfdataset(paths, combine='by_coords')

# Extract temperature
log("Extracting temperature variable...")
temp = ds['t']  # adjust if needed

# Select model levels 61 to 137
log("Selecting model levels 61 to 137...")
temp = temp.isel(model_level=slice(-77, None))  #61 to 137

# Extract the hour from time and filter
log("Extracting valid_time hours...")
hours = temp['valid_time'].dt.hour

# Training: 00, 06, 12, 18
log("Filtering training times (00, 06, 12, 18)...")
train_hours = [0, 6, 12, 18]
train = temp.sel(valid_time=temp['valid_time'].where(hours.isin(train_hours), drop=True))

# Testing: 01, 07, 13, 19
log("Filtering testing times (01, 07, 13, 19)...")
test_hours = [1, 7, 13, 19]
test = temp.sel(valid_time=temp['valid_time'].where(hours.isin(test_hours), drop=True))

log("Done.")


[2025-04-13 09:08:20] Starting data loading...
[2025-04-13 09:08:20] Opening and combining NetCDF files...


  result = blockwise(


[2025-04-13 09:09:55] Extracting temperature variable...
[2025-04-13 09:09:55] Selecting model levels 61 to 137...
[2025-04-13 09:09:55] Extracting valid_time hours...
[2025-04-13 09:09:55] Filtering training times (00, 06, 12, 18)...
[2025-04-13 09:10:05] Filtering testing times (01, 07, 13, 19)...
[2025-04-13 09:10:16] Done.


In [None]:
import xarray as xr
import numpy as np
import xesmf as xe
import os
import hashlib

# Load ClimT grid
log("Loading ClimT grid...")
climt_load = np.load('/run/media/adamh/X6/test_levels/climt_lat_lon.npz')
climt_lat, climt_lon = climt_load['latitude'], climt_load['longitude']

# Create a new grid for the regridding
climt_grid = xr.Dataset(
    {
        'lat': (['lat'], climt_lat),
        'lon': (['lon'], climt_lon),
    }
)

def xr_regrid_all(ds, climt_grid, method='bilinear', weight_dir='/mnt/Elements/regrid_weights'):
    """
    Regrid the entire dataset at once, without looping over individual time steps.
    """
    # Create a unique hash for the grid config to name the weight file
    grid_id = hashlib.md5((str(ds['latitude'].values.tobytes()) +
                           str(ds['longitude'].values.tobytes()) +
                           str(climt_grid['lat'].values.tobytes()) +
                           str(climt_grid['lon'].values.tobytes()) +
                           method).encode()).hexdigest()
    
    os.makedirs(weight_dir, exist_ok=True)
    weight_path = os.path.join(weight_dir, f'{method}_{grid_id}.nc')

    log(f"Creating regridder with weights file: {weight_path}")
    regridder = xe.Regridder(ds, climt_grid, method=method, periodic=True,
                             filename=weight_path,
                             reuse_weights=os.path.exists(weight_path))

    log("Regridding entire dataset...")
    regridded_ds = regridder(ds)

    # Convert pressure levels (if present)
    if 'pressure_level' in regridded_ds.coords:
        regridded_ds = regridded_ds.assign_coords(
            pressure_level=np.flip(regridded_ds.pressure_level.values) * 100
        )

    log("Regridding complete.")
    return regridded_ds

# Regrid training and testing data
log("Regridding training data...")
train_regridded = xr_regrid_all(train, climt_grid)

log("Regridding testing data...")
test_regridded = xr_regrid_all(test, climt_grid)

log("Process complete.")


[2025-04-13 09:11:07] Loading ClimT grid...
[2025-04-13 09:11:07] Regridding training data...
[2025-04-13 09:11:07] Creating regridder with weights file: /mnt/Elements/regrid_weights/bilinear_0def5f10f091d2681bf0f92676d969c7.nc
[2025-04-13 09:11:20] Regridding entire dataset...


  result_var = func(*data_vars)


[2025-04-13 09:11:26] Regridding complete.
[2025-04-13 09:11:26] Regridding testing data...
[2025-04-13 09:11:26] Creating regridder with weights file: /mnt/Elements/regrid_weights/bilinear_0def5f10f091d2681bf0f92676d969c7.nc
[2025-04-13 09:11:26] Regridding entire dataset...


  result_var = func(*data_vars)


[2025-04-13 09:11:38] Regridding complete.
[2025-04-13 09:11:38] Process complete.


In [None]:
bleh = train_regridded.isel(valid_time=slice(0,5))
print(bleh.values)

In [None]:
import torch

def convert_to_tensor(temp_data):
    """
    Convert regridded xarray data to a tensor suitable for U-Net input.
    Assumes the xarray has shape (77, 64, 128) for temperature, with 77 being model levels.
    """
    # Extract temperature (assuming the variable name is 't')
    #temp_data = xr_data  # Replace 't' with the actual name of your variable if different

    # Ensure the data is in the shape (77, 64, 128)
    temp_data = temp_data.values  # This will give you a numpy array with shape (77, 64, 128)
    print(temp_data.shape)

    # Add the batch dimension (1), making it (1, 77, 64, 128)
    temp_tensor = torch.tensor(temp_data).unsqueeze(0)  # Adds an extra dimension for the channel

    return temp_tensor

# Convert training and testing data
log("Converting train_regridded...")
train_tensor = convert_to_tensor(train_regridded)  # Assuming 'train_regridded' is your training xarray
log("Converting test_regridded...")
test_tensor = convert_to_tensor(test_regridded)  # Assuming 'test_regridded' is your testing xarray

print(f"Train tensor shape: {train_tensor.shape}")
print(f"Test tensor shape: {test_tensor.shape}")


[2025-04-13 09:15:54] Converting train_regridded...


KeyboardInterrupt: 