In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
import numpy as np
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader, SubsetRandomSampler
from sklearn.model_selection import KFold
import os
import copy

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CUDA_LAUNCH_BLOCKING=1

for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
%pip install codecarbon comet_ml

In [None]:
from comet_ml import Experiment
from codecarbon import EmissionsTracker
from datetime import datetime

# Initialise and start CodeCarbon tracker
tracker = EmissionsTracker()
tracker.start()

start_time = datetime.now()
print(f'Start time is {start_time}')

# Initialise the Comet experiment
experiment = Experiment(
    api_key="XXXXXXXXXXXXXXXXXXXXXXXXX",
    project_name="",
    workspace="",
)

In [None]:
## Custom data unsupervised 2D  starting from 1
### Definition
#### Chord dictionaries
triads_dictionary = {
    5:"I35",
    6:"I6",
    7:"I46",
    12:"II35",
    13:"II6",
    14:"II46",
    19:"III35",
    20:"III6",
    21:"III46",
    26:"IV35",
    27:"IV6",
    28:"IV46",
    33:"V35",
    34:"V6",
    35:"V46",
    40:"VI35",
    41:"VI6",
    42:"VI46",
    47:"VII35",
    48:"VII6",
    49:"VII46"    
}
seventhchords_dictionary = {
    1:"I7",
    2:"I56",
    3:"I34",
    4:"I2",
    8:"II7",
    9:"II56",
    10:"II34",
    11:"II2",
    15:"III7",
    16:"III56",
    17:"III34",
    18:"III2",
    22:"IV7",
    23:"IV56",
    24:"IV34",
    25:"IV2",
    29:"V7",
    30:"V56",
    31:"V34",
    32:"V2",
    36:"VI7",
    37:"VI56",
    38:"VI34",
    39:"VI2",
    43:"VII7",
    44:"VII56",
    45:"VII34",
    46:"VII2",
}
chords_dictionary = {**seventhchords_dictionary, **triads_dictionary}
chords_dictionary

