In [1]:
import os, sys
import pickle
import numpy as np  
import torch
import pandas as pd
import matplotlib.pyplot as plt
import time
import copy
import numpy as np
from sklearn.model_selection import train_test_split
sys.path.insert(0,'../../LongTermEMG-master')
os.chdir("/home/laiy/gitrepos/msr_final/LongTermEMG-master/LongTermClassificationMain/TrainingsAndEvaluations/ForTrainingSessions/TSD_DNN")

In [2]:
with open("../../../Processed_datasets/TSD_features_set_training_session.pickle", 'rb') as f:
    dataset_training = pickle.load(file=f)

examples_datasets_train = dataset_training['examples_training']
print('traning examples ', np.shape(examples_datasets_train))
print("one group example ", np.shape(examples_datasets_train[0][0][0]))
labels_datasets_train = dataset_training['labels_training']
print('traning labels ', np.shape(labels_datasets_train))
print("one group label ", np.shape(labels_datasets_train[0][0][0]))

traning examples  (1, 4, 4)
one group example  (1009, 385)
traning labels  (1, 4, 4)
one group label  (1009,)


In [3]:
os.chdir("/home/laiy/gitrepos/msr_final/LongTermEMG_myo")

from Models.TSD_neural_network import TSD_Network
from PrepareAndLoadData.load_dataset_in_dataloader import load_dataloaders_training_sessions

