# Mutichannel-multitask model
Combine all system monitoring conditions together into a multitask learning approach which to apply to both the CNN and the LSTM based network

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import os
from natsort import natsorted
import torch
import torch.nn as nn
from tensorboardX import SummaryWriter

import sys
sys.path.insert(0, "../")
from data_preparation.prepare_data import import_data
from model import make_loaders, CNN, LSTMattn, MultiClassifier, eval_batch

### Training routine

In [7]:
def train_multitask(model, train_loader, test_loader, dir_checkpoint, dir_writer, 
                    epochs=10, batch_size=4, lr=1e-5, save_cp=True, gpu=False):
    
    writer = SummaryWriter(dir_writer)
    
    print(f'''Start training: 
                Epocs = {epochs}
                Batch size = {batch_size}
                Learning rate = {lr} 
                Training size = {train_loader.dataset.__len__()}
                Validation size = {test_loader.dataset.__len__()}
                CUDA = {gpu}
          ''')
   
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)

    for epoch in range(epochs):

        print('Starting epoch {}/{}.'.format(epoch + 1, epochs))
        model.train()
        epoch_loss = 0

        for i, sample_batch in enumerate(train_loader):

            inputs, labels = sample_batch['sequence'], sample_batch['label'] 
            if gpu:
                inputs = inputs.cuda()
                labels = labels.cuda()

            outputs = model(inputs)
            # multi-task loss
            loss = 0
            for lx in range(len(outputs)):
                loss += criterion(outputs[lx], labels[:, lx])
            epoch_loss += loss.item()

            if i%10 == 0:
                print(f'epoch = {epoch+1:d}, iteration = {i:d}/{len(train_loader):d}, loss = {loss.item():.5f}')

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        print(f'Epoch finished ! Loss: {epoch_loss/i}')
        # training set accuracy
        accuracy = eval_batch(model, train_loader, n_labels=len(targets))
        print(f'Accuracy = {accuracy}')
        writer.add_scalars('accuracy', {f'label_{i}':a for i,a in enumerate(accuracy)}, len(train_loader) * (epoch+1))
        
        if save_cp:
            torch.save(model.state_dict(), dir_checkpoint + 'CP{}.pth'.format(epoch + 1))
            print('Checkpoint {} saved !'.format(epoch + 1))

        
    writer.close()

### Main

In [8]:
data_dir = '<path-to-dataset>'
writer_dir = './runs/experiment_1'
checkpoint_dir = './checkpoints/experiment_1/'
os.makedirs(checkpoint_dir)

channels = ['CP', 'FS1', 'PS1', 'PS2', 'PS3', 'PS4', 'PS5', 'SE', 'VS1']
targets = [0,1,2,3]
gpu = False
sequence = 50

print('Preparing data')
data = import_data(data_dir, sequence)
train_loader, test_loader = make_loaders(data, channels, targets)

# CNN model
# model = CNN(sequence, input_dim=len(channels)) 
# LSTM model
model = LSTMattn(len(channels), hidden_dim=20, num_layers=1)
# add multi-task classifier
model.classifier = MultiClassifier(model.classifier.in_features)

if gpu:
    model.cuda()

train_multitask(model, train_loader, test_loader, checkpoint_dir, writer_dir, lr=1e-3 )
accuracy = eval_batch(model, test_loader, n_labels=4)
print(accuracy)

Preparing data


  "num_layers={}".format(dropout, num_layers))


Start training: 
                Epocs = 10
                Batch size = 4
                Learning rate = 0.001 
                Training size = 1984
                Validation size = 221
                CUDA = False
          
Starting epoch 1/10.
epoch = 1, iteration = 0/496, loss = 5.10720
epoch = 1, iteration = 10/496, loss = 4.73586
epoch = 1, iteration = 20/496, loss = 4.83484
epoch = 1, iteration = 30/496, loss = 5.03915
epoch = 1, iteration = 40/496, loss = 4.47967
epoch = 1, iteration = 50/496, loss = 5.01119
epoch = 1, iteration = 60/496, loss = 4.71459
epoch = 1, iteration = 70/496, loss = 4.86597
epoch = 1, iteration = 80/496, loss = 4.30579
epoch = 1, iteration = 90/496, loss = 4.11236
epoch = 1, iteration = 100/496, loss = 4.56768
epoch = 1, iteration = 110/496, loss = 4.49729
epoch = 1, iteration = 120/496, loss = 4.35327
epoch = 1, iteration = 130/496, loss = 4.98818
epoch = 1, iteration = 140/496, loss = 4.18252
epoch = 1, iteration = 150/496, loss = 3.89825
epoch = 1

epoch = 4, iteration = 140/496, loss = 3.75629
epoch = 4, iteration = 150/496, loss = 3.27393
epoch = 4, iteration = 160/496, loss = 2.64722
epoch = 4, iteration = 170/496, loss = 3.68406
epoch = 4, iteration = 180/496, loss = 1.46410
epoch = 4, iteration = 190/496, loss = 3.51601
epoch = 4, iteration = 200/496, loss = 4.02048
epoch = 4, iteration = 210/496, loss = 2.69166
epoch = 4, iteration = 220/496, loss = 2.82160
epoch = 4, iteration = 230/496, loss = 2.76942
epoch = 4, iteration = 240/496, loss = 2.24821
epoch = 4, iteration = 250/496, loss = 3.75624
epoch = 4, iteration = 260/496, loss = 3.81716
epoch = 4, iteration = 270/496, loss = 4.00216
epoch = 4, iteration = 280/496, loss = 3.25451
epoch = 4, iteration = 290/496, loss = 3.24127
epoch = 4, iteration = 300/496, loss = 3.65961
epoch = 4, iteration = 310/496, loss = 3.27634
epoch = 4, iteration = 320/496, loss = 2.60549
epoch = 4, iteration = 330/496, loss = 2.16721
epoch = 4, iteration = 340/496, loss = 1.85576
epoch = 4, it

epoch = 7, iteration = 320/496, loss = 2.35833
epoch = 7, iteration = 330/496, loss = 1.77145
epoch = 7, iteration = 340/496, loss = 1.13963
epoch = 7, iteration = 350/496, loss = 1.91002
epoch = 7, iteration = 360/496, loss = 1.29701
epoch = 7, iteration = 370/496, loss = 4.67425
epoch = 7, iteration = 380/496, loss = 1.66910
epoch = 7, iteration = 390/496, loss = 3.21856
epoch = 7, iteration = 400/496, loss = 3.27639
epoch = 7, iteration = 410/496, loss = 1.86339
epoch = 7, iteration = 420/496, loss = 1.48884
epoch = 7, iteration = 430/496, loss = 1.69565
epoch = 7, iteration = 440/496, loss = 3.47465
epoch = 7, iteration = 450/496, loss = 1.00986
epoch = 7, iteration = 460/496, loss = 2.22415
epoch = 7, iteration = 470/496, loss = 3.57184
epoch = 7, iteration = 480/496, loss = 2.37786
epoch = 7, iteration = 490/496, loss = 1.92961
Epoch finished ! Loss: 2.763951911257975
Accuracy = [1.0, 0.55, 0.64, 0.52]
Checkpoint 7 saved !
Starting epoch 8/10.
epoch = 8, iteration = 0/496, loss =

Accuracy = [1.0, 0.57, 0.7, 0.55]
Checkpoint 10 saved !
[1.0, 0.57, 0.73, 0.55]