#### Sequences
sequences = [
                ['I7', 'II35', 'I56', 'IV6', 'V6'],
                ['I7', 'VII6', 'I56', 'IV6', 'V6'],
                ['I7', 'V46', 'I56', 'IV6', 'V6'],

                ['I7', 'II35', 'I56', 'IV6', 'II35'],
                ['I7', 'VII6', 'I56', 'IV6', 'II35'],
                ['I7', 'V46', 'I56', 'IV6', 'II35'],

                ['I7', 'II35', 'I56', 'IV2', 'II35'],
                ['I7', 'VII6', 'I56', 'IV2', 'II35'],
                ['I7', 'V46', 'I56', 'IV2', 'II35'],

                ['I7', 'II35', 'VI34', 'II35', 'V6'],
                ['I7', 'VII6', 'VI34', 'II35', 'V6'],
                ['I7', 'V46', 'VI34', 'II35', 'V6'],

                ['I7', 'II35', 'VI34', 'IV35', 'II6'],
                ['I7', 'VII6', 'VI34', 'IV35', 'II6'],
                ['I7', 'V46', 'VI34', 'IV35', 'II6'],

                ['I7', 'II35', 'VI34', 'VII6', 'VI6'],
                ['I7', 'VII6', 'VI34', 'VII6', 'VI6'],
                ['I7', 'V46', 'VI34', 'VII6', 'VI6'],

                ['I7', 'II35', 'VI34', 'IV2', 'V34'],
                ['I7', 'VII6', 'VI34', 'IV2', 'V34'],
                ['I7', 'V46', 'VI34', 'IV2', 'V34'],

                ['I7', 'II35', 'IV2', 'VII6', 'VI6'],
                ['I7', 'VII6', 'IV2', 'VII6', 'VI6'],
                ['I7', 'V46', 'IV2', 'VII6', 'VI6'],

                ['I7', 'II35', 'IV2', 'V34', 'III2'],
                ['I7', 'VII6', 'IV2', 'V34', 'III2'],
                ['I7', 'V46', 'IV2', 'V34', 'III2'],

                ['I7', 'II35', 'IV2', 'VII56', 'V34'],
                ['I7', 'VII6', 'IV2', 'VII56', 'V34'],
                ['I7', 'V46', 'IV2', 'VII56', 'V34'],


                ['II7', 'III35', 'II56', 'V35', 'I6'],
                ['II7', 'I6', 'II56', 'V35', 'I6'],
                ['II7', 'VI46', 'II56', 'V35', 'I6'],

                ['II7', 'III35', 'II56', 'V6', 'VI35'],
                ['II7', 'I6', 'II56', 'V6', 'VI35'],
                ['II7', 'VI46', 'II56', 'V6', 'VI35'],

                ['II7', 'III35', 'II56', 'V2', 'I6'],
                ['II7', 'I6', 'II56', 'V2', 'I6'],
                ['II7', 'VI46', 'II56', 'V2', 'I6'],

                ['II7', 'III35', 'II56', 'V35', 'I35'],
                ['II7', 'I6', 'II56', 'V35', 'I35'],
                ['II7', 'VI46', 'II56', 'V35', 'I35'],

                ['II7', 'III35', 'VII34', 'III35', 'VI6'],
                ['II7', 'I6', 'VII34', 'III35', 'VI6'],
                ['II7', 'VI46', 'VII34', 'III35', 'VI6'],

                ['II7', 'III35', 'VII34', 'V35', 'III35'],
                ['II7', 'I6', 'VII34', 'V35', 'III35'],
                ['II7', 'VI46', 'VII34', 'V35', 'III35'],

                ['II7', 'III35', 'VII34', 'I56', 'II6'],
                ['II7', 'I6', 'VII34', 'I56', 'II6'],
                ['II7', 'VI46', 'VII34', 'I56', 'II6'],

                ['II7', 'III35', 'VII34', 'V2', 'VI34'],
                ['II7', 'I6', 'VII34', 'V2', 'VI34'],
                ['II7', 'VI46', 'VII34', 'V2', 'VI34'],

                ['II7', 'III35', 'V2', 'I6', 'II35'],
                ['II7', 'I6', 'V2', 'I6', 'II35'],
                ['II7', 'VI46', 'V2', 'I6', 'II35'],

                ['II7', 'III35', 'V2', 'VI34', 'IV2'],
                ['II7', 'I6', 'V2', 'VI34', 'IV2'],
                ['II7', 'VI46', 'V2', 'VI34', 'IV2'],

                ['II7', 'III35', 'V2', 'I56', 'VI34'],
                ['II7', 'I6', 'V2', 'I56', 'VI34'],
                ['II7', 'VI46', 'V2', 'I56', 'VI34'],


                ['III7', 'IV35', 'III56', 'VI35', 'II35'],
                ['III7', 'II6', 'III56', 'VI35', 'II35'],
                ['III7', 'VII46', 'III56', 'VI35', 'II35'],

                ['III7', 'IV35', 'III56', 'VI35', 'IV6'],
                ['III7', 'II6', 'III56', 'VI35', 'IV6'],
                ['III7', 'VII46', 'III56', 'VI35', 'IV6'],

                ['III7', 'IV35', 'III56', 'VI2', 'IV35'],
                ['III7', 'II6', 'III56', 'VI2', 'IV35'],
                ['III7', 'VII46', 'III56', 'VI2', 'IV35'],

                ['III7', 'IV35', 'I34', 'IV35', 'V6'],
                ['III7', 'II6', 'I34', 'IV35', 'V6'],
                ['III7', 'VII46', 'I34', 'IV35', 'V6'],

                ['III7', 'IV35', 'I34', 'IV35', 'I35'],
                ['III7', 'II6', 'I34', 'IV35', 'I35'],
                ['III7', 'VII46', 'I34', 'IV35', 'I35'],

                ['III7', 'IV35', 'I34', 'VI35', 'IV6'],
                ['III7', 'II6', 'I34', 'VI35', 'IV6'],
                ['III7', 'VII46', 'I34', 'VI35', 'IV6'],

                ['III7', 'IV35', 'I34', 'II6', 'I6'],
                ['III7', 'II6', 'I34', 'II6', 'I6'],
                ['III7', 'VII46', 'I34', 'II6', 'I6'],

                ['III7', 'IV35', 'I34', 'VI2', 'II6'],
                ['III7', 'II6', 'I34', 'VI2', 'II6'],
                ['III7', 'VII46', 'I34', 'VI2', 'II6'],

                ['III7', 'IV35', 'VI2', 'II6', 'I6'],
                ['III7', 'II6', 'VI2', 'II6', 'I6'],
                ['III7', 'VII46', 'VI2', 'II6', 'I6'],

                ['III7', 'IV35', 'VI2', 'VII34', 'V2'],
                ['III7', 'II6', 'VI2', 'VII34', 'V2'],
                ['III7', 'VII46', 'VI2', 'VII34', 'V2'],

                ['III7', 'IV35', 'VI2', 'II56', 'V2'],
                ['III7', 'II6', 'VI2', 'II56', 'V2'],
                ['III7', 'VII46', 'VI2', 'II56', 'V2'],


                ['IV7', 'V35', 'IV56', 'II34', 'V6'],
                ['IV7', 'III6', 'IV56', 'II34', 'V6'],
                ['IV7', 'I46', 'IV56', 'II34', 'V6'],

                ['IV7', 'V35', 'IV56', 'II34', 'III56'],
                ['IV7', 'III6', 'IV56', 'II34', 'III56'],
                ['IV7', 'I46', 'IV56', 'II34', 'III56'],

                ['IV7', 'V35', 'IV56', 'VII2', 'III6'],
                ['IV7', 'III6', 'IV56', 'VII2', 'III6'],
                ['IV7', 'I46', 'IV56', 'VII2', 'III6'],

                ['IV7', 'V35', 'II34', 'V35', 'I6'],
                ['IV7', 'III6', 'II34', 'V35', 'I6'],
                ['IV7', 'I46', 'II34', 'V35', 'I6'],

                ['IV7', 'V35', 'II34', 'VII2', 'III6'],
                ['IV7', 'III6', 'II34', 'VII2', 'III6'],
                ['IV7', 'I46', 'II34', 'VII2', 'III6'],

                ['IV7', 'V35', 'II34', 'III56', 'II6'],
                ['IV7', 'III6', 'II34', 'III56', 'II6'],
                ['IV7', 'I46', 'II34', 'III56', 'II6'],

                ['IV7', 'V35', 'II34', 'VII2', 'I34'],
                ['IV7', 'III6', 'II34', 'VII2', 'I34'],
                ['IV7', 'I46', 'II34', 'VII2', 'I34'],

                ['IV7', 'V35', 'VII2', 'I34', 'II56'],
                ['IV7', 'III6', 'VII2', 'I34', 'II56'],
                ['IV7', 'I46', 'VII2', 'I34', 'II56'],

                ['IV7', 'V35', 'VII2', 'I34', 'VI2'],
                ['IV7', 'III6', 'VII2', 'I34', 'VI2'],
                ['IV7', 'I46', 'VII2', 'I34', 'VI2'],

                ['IV7', 'V35', 'VII2', 'III56', 'I34'],
                ['IV7', 'III6', 'VII2', 'III56', 'I34'],
                ['IV7', 'I46', 'VII2', 'III56', 'I34'],


                ['V7', 'VI35', 'V56', 'I35', 'VI6'],
                ['V7', 'IV6', 'V56', 'I35', 'VI6'],
                ['V7', 'II46', 'V56', 'I35', 'VI6'],

                ['V7', 'VI35', 'V56', 'I35', 'VI35'],
                ['V7', 'IV6', 'V56', 'I35', 'VI35'],
                ['V7', 'II46', 'V56', 'I35', 'VI35'],

                ['V7', 'VI35', 'V56', 'I2', 'VI35'],
                ['V7', 'IV6', 'V56', 'I2', 'VI35'],
                ['V7', 'II46', 'V56', 'I2', 'VI35'],

                ['V7', 'VI35', 'III34', 'VI35', 'II6'],
                ['V7', 'IV6', 'III34', 'VI35', 'II6'],
                ['V7', 'II46', 'III34', 'VI35', 'II6'],

                ['V7', 'VI35', 'III34', 'I35', 'VI6'],
                ['V7', 'IV6', 'III34', 'I35', 'VI6'],
                ['V7', 'II46', 'III34', 'I35', 'VI6'],

                ['V7', 'VI35', 'III34', 'IV6', 'III6'],
                ['V7', 'IV6', 'III34', 'IV6', 'III6'],
                ['V7', 'II46', 'III34', 'IV6', 'III6'],

                ['V7', 'VI35', 'III34', 'I2', 'II34'],
                ['V7', 'IV6', 'III34', 'I2', 'II34'],
                ['V7', 'II46', 'III34', 'I2', 'II34'],

                ['V7', 'VI35', 'I2', 'IV56', 'III6'],
                ['V7', 'IV6', 'I2', 'IV56', 'III6'],
                ['V7', 'II46', 'I2', 'IV56', 'III6'],

                ['V7', 'VI35', 'I2', 'II34', 'V35'],
                ['V7', 'IV6', 'I2', 'II34', 'V35'],
                ['V7', 'II46', 'I2', 'II34', 'V35'],

                ['V7', 'VI35', 'I2', 'IV56', 'II34'],
                ['V7', 'IV6', 'I2', 'IV56', 'II34'],
                ['V7', 'II46', 'I2', 'IV56', 'II34'],


                ['VI7', 'VII35', 'VI56', 'II35', 'V35'],
                ['VI7', 'V6', 'VI56', 'II35', 'V35'],
                ['VI7', 'III46', 'VI56', 'II35', 'V35'],

                ['VI7', 'VII35', 'VI56', 'IV34', 'V6'],
                ['VI7', 'V6', 'VI56', 'IV34', 'V6'],
                ['VI7', 'III46', 'VI56', 'IV34', 'V6'],

                ['VI7', 'VII35', 'VI56', 'II2', 'V6'],
                ['VI7', 'V6', 'VI56', 'II2', 'V6'],
                ['VI7', 'III46', 'VI56', 'II2', 'V6'],

                ['VI7', 'VII35', 'IV34', 'V56', 'I35'],
                ['VI7', 'V6', 'IV34', 'V56', 'I35'],
                ['VI7', 'III46', 'IV34', 'V56', 'I35'],

                ['VI7', 'VII35', 'IV34', 'II35', 'V6'],
                ['VI7', 'V6', 'IV34', 'II35', 'V6'],
                ['VI7', 'III46', 'IV34', 'II35', 'V6'],

                ['VI7', 'VII35', 'IV34', 'V6', 'IV6'],
                ['VI7', 'V6', 'IV34', 'V6', 'IV6'],
                ['VI7', 'III46', 'IV34', 'V6', 'IV6'],

                ['VI7', 'VII35', 'IV34', 'II2', 'III34'],
                ['VI7', 'V6', 'IV34', 'II2', 'III34'],
                ['VI7', 'III46', 'IV34', 'II2', 'III34'],

                ['VI7', 'VII35', 'II2', 'V6', 'IV6'],
                ['VI7', 'V6', 'II2', 'V6', 'IV6'],
                ['VI7', 'III46', 'II2', 'V6', 'IV6'],

                ['VI7', 'VII35', 'II2', 'III34', 'I2'],
                ['VI7', 'V6', 'II2', 'III34', 'I2'],
                ['VI7', 'III46', 'II2', 'III34', 'I2'],

                ['VI7', 'VII35', 'II2', 'V56', 'III34'],
                ['VI7', 'V6', 'II2', 'V56', 'III34'],
                ['VI7', 'III46', 'II2', 'V56', 'III34'],


                ['VII7', 'I35', 'VII56', 'III35', 'I6'],
                ['VII7', 'VI6', 'VII56', 'III35', 'I6'],
                ['VII7', 'IV46', 'VII56', 'III35', 'I6'],
    
                ['VII7', 'I35', 'VII56', 'I6', 'II35'],
                ['VII7', 'VI6', 'VII56', 'I6', 'II35'],
                ['VII7', 'IV46', 'VII56', 'I6', 'II35'],

                ['VII7', 'I35', 'VII56', 'V34', 'I35'],
                ['VII7', 'VI6', 'VII56', 'V34', 'I35'],
                ['VII7', 'IV46', 'VII56', 'V34', 'I35'],

                ['VII7', 'I35', 'V34', 'I35', 'IV6'],
                ['VII7', 'VI6', 'V34', 'I35', 'IV6'],
                ['VII7', 'IV46', 'V34', 'I35', 'IV6'],

                ['VII7', 'I35', 'V34', 'III35', 'I6'],
                ['VII7', 'VI6', 'V34', 'III35', 'I6'],
                ['VII7', 'IV46', 'V34', 'III35', 'I6'],

                ['VII7', 'I35', 'V34', 'VI6', 'V6'],
                ['VII7', 'VI6', 'V34', 'VI6', 'V6'],
                ['VII7', 'IV46', 'V34', 'VI6', 'V6'],

                ['VII7', 'I35', 'V34', 'III2', 'IV34'],
                ['VII7', 'VI6', 'V34', 'III2', 'IV34'],
                ['VII7', 'IV46', 'V34', 'III2', 'IV34'],

                ['VII7', 'I35', 'III2', 'VI6', 'V6'],
                ['VII7', 'VI6', 'III2', 'VI6', 'V6'],
                ['VII7', 'IV46', 'III2', 'VI6', 'V6'],

                ['VII7', 'I35', 'III2', 'IV34', 'II2'],
                ['VII7', 'VI6', 'III2', 'IV34', 'II2'],
                ['VII7', 'IV46', 'III2', 'IV34', 'II2'],

                ['VII7', 'I35', 'III2', 'VI56', 'IV34'],
                ['VII7', 'VI6', 'III2', 'VI56', 'IV34'],
                ['VII7', 'IV46', 'III2', 'VI56', 'IV34']
]
digi_sequences = []
for sequence in sequences:
    digi_sequence = [key for chord in sequence for key in chords_dictionary if chord == chords_dictionary[key]]
    digi_sequences.append(digi_sequence)

