# Machine learning example with OSTIA

This is based off ERA example, and isn't working yet: just a proof of concept of loading/using data, the results are garbage.

There is no download script yet, the data can be downloaded manually from https://data.marine.copernicus.eu/product/SST_GLO_SST_L4_NRT_OBSERVATIONS_010_001/files?subdataset=METOFFICE-GLO-SST-L4-NRT-OBS-SST-V2, or using copernicus API (not tested yet).

In [1]:
import xarray as xr
import glob
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import cartopy
import cartopy.crs as ccrs
from tqdm import tqdm

In [None]:
# combine different dates in one dataset
path = "./data"
nc_files = glob.glob(os.path.join(path, "*.nc"))

datasets = [xr.open_dataset(f, engine='netcdf4') for f in nc_files]

combined_ds = xr.concat(datasets, dim='time')
combined_ds

In [None]:
# subset original dataset to one degree to save space
subset_ds = combined_ds.isel(lat=slice(None, None, 20), lon=slice(None, None, 20))
subset_ds

In [None]:
# Plot sample SST
fig, ax = plt.subplots(figsize=(8, 5),
                       subplot_kw={'projection': ccrs.PlateCarree()})

pcolor_plot = ax.pcolormesh(subset_ds["lon"],
                            subset_ds["lat"],
                            subset_ds["analysed_sst"][5,:,:],
                            transform=ccrs.PlateCarree(),
                            cmap='viridis')
ax.coastlines()
ax.gridlines(draw_labels=True, dms=True, x_inline=False, y_inline=False)
cbar = fig.colorbar(pcolor_plot, ax=ax, shrink=0.75, orientation='horizontal')
ax.set_title('Subsetted SST')

In [None]:
# Plot sample ice concentration
fig, ax = plt.subplots(figsize=(8, 5),
                       subplot_kw={'projection': ccrs.NorthPolarStereo()})

pcolor_plot = ax.pcolormesh(subset_ds["lon"],
                            subset_ds["lat"],
                            subset_ds["sea_ice_fraction"][5,:,:],
                            transform=ccrs.PlateCarree(),
                            cmap='jet')
ax.coastlines()
ax.gridlines(draw_labels=True, dms=True, x_inline=False, y_inline=False)
ax.set_extent([-180, 180, 60, 90], ccrs.PlateCarree())
cbar = fig.colorbar(pcolor_plot, ax=ax, shrink=0.75, orientation='horizontal')
ax.set_title('Subsetted sea ice fraction')

In [None]:
# Define hyper parameters
config = {
    'batch_size': 8,
    'num_epochs': 50,
    'learning_rate': 1e-3,
    'test_size': 0.2,
}

In [None]:
# Variables to use (ignore the mask and analysis error)
data_vars = ['analysed_sst', 'sea_ice_fraction']

In [None]:
# Initialize the StandardScaler
scaler = StandardScaler()

# Extract variables from dataset and stack them into a numpy array (time, lat, long, vars)
data = np.stack([subset_ds[var].values for var in data_vars], axis=-1)

# Reshape the data for StandardScaler (it expects 2D, so combine lat, lon, and vars)
# Reshape to (time, lat*lon*vars) for scaling, later we'll reshape back
n_time, n_lat, n_lon, n_vars = data.shape
reshaped_data = data.reshape(n_time, -1)

# Ignore nans when standardizing
# AS: this is added compared to the ERA example and hasn't been properly debugged yet
nan_mask = np.isnan(reshaped_data[0,:])
reshaped_data_clean = reshaped_data[:,~nan_mask]

# Fit and transform the data using StandardScaler
scaled_data_clean = scaler.fit_transform(reshaped_data_clean)

# Put rescaled data back with nans
scaled_data = reshaped_data.copy()
scaled_data[:,~nan_mask] = scaled_data_clean
# put zeros where there were nans
scaled_data[:,nan_mask] = 0.0

# Reshape back to original (time, lat, long, vars) shape
standardized_data = scaled_data.reshape(n_time, n_lat, n_lon, n_vars)

# Split the data into training and test sets
X_train, X_test = train_test_split(standardized_data, test_size=config['test_size'])

