# IMPORTS

In [None]:
from conflict_lstm.latest_run import *
from conflict_lstm.hpc_construct import *

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, Dataset
import torchvision.transforms as transforms
import numpy as np
from sklearn.model_selection import StratifiedShuffleSplit
import random

import matplotlib.pyplot as plt
import h5py

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True


## Functional form
The wrapper function carries out most of the heavy lifting when we wish to train our models. We pass the structure argument (as detailed in the LSTMencdec docstrings) alongside the loss function and normalising averages calculated from our dataset. Here we modify the weights of a binary cross entropy loss function. The class weights, average, std and hdf5 dataset should be located in the local directory from which the script is run. Please note the example below is extracted from a script run on the HPC. It is not recommended to try to run this in an .ipynb notebook due to the overhead. For previously run scripts please see the results folder.

In [None]:
# defining the structure of the encoder decoder to be produced inside wrapper.
# The structure wrapper outlines the number of channels of both the encoder and decoder model.
# The structure argument is outlined by the number of channels in each list entry.
# For the below structure a 0 entry denotes the absence of a layer, the non zero values encode the number of channels to convolute the dimensions of the input image into new
# Data. for example, the first encoder has two layers, one with 12 channels, one with 24. The hidden state of the 24 layer channel is passed to the decoder, which has 24,12,6,5
# Respectively

structure = np.array([[12,24,0,0,0],[0,24,12,6,5]])

# here we produce a weighted binary cross entropy loss function.
d = np.load("weights_bce.npy")
weights = torch.tensor(d)
weights = weights // 3
weights = weights.to(device)
b = nn.BCEWithLogitsLoss(pos_weight=weights)

# here we load in the average and standard deviation of out image sequence channels 
# for standard score normalisation
avg = np.load("min_event_25_avg.npy")
std = np.load("min_event_25_std.npy")

# here we define which of our image channels need to be normalised (the channels which are not normalised are pre normlised by virtue of being categorical)
apbln = [0,1,0,0,1]

wrapper_full("bce_3", 10, structure, b, avg, std, apbln, lr = 0.001, epochs = 2000, batch_size = 200)