#### Matrix definition
list1 = [0, 1, 0, 0, 2, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
list2 = [35, 48, 12, 49, 13, 26, 14, 27, 40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
list3 = [0, 2, 0, 0, 3, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
list4 = [0, 0, 0, 0, 15, 0, 0, 16, 0, 0, 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
list5 = [0, 0, 0, 49, 13, 26, 14, 27, 40, 28, 41, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
list6 = [0, 0, 0, 0, 16, 0, 0, 17, 0, 0, 18, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
list7 = [0, 0, 0, 0, 0, 0, 0, 29, 0, 0, 30, 0, 0, 31, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
list8 = [0, 0, 0, 0, 0, 0, 14, 27, 40, 28, 41, 5, 42, 6, 19, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
list9 = [0, 0, 0, 0, 0, 0, 0, 30, 0, 0, 31, 0, 0, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
list10 = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 43, 0, 0, 44, 0, 0, 45, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
list11 = [0, 0, 0, 0, 0, 0, 0, 0, 0, 28, 41, 5, 42, 6, 19, 7, 20, 33, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
list12 = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 44, 0, 0, 45, 0, 0, 46, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
list13 = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 9, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0]
list14 = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 42, 6, 19, 7, 20, 33, 21, 34, 47, 0, 0, 0, 0, 0, 0, 0]
list15 = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 0, 0, 10, 0, 0, 11, 0, 0, 0, 0, 0, 0, 0, 0]
list16 = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 22, 0, 0, 23, 0, 0, 24, 0, 0, 0, 0, 0]
list17 = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 20, 33, 21, 34, 47, 35, 48, 12, 0, 0, 0, 0]
list18 = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 23, 0, 0, 24, 0, 0, 25, 0, 0, 0, 0, 0]
list19 = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 36, 0, 0, 37, 0, 0, 38, 0, 0]
list20 = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 21, 34, 47, 35, 48, 12, 49, 13, 26, 0]
list21 = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 37, 0, 0, 38, 0, 0, 39, 0, 0]
list22 = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
list23 = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
list24 = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
list25 = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
list26 = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
list27 = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
list28 = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
graph = torch.tensor([list1, list2, list3, list4, list5, list6, list7, list8, list9, list10, 
                      list11, list12, list13, list14, list15, list16, list17, list18, list19, list20,
                      list21, list22, list23, list24, list25, list26, list27, list28], dtype=torch.int32)

In [None]:
### Conversion
#### Retrieve indexes
# Function that accounts for the sequence of appearence - you get 5 maps for one sequence
def get_sequence_indexes(seq, graph):
    idx_dict = {}
    indexes = []
    chords_list = [] 
    for chord in seq:                           # for each chord in a sequence
        chord_positions = []
        for row in range(len(graph)):           # iterate through 28 rows of the matrix
            val = (graph[row]==chord).nonzero() # find the indices of an array, where a condition is True
            if len(val) > 0:                    # if the tensor is not empty, it is returning the position of the chord within the matrix
                position_dict = {
                    "position": val.item(),    # position in a row where the value should be inserted
                    "chord": chord,            # value to be inserted
                    "row": row                 # rows where the value should be inserted
                    }
                chord_positions.append(position_dict)
                chords_list.append(position_dict)      
    return chords_list
chords_list = get_sequence_indexes(digi_sequences[0], graph)
chords_list

In [None]:
#### Group dicts
# Group dicts
grouped_chord_sequences = []
for sequence in digi_sequences:             # for each harmonic sequence
    dictionaries = get_sequence_indexes(sequence, graph)  # get positions where the chords should be inserted   
    
    grouped_chord_sequence = []
    group = []
    for i, dictionary in enumerate(dictionaries):
        if i == 0:
            group.append(dictionary)
        else:
            if dictionary.get('chord') == dictionaries[i-1].get('chord'):
                group.append(dictionary)
            else:
                grouped_chord_sequence.append(group)
                group = []
                group.append(dictionary)
    grouped_chord_sequence.append(group)
    grouped_chord_sequences.append(grouped_chord_sequence)
    
len(grouped_chord_sequences)

In [None]:
#### Get matrices
def get_tensor_with_positions(list_entry):
    zero_tensor = torch.zeros(28, 28)                # initialise a 2D tensor filled with zero values      
    for dict_entry in list_entry: 
        zero_tensor[dict_entry.get('row')][dict_entry.get('position')] = dict_entry.get('chord') # insert value in a 2D tensor primarely filled with zeros
    return zero_tensor

def get_sequence_data(grouped_dictionary):
    sequence_data = []                                    # list of one sequence of 5 maps, should be zeroed at each itertion
    for list_entry in grouped_dictionary:                 # for each dictionary in a list of grouped dicts
        zero_tensor = get_tensor_with_positions(list_entry)
        deep_copy = copy.deepcopy(zero_tensor)
        sequence_data.append(deep_copy)
    return sequence_data

sequences_data = [] # list of sequences
for grouped_dictionary in grouped_chord_sequences:
    sequence_data = get_sequence_data(grouped_dictionary)
    sequences_data.append(sequence_data)
len(sequences_data)

In [None]:
#### Normalize and augment
normalized_sequences_data = []
for i in range(10):
    num = float("49.0" + str(i))
    for sequence in sequences_data:
        normalized_sequence_data = [] # holds 5 maps for 5 chords
        for j in range(len(sequence)):
            one_matrix = copy.deepcopy(sequence[j])
            one_matrix /= 49
            normalized_sequence_data.append(one_matrix)
        normalized_sequences_data.append(normalized_sequence_data)
        
print(len(normalized_sequences_data))

In [None]:
#### Reshape
# Function to reshape sequence data
reshaped_sequences = []
for i in range(0, (len(normalized_sequences_data)), 1):
    channel_redim_sequence_2 = np.append(normalized_sequences_data[i][0], normalized_sequences_data[i][1])
    channel_redim_sequence_3 = np.append(channel_redim_sequence_2, normalized_sequences_data[i][2])
    channel_redim_sequence_4 = np.append(channel_redim_sequence_3, normalized_sequences_data[i][3])
    channel_redim_sequence_5 = np.append(channel_redim_sequence_4, normalized_sequences_data[i][4])
    channel_reshape_5 = channel_redim_sequence_5.reshape(5, 28, 28)
    reshaped_sequences.append(channel_reshape_5)
reshaped_sequences_np = np.array(reshaped_sequences)

print(reshaped_sequences_np.shape)

In [None]:
# Shuffle custom data
np.random.shuffle(reshaped_sequences_np)

In [None]:
### Utils
def collate(batch):
    batch = torch.tensor(batch).unsqueeze(1)
    batch = batch.to(device)  
    return batch[:,:,0:4], batch[:,:,4]

def reset_weights(m):
    m.to(device)
    for layer in m.children():
        if hasattr(layer, 'reset_parameters'):
            print(f'Reset trainable parameters of layer = {layer}')
            layer.reset_parameters()
    

In [None]:
### ConvLSTM cell and layer
# Original ConvLSTM cell as proposed by Shi et al.
class ConvLSTMCell(nn.Module):
    def __init__(self, in_channels, out_channels, 
    kernel_size, padding, activation, frame_size):
        super(ConvLSTMCell, self).__init__()
        if activation == "tanh":
            self.activation = torch.tanh 
        elif activation == "relu":
            self.activation = torch.relu
        
        # Idea adapted from https://github.com/ndrplz/ConvLSTM_pytorch
        self.conv = nn.Conv2d(
            in_channels=in_channels + out_channels, 
            out_channels=4 * out_channels, 
            kernel_size=kernel_size, 
            padding=padding)           

        # Initialize weights for Hadamard Products - changed to torch.rand
        self.W_ci = nn.Parameter(torch.randn(out_channels, *frame_size)) # out-channels=28
        self.W_co = nn.Parameter(torch.randn(out_channels, *frame_size)) # frame_size=(28, 28)
        self.W_cf = nn.Parameter(torch.randn(out_channels, *frame_size))

    def forward(self, X, H_prev, C_prev):
        # Idea adapted from https://github.com/ndrplz/ConvLSTM_pytorch
        conv_output = self.conv(torch.cat([X, H_prev], dim=1))

        # Idea adapted from https://github.com/ndrplz/ConvLSTM_pytorch
        i_conv, f_conv, C_conv, o_conv = torch.chunk(conv_output, chunks=4, dim=1)

        input_gate = torch.sigmoid(i_conv + self.W_ci * C_prev )
        forget_gate = torch.sigmoid(f_conv + self.W_cf * C_prev )

        # Current Cell output
        C = forget_gate*C_prev + input_gate * self.activation(C_conv)
        output_gate = torch.sigmoid(o_conv + self.W_co * C )

        # Current Hidden State
        H = output_gate * self.activation(C)
        return H, C

### ConvLSTM layer
class ConvLSTM(nn.Module):
    def __init__(self, in_channels, out_channels, 
    kernel_size, padding, activation, frame_size):

        super(ConvLSTM, self).__init__()
        self.out_channels = out_channels

        # We will unroll this over time steps
        self.convLSTMcell = ConvLSTMCell(in_channels, out_channels, 
        kernel_size, padding, activation, frame_size)

    def forward(self, X):
        # X is a frame sequence (batch_size, num_channels, seq_len, height, width)
        # Get the dimensions
        batch_size, _, seq_len, height, width = X.size()

        # Initialize output
        output = torch.zeros(batch_size, self.out_channels, seq_len, 
        height, width, device=device)
        
        # Initialize Hidden State
        H = torch.zeros(batch_size, self.out_channels, 
        height, width, device=device)

        # Initialize Cell Input
        C = torch.zeros(batch_size,self.out_channels, 
        height, width, device=device)

        # Unroll over time steps
        for time_step in range(seq_len):
            H, C = self.convLSTMcell(X[:,:,time_step], H, C)
            output[:,:,time_step] = H
        return output

In [None]:
### ConvLSTM model 1 layer

class Seq2Seq(nn.Module):
    def __init__(self, num_channels, num_kernels, kernel_size, padding, 
    activation, frame_size):
        super(Seq2Seq, self).__init__()
        self.sequential = nn.Sequential()

        # Add First layer (Different in_channels than the rest)
        self.sequential.add_module(
            "convlstm1", ConvLSTM(
                in_channels=num_channels, out_channels=num_kernels,
                kernel_size=kernel_size, padding=padding, 
                activation=activation, frame_size=frame_size)
        )

        self.sequential.add_module(
            "batchnorm1", nn.BatchNorm3d(num_features=num_kernels)
        ) 

        # Add Convolutional Layer to predict output frame
        self.conv = nn.Conv2d(
            in_channels=num_kernels, out_channels=num_channels,
            kernel_size=kernel_size, padding=padding)

    def forward(self, X):
        # Forward propagation through all the layers
        output = self.sequential(X)

        # Return only the last output frame
        output = self.conv(output[:,:,-1])        
        return nn.Sigmoid()(output)

In [None]:
### K-fold Cross Validator
# Params
torch.manual_seed(42)
num_epochs = 100
criterion = nn.BCELoss(reduction='sum')

# Fold results storage objects
train_start_results = {}
val_start_results = {}

train_end_results = {}
val_end_results = {}

# Per fold epoch results storage objects
train_results_per_epoch = []
val_results_per_epoch = []

train_results = []
val_results = []

accuracy_train_per_epoch = []
accuracy_val_per_epoch = []

raw_accuracy_train_per_batch = []
raw_accuracy_val_per_batch = []

accuracy_train_results = []
accuracy_val_results = []

raw_accuracy_train_results = []
raw_accuracy_val_results = []

# Define the K-fold Cross Validator
k_folds = 5
kfold = KFold(n_splits=k_folds, shuffle=True)

# Whole dataset
dataset = reshaped_sequences_np

# K-fold Cross Validation model evaluation
for fold, (train_ids, val_ids) in enumerate(kfold.split(dataset)):
    print(f'FOLD {fold}')
    
    # Sample elements randomly from a given list of ids, no replacement.
    train_subsampler = SubsetRandomSampler(train_ids)
    val_subsampler = SubsetRandomSampler(val_ids)
    
    # Define data loaders for training and testing data in this fold
    train_loader = DataLoader(dataset, batch_size=15, collate_fn=collate, sampler=train_subsampler)
    val_loader = DataLoader(dataset, batch_size=15, collate_fn=collate, sampler=val_subsampler)
    
    # Initialization
    model = Seq2Seq(num_channels=1, num_kernels=28, kernel_size=(3, 3), padding=(1, 1), activation="tanh", frame_size=(28, 28)).to(device)
    model.apply(reset_weights)  
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    
    train_results_per_epoch = []
    val_results_per_epoch = []
    for epoch in range(1, num_epochs+1):
        train_loss = 0
        train_accuracy = 0
        model = model.to(device)
        model.train()        
        for batch, (x, y) in enumerate(train_loader, 1):
            # Make sure input has no NaN
            #input_data = torch.nan_to_num(x)
            
            # Move data and the model to cuda
            #input_data = input_data.to(device)
            x = x.to(device)
            y = y.to(device)
            
            # Get the model prediction
            #output = model(input_data)
            output = model(x)
            
            # Make sure output has no NaN
            #preds = torch.nan_to_num(output)
            
            # Move predictions to cuda
            #preds = preds.to(device)
            output = output.to(device)
            
            # Calculate train loss for each batch            
            #loss_train = criterion(preds.flatten(), y.flatten())
            loss_train = criterion(output.flatten(), y.flatten())
            
            # Backpropagate the loss
            loss_train.backward()
            
            # Apply oplimizer and reset it
            optimizer.step()                                               
            optimizer.zero_grad() 
            
            # Sum up train loss for all the batches in the epoch
            train_loss += loss_train.item()
            
            # Accuracy
            target_keys_train = get_keys(y)
            pred_keys_train = get_keys(output)
            
            accuracy_train = get_accuracy(target_keys_train, pred_keys_train)
            train_accuracy += accuracy_train
            
        # Calculate total train loss for the epoch
        total_train_loss = train_loss / len(train_loader.dataset)
        train_results_per_epoch.append(total_train_loss)
        
        # Calculate total train accuracy for the epoch
        raw_accuracy_train_per_batch.append(accuracy_train)
        if train_accuracy > 0:
            total_train_accuracy = train_accuracy / len(train_loader.dataset)
            accuracy_train_per_epoch.append(total_train_accuracy)            
        else:
            accuracy_train_per_epoch.append(train_accuracy)
        
        val_loss = 0
        val_accuracy = 0
        model.eval()                                                   
        with torch.no_grad():                                          
            for x, y in val_loader:
                # Make sure input has no NaN
                #input_data = torch.nan_to_num(x)
                
                # Move data to cuda
                #input_data = input_data.to(device)
                x = x.to(device)
                y = y.to(device)
                
                # Get the model prediction
                #output = model(input_data)
                output = model(x)
                
                # Make sure output has no NaN
                #preds = torch.nan_to_num(output)
                
                # Move preds to cuda
                output = output.to(device)
                
                # Calculate validation loss for each batch 
                #loss_val = criterion(preds.flatten(), y.flatten())
                loss_val = criterion(output.flatten(), y.flatten())
                
                # Sum up validation loss for all the batches in the epoch
                val_loss += loss_val.item()
                
                # Accuracy
                target_keys_val = get_keys(y)
                pred_keys_val = get_keys(output)
                
                accuracy_val = get_accuracy(target_keys_val,pred_keys_val)
                val_accuracy += accuracy_val
                
        # Calculate total validation loss for the epoch
        total_val_loss = val_loss / len(val_loader.dataset)
        val_results_per_epoch.append(total_val_loss)
        
        # Calculate total validation accuracy for the epoch
        raw_accuracy_val_per_batch.append(val_accuracy)
        if val_accuracy > 0:
            total_val_accuracy = val_accuracy / len(val_loader.dataset)
            accuracy_val_per_epoch.append(total_val_accuracy)
        else:
            accuracy_val_per_epoch.append(val_accuracy)
        
        # Store scores        
        if epoch == 1:
            val_start_results[fold] = total_val_loss
            train_start_results[fold] = total_train_loss
        else:
            val_end_results[fold] = total_val_loss
            train_end_results[fold] = total_train_loss
            
        print("Epoch:{} Training Loss:{:.2f} Validation Loss:{:.2f}\n".format(
            epoch, total_train_loss, total_val_loss))
    train_results.append(train_results_per_epoch)
    val_results.append(val_results_per_epoch)
    
    accuracy_train_results.append(accuracy_train_per_epoch)
    accuracy_val_results.append(accuracy_val_per_epoch)
    raw_accuracy_train_results.append(raw_accuracy_train_per_batch)
    raw_accuracy_val_results.append(raw_accuracy_val_per_batch)
    
    # Saving the model
    save_path = f'./convlstm-custom2000-model-fold-{fold}.pth'
    torch.save(model.state_dict(), save_path)

In [None]:
# Stop CO2 tracker and print emissions

emissions: float = tracker.stop()
print(f"Emissions: {emissions} kg")

# Calculate the time spent
stop_time = datetime.now() - start_time
time_spend = start_time - stop_time

# Time logs
experiment.log_metric("start_time", start_time) 
experiment.log_metric("stop_time", stop_time)
experiment.log_metric("time_spend", time_spend)

# Turn off Comet
experiment.end()

In [None]:
# Inference
data_loader = DataLoader(dataset, batch_size=1, collate_fn=collate, drop_last=True)
data, target = next(iter(data_loader))

model.eval()                                                   
with torch.no_grad():                                          
    output = model(data)
output.shape

In [None]:
# Generated image visualization
img_gen = output[0].cpu().reshape(28, 28, 1)
plt.imshow(img_gen)
plt.show()

In [None]:
# Target image visualization
target_reshaped = target[0].cpu().reshape(28, 28, 1)
plt.imshow(target_reshaped)
plt.show()

In [None]:
# Inference for 15 sequences
data_loader = DataLoader(dataset, batch_size=15, collate_fn=collate, drop_last=True)
data, target = next(iter(data_loader))

model.eval()                                                   
with torch.no_grad():                                          
    output = model(data)

# Reshape targets and generated
targets = target.reshape(15, 28, 28, 1)
imgs_gen = output.reshape(15, 28, 28, 1)

# Join tensors for a singe image
combined = torch.cat((targets, imgs_gen), 0)
combined.shape

In [None]:
# Single normalization

fig = plt.figure(figsize=(15, 15))
grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(2, 15),  # creates 2x2 grid of axes
                 axes_pad=0.05,  # pad between axes in inch.
                 )

