# Channel Prediction Model for Quality Control

This notebook contains the code and outlines the development of the
SSMI Channel Prediction Neural Network for QC purposes of SSMIS L1C
data. The simple Feed-Forward Neural Networks (FFNNs) are trained on
good quality SSMIS brightness temperature observation vectors and are
trained to predict one channel from all the others as predictors.

The end product is a tree of FFNNs, discretized by surface (ocean or 
land), and by channel, since these are fully distinct problems.

In [1]:


import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import glob
import torch
from torch import nn
import cartopy.crs as ccrs
from util_funcs.L1C import scantime2datetime
from util_funcs import data2xarray, array_funcs
import geography
from tqdm import tqdm

from dataset_class import dataset
from model_class import channel_predictor
import local_functions
import sensor_info


#General parameters:

satellite = sensor_info.satellite
sensor = sensor_info.sensor

nchans = sensor_info.nchannels
batch_size = 1000
input_size = nchans - 1
hidden_size = 256
output_size = 1

In [2]:
nchans

12

In [3]:
'''
Function Definitions:
'''

chan_desc = sensor_info.channel_descriptions


def extract_channel(Tb_array, chan):

    '''
    Use in preparing training data.

    Assumes Tb array is [m x n] where m (rows) are samples
    and n (columns) is the number of channels.

    Passing in the channel description splits the data so
    that the specified channel is its own vector y and the
    rest are kept as predictors x.

    Inputs:
        Tb_array    |  ndarray of Tbs
        chan        |  string of channel name
    Outputs:
        x           |  matrix of predictors (other channels)
        y           |  vector of predictands (the missing channel)
        
    '''
    
    chan_desc = np.array(sensor_info.channel_descriptions)


    chan_indx = np.where(chan == chan_desc)[0]

    if np.size(chan_indx) == 0:
        raise ValueError(f'Channel description must be in list {chan_desc}.')

    y = Tb_array[:,chan_indx]
    x = Tb_array[:,np.delete(np.arange(0,len(chan_desc)),chan_indx)]

    return x, y


def train_model(model, nepochs, dataloader, learning_rate=0.001, quiet=False, stage=None, validation_dataloader=None):

    nbatches = len(dataloader)
    
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=nepochs)
    device = None

    if stage:
        print(f'Training stage: {stage}')

    loss_arr    = np.zeros([nbatches,nepochs], dtype='f')
    valloss_arr = np.zeros([nbatches,nepochs], dtype='f')
    
    for epoch in range(nepochs):
        for i, (profs, obs) in enumerate(dataloader):
            if device:
                profs, obs = profs.to(device), obs.to(device)

            if validation_dataloader and i%1000==0:
                valprofs, valobs = next(enumerate(validation_dataloader))[1]
                val_pred = model(valprofs)
                valloss  = criterion(val_pred, valobs)
                print(f'Validation Loss = {valloss.item():.3f}')

            #Forward pass:
            obs_pred = model(profs)
            loss     = criterion(obs_pred, obs)


            #Backward pass:
            optimizer.zero_grad()
            loss.backward()

            #Update neurons:
            optimizer.step()

            loss_arr[i,epoch] = loss.item()
            valloss_arr[i,epoch] = valloss.item()
            
            if not quiet:
                if i%1000 == 0:
                    print(f'Channel {channel}, Epoch={epoch+1}, batch = {i} of {nbatches}, loss={loss.item():.3f}, LR={scheduler.get_last_lr()[0]}')
        
        scheduler.step()




    return loss_arr, valloss_arr



In [4]:
'''

AMSR2 CHANNEL PREDICTION MODEL:
    1: Ocean

'''

sfc = [1]

with xr.open_dataset(f'training_data/{satellite}_training_data.nc') as f:
    
    sfctype = f.sfctype.values

    correct_sfc = np.isin(sfctype, sfc)
    sfcindcs = np.where(correct_sfc)[0]

    Tbs = f.Tbs.values[sfcindcs]

#---Don't shuffle data before splitting
#np.random.seed(40)
#Tbs = array_funcs.shuffle_data(Tbs, axis=0)

if Tbs.shape[0] > 10.0e+06:
    Tbs = Tbs[:10_000_000,:]


print(Tbs.shape)