In [4]:
def DANN_BN_Training(gesture_classifier, crossEntropyLoss, optimizer_classifier, train_dataset_source, scheduler,
                     train_dataset_target, validation_dataset_source, patience_increment=10, max_epochs=500,
                     domain_loss_weight=1e-1):
    """
    gesture_classification: model
    crossEntropyLoss
    optimizer_classifier
    scheduler
    
    target: unlabeled; source: labeled
    train_dataset_source: the first session of a participant's training set
    validation_dataset_source:  the first session of a participant's validation set
    train_dataset_target: one seesion (except for the first) of a participant's traning set

    patience_increment: number of epchos to wait after no best loss is found and before existing training
    max_epochs
    
    domain_loss_weight: coefficient of doman loss percantage to account in calculating loss (loss_main_source and loss_doman_target)
    
    
    """
    since = time.time()
    patience = 0 + patience_increment

    # Create a list of dictionaries that will hold the weights of the batch normalisation layers for each dataset
    #  (i.e. each participants)
    list_dictionaries_BN_weights = []
    for index_BN_weights in range(2):
        state_dict = gesture_classifier.state_dict()
        batch_norm_dict = {}
        for key in state_dict:
            if "batchNorm" in key:
                batch_norm_dict.update({key: state_dict[key]})
        list_dictionaries_BN_weights.append(copy.deepcopy(batch_norm_dict))

    best_loss = float("inf")
    best_state = {'epoch': 0, 'state_dict': copy.deepcopy(gesture_classifier.state_dict()),
                  'optimizer': optimizer_classifier.state_dict(), 'scheduler': scheduler.state_dict()}

    print("STARTING TRAINING")
    for epoch in range(1, max_epochs):
        epoch_start = time.time()

        loss_main_sum, n_total = 0, 0
        loss_domain_sum, loss_src_class_sum, loss_src_vat_sum, loss_trg_cent_sum, loss_trg_vat_sum = 0, 0, 0, 0, 0
        running_corrects, running_correct_domain, total_for_accuracy, total_for_domain_accuracy = 0, 0, 0, 0

        'TRAINING'
        gesture_classifier.train()
        for source_batch, target_batch in zip(train_dataset_source, train_dataset_target):

            input_source, labels_source = source_batch
            input_source, labels_source = input_source, labels_source
            input_target, _ = target_batch
            input_target = input_target

            # Feed the inputs to the classifier network
            # Retrieves the BN weights calculated so far for the source dataset
            BN_weights = list_dictionaries_BN_weights[0]
            gesture_classifier.load_state_dict(BN_weights, strict=False)
            pred_gesture_source, pred_domain_source = gesture_classifier(input_source, get_all_tasks_output=True)

            'Classifier losses setup.'
            # Supervised/self-supervised gesture classification
            loss_source_class = crossEntropyLoss(pred_gesture_source, labels_source)

            # Try to be bad at the domain discrimination for the full network

            label_source_domain = torch.zeros(len(pred_domain_source), device='cpu', dtype=torch.long)
            loss_domain_source = crossEntropyLoss(pred_domain_source, label_source_domain)
            # Combine all the loss of the classifier
            loss_main_source = (0.5 * loss_source_class + domain_loss_weight * loss_domain_source)

            ' Update networks '
            # Update classifiers.
            # Zero the gradients
            optimizer_classifier.zero_grad()
            # loss_main_source.backward(retain_graph=True)
            loss_main_source.backward()
            optimizer_classifier.step()
            # Save the BN stats for the source
            state_dict = gesture_classifier.state_dict()
            batch_norm_dict = {}
            for key in state_dict:
                if "batchNorm" in key:
                    batch_norm_dict.update({key: state_dict[key]})
            list_dictionaries_BN_weights[0] = copy.deepcopy(batch_norm_dict)

            _, pred_domain_target = gesture_classifier(input_target, get_all_tasks_output=True)
            label_target_domain = torch.ones(len(pred_domain_target), device='cpu', dtype=torch.long)
            loss_domain_target = 0.5 * (crossEntropyLoss(pred_domain_target, label_target_domain))
            # Combine all the loss of the classifier
            loss_domain_target = 0.5 * domain_loss_weight * loss_domain_target
            # Update classifiers.
            # Zero the gradients
            loss_domain_target.backward()
            optimizer_classifier.step()

            # Save the BN stats for the target
            state_dict = gesture_classifier.state_dict()
            batch_norm_dict = {}
            for key in state_dict:
                if "batchNorm" in key:
                    batch_norm_dict.update({key: state_dict[key]})
            list_dictionaries_BN_weights[1] = copy.deepcopy(batch_norm_dict)

            loss_main = loss_main_source + loss_domain_target
            loss_domain = loss_domain_source + loss_domain_target

            loss_domain_sum += loss_domain.item()
            loss_src_class_sum += loss_source_class.item()
            loss_main_sum += loss_main.item()
            n_total += 1

            _, gestures_predictions_source = torch.max(pred_gesture_source.data, 1)
            running_corrects += torch.sum(gestures_predictions_source == labels_source.data)
            total_for_accuracy += labels_source.size(0)

            _, gestures_predictions_domain_source = torch.max(pred_domain_source.data, 1)
            _, gestures_predictions_domain_target = torch.max(pred_domain_target.data, 1)
            running_correct_domain += torch.sum(gestures_predictions_domain_source == label_source_domain.data)
            running_correct_domain += torch.sum(gestures_predictions_domain_target == label_target_domain.data)
            total_for_domain_accuracy += label_source_domain.size(0)
            total_for_domain_accuracy += label_target_domain.size(0)

        print('Accuracy source %4f,'
              ' main loss classifier %4f,'
              ' source classification loss %4f,'
              ' loss domain distinction %4f,'
              ' accuracy domain distinction %4f'
              %
              (running_corrects.item() / total_for_accuracy,
               loss_main_sum / n_total,
               loss_src_class_sum / n_total,
               loss_domain_sum / n_total,
               running_correct_domain.item() / total_for_domain_accuracy
               ))

        'VALIDATION STEP'
        running_loss_validation = 0.
        running_corrects_validation = 0
        total_validation = 0
        n_total_val = 0

        # BN_weights = copy.deepcopy(list_dictionaries_BN_weights[0])
        # gesture_classifier.load_state_dict(BN_weights, strict=False)
        gesture_classifier.eval()
        for validation_batch in validation_dataset_source:
            # get the inputs
            inputs, labels = validation_batch

            inputs, labels = inputs, labels
            # zero the parameter gradients
            optimizer_classifier.zero_grad()

            with torch.no_grad():
                # forward
                outputs = gesture_classifier(inputs)
                _, predictions = torch.max(outputs.data, 1)

                loss = crossEntropyLoss(outputs, labels)
                loss = loss.item()

                # statistics
                running_loss_validation += loss
                running_corrects_validation += torch.sum(predictions == labels.data)
                total_validation += labels.size(0)
                n_total_val += 1

        epoch_loss = running_loss_validation / n_total_val
        epoch_acc = running_corrects_validation.item() / total_validation
        print('{} Loss: {:.8f} Acc: {:.8}'.format("VALIDATION", epoch_loss, epoch_acc))

        scheduler.step(running_loss_validation / n_total_val)
        if running_loss_validation / n_total_val < best_loss:
            print("New best validation loss: ", running_loss_validation / n_total_val)
            best_loss = running_loss_validation / n_total_val
            BN_weights = copy.deepcopy(list_dictionaries_BN_weights[1])
            gesture_classifier.load_state_dict(BN_weights, strict=False)
            best_state = {'epoch': epoch, 'state_dict': copy.deepcopy(gesture_classifier.state_dict()),
                          'optimizer': optimizer_classifier.state_dict(), 'scheduler': scheduler.state_dict()}
            patience = epoch + patience_increment

        if patience < epoch:
            break

        print("Epoch {} of {} took {:.3f}s".format(
            epoch, max_epochs, time.time() - epoch_start))

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))

    return best_state