for ax, im in zip(grid, combined.cpu()):
    # Iterating over the grid returns the Axes.
    ax.imshow(im)
    
plt.show()

In [None]:
# Print start fold results
print(f'Start K-FOLD RESULTS FOR {k_folds} FOLDS')
sum = 0.0
for key, value in train_start_results.items():
    print(f'Fold {key}: {value}')
    sum += value
print(f'Average train: {sum/len(train_start_results.items())}')

sum = 0.0
for key, value in val_start_results.items():
    print(f'Fold {key}: {value}')
    sum += value
print(f'Average val: {sum/len(val_start_results.items())}')

In [None]:
# Print final fold results
print(f'End K-FOLD RESULTS FOR {k_folds} FOLDS')
sum = 0.0
for key, value in train_end_results.items():
    print(f'Fold {key}: {value}')
    sum += value
print(f'Average train: {sum/len(train_end_results.items())}')

sum = 0.0
for key, value in val_end_results.items():
    print(f'Fold {key}: {value}')
    sum += value
print(f'Average val: {sum/len(val_end_results.items())}')

In [None]:
# Train and validation results

import matplotlib.pyplot as plt
x = list(range(0, 100))

fig, ax = plt.subplots()
t1, = ax.plot(x, train_results[0], c="blue")
t2, = ax.plot(x, train_results[1], c="brown")
t3, = ax.plot(x, train_results[2], c="green")
t4, = ax.plot(x, train_results[3], c="orange")
t5, = ax.plot(x, train_results[4], c="magenta")
v1, = ax.plot(x, val_results[0], c="blue", ls="dashed")
v2, = ax.plot(x, val_results[1], c="brown", ls="dashed")
v3, = ax.plot(x, val_results[2], c="green", ls="dashed")
v4, = ax.plot(x, val_results[3], c="orange", ls="dashed")
v5, = ax.plot(x, val_results[4], c="magenta", ls="dashed")
ax.legend((t1, t2, t3, t4, t5, v1, v2, v3, v4, v5), ('1st train fold', '2nd train fold', "3rd train fold", "4th train fold", "5th train fold", '1st val fold', '2nd val fold', "3rd val fold", "4th val fold", "5th val fold"), loc='upper right', shadow=True)
ax.set_xlabel('epochs')
ax.set_ylabel('loss')
ax.set_title('Train and validation results for 5 folds')
plt.show()

