# Training a simple CNN model using pytorch for Tornado Detection

This notebook steps through how to train a simple CNN model using a subset of the dataset.

This will not produce a model with any skill, but simply provides a working end-to-end example of how to set up a data loader, build, and fit a model

In [12]:
import sys
# Uncomment if tornet isn't installed in your environment or in your path already
#sys.path.append('../')  

import os
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from torch import optim, nn
from torch.utils.data import Dataset
from torchvision import transforms, utils

from tornet.data.loader import read_file, TornadoDataLoader
from tornet.data.preprocess import add_coordinates, remove_time_dim, permute_dims
from tornet.data.constants import ALL_VARIABLES

In [14]:
# Create basic dataloader
# This option loads directly from netcdf files, and will be slow
# To speed up training,
#     rebuild dataset as array_record  (see tornet/data/tfds/tornet/README.md)

data_root='/Users/jackgao/Downloads/data'

data_type='train'
years = [2013, 2014]

catalog_path = os.path.join(data_root,'catalog.csv')
if not os.path.exists(catalog_path):
    raise RuntimeError('Unable to find catalog.csv at '+data_root)
        
catalog = pd.read_csv(catalog_path,parse_dates=['start_time','end_time'])
catalog = catalog[catalog['type']==data_type]
catalog = catalog[catalog.start_time.dt.year.isin(years)]
catalog = catalog.sample(frac=1,random_state=1234) # shuffles list
file_list = [os.path.join(data_root,f) for f in catalog.filename]

# Dataset, with preprocessing
class TornadoDataset(TornadoDataLoader,Dataset):
    pass
transform = transforms.Compose([
            # add coordinates tensor to data
            lambda d: add_coordinates(d,include_az=False,tilt_last=False,backend=torch), 
            # Remove time dimension
            lambda d: remove_time_dim(d)])                                
torch_ds = TornadoDataset(file_list,
                          variables=ALL_VARIABLES,
                          n_frames=1,
                          tilt_last=False, # so ordering of dims is [time,tilt,az,range]
                          transform=transform) 
                          
# Torch data loader
batch_size=32
torch_dl = torch.utils.data.DataLoader( torch_ds, 
                                        batch_size=batch_size, 
                                        num_workers=8 )


In [15]:
# If data was registered in tensorflow_dataset, run this cell instead
# env variable TFDS_DATA_DIR should point to location of this resaved dataset

#import tensorflow_datasets as tfds # need version >= 4.9.3
#import tornet.data.tfds.tornet.tornet_dataset_builder # registers 'tornet'
#from tornet.data.torch.loader import TFDSTornadoDataset
#data_type='train'
#years = [2017,]
#ds = tfds.data_source('tornet')
## Dataset, with preprocessing
#transform = transforms.Compose([
#            # transpose to [time,tile,az,rng]
#            lambda d: permute_dims(d,(0,3,1,2)),
#            # add coordinates tensor to data
#            lambda d: add_coordinates(d,include_az=False,tilt_last=False,backend=torch), 
#            # Remove time dimension
#            lambda d: remove_time_dim(d)])                                
#datasets = [
#     TFDSTornadoDataset(ds['%s-%d' % (data_type,y)] ,transform)  for y in years
#     ]
#dataset = torch.utils.data.ConcatDataset(datasets)
#torch_dl = torch.utils.data.DataLoader( dataset, 
#                                      batch_size=32,
#                                      num_workers=20)

In [16]:
# Create simple CNN model
from tornet.models.torch.cnn_baseline import NormalizeVariable
from tornet.data.constants import CHANNEL_MIN_MAX


class TornadoLikelihood(nn.Module):
    """
    Template for CNN that produces likelihood field
    """
    def __init__(self,radar_variables=ALL_VARIABLES):
        super(TornadoLikelihood, self).__init__()
        self.radar_variables=radar_variables
        
        # Set up normalizers
        self.input_norm_layers = {}
        for v in radar_variables:
            min_max = np.array(CHANNEL_MIN_MAX[v]) # [2,]
            scale = 1/(min_max[1]-min_max[0])
            offset = min_max[0]
            self.input_norm_layers[v] = NormalizeVariable(scale,offset)
            
        # Processing layers
        self.conv1 = nn.Conv2d(in_channels=12, out_channels=32, kernel_size=(3,3),padding='same')
        # add more..
        self.conv_out = nn.Conv2d(in_channels=32, out_channels=1, kernel_size=(3,3),padding='same')
        
    def _normalize_inputs(self,data):
        normed_data = {}
        for v in self.radar_variables:
            normed_data[v] = self.input_norm_layers[v](data[v])
        return normed_data
    
    def forward(self,x):
        """
        Assumes x contains radar varialbes on [batch,tilt,az,rng]
        """
        # extract radar inputs
        x = {v:x[v] for v in self.radar_variables} # each [batch,tilt,Az,Rng]
        # normalize
        x = self._normalize_inputs(x) # each [batch,tilt,Az,Rng]
        # concatenate along channel (tilt) dim
        x = torch.cat([x[v] for v in self.radar_variables],axis=1) #  [batch,tilt*len(radar_variables)*2,Az,Rng]
        # Remove nan's
        x = torch.where(torch.isnan(x),-3,x)
        
        # process
        x = self.conv1(x)
        # add more..
        x = self.conv_out(x)
        
        return x




ModuleNotFoundError: No module named 'tornet.models.torch'

In [None]:
# Train this model using torch lightning
import lightning as L
import torchmetrics
from torchmetrics import MetricCollection
from tornet.models.torch.cnn_baseline import TornadoClassifier

# Metrics expected to be binary classification metrics that expect (logits,label)
#    where logits and label are both (N,) tensors 
#    e.g. torchmetrics.classification.BinaryAccuracy
metrics = MetricCollection([
            torchmetrics.classification.BinaryAccuracy(), 
            torchmetrics.classification.BinaryAUROC(), 
            torchmetrics.classification.BinaryAveragePrecision()
        ])

cnn = TornadoLikelihood()
classifier = TornadoClassifier(cnn,metrics=metrics)

# Low number of train_batches/epochs only for demo purposes
trainer = L.Trainer(limit_train_batches=10, max_epochs=3)
trainer.fit(classifier,train_dataloaders=torch_dl)