#---Split data into train/test/val:
train_indcs, test_indcs, val_indcs = local_functions.split_data_indcs(Tbs)

Tbs_train = Tbs[train_indcs]
Tbs_test  = Tbs[test_indcs]
Tbs_val   = Tbs[val_indcs]

#---Shuffle before converting to tensors:
np.random.seed(40)
Tbs_train = array_funcs.shuffle_data(Tbs_train, axis=0)


(10000000, 12)


In [5]:
'''
Predict channels: Train all models
'''

torch.set_num_threads(10)

for ichan, channel in enumerate(chan_desc):

    print(channel)

    #---Extract channel, x = predictors, y = channel to predict
    x_train, y_train = extract_channel(Tbs_train, channel)
    x_test,  y_test  = extract_channel(Tbs_test, channel)
    x_val,   y_val   = extract_channel(Tbs_val, channel)
    
    x_train, y_train = torch.tensor(x_train), torch.tensor(y_train)
    x_test,  y_test  = torch.tensor(x_test), torch.tensor(y_test)
    x_val,   y_val   = torch.tensor(x_val), torch.tensor(y_val)


    #---Set up dataloaders:
    train_loader = torch.utils.data.DataLoader(dataset=dataset(x_train,y_train), 
                                               batch_size=batch_size, 
                                               shuffle=True, drop_last=True)
    test_loader = torch.utils.data.DataLoader(dataset=dataset(x_test,y_test), 
                                               batch_size=None, 
                                               shuffle=False, drop_last=False)
    val_loader = torch.utils.data.DataLoader(dataset=dataset(x_val,y_val), 
                                               batch_size=None, 
                                               shuffle=False, drop_last=False)

    #---Create model:
    model = channel_predictor(input_size, hidden_size, output_size)

    #---Train_model:
    nbatches = len(train_loader)
    nepochs_stage1 = 5
    nepochs_stage2 = 10
    nepochs_stage3 = 20

    loss_stage1, valloss_stage1 = train_model(model, nepochs=5, dataloader=train_loader, 
                                          learning_rate=0.001, quiet=False, stage=1, validation_dataloader=val_loader)
    loss_stage2, valloss_stage2 = train_model(model, nepochs=10, dataloader=train_loader, 
                                          learning_rate=0.001, quiet=False, stage=2, validation_dataloader=val_loader)
    loss_stage3, valloss_stage3 = train_model(model, nepochs=20, dataloader=train_loader, 
                                          learning_rate=0.001, quiet=False, stage=3, validation_dataloader=val_loader)

    torch.save(model.state_dict(), f'models/{sensor}_{satellite}_channel_predictor_{channel}_ocean_no89.pt')

    loss_data = data2xarray(data_vars = (loss_stage1, loss_stage2, loss_stage3, 
                                     valloss_stage1, valloss_stage2, valloss_stage3),
                        var_names = ('LossStage1','LossStage2','LossStage3',
                                     'ValidationLossStage1', 'ValidationLossStage2', 'ValidationLossStage3'),
                        dims = (nbatches,nepochs_stage1,nepochs_stage2, nepochs_stage3),
                        dim_names = ('training_batches', 'epochs_stage1', 'epochs_stage2', 'epochs_stage3'))

    loss_data.to_netcdf(f'diagnostics/loss_data_{channel}_ocean_no89.nc', engine='netcdf4')

    print(f'Finished training model for channel {channel}.')

6V
Training stage: 1
Validation Loss = 24179.082
Channel 6V, Epoch=1, batch = 0 of 8000, loss=26847.863, LR=0.001
Validation Loss = 1.371
Channel 6V, Epoch=1, batch = 1000 of 8000, loss=0.527, LR=0.001
Validation Loss = 0.095
Channel 6V, Epoch=1, batch = 2000 of 8000, loss=0.444, LR=0.001
Validation Loss = 0.710
Channel 6V, Epoch=1, batch = 3000 of 8000, loss=0.423, LR=0.001
Validation Loss = 0.872
Channel 6V, Epoch=1, batch = 4000 of 8000, loss=0.499, LR=0.001
Validation Loss = 1.377
Channel 6V, Epoch=1, batch = 5000 of 8000, loss=0.789, LR=0.001
Validation Loss = 0.359
Channel 6V, Epoch=1, batch = 6000 of 8000, loss=0.331, LR=0.001
Validation Loss = 0.005
Channel 6V, Epoch=1, batch = 7000 of 8000, loss=0.503, LR=0.001
Validation Loss = 0.734
Channel 6V, Epoch=2, batch = 0 of 8000, loss=0.446, LR=0.0009045084971874737
Validation Loss = 0.000
Channel 6V, Epoch=2, batch = 1000 of 8000, loss=0.417, LR=0.0009045084971874737
Validation Loss = 0.055
Channel 6V, Epoch=2, batch = 2000 of 8000