In [None]:
# Train
fig, ax = plt.subplots()
t1, = ax.plot(x, train_results[0], c="blue")
t2, = ax.plot(x, train_results[1], c="brown")
t3, = ax.plot(x, train_results[2], c="green")
t4, = ax.plot(x, train_results[3], c="orange")
t5, = ax.plot(x, train_results[4], c="magenta")
ax.legend((t1, t2, t3, t4, t5), ('1st train fold', '2nd train fold', "3rd train fold", "4th train fold", "5th train fold"), loc='upper right', shadow=True)
ax.set_xlabel('epochs')
ax.set_ylabel('loss')
ax.set_title('Train results for 5 folds')
plt.show()

In [None]:
# Validation
fig, ax = plt.subplots()
v1, = ax.plot(x, val_results[0], c="blue", ls="dashed")
v2, = ax.plot(x, val_results[1], c="brown", ls="dashed")
v3, = ax.plot(x, val_results[2], c="green", ls="dashed")
v4, = ax.plot(x, val_results[3], c="orange", ls="dashed")
v5, = ax.plot(x, val_results[4], c="magenta", ls="dashed")
ax.legend((v1, v2, v3, v4, v5), ("1st val fold", "2nd val fold", "3rd val fold", "4th val fold", "5th val fold"), loc='upper right', shadow=True)
ax.set_xlabel('epochs')
ax.set_ylabel('loss')
ax.set_title('Validation results for 5 folds')
plt.show()

In [None]:
import json

with open("./" + "train_results.json", 'w') as outfile:
    json.dump(train_results, outfile)
with open("./" + "val_results.json", 'w') as outfile:
    json.dump(val_results, outfile)