In [5]:
def train_DA_spectrograms(examples_datasets_train, labels_datasets_train, num_kernels, filter_size=(4, 10),
                          algo_name="DANN",
                          path_weights_to_save_to="../Weights/weights_", batch_size=512, patience_increment=10,
                          path_weights_fine_tuning="../weights_TWO_CYCLES_normal_training_fine_tuning",
                          number_of_cycle_for_first_training=3, number_of_cycles_rest_of_training=3,
                          number_of_classes=11, 
                          feature_vector_input_length=None, learning_rate=0.001316):
    """
    examples_datasets_train
    labels_datasets_train
    num_kernels
    filter_size
    algo_name
    path_weights_to_save_to: path to save DANN weights
    batch_size
    patience_increment
    path_weights_fine_tuning: path to load normal TSD_DNN weights 
    number_of_cycle_for_first_training
    number_of_cycles_rest_of_training
    number_of_classes
    spectrogram_model
    """
    participants_train, participants_validation, participants_test = load_dataloaders_training_sessions(
        examples_datasets_train, labels_datasets_train, batch_size=batch_size,
        number_of_cycle_for_first_training=number_of_cycle_for_first_training, get_validation_set=True,
        number_of_cycles_rest_of_training=number_of_cycles_rest_of_training)

    for participant_i in range(len(participants_train)):
        print("SHAPE SESSIONS: ", np.shape(participants_train[participant_i]))

        # Skip the first session as it will be identical to normal training
        for session_j in range(1, len(participants_train[participant_i])):
            print(np.shape(participants_train[participant_i][session_j]))

            gesture_classification = TSD_Network(number_of_class=number_of_classes, num_neurons=num_kernels,
                                                 feature_vector_input_length=feature_vector_input_length)

            # loss functions
            crossEntropyLoss = nn.CrossEntropyLoss()
            # optimizer
            precision = 1e-8
            optimizer_classifier = optim.Adam(gesture_classification.parameters(), lr=learning_rate, betas=(0.5, 0.999))
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer_classifier, mode='min', factor=.2,
                                                             patience=5, verbose=True, eps=precision)
            # Fine-tune from the previous training
            gesture_classification, optimizer_classifier, scheduler, start_epoch = load_checkpoint( 
                model=gesture_classification, optimizer=optimizer_classifier, scheduler=scheduler,  
                filename=path_weights_fine_tuning + "/participant_%d/best_state_%d.pt" % (participant_i, 0))  

            best_weights = DANN_BN_Training(gesture_classifier=gesture_classification, scheduler=scheduler, 
                                            optimizer_classifier=optimizer_classifier,  
                                            train_dataset_source=participants_train[participant_i][0],  
                                            train_dataset_target=participants_train[participant_i][session_j],  
                                            validation_dataset_source=participants_validation[participant_i][0],  
                                            crossEntropyLoss=crossEntropyLoss,  
                                            patience_increment=patience_increment,  
                                            domain_loss_weight=1e-1)

            if not os.path.exists(path_weights_to_save_to + algo_name + "/participant_%d" % participant_i):
                os.makedirs(path_weights_to_save_to + algo_name + "/participant_%d" % participant_i)
            torch.save(best_weights, f=path_weights_to_save_to + algo_name + "/participant_%d/best_state_%d.pt" % (participant_i, session_j))

In [6]:
os.chdir("/home/laiy/gitrepos/msr_final/LongTermEMG-master")
from LongTermClassificationMain.Models.TSD_neural_network import TSD_Network
from LongTermClassificationMain.TrainingsAndEvaluations.training_loops_preparations import load_checkpoint
from LongTermClassificationMain.Models.model_training import train_model_standard

import torch
import torch.nn as nn
import torch.optim as optim

os.chdir("/home/laiy/gitrepos/msr_final/LongTermEMG-master/LongTermClassificationMain/TrainingsAndEvaluations/ForTrainingSessions/TSD_DNN")


In [7]:
num_neurons = [200, 200, 200]
feature_vector_input_length = 385
gestures_to_remove = [5, 6, 9, 10]
gestures_to_remove = None
number_of_class = 11
number_of_cycle_for_first_training = 4
number_of_cycles_rest_of_training = 4
learning_rate = 0.002515

path_weights_fine_tuning = "Weights_TSD/weights_THREE_CYCLES_TSD_ELEVEN_Gestures"
algo_name = "DANN_THREE_CYCLES_11Gestures_TSD"

In [8]:
train_DA_spectrograms(examples_datasets_train, labels_datasets_train, filter_size=None,
                      num_kernels=num_neurons, algo_name=algo_name,
                      path_weights_fine_tuning=path_weights_fine_tuning,
                      number_of_classes=number_of_class,
                      number_of_cycle_for_first_training=number_of_cycle_for_first_training,
                      number_of_cycles_rest_of_training=number_of_cycles_rest_of_training,
                      batch_size=128,
                      feature_vector_input_length=feature_vector_input_length,
                      path_weights_to_save_to="Weights_TSD/weights_", learning_rate=learning_rate)