In [6]:
#General parameters:
input_size = nchans - 1 + 1 #Remaining channels + surface type

In [7]:
'''

AMSR2 CHANNEL PREDICTION MODEL:
    2: Non-Ocean Surfaces

'''

with xr.open_dataset(f'training_data/{satellite}_training_data.nc') as f:
    
    sfctype = f.sfctype.values

    correct_sfc = sfctype > 1

    Tbs = f.Tbs.values[correct_sfc,:]
    sfctype = sfctype[correct_sfc]

#---Split data into train/test/val:
train_indcs, test_indcs, val_indcs = local_functions.split_data_indcs(Tbs)

Tbs_train = Tbs[train_indcs]
Tbs_test  = Tbs[test_indcs]
Tbs_val   = Tbs[val_indcs]

sfctype_train = sfctype[train_indcs].astype(np.float32)
sfctype_test  = sfctype[test_indcs].astype(np.float32)
sfctype_val   = sfctype[val_indcs].astype(np.float32)

#---Shuffle before converting to tensors:
np.random.seed(40)
Tbs_train, shuffled_indcs = array_funcs.shuffle_data(Tbs_train, axis=0, return_indcs=True)
sfctype_train = sfctype_train[shuffled_indcs]

In [8]:
sfctype_train, sfctype_test, sfctype_val

(array([14., 11., 17., ...,  2., 12., 13.],
       shape=(32186272,), dtype=float32),
 array([14., 14., 14., ..., 15.,  5.,  5.], shape=(4023284,), dtype=float32),
 array([ 2.,  2.,  2., ..., 11., 11., 11.], shape=(4023284,), dtype=float32))

In [9]:
chan_desc[10]

'37V'

In [11]:
'''
Predict channels: Train all models
'''

torch.set_num_threads(10)

for ichan, channel in enumerate(chan_desc):

    #if ichan < 10: continue

    print(channel)

    #---Extract channel, x = predictors, y = channel to predict
    x_train, y_train = extract_channel(Tbs_train, channel)
    x_test,  y_test  = extract_channel(Tbs_test, channel)
    x_val,   y_val   = extract_channel(Tbs_val, channel)

    x_train = np.concatenate((x_train, sfctype_train[:,None]), axis=1)
    x_test  = np.concatenate((x_test,  sfctype_test[:,None]), axis=1)
    x_val   = np.concatenate((x_val,   sfctype_val[:,None]), axis=1)
    
    x_train, y_train = torch.tensor(x_train), torch.tensor(y_train)
    x_test,  y_test  = torch.tensor(x_test), torch.tensor(y_test)
    x_val,   y_val   = torch.tensor(x_val), torch.tensor(y_val)


    #---Set up dataloaders:
    train_loader = torch.utils.data.DataLoader(dataset=dataset(x_train,y_train), 
                                               batch_size=batch_size, 
                                               shuffle=True, drop_last=True)
    test_loader = torch.utils.data.DataLoader(dataset=dataset(x_test,y_test), 
                                               batch_size=None, 
                                               shuffle=False, drop_last=False)
    val_loader = torch.utils.data.DataLoader(dataset=dataset(x_val,y_val), 
                                               batch_size=None, 
                                               shuffle=False, drop_last=False)

    #---Create model:
    model = channel_predictor(input_size, hidden_size, output_size)

    #---Train_model:
    nbatches = len(train_loader)
    nepochs_stage1 = 5
    nepochs_stage2 = 10
    nepochs_stage3 = 20

    loss_stage1, valloss_stage1 = train_model(model, nepochs=5, dataloader=train_loader, 
                                          learning_rate=0.001, quiet=False, stage=1, validation_dataloader=val_loader)
    loss_stage2, valloss_stage2 = train_model(model, nepochs=10, dataloader=train_loader, 
                                          learning_rate=0.001, quiet=False, stage=2, validation_dataloader=val_loader)
    loss_stage3, valloss_stage3 = train_model(model, nepochs=20, dataloader=train_loader, 
                                          learning_rate=0.001, quiet=False, stage=3, validation_dataloader=val_loader)

    torch.save(model.state_dict(), f'models/{sensor}_{satellite}_channel_predictor_{channel}_nonocean_no89.pt')

    loss_data = data2xarray(data_vars = (loss_stage1, loss_stage2, loss_stage3, 
                                     valloss_stage1, valloss_stage2, valloss_stage3),
                        var_names = ('LossStage1','LossStage2','LossStage3',
                                     'ValidationLossStage1', 'ValidationLossStage2', 'ValidationLossStage3'),
                        dims = (nbatches,nepochs_stage1,nepochs_stage2, nepochs_stage3),
                        dim_names = ('training_batches', 'epochs_stage1', 'epochs_stage2', 'epochs_stage3'))

    loss_data.to_netcdf(f'diagnostics/loss_data_{channel}_nonocean_no89.nc', engine='netcdf4')

    print(f'Finished training model for channel {channel}.')

