# Channel Prediction Model for Quality Control

This notebook contains the code and outlines the development of the
GMI Channel Prediction Neural Network for QC purposes of GMI L1C
data. The simple Feed-Forward Neural Networks (FFNNs) are trained on
good quality brightness temperature observation vectors and are
trained to predict one channel from all the others as predictors.

The end product is a tree of FFNNs, discretized by surface (ocean or 
land), and by channel, since these are fully distinct problems.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import glob
import torch
from torch import nn
import paths
from src.utils import data2xarray, array_funcs, extract_channel
from src import surface, sensor_info, training_funcs
from src.classes.dataset_class import dataset
from src.classes.model_class import channel_predictor

In [None]:
#General parameters
satellite = sensor_info.satellite
sensor = sensor_info.sensor

nchans = sensor_info.nchannels
batch_size = 1000
input_size = nchans - 1
hidden_size = 256
output_size = 1

#Number of cpus to use for parallelization
ncpus = 10

In [None]:
'''

GMI CHANNEL PREDICTION MODEL:
    1: Ocean

'''

sfc = [1]

with xr.open_dataset(f'{paths.training_datapath}/{sensor_info.satellite}_training_data.nc') as f:
    
    sfctype = f.sfctype.values

    correct_sfc = np.isin(sfctype, sfc)
    sfcindcs = np.where(correct_sfc)[0]

    Tbs = f.Tbs.values[sfcindcs]

if Tbs.shape[0] > 5.0e+06:
    Tbs = Tbs[:5_000_000,:]

print(Tbs.shape)

#---Split data into train/test/val:
train_indcs, test_indcs, val_indcs = training_funcs.split_data_indcs(Tbs)

Tbs_train = Tbs[train_indcs]
Tbs_test  = Tbs[test_indcs]
Tbs_val   = Tbs[val_indcs]

#---Shuffle before converting to tensors:
np.random.seed(40)
Tbs_train = training_funcs.shuffle_data(Tbs_train, axis=0)


In [None]:
'''
Predict channels: Train all models
'''

torch.set_num_threads(ncpus)

for ichan, channel in enumerate(sensor_info.channel_descriptions):

    print(channel)

    x_train = Tbs_train
    y_train = extract_channel(Tbs_train, channel)

    #---Extract channel, x = predictors, y = channel to predict
    x_train, y_train = extract_channel(Tbs_train, channel)
    x_test,  y_test  = extract_channel(Tbs_test, channel)
    x_val,   y_val   = extract_channel(Tbs_val, channel)
    
    x_train, y_train = torch.tensor(x_train), torch.tensor(y_train)
    x_test,  y_test  = torch.tensor(x_test), torch.tensor(y_test)
    x_val,   y_val   = torch.tensor(x_val), torch.tensor(y_val)


    #---Set up dataloaders:
    train_loader = torch.utils.data.DataLoader(dataset=dataset(x_train,y_train), 
                                               batch_size=batch_size, 
                                               shuffle=True, drop_last=True)
    test_loader = torch.utils.data.DataLoader(dataset=dataset(x_test,y_test), 
                                               batch_size=None, 
                                               shuffle=False, drop_last=False)
    val_loader = torch.utils.data.DataLoader(dataset=dataset(x_val,y_val), 
                                               batch_size=None, 
                                               shuffle=False, drop_last=False)

    #---Create model:
    model = channel_predictor(input_size, hidden_size, output_size)

    #---Train_model:
    nbatches = len(train_loader)
    nepochs_stage1 = 5
    nepochs_stage2 = 10
    nepochs_stage3 = 20

    loss_stage1, valloss_stage1 = training_funcs.train_model(model, 
                                              nepochs=nepochs_stage1, 
                                              dataloader=train_loader, 
                                              learning_rate=0.001, 
                                              quiet=False, 
                                              stage=1, 
                                              validation_dataloader=val_loader)
    
    loss_stage2, valloss_stage2 = training_funcs.train_model(model, 
                                              nepochs=nepochs_stage2, 
                                              dataloader=train_loader, 
                                              learning_rate=0.001, 
                                              quiet=False, 
                                              stage=2, 
                                              validation_dataloader=val_loader)

    loss_stage3, valloss_stage3 = training_funcs.train_model(model, 
                                              nepochs=nepochs_stage3, 
                                              dataloader=train_loader, 
                                              learning_rate=0.001, 
                                              quiet=False, 
                                              stage=3, 
                                              validation_dataloader=val_loader)

    torch.save(model.state_dict(), 
               f'{model_path}/{sensor}/{sensor}_{satellite}_channel_predictor_{channel}_ocean.pt')

    loss_data = data2xarray(data_vars = (loss_stage1, loss_stage2, loss_stage3, 
                                     valloss_stage1, valloss_stage2, valloss_stage3),
                        var_names = ('LossStage1','LossStage2','LossStage3',
                                     'ValidationLossStage1', 'ValidationLossStage2', 'ValidationLossStage3'),
                        dims = (nbatches,nepochs_stage1,nepochs_stage2, nepochs_stage3),
                        dim_names = ('training_batches', 'epochs_stage1', 'epochs_stage2', 'epochs_stage3'))

    loss_data.to_netcdf(f'diagnostics/loss_data_{channel}_ocean.nc', engine='netcdf4')

    print(f'Finished training model for channel {channel}.')

