### Imports

First we need to correctly set the python environment. This is done by adding the top directory of the repository to the Python path. Once that's done, we can import various packages from inside the repository.

In [1]:
import sys, yaml, math, numpy as np, torch
sys.path.append('/scratch') # This line is equivalent to doing source scripts/source_me.sh in a bash terminal
from torch.utils.data import DataLoader
from TauRNN.trainers import Trainer
from TauRNN import datasets

### Configuring

Most of the training options are set in a configuration YAML file. We're going to load this config, and then the options inside will be passed to the relevent piece of the training framework.

In [2]:
with open('/scratch/TauRNN/config/taurnn.yaml') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

### Collate function
A function to collate data within each batch

In [3]:
from torch.nn.utils.rnn import pack_sequence
def collate(data):
    x_fixed = torch.stack([ d['x_fixed'] for d in data ])
    x_point = pack_sequence([ d['x_point_var'].squeeze(dim=0).T for d in data ], enforce_sorted=False)
    x_hit = pack_sequence([ d['x_hit_var'].squeeze(dim=0).T for d in data ], enforce_sorted=False)
    y = torch.tensor([ d['y'] for d in data ]).float()
    return { 'x_fixed': x_fixed,
             'x_point_var': x_point,
             'x_hit_var': x_hit,
             'y': y }

### Loading inputs

Here we load the dataset and the trainer, which is responsible for building the model and overseeing training. There's a block of code which is responsible for slicing the full dataset up into a training dataset and a validation dataset.

In [4]:
full_dataset = datasets.get_dataset(**config['data'])
trainer = Trainer(**config['trainer'])

fulllen = len(full_dataset)
tv_num = math.ceil(fulllen*config['data']['t_v_split'])
splits = np.cumsum([fulllen-tv_num,0,tv_num])

train_dataset = torch.utils.data.Subset(full_dataset,np.arange(start=0,stop=splits[0]))
valid_dataset = torch.utils.data.Subset(full_dataset,np.arange(start=splits[1],stop=splits[2]))
train_loader = DataLoader(train_dataset, collate_fn=collate, **config['data_loader'], shuffle=True)
valid_loader = DataLoader(valid_dataset, collate_fn=collate, **config['data_loader'], shuffle=False)

### Building the model

The trainer will load the network architecture and compile it into a model

In [5]:
trainer.build_model(**config['model'])

### Training!

Once all the setup is done, all that's left is to run training and save some summary statistics to file.

In [6]:
train_summary = trainer.train(train_loader, valid_data_loader=valid_loader, **config['trainer'])
print(train_summary)
torch.save(train_summary, 'summary_test.pt')

loss = 0.64253:   3%|▎         | 933/28397 [00:21<09:13, 49.62it/s] 

RuntimeError: Length of all samples has to be greater than 0, but found an element in 'lengths' that is <= 0

loss = 0.64253:   3%|▎         | 933/28397 [00:40<09:13, 49.62it/s]