GET one participant_examples  (4, 4)
   GET one training_index_examples  (4,)  at  0
   GOT one group XY  (4024, 385)    (4024,)
       one group XY test  (0,)    (0,)
       one group XY train (3621, 385)    (3621,)
       one group XY valid (403, 385)    (403, 385)
   GET one training_index_examples  (4,)  at  1
   GOT one group XY  (4190, 385)    (4190,)
       one group XY test  (0,)    (0,)
       one group XY train (3771, 385)    (3771,)
       one group XY valid (419, 385)    (419, 385)
   GET one training_index_examples  (4,)  at  2
   GOT one group XY  (4289, 385)    (4289,)
       one group XY test  (0,)    (0,)
       one group XY train (3860, 385)    (3860,)
       one group XY valid (429, 385)    (429, 385)
   GET one training_index_examples  (4,)  at  3
   GOT one group XY  (4099, 385)    (4099,)
       one group XY test  (0,)    (0,)
       one group XY train (3689, 385)    (3689,)
       one group XY valid (410, 385)    (410, 385)
dataloaders: 
   train  (1, 4)
   valid

  Variable._execution_engine.run_backward(


Accuracy source 0.891462, main loss classifier 0.286335, source classification loss 0.400808, loss domain distinction 0.641986, accuracy domain distinction 0.501953
VALIDATION Loss: 0.49979112 Acc: 0.84863524
New best validation loss:  0.49979111552238464
Epoch 1 of 500 took 0.594s
Accuracy source 0.903739, main loss classifier 0.250760, source classification loss 0.337574, loss domain distinction 0.578003, accuracy domain distinction 0.511858
VALIDATION Loss: 0.51330012 Acc: 0.83870968
Epoch 2 of 500 took 0.633s
Accuracy source 0.906808, main loss classifier 0.230268, source classification loss 0.302816, loss domain distinction 0.526905, accuracy domain distinction 0.510324
VALIDATION Loss: 0.48187330 Acc: 0.8560794
New best validation loss:  0.4818733036518097
Epoch 3 of 500 took 0.532s
Accuracy source 0.919643, main loss classifier 0.208102, source classification loss 0.262303, loss domain distinction 0.486857, accuracy domain distinction 0.501814
VALIDATION Loss: 0.55613410 Acc: 0.

Accuracy source 0.885324, main loss classifier 0.301480, source classification loss 0.430092, loss domain distinction 0.643969, accuracy domain distinction 0.497210
VALIDATION Loss: 0.63950646 Acc: 0.81141439
New best validation loss:  0.639506459236145
Epoch 1 of 500 took 0.532s
Accuracy source 0.899275, main loss classifier 0.250791, source classification loss 0.336432, loss domain distinction 0.586922, accuracy domain distinction 0.499442
VALIDATION Loss: 0.55062532 Acc: 0.83126551
New best validation loss:  0.5506253242492676
Epoch 2 of 500 took 0.526s
Accuracy source 0.907924, main loss classifier 0.233763, source classification loss 0.308926, loss domain distinction 0.532995, accuracy domain distinction 0.492327
VALIDATION Loss: 0.55077571 Acc: 0.82878412
Epoch 3 of 500 took 0.520s
Accuracy source 0.912109, main loss classifier 0.217942, source classification loss 0.281916, loss domain distinction 0.484622, accuracy domain distinction 0.496094
VALIDATION Loss: 0.46626574 Acc: 0.8

Accuracy source 0.910993, main loss classifier 0.205884, source classification loss 0.267212, loss domain distinction 0.390884, accuracy domain distinction 0.504046
VALIDATION Loss: 0.42521551 Acc: 0.85111663
Epoch 7 of 500 took 0.519s
Accuracy source 0.914900, main loss classifier 0.198051, source classification loss 0.251089, loss domain distinction 0.387187, accuracy domain distinction 0.501535
VALIDATION Loss: 0.41860244 Acc: 0.86848635
Epoch 8 of 500 took 0.526s
Accuracy source 0.921038, main loss classifier 0.195871, source classification loss 0.246116, loss domain distinction 0.384408, accuracy domain distinction 0.494420
VALIDATION Loss: 0.54058003 Acc: 0.82630273
Epoch 9 of 500 took 0.518s
Accuracy source 0.916853, main loss classifier 0.199931, source classification loss 0.256751, loss domain distinction 0.372859, accuracy domain distinction 0.502790
VALIDATION Loss: 0.56878448 Acc: 0.82133995
Epoch 10 of 500 took 0.521s
Accuracy source 0.915458, main loss classifier 0.201753