6V
Training stage: 1
Validation Loss = 45295.352
Channel 6V, Epoch=1, batch = 0 of 32186, loss=59913.645, LR=0.001
Validation Loss = 0.178
Channel 6V, Epoch=1, batch = 1000 of 32186, loss=1.547, LR=0.001
Validation Loss = 0.021
Channel 6V, Epoch=1, batch = 2000 of 32186, loss=1.006, LR=0.001
Validation Loss = 0.231
Channel 6V, Epoch=1, batch = 3000 of 32186, loss=1.101, LR=0.001
Validation Loss = 0.663
Channel 6V, Epoch=1, batch = 4000 of 32186, loss=1.943, LR=0.001
Validation Loss = 0.010
Channel 6V, Epoch=1, batch = 5000 of 32186, loss=0.894, LR=0.001
Validation Loss = 1.130
Channel 6V, Epoch=1, batch = 6000 of 32186, loss=3.162, LR=0.001
Validation Loss = 0.060
Channel 6V, Epoch=1, batch = 7000 of 32186, loss=0.798, LR=0.001
Validation Loss = 0.021
Channel 6V, Epoch=1, batch = 8000 of 32186, loss=1.202, LR=0.001
Validation Loss = 4.172
Channel 6V, Epoch=1, batch = 9000 of 32186, loss=3.899, LR=0.001
Validation Loss = 0.010
Channel 6V, Epoch=1, batch = 10000 of 32186, loss=0.882, LR=

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Validation Loss = 0.016
Channel 24V, Epoch=4, batch = 1000 of 32186, loss=0.852, LR=0.0007938926261462366
Validation Loss = 0.124
Channel 24V, Epoch=4, batch = 2000 of 32186, loss=0.791, LR=0.0007938926261462366
Validation Loss = 0.014
Channel 24V, Epoch=4, batch = 3000 of 32186, loss=0.734, LR=0.0007938926261462366
Validation Loss = 0.000
Channel 24V, Epoch=4, batch = 4000 of 32186, loss=0.775, LR=0.0007938926261462366
Validation Loss = 0.056
Channel 24V, Epoch=4, batch = 5000 of 32186, loss=0.803, LR=0.0007938926261462366
Validation Loss = 0.082
Channel 24V, Epoch=4, batch = 6000 of 32186, loss=0.737, LR=0.0007938926261462366
Validation Loss = 0.001
Channel 24V, Epoch=4, batch = 7000 of 32186, loss=0.642, LR=0.0007938926261462366
Validation Loss = 0.039
Channel 24V, Epoch=4, batch = 8000 of 32186, loss=0.692, LR=0.0007938926261462366
Validation Loss = 0.047
Channel 24V, Epoch=4, batch = 9000 of 32186, loss=0.754, LR=0.0007938926261462366
Validation Loss = 0.078
Channel 24V, Epoch=4, 