In [1]:
import logging

import torch
 
from base_models.aggwr import AGGWR
logger = logging.getLogger('EGGWR-Log')

ModuleNotFoundError: No module named 'base_models'

In [None]:
def add_sparse(matrix, indices, value):
    a,b,c = indices
    return matrix.add(torch.sparse_coo_tensor([[a],[b],[c]],[value],matrix.size()))

def mul_sparse(matrix, indices, value):
    a,b,c = indices
    return matrix.mul(torch.sparse_coo_tensor([[a],[b],[c]],[value],matrix.size()))

In [1]:
class EGGWR_Plus(AGGWR):
    def __init__(self, semantic=False, **kwargs):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        super().__init__(**kwargs)
        self.name = 'EGGWR'
        #adding third dimension to store activation count, average reward and average action of edge
        #self.temporal = torch.zeros((self.size, self.size,8)).to(self.device)
        self.temporal = torch.sparse_coo_tensor((self.size,self.size,8)).coalesce()
        self.num_labels = kwargs.get('num_labels')
        self.H = torch.zeros((self.size, self.num_labels, 1))
        self.num_neighbors = 0
        self.semantic = semantic

    def insert_node(self, sample, b):
        """
        Insert new node into network at highest index in node list V, edge list E, context C, temporal connection list
        and, habituation list h.
        :param sample: Feature values of current observation
        :param b: Index of BMU in 2D tensor V of network nodes
        """
        self.h = torch.cat((self.h, torch.ones(1)))
        self.V = torch.cat((self.V, torch.unsqueeze((sample + self.V[b]) / 2, 0)), dim=0)
        self.E = torch.cat((self.E, torch.full((1, self.size), -1)), dim=0)
        self.E = torch.cat((self.E, torch.full((self.size + 1, 1), -1)), dim=1)
        context = (0.5 * (self.global_C + self.C[b])).view(1, self.n_context, -1)
        self.C = torch.cat((self.C, context), dim=0)
        self.H = torch.cat((self.H, torch.zeros((1, self.H.shape[1], self.H.shape[2]))), dim=0)
        temporal_size = torch.tensor(self.temporal.size())
        temporal_size[0:2] += 1
        self.temporal = torch.sparse_coo_tensor(self.temporal.indices(),self.temporal.values(),torch.Size(temporal_size))
        self.temporal = self.temporal.coalesce()
	    #self.temporal = torch.cat((self.temporal, torch.zeros((self.size + 1, 1,8))), dim=1)
        self.size += 1

    def delete_node(self, idx):
        """
        Delete node without edge connections in node list V, edge list E, context C , temporal connection list,
        and habituation list h.
        :param idx: Index of node to be deleted
        """
        self.h = torch.cat((self.h[:idx], self.h[idx + 1:]), dim=0)
        self.V = torch.cat((self.V[:idx], self.V[idx + 1:]), dim=0)
        self.E = torch.cat((self.E[:idx], self.E[idx + 1:]), dim=0)
        self.E = torch.cat((self.E[:, :idx], self.E[:, idx + 1:]), dim=1)
        self.C = torch.cat((self.C[:idx], self.C[idx + 1:]), dim=0)
        self.H = torch.cat((self.H[:idx], self.H[idx + 1:]), dim=0)
        self.temporal = self.temporal.to_dense()
        self.temporal = torch.cat((self.temporal[:idx], self.temporal[idx + 1:]), dim=0)
        self.temporal = torch.cat((self.temporal[:, :idx], self.temporal[:, idx + 1:]), dim=1)
        self.temporal = self.temporal.to_sparse().coalesce()
        self.size -= 1

    def update_bmu(self, b, sample, correct_pred):
        """
        Adapt position of BMU (in case of no new insertion)
        :param b: Index of BMU in 2D tensor V of network nodes
        :param sample: Feature values of current observation
        """
        if correct_pred:
            self.V[b] += self.eps_b * self.h[b] * (sample - self.V[b])
            self.C[b] += self.eps_b * self.h[b] * (self.global_C - self.C[b])
        self.h[b] += self.tau_b * self.kappa * (1 - self.h[b]) - self.tau_b

    def update_neighbor(self, n, b, sample, correct_pred):
        """
        Adapt position of BMU neighbor, update its edge connection to the BMU, and check if it has to be deleted
        :param n: Index of BMU neighbor in 2D tensor V of network nodes
        :param b: Index of BMU in 2D tensor V of network nodes
        :param sample: Feature values of current observation
        :return delete: True if current node has to be deleted (no outgoing edges), False otherwise
        """
        if correct_pred:
            self.V[n] += self.eps_n * self.h[n] * (sample - self.V[n])
            self.C[n] += self.eps_b * self.h[n] * (self.global_C - self.C[n])
        self.h[n] += self.tau_n * self.kappa * (1 - self.h[n]) - self.tau_n

        delete = False
        if self.E[b, n] < self.max_age:
            self.E[b, n] += 1
            self.E[n, b] += 1
        else:
            self.E[b, n], self.E[n, b] = -1, -1
            if all(torch.eq(self.E[n], -1)):
                delete = True

        return delete

    def update_temporal(self, current_idx, prev_idx, action, reward):
        """
        Update the temporal connection list.
        :param current_idx: Index of current BMU
        :param prev_idx: Index of the previous BMU
        """
        if prev_idx != -1 and prev_idx != current_idx:
            #self.temporal[prev_idx][current_idx][0] += 1
            sparse_z_axis = self.temporal[prev_idx, current_idx].to_dense()
            sparse_z_axis[0] += 1
            #self.temporal = add_sparse(self.temporal, [prev_idx, current_idx, 0],1)

            num_activ = sparse_z_axis[0]
            #store the average action associated to temporal edge, weighted with number of activations
            #self.temporal[prev_idx][current_idx][1:7] *= (num_activ-1/num_activ)
            #self.temporal[prev_idx][current_idx][1:7] += (1/num_activ) * action
            sparse_z_axis[1:7] *= (num_activ-1/num_activ)
            sparse_z_axis[1:7] += (1/num_activ) * action
            
            #store the average reward
            #self.temporal[prev_idx][current_idx][7] *= (num_activ-1/num_activ)
            #self.temporal[prev_idx][current_idx][7] += (1/num_activ) * reward
            sparse_z_axis[7] *= (num_activ-1/num_activ)
            sparse_z_axis[7] += (1/num_activ) * reward

            #self.temporal = mul_sparse(self.temporal, [prev_idx, current_idx, 7],(num_activ-1/num_activ))
            #self.temporal = add_sparse(self.temporal, [prev_idx, current_idx, 7],((1/num_activ)*reward))
            
            a = prev_idx
            b = current_idx
            make_zero = self.temporal[prev_idx, current_idx].to_dense()
            self.temporal = self.temporal.add(torch.sparse_coo_tensor([[a,a,a,a,a,a,a,a],[b,b,b,b,b,b,b,b],[0,1,2,3,4,5,6,7]],-make_zero,self.temporal.size()))
            self.temporal = self.temporal.add(torch.sparse_coo_tensor([[a,a,a,a,a,a,a,a],[b,b,b,b,b,b,b,b],[0,1,2,3,4,5,6,7]],sparse_z_axis,self.temporal.size()))

            self.temporal = self.temporal.coalesce()



    def update_label(self, label, b, r=None):
        """
        Update label matrix H according to currently seen sample and corresponding labels
        :param label: List of class labels of input observations
        :param b: Index of BMU in 2D tensor V of network nodes
        :param r: Index of new node in 2D tensor V of network nodes
        """
        for i in range(self.num_labels):
            if self.num_labels == 1:
                lab = int(label)
            else:
                lab = int(label[i])
            if lab == -1:
                continue
            if self.H.shape[2] <= lab:
                self.H = torch.cat((self.H, torch.zeros((self.size, self.num_labels, lab - self.H.shape[2] + 1))), dim=2)
            if r is None:
                self.H[b][i] -= self.delta_minus
                self.H[b][i][lab] += self.delta_plus + self.delta_minus
            else:
                self.H[r][i][lab] = 1

    @torch.no_grad()
    def forward(self, it, data, action, reward):
        """
        Original Episodic Gamma-GWR algorithm as in Parisi et al. (2018)
        b: Index of BMU | s: Index of second BMU | a: BMU activity
        :param it: Number of batch/iteration
        :param data: List of mini-batch samples (contains just a single sample for continuous data stream)
        """
        y_pred = [[] for _ in range(self.num_labels)]
        bmus = torch.zeros((1, self.V[0].shape[0]))
        bmu_indices = []
        prev_bmu = -1
        iter = 0
        correct_pred = True
        false_pred = True
        test_context = torch.zeros(self.global_C.shape)
        if not self.training:
            old_context = self.global_C
            self.global_C = test_context
        prev_sample = torch.zeros(self.V[0].shape)
        for sample, label in data:
            b, s, a = self.activate_bmu(sample)
            for l in range(self.num_labels):
                y_pred[l].append(torch.argmax(self.H[b][l]).item())
            bmus = torch.cat((bmus, self.V[b].view(1, -1)), dim=0)
            #new line
            bmu_indices += [b] 
            if self.training:
                if self.semantic:
                    correct_pred = y_pred[0][iter] == label.item()
                    false_pred = not correct_pred
                self.update_temporal(b, prev_bmu, action, reward)
                prev_bmu = b
                if a < self.a_t and self.h[b] < self.h_t and false_pred:
                    self.insert_node(sample, b)
                    self.update_edges(b, s, self.size - 1)
                    self.update_label(label, b, self.size - 1)
                    logger.info('Iteration {}. Inserted new node at position (first dimensions): {}. Label: {}. '
                                'BMU index: {}. Updated {} size: {}.'.format(it, 4,
                                                                             label, b, self.name,  self.size))
                #sample.numpy()[:4].round(3)
                else:
                    self.update_edges(b, s)
                    self.update_label(label, b)
                    self.update_bmu(b, sample, correct_pred)

                    neighbors = torch.nonzero(self.E[b] > -1)
                    if len(neighbors) > self.num_neighbors:
                        self.num_neighbors = len(neighbors)
                    n_deleted = 0
                    for n in neighbors:
                        n = int(n)
                        if self.update_neighbor(n - n_deleted, b, sample, correct_pred):
                            logger.info('Iteration {}. Deleted node at index: {}. Label: {}. Updated {} size: {}.'.
                                        format(it, n - n_deleted, torch.argmax(self.H[n - n_deleted]).item(), self.name,
                                               self.size))
                            self.delete_node(n - n_deleted)
                            b -= 1 if n < b else b
                            n_deleted += 1
                # update Context for the next step
                self.update_global_context(b)
            else:
                #logger.info('BMU: {} BMU prediction: {} real label {}'.format(b, torch.argmax(self.H[b], dim=1), label))
                test_context[0] = sample
                for i in range(1, self.n_context):
                    test_context[i] = test_context[i-1]
                self.global_C = test_context
            iter += 1
        if not self.training:
            self.global_C = old_context
        #return bmus[1,:], y_pred
        return bmu_indices, y_pred