In [None]:
#General parameters:
input_size = nchans - 1 + 1 #Remaining channels + surface type

In [None]:
'''

GMI CHANNEL PREDICTION MODEL:
    2: Non-Ocean Surfaces

'''

with xr.open_dataset(f'{paths.training_datapath}/{sensor_info.satellite}_training_data.nc') as f:
    
    sfctype = f.sfctype.values

    correct_sfc = sfctype > 1

    Tbs = f.Tbs.values[correct_sfc,:]
    sfctype = sfctype[correct_sfc]

#---Split data into train/test/val:
train_indcs, test_indcs, val_indcs = training_funcs.split_data_indcs(Tbs)

Tbs_train = Tbs[train_indcs]
Tbs_test  = Tbs[test_indcs]
Tbs_val   = Tbs[val_indcs]

sfctype_train = sfctype[train_indcs].astype(np.float32)
sfctype_test  = sfctype[test_indcs].astype(np.float32)
sfctype_val   = sfctype[val_indcs].astype(np.float32)

#---Shuffle before converting to tensors:
np.random.seed(40)
Tbs_train, shuffled_indcs = training_funcs.shuffle_data(Tbs_train, axis=0, return_indcs=True)
sfctype_train = sfctype_train[shuffled_indcs]

In [None]:
'''
Predict channels: Train all models
'''

torch.set_num_threads(ncpus)

for ichan, channel in enumerate(sensor_info.channel_descriptions):

    print(channel)

    #---Extract channel, x = predictors, y = channel to predict
    x_train, y_train = extract_channel(Tbs_train, channel)
    x_test,  y_test  = extract_channel(Tbs_test, channel)
    x_val,   y_val   = extract_channel(Tbs_val, channel)

    x_train = np.concatenate((x_train, sfctype_train[:,None]), axis=1)
    x_test  = np.concatenate((x_test,  sfctype_test[:,None]), axis=1)
    x_val   = np.concatenate((x_val,   sfctype_val[:,None]), axis=1)
    
    x_train, y_train = torch.tensor(x_train), torch.tensor(y_train)
    x_test,  y_test  = torch.tensor(x_test), torch.tensor(y_test)
    x_val,   y_val   = torch.tensor(x_val), torch.tensor(y_val)


    #---Set up dataloaders:
    train_loader = torch.utils.data.DataLoader(dataset=dataset(x_train,y_train), 
                                               batch_size=batch_size, 
                                               shuffle=True, drop_last=True)
    test_loader = torch.utils.data.DataLoader(dataset=dataset(x_test,y_test), 
                                               batch_size=None, 
                                               shuffle=False, drop_last=False)
    val_loader = torch.utils.data.DataLoader(dataset=dataset(x_val,y_val), 
                                               batch_size=None, 
                                               shuffle=False, drop_last=False)
    
    #---Create model:
    model = channel_predictor(input_size, hidden_size, output_size)

    #---Train_model:
    nbatches = len(train_loader)
    nepochs_stage1 = 5
    nepochs_stage2 = 10
    nepochs_stage3 = 20

    loss_stage1, valloss_stage1 = training_funcs.train_model(model, 
                                                             nepochs=nepochs_stage1, 
                                                             dataloader=train_loader, 
                                                             learning_rate=0.001, 
                                                             quiet=False, 
                                                             stage=1, 
                                                             validation_dataloader=val_loader)
    
    loss_stage2, valloss_stage2 = training_funcs.train_model(model, 
                                                             nepochs=nepochs_stage2, 
                                                             dataloader=train_loader, 
                                                             learning_rate=0.001, 
                                                             quiet=False, 
                                                             stage=2, 
                                                             validation_dataloader=val_loader)
    
    loss_stage3, valloss_stage3 = training_funcs.train_model(model, 
                                                             nepochs=nepochs_stage2, 
                                                             dataloader=train_loader, 
                                                             learning_rate=0.001, 
                                                             quiet=False, 
                                                             stage=3, 
                                                             validation_dataloader=val_loader)

    torch.save(model.state_dict(), 
               f'{model_path}/{sensor}/{sensor}_{satellite}_channel_predictor_{channel}_nonocean.pt')

    loss_data = data2xarray(data_vars = (loss_stage1, loss_stage2, loss_stage3, 
                                     valloss_stage1, valloss_stage2, valloss_stage3),
                        var_names = ('LossStage1','LossStage2','LossStage3',
                                     'ValidationLossStage1', 'ValidationLossStage2', 'ValidationLossStage3'),
                        dims = (nbatches,nepochs_stage1,nepochs_stage2, nepochs_stage3),
                        dim_names = ('training_batches', 'epochs_stage1', 'epochs_stage2', 'epochs_stage3'))

    loss_data.to_netcdf(f'diagnostics/loss_data_{channel}_nonocean.nc', engine='netcdf4')

    print(f'Finished training model for channel {channel}.')