In [12]:
from torch import nn, Tensor, tensor, arange, zeros, cat, gather, matmul, device
from torch.nn.functional import pad
from torch_geometric.data import Data
from math import log2, ceil


class SOTLayer(nn.Module):
    
    def __init__(self, number_of_leaves: int, kernel_size: int, stride: int = 2, padding: int = 0, lr: float = 0.3, device = device("cpu")):
        super(SOTLayer, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.leaf_num = 2**ceil(log2(number_of_leaves))
        self.depth = log2(self.leaf_num)
        self.learning_rates = self.learning_rates_per_branch(lr)
        self.nodes = nn.Parameter(data= Tensor((self.leaf_num*2)-1, self.kernel_size**2), requires_grad=False).to(device)
        self.nodes.data.uniform_(0, 1)
        self.node_indices = arange(self.nodes.shape[0])
        self.device = device

        #self.tree_graph = Data(x= self.nodes, edge_index = self.get_tree_edges())
    
    def get_tree_edges(self):
        start, num = 0, 2
        layers, prev_layer = [], [0]
        for n in range(int(log2(self.leaf_num)+1)):
            nodes_per_layer = num ** n
            layer = list(range(start, + start + nodes_per_layer))
            dupl_l = [val for val in prev_layer for _ in (0, 1)]
            layers += (list(zip(dupl_l, layer)))
            start += nodes_per_layer 
            prev_layer = layer
        return tensor(layers[1:]).T
    
    def img2patches(self, x):
        padded = pad(x, [self.padding] * 4, "constant", 0)
        p = padded.unfold(2, self.kernel_size, self.stride).unfold(3, self.kernel_size, self.stride)
        out = p.reshape(p.shape[0],p.shape[1], p.shape[2] * p.shape[3], p.shape[4] * p.shape[5])
        output_dim_x, output_dim_y = p.shape[2], p.shape[3]
        return out, output_dim_x, output_dim_y
    
    def pnorm(self, x1, x2, p=2):
        return torch.pow(torch.pow(x1 - x2.unsqueeze(dim=3), p).sum(dim=4), p)
    
    def learning_rates_per_branch(self, lr: float):
        return (lr * 2 ** arange(1,self.depth+2, dtype=torch.float)) / (2**(self.depth+1))
    
    def _propagate_through_tree(self, X):
        patch_num = X.shape[2]
        start, num = 1, 2
        layers = []
        layer_state = torch.zeros(patch_num, 1, dtype=int)
        update_amount_indices = torch.zeros(2, dtype=int)
        for n in range(1, int(log2(self.leaf_num)+1)):
            nodes_per_layer = num ** n
            layer = arange(start, + start + nodes_per_layer)
            layer_state = layer_state.repeat_interleave(2).reshape(patch_num,nodes_per_layer)
            max_val = layer_state.max(dim=1).values.unsqueeze(dim=1)
            competing_indices = layer.repeat(patch_num,1)[layer_state == max_val].reshape(patch_num, num)
            competing_nodes = self.nodes[competing_indices].clone().to('cpu')
            dist = self.pnorm(competing_nodes, X)
            bmu_dists, bmu = torch.min(dist, 3)
            bmu_index = gather(input = competing_indices, dim = 1, index = bmu.squeeze().unsqueeze(dim=1))
            layer_state = layer_state.add((layer == bmu_index).to(torch.int64))
            layers.append(layer_state)
            start += nodes_per_layer 
        return cat(layers, dim=1), bmu_index, bmu_dists
    
    def forward(self, X):
        X, output_dim_x, output_dim_y = self.img2patches(X)
        indices, bmu_indices, bmu_dists = self._propagate_through_tree(X)
        neighborhood_lrs = gather(input=self.learning_rates, index= indices.flatten(), dim=0).reshape(indices.shape)
        neighborhood_updates = (neighborhood_lrs.unsqueeze(2).to(self.device) * (X.unsqueeze(3).squeeze(0).squeeze(0).to(self.device) - self.nodes[1:,:].unsqueeze(0).to(self.device))).mean(0)
        self.nodes[1:, :] += neighborhood_updates
        return
    
    

from torch import rand
import torch
img = rand(1,1,6,6)    
print(img)
cuda = device('cuda') 
tree = SOTLayer(number_of_leaves = 7, 
                kernel_size = 3, 
                stride = 1, 
                padding = 0,
                lr = 0.5,
                device = cuda
               )

tree.forward(img)
#a,_,_ = tree.img2patches(img)
#s, bmu_indices, bmu_dists = tree.propagate_through_tree(img)
#l = tree.learning_rates_per_branch(0.2)

#torch.gather(input=l, index= s.flatten(), dim=0).reshape(s.shape)
#tree.forward(img)
print(tree.nodes.shape)


tensor([[[[0.0617, 0.6301, 0.1402, 0.2946, 0.2888, 0.1461],
          [0.0952, 0.3256, 0.5204, 0.4625, 0.8959, 0.8129],
          [0.9650, 0.2442, 0.9210, 0.2398, 0.9317, 0.0939],
          [0.6034, 0.9018, 0.0202, 0.9215, 0.3780, 0.5504],
          [0.1710, 0.4506, 0.9833, 0.1394, 0.0281, 0.3948],
          [0.7193, 0.6816, 0.3467, 0.9278, 0.3472, 0.8764]]]])
-1 -1
-1 -1
-1 -1
torch.Size([15, 9])


In [42]:
lr, depth = 0.2, 3
#lr * 
torch.tensor(2).pow(arange(1,depth+2, dtype=torch.float)).flip(0) / (2**(depth+1))

tensor([1.0000, 0.5000, 0.2500, 0.1250])

16

In [21]:
p = torch.zeros(1).repeat(16,1)
l = torch.arange(16).unsqueeze(dim=1)
a = cat((p,l),1)
print(a)
print(a.repeat_interleave(2))#.reshape(16,4)#.repeat(1,2)#.T#.flatten()

tensor([[ 0.,  0.],
        [ 0.,  1.],
        [ 0.,  2.],
        [ 0.,  3.],
        [ 0.,  4.],
        [ 0.,  5.],
        [ 0.,  6.],
        [ 0.,  7.],
        [ 0.,  8.],
        [ 0.,  9.],
        [ 0., 10.],
        [ 0., 11.],
        [ 0., 12.],
        [ 0., 13.],
        [ 0., 14.],
        [ 0., 15.]])
tensor([ 0.,  0.,  0.,  0.,  0.,  0.,  1.,  1.,  0.,  0.,  2.,  2.,  0.,  0.,
         3.,  3.,  0.,  0.,  4.,  4.,  0.,  0.,  5.,  5.,  0.,  0.,  6.,  6.,
         0.,  0.,  7.,  7.,  0.,  0.,  8.,  8.,  0.,  0.,  9.,  9.,  0.,  0.,
        10., 10.,  0.,  0., 11., 11.,  0.,  0., 12., 12.,  0.,  0., 13., 13.,
         0.,  0., 14., 14.,  0.,  0., 15., 15.])


In [2]:
from torch import rand
import torch
a = rand(1,1,10)    
a.squeeze()

tensor([0.4812, 0.3411, 0.0650, 0.4459, 0.9007, 0.0556, 0.4408, 0.5423, 0.8323,
        0.2666])

In [25]:
b = tensor([1,2,3,4])
b.repeat(2,1).T.flatten()

tensor([1, 1, 2, 2, 3, 3, 4, 4])

In [28]:
tensor([[1,2,3], [4,5,6], [7,8,9]])[[0,2],]

tensor([[1, 2, 3],
        [7, 8, 9]])

In [15]:
from torch_geometric.utils import k_hop_subgraph

k_hop_subgraph(tree.tree_graph,2,3)

AttributeError: 'Data' object has no attribute 'k_hop_subgraph'

In [14]:
import torch
from torch import nn
from torch import Tensor
from torch.nn.functional import pad
from math import log2, ceil


#torch.random.manual_seed(0)  # Set a known random seed for reproducibility


class NewSOTT(nn.Module):
    '''
    
    '''

    def __init__(self, kernel_size: int, leaf_num: int, 
                 niter: int, stride: int = 2, padding: int = 0, 
                 alpha: float = None, sigma: float = None, device = torch.device("cpu")
                 ):
        super(NewSOTT, self).__init__()
        self.kernel_size = kernel_size
        self.leaf_num = leaf_num
        self.stride = stride
        self.padding = padding
        self.niter = niter
        #self.locations = self._neuron_locations()
        self.leaf_num = 2**ceil(log2(leaf_num))
        self.nodes = torch.nn.Parameter(data=Tensor((leafNum*2)-2, kernel_size**2), requires_grad=False)
        self.nodes.data.uniform_(0, 1)
        if alpha is None:
            self.alpha = 0.3
        else:
            self.alpha = float(alpha)
        if sigma is None:
            self.sigma = leaf_num / 2.0
        else:
            self.sigma = float(sigma)
        #self.w = torch.randint(low=0, high=256, size = (grid_size[0] * grid_size[1], kernel_size * kernel_size), 
        #                       dtype=torch.uint8, device=device, requires_grad=False)
        self.it = 1
        self.device = device

    def _2diImg2col(self, x):
        padded = pad(x, [self.padding] * 4, "constant", 0)
        p = padded.unfold(0, self.kernel_size, self.stride).unfold(1, self.kernel_size, self.stride)
        out = p.reshape(p.shape[0] * p.shape[1], p.shape[2] * p.shape[3])
        output_dim_x, output_dim_y = p.shape[0], p.shape[1]
        return out, output_dim_x, output_dim_y

    def _neuron_locations(self):
        a, b = torch.meshgrid(torch.arange(self.grid_size[0]), torch.arange(self.grid_size[1]))
        return torch.transpose(torch.LongTensor(torch.stack([a.flatten(), b.flatten()])), 0, 1)

    def _pnorm(self, x1, x2, p=2):
        return torch.pow(torch.pow(x1 - x2.unsqueeze(dim=1), p).sum(2), p)

    def get_bmu_indices(self, x):
        dist = self._pnorm(self.w, x)
        _, bmu_index = torch.min(dist, 1)
        return bmu_index

    def adjust_synapses(self, x):
        x, output_dim_x, output_dim_y = self._2diImg2col(x)
        dist = self._pnorm(self.w, x)
        _, bmu_index = torch.min(dist, 1)
        bmu_loc = self.locations[bmu_index, :]

        learning_rate_op = 1.0 - self.it / self.niter
        alpha_op = self.alpha * learning_rate_op
        sigma_op = self.sigma * learning_rate_op

        bmu_distance_squares = torch.sum(torch.pow(self.locations.expand(bmu_loc.shape[0], 
                                                                         self.locations.shape[0],
                                                                         self.locations.shape[1]) - bmu_loc.unsqueeze(dim=1), 2), 2)
        neighbourhood_func = torch.exp(torch.neg(torch.div(bmu_distance_squares, sigma_op ** 2)))
        learning_rate = alpha_op * neighbourhood_func
        self.w.data += (learning_rate.unsqueeze(2) * (x.unsqueeze(dim=1) - self.w.unsqueeze(dim=0))).mean(dim=0)
        self.it += 1
        return bmu_loc.reshape(2, output_dim_x, output_dim_y)

    def set_mode(self, m):
        self.mode = m

    def forward(self, x) -> Tensor:
        x.to(self.device)
        output = self.adjust_synapses(x)
        return output


In [15]:
device = torch.device("cpu")
epochs = 10
NewSOTT(kernel_size = 3, 
       leaf_num= 100, 
       niter = epochs,
       stride = 1,
       device = device
                 )

NameError: name 'grid_size' is not defined