In [None]:
import math
import numpy as np
import matplotlib.pyplot as plt
from matrixprofile import matrixProfile, motifs
from pyclustering.nnet.som import type_conn
import dtwsom

class FindMotifs:
    def __init__(self, time_series):
        self.original_data = time_series
        self.size = len(time_series[0])  # x data points
        self.sample_size = len(time_series)  # rows
        self.joined_data = np.concatenate(time_series)
    
    def find_motifs(self, max_motifs_user=1000):
        self.mp = matrixProfile.stomp(self.joined_data)
        self.mtfs, self.motif_d = motifs.motifs(self.joined_data, self.mp, max_motifs=max_motifs_user)
        
        self.motif_center_list = []
        for motif in self.mtfs:
            center_pointers = list(np.arange(motif[0], motif[0] + self.size))
            motif_center = self.joined_data[center_pointers]
            self.motif_center_list.append(motif_center)
        
        return self.mtfs, self.motif_d, self.motif_center_list

class AutoDTWSOM(FindMotifs):
    def __init__(self, structure_user, epochs_user, original_data):
        super().__init__(original_data)
        self.mtfs, self.motif_d, self.motif_center_list = self.find_motifs()
        
        self.rows = 3
        self.cols = 3
        self.structure = structure_user
        self.network = dtwsom.DtwSom(self.rows, self.cols, self.structure)
        self.epochs_user = epochs_user

    def train_network(self):
        self.network.train(self.motif_center_list, self.epochs_user)
    
    def showing_results(self):
        self.network.show_distance_matrix()
        self.network.show_winner_matrix()
        
        n_neurons = self.network._size
        fig, axs = plt.subplots(n_neurons, n_neurons, figsize=(10, 6), sharex=True, sharey=True)
        
        for neuron_index in range(n_neurons):
            col = neuron_index // n_neurons
            row = neuron_index % n_neurons
            neuron_weights = self.network._weights[neuron_index]
            axs[row, col].plot(np.arange(len(neuron_weights)), neuron_weights, label=str(neuron_index))
            axs[row, col].set_title("Unit: " + str(neuron_index))
            
            if col == 0:
                axs[row, col].set_ylabel("Sequence values")
            if row == n_neurons - 1:
                axs[row, col].set_xlabel("Time")
        
        plt.tight_layout()
        plt.show()
        
        fig, axs = plt.subplots(n_neurons, n_neurons, figsize=(10, 6), sharex=True, sharey=True)
        for neuron_index in range(n_neurons):
            col = neuron_index // n_neurons
            row = neuron_index % n_neurons
            cluster_list = self.network.capture_objects[neuron_index]
            
            for member_index in cluster_list:
                axs[row, col].plot(np.arange(len(self.motif_center_list[member_index])), self.motif_center_list[member_index])
                axs[row, col].set_title("Cluster: " + str(neuron_index))
                if col == 0:
                    axs[row, col].set_ylabel("Seq. values")
                if row == 2:
                    axs[row, col].set_xlabel("Time")
        
        plt.tight_layout()
        plt.show()
    
    def recover_curves(self, original_data, unit):
        id_motifs = self.network.capture_objects  
        curves = []
        
        for i in id_motifs[unit]:
            motifs_list = self.mtfs[i]  
            
            for motif in motifs_list:
                curves.append(self.joined_data[motif:motif + size])
        
        return curves

    def recover_curves_plot(self, unit):
        id_motifs = self.network.capture_objects 
        
        for i in id_motifs[unit]: 
            motifs = self.mtfs[i] 
            
            for motif in motifs: 
                
                x_ticks = np.arange(0, size, step=100)
                x_non_ticks = np.arange(motif, motif+size, step=100)
                
                plt.plot(self.joined_data[motif:motif+size])
                plt.title(f"Curva {id_curva} do motif {i} do cluster {unit}")
                plt.xticks(x_ticks, x_non_ticks)
                
                plt.xlabel('X data points',fontsize=12)
                plt.ylabel('Force (nN)',fontsize=12)
                
                plt.show()
