# Channel Prediction Model for Quality Control

This notebook contains the code and outlines the development of the
SSMI Channel Prediction Neural Network for QC purposes of SSMIS L1C
data. The simple Feed-Forward Neural Networks (FFNNs) are trained on
good quality SSMIS brightness temperature observation vectors and are
trained to predict one channel from all the others as predictors.

The end product is a tree of FFNNs, discretized by surface (ocean or 
land), and by channel, since these are fully distinct problems.

In [1]:


import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import glob
import torch
from torch import nn
import cartopy.crs as ccrs
from util_funcs.L1C import scantime2datetime
from util_funcs import data2xarray, array_funcs
import geography
from tqdm import tqdm

from dataset_class import dataset
from model_class import channel_predictor
import local_functions
import sensor_info


#General parameters:

satellite = sensor_info.satellite
sensor = sensor_info.sensor

nchans = sensor_info.nchannels
batch_size = 1000
input_size = nchans - 1
hidden_size = 256
output_size = 1

In [2]:
nchans

12

In [3]:
'''
Function Definitions:
'''

chan_desc = sensor_info.channel_descriptions


def extract_channel(Tb_array, chan):

    '''
    Use in preparing training data.

    Assumes Tb array is [m x n] where m (rows) are samples
    and n (columns) is the number of channels.

    Passing in the channel description splits the data so
    that the specified channel is its own vector y and the
    rest are kept as predictors x.

    Inputs:
        Tb_array    |  ndarray of Tbs
        chan        |  string of channel name
    Outputs:
        x           |  matrix of predictors (other channels)
        y           |  vector of predictands (the missing channel)
        
    '''
    
    chan_desc = np.array(sensor_info.channel_descriptions)


    chan_indx = np.where(chan == chan_desc)[0]

    if np.size(chan_indx) == 0:
        raise ValueError(f'Channel description must be in list {chan_desc}.')

    y = Tb_array[:,chan_indx]
    x = Tb_array[:,np.delete(np.arange(0,len(chan_desc)),chan_indx)]

    return x, y


def train_model(model, nepochs, dataloader, learning_rate=0.001, quiet=False, stage=None, validation_dataloader=None):

    nbatches = len(dataloader)
    
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=nepochs)
    device = None

    if stage:
        print(f'Training stage: {stage}')

    loss_arr    = np.zeros([nbatches,nepochs], dtype='f')
    valloss_arr = np.zeros([nbatches,nepochs], dtype='f')
    
    for epoch in range(nepochs):
        for i, (profs, obs) in enumerate(dataloader):
            if device:
                profs, obs = profs.to(device), obs.to(device)

            if validation_dataloader and i%1000==0:
                valprofs, valobs = next(enumerate(validation_dataloader))[1]
                val_pred = model(valprofs)
                valloss  = criterion(val_pred, valobs)
                print(f'Validation Loss = {valloss.item():.3f}')

            #Forward pass:
            obs_pred = model(profs)
            loss     = criterion(obs_pred, obs)


            #Backward pass:
            optimizer.zero_grad()
            loss.backward()

            #Update neurons:
            optimizer.step()

            loss_arr[i,epoch] = loss.item()
            valloss_arr[i,epoch] = valloss.item()
            
            if not quiet:
                if i%1000 == 0:
                    print(f'Channel {channel}, Epoch={epoch+1}, batch = {i} of {nbatches}, loss={loss.item():.3f}, LR={scheduler.get_last_lr()[0]}')
        
        scheduler.step()




    return loss_arr, valloss_arr



In [4]:
'''

AMSR2 CHANNEL PREDICTION MODEL:
    1: Ocean

'''

sfc = [1]

with xr.open_dataset(f'training_data/{satellite}_training_data.nc') as f:
    
    sfctype = f.sfctype.values

    correct_sfc = np.isin(sfctype, sfc)
    sfcindcs = np.where(correct_sfc)[0]

    Tbs = f.Tbs.values[sfcindcs]

#---Don't shuffle data before splitting
#np.random.seed(40)
#Tbs = array_funcs.shuffle_data(Tbs, axis=0)

if Tbs.shape[0] > 5.0e+06:
    Tbs = Tbs[:5_000_000,:]


print(Tbs.shape)

#---Split data into train/test/val:
train_indcs, test_indcs, val_indcs = local_functions.split_data_indcs(Tbs)

Tbs_train = Tbs[train_indcs]
Tbs_test  = Tbs[test_indcs]
Tbs_val   = Tbs[val_indcs]

#---Shuffle before converting to tensors:
np.random.seed(40)
Tbs_train = array_funcs.shuffle_data(Tbs_train, axis=0)


(5000000, 12)


In [7]:
'''
Predict channels: Train all models
'''

torch.set_num_threads(10)