# Convert to PyTorch tensors and change dimensions to (time, vars, lat, long)
tensor_data_train = torch.Tensor(X_train).permute(0, 3, 1, 2)
tensor_data_test = torch.Tensor(X_test).permute(0, 3, 1, 2)

# Create TensorDataset
tensor_dataset_train = TensorDataset(tensor_data_train)
tensor_dataset_test = TensorDataset(tensor_data_test)

# Create DataLoaders
train_loader = DataLoader(tensor_dataset_train, batch_size=config['batch_size'], shuffle=True)
test_loader = DataLoader(tensor_dataset_test, batch_size=config['batch_size'], shuffle=True)

In [None]:
# Determine the device to use
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA device")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS device")
else:
    device = torch.device("cpu")
    print("Using CPU device")

# Define the model
class ConvAutoencoder(nn.Module):
    def __init__(self, input_size):
        super(ConvAutoencoder, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(input_size[0], 16, kernel_size=3, stride=(1,2), padding=1),
            nn.ReLU(True),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(True),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(True),
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),  
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),  
            nn.ReLU(True),
            nn.ConvTranspose2d(16, input_size[0], kernel_size=3, stride=(1,2), padding=1, output_padding=(0, 1)),  
        )


    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)

        return x

dummy_input= next(iter(test_loader))[0].to(device)
model = ConvAutoencoder(dummy_input[0].shape).to(device)
print("Input shape:", dummy_input.shape)
print("Output shape:", model(dummy_input).shape)

In [None]:
%%time

# Define the loss function and the optimizer
criterion = nn.MSELoss()  # Mean Squared Error for reconstruction
optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'])

num_epochs = config['num_epochs']

# Training loop
for epoch in range(num_epochs):
    model.train()  
    train_loss = 0.0

    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{config['num_epochs']-1}", leave=False)
    for data in progress_bar:
        inputs = data[0].to(device)
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, inputs)  # Reconstruction loss
        train_loss += loss.item()

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        progress_bar.set_postfix({'Train Loss': loss.item()})

    # Calculate average training loss
    train_loss /= len(train_loader)

    # Evaluate on the test set
    model.eval()  
    test_loss = 0.0
    with torch.no_grad():  
        for data in test_loader:
            inputs = data[0].to(device)
            outputs = model(inputs)
            loss = criterion(outputs, inputs)
            test_loss += loss.item()

    # Calculate average test loss
    test_loss /= len(test_loader)

    # Write test loss to tensorboard

    print(f"Epoch [{epoch}/{num_epochs-1}], Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")

print("Training completed.")

In [None]:
# Run the model on one example
test_data = next(iter(test_loader))[0].to(device)

# Set the model in evaluation mode and turn off gradient calculations
model.eval()
with torch.no_grad():
    # Pass the test data through the model to get the reconstruction
    reconstructed_data = model(test_data)

def data2xarray(dataset, data):
    for i, var in enumerate(data_vars):
        getattr(dataset, var).values = data[i]
    return dataset

# Create a figure with Cartopy projections for both subplots
fig, axes = plt.subplots(2, 1, figsize=(10, 12),
                         subplot_kw={'projection': ccrs.PlateCarree()})

# Create the original and reconstructed xarray datasets
original_data_xr = data2xarray(subset_ds.isel(time=0).copy(), test_data[0].to('cpu'))
reconstructed_data_xr = data2xarray(subset_ds.isel(time=0).copy(), reconstructed_data[0].to('cpu'))

# Plot original data with Cartopy features
original_data_xr.analysed_sst.plot(ax=axes[0], transform=ccrs.PlateCarree(), cmap='viridis')
axes[0].coastlines()
axes[0].gridlines(draw_labels=True, dms=True, x_inline=False, y_inline=False)
axes[0].set_title('Original Data')

# Plot reconstructed data with Cartopy features
reconstructed_data_xr.analysed_sst.plot(ax=axes[1], transform=ccrs.PlateCarree(), cmap='viridis')
axes[1].coastlines()
axes[1].gridlines(draw_labels=True, dms=True, x_inline=False, y_inline=False)
axes[1].set_title('Reconstructed Data')

# Display the plot
plt.tight_layout()
plt.show()