for ichan, channel in enumerate(chan_desc):

    print(channel)

    #---Extract channel, x = predictors, y = channel to predict
    x_train, y_train = extract_channel(Tbs_train, channel)
    x_test,  y_test  = extract_channel(Tbs_test, channel)
    x_val,   y_val   = extract_channel(Tbs_val, channel)
    
    x_train, y_train = torch.tensor(x_train), torch.tensor(y_train)
    x_test,  y_test  = torch.tensor(x_test), torch.tensor(y_test)
    x_val,   y_val   = torch.tensor(x_val), torch.tensor(y_val)


    #---Set up dataloaders:
    train_loader = torch.utils.data.DataLoader(dataset=dataset(x_train,y_train), 
                                               batch_size=batch_size, 
                                               shuffle=True, drop_last=True)
    test_loader = torch.utils.data.DataLoader(dataset=dataset(x_test,y_test), 
                                               batch_size=None, 
                                               shuffle=False, drop_last=False)
    val_loader = torch.utils.data.DataLoader(dataset=dataset(x_val,y_val), 
                                               batch_size=None, 
                                               shuffle=False, drop_last=False)

    #---Create model:
    model = channel_predictor(input_size, hidden_size, output_size)

    #---Train_model:
    nbatches = len(train_loader)
    nepochs_stage1 = 5
    nepochs_stage2 = 10
    nepochs_stage3 = 20

    loss_stage1, valloss_stage1 = train_model(model, nepochs=5, dataloader=train_loader, 
                                          learning_rate=0.001, quiet=False, stage=1, validation_dataloader=val_loader)
    loss_stage2, valloss_stage2 = train_model(model, nepochs=10, dataloader=train_loader, 
                                          learning_rate=0.001, quiet=False, stage=2, validation_dataloader=val_loader)
    loss_stage3, valloss_stage3 = train_model(model, nepochs=20, dataloader=train_loader, 
                                          learning_rate=0.001, quiet=False, stage=3, validation_dataloader=val_loader)

    torch.save(model.state_dict(), f'models/{sensor}_{satellite}_channel_predictor_{channel}_ocean_no89.pt')

    loss_data = data2xarray(data_vars = (loss_stage1, loss_stage2, loss_stage3, 
                                     valloss_stage1, valloss_stage2, valloss_stage3),
                        var_names = ('LossStage1','LossStage2','LossStage3',
                                     'ValidationLossStage1', 'ValidationLossStage2', 'ValidationLossStage3'),
                        dims = (nbatches,nepochs_stage1,nepochs_stage2, nepochs_stage3),
                        dim_names = ('training_batches', 'epochs_stage1', 'epochs_stage2', 'epochs_stage3'))

    loss_data.to_netcdf(f'diagnostics/loss_data_{channel}_ocean_no89.nc', engine='netcdf4')

    print(f'Finished training model for channel {channel}.')

6V
Training stage: 1
Validation Loss = 31840.910
Channel 6V, Epoch=1, batch = 0 of 4000, loss=28001.924, LR=0.001
Validation Loss = 0.061
Channel 6V, Epoch=1, batch = 1000 of 4000, loss=0.624, LR=0.001
Validation Loss = 0.155
Channel 6V, Epoch=1, batch = 2000 of 4000, loss=0.389, LR=0.001
Validation Loss = 0.411
Channel 6V, Epoch=1, batch = 3000 of 4000, loss=0.332, LR=0.001
Validation Loss = 4.540
Channel 6V, Epoch=2, batch = 0 of 4000, loss=3.290, LR=0.0009045084971874737
Validation Loss = 0.717
Channel 6V, Epoch=2, batch = 1000 of 4000, loss=0.902, LR=0.0009045084971874737
Validation Loss = 3.651
Channel 6V, Epoch=2, batch = 2000 of 4000, loss=2.593, LR=0.0009045084971874737
Validation Loss = 1.624
Channel 6V, Epoch=2, batch = 3000 of 4000, loss=1.214, LR=0.0009045084971874737
Validation Loss = 0.069
Channel 6V, Epoch=3, batch = 0 of 4000, loss=0.243, LR=0.0006545084971874737
Validation Loss = 0.061
Channel 6V, Epoch=3, batch = 1000 of 4000, loss=0.237, LR=0.0006545084971874737
Vali

In [4]:
#General parameters:
input_size = nchans - 1 + 1 #Remaining channels + surface type

In [5]:
'''

SSMI CHANNEL PREDICTION MODEL:
    2: Non-Ocean Surfaces

'''

with xr.open_dataset(f'training_data/{satellite}_training_data.nc') as f:
    
    sfctype = f.sfctype.values

    correct_sfc = sfctype > 1

    Tbs = f.Tbs.values[correct_sfc,:]
    sfctype = sfctype[correct_sfc]

#---Split data into train/test/val:
train_indcs, test_indcs, val_indcs = local_functions.split_data_indcs(Tbs)

Tbs_train = Tbs[train_indcs]
Tbs_test  = Tbs[test_indcs]
Tbs_val   = Tbs[val_indcs]

sfctype_train = sfctype[train_indcs].astype(np.float32)
sfctype_test  = sfctype[test_indcs].astype(np.float32)
sfctype_val   = sfctype[val_indcs].astype(np.float32)

#---Shuffle before converting to tensors:
np.random.seed(40)
Tbs_train, shuffled_indcs = array_funcs.shuffle_data(Tbs_train, axis=0, return_indcs=True)
sfctype_train = sfctype_train[shuffled_indcs]

In [6]:
sfctype_train, sfctype_test, sfctype_val

(array([14., 11., 11., ..., 10., 15.,  2.],
       shape=(32384092,), dtype=float32),
 array([4., 4., 4., ..., 2., 2., 2.], shape=(4048011,), dtype=float32),
 array([ 2.,  2.,  2., ..., 11., 11., 11.], shape=(4048013,), dtype=float32))

In [8]:
chan_desc[10]

'37V'

In [None]:
'''
Predict channels: Train all models
'''

torch.set_num_threads(10)

for ichan, channel in enumerate(chan_desc):

    if ichan < 10: continue

    print(channel)

    #---Extract channel, x = predictors, y = channel to predict
    x_train, y_train = extract_channel(Tbs_train, channel)
    x_test,  y_test  = extract_channel(Tbs_test, channel)
    x_val,   y_val   = extract_channel(Tbs_val, channel)

    x_train = np.concatenate((x_train, sfctype_train[:,None]), axis=1)
    x_test  = np.concatenate((x_test,  sfctype_test[:,None]), axis=1)
    x_val   = np.concatenate((x_val,   sfctype_val[:,None]), axis=1)
    
    x_train, y_train = torch.tensor(x_train), torch.tensor(y_train)
    x_test,  y_test  = torch.tensor(x_test), torch.tensor(y_test)
    x_val,   y_val   = torch.tensor(x_val), torch.tensor(y_val)


    #---Set up dataloaders:
    train_loader = torch.utils.data.DataLoader(dataset=dataset(x_train,y_train), 
                                               batch_size=batch_size, 
                                               shuffle=True, drop_last=True)
    test_loader = torch.utils.data.DataLoader(dataset=dataset(x_test,y_test), 
                                               batch_size=None, 
                                               shuffle=False, drop_last=False)
    val_loader = torch.utils.data.DataLoader(dataset=dataset(x_val,y_val), 
                                               batch_size=None, 
                                               shuffle=False, drop_last=False)

    #---Create model:
    model = channel_predictor(input_size, hidden_size, output_size)

    #---Train_model:
    nbatches = len(train_loader)
    nepochs_stage1 = 5
    nepochs_stage2 = 10
    nepochs_stage3 = 20

    loss_stage1, valloss_stage1 = train_model(model, nepochs=5, dataloader=train_loader, 
                                          learning_rate=0.001, quiet=False, stage=1, validation_dataloader=val_loader)
    loss_stage2, valloss_stage2 = train_model(model, nepochs=10, dataloader=train_loader, 
                                          learning_rate=0.001, quiet=False, stage=2, validation_dataloader=val_loader)
    loss_stage3, valloss_stage3 = train_model(model, nepochs=20, dataloader=train_loader, 
                                          learning_rate=0.001, quiet=False, stage=3, validation_dataloader=val_loader)

    torch.save(model.state_dict(), f'models/{sensor}_{satellite}_channel_predictor_{channel}_nonocean_no89.pt')

    loss_data = data2xarray(data_vars = (loss_stage1, loss_stage2, loss_stage3, 
                                     valloss_stage1, valloss_stage2, valloss_stage3),
                        var_names = ('LossStage1','LossStage2','LossStage3',
                                     'ValidationLossStage1', 'ValidationLossStage2', 'ValidationLossStage3'),
                        dims = (nbatches,nepochs_stage1,nepochs_stage2, nepochs_stage3),
                        dim_names = ('training_batches', 'epochs_stage1', 'epochs_stage2', 'epochs_stage3'))

    loss_data.to_netcdf(f'diagnostics/loss_data_{channel}_nonocean_no89.nc', engine='netcdf4')

    print(f'Finished training model for channel {channel}.')

37V
Training stage: 1
Validation Loss = 44298.156
Channel 37V, Epoch=1, batch = 0 of 32384, loss=60272.070, LR=0.001
Validation Loss = 22.890
Channel 37V, Epoch=1, batch = 1000 of 32384, loss=5.276, LR=0.001
Validation Loss = 12.305
Channel 37V, Epoch=1, batch = 2000 of 32384, loss=5.592, LR=0.001
Validation Loss = 31.500
Channel 37V, Epoch=1, batch = 3000 of 32384, loss=4.799, LR=0.001
Validation Loss = 24.795
Channel 37V, Epoch=1, batch = 4000 of 32384, loss=3.815, LR=0.001
Validation Loss = 12.561
Channel 37V, Epoch=1, batch = 5000 of 32384, loss=5.312, LR=0.001
Validation Loss = 22.232
Channel 37V, Epoch=1, batch = 6000 of 32384, loss=4.074, LR=0.001
Validation Loss = 18.455
Channel 37V, Epoch=1, batch = 7000 of 32384, loss=4.297, LR=0.001
Validation Loss = 28.768
Channel 37V, Epoch=1, batch = 8000 of 32384, loss=4.007, LR=0.001
Validation Loss = 49.677
Channel 37V, Epoch=1, batch = 9000 of 32384, loss=6.916, LR=0.001
Validation Loss = 33.317
Channel 37V, Epoch=1, batch = 10000 of 