In [14]:
import torch
from torch import nn
from torch import Tensor
from torch.nn.functional import pad
from math import log2, ceil


#torch.random.manual_seed(0)  # Set a known random seed for reproducibility


class NewSOTT(nn.Module):
    '''
    
    '''

    def __init__(self, kernel_size: int, leaf_num: int, 
                 niter: int, stride: int = 2, padding: int = 0, 
                 alpha: float = None, sigma: float = None, device = torch.device("cpu")
                 ):
        super(NewSOTT, self).__init__()
        self.kernel_size = kernel_size
        self.leaf_num = leaf_num
        self.stride = stride
        self.padding = padding
        self.niter = niter
        #self.locations = self._neuron_locations()
        self.leaf_num = 2**ceil(log2(leaf_num))
        self.nodes = torch.nn.Parameter(data=Tensor((leafNum*2)-2, kernel_size**2), requires_grad=False)
        self.nodes.data.uniform_(0, 1)
        if alpha is None:
            self.alpha = 0.3
        else:
            self.alpha = float(alpha)
        if sigma is None:
            self.sigma = leaf_num / 2.0
        else:
            self.sigma = float(sigma)
        #self.w = torch.randint(low=0, high=256, size = (grid_size[0] * grid_size[1], kernel_size * kernel_size), 
        #                       dtype=torch.uint8, device=device, requires_grad=False)
        self.it = 1
        self.device = device

    def _2diImg2col(self, x):
        padded = pad(x, [self.padding] * 4, "constant", 0)
        p = padded.unfold(0, self.kernel_size, self.stride).unfold(1, self.kernel_size, self.stride)
        out = p.reshape(p.shape[0] * p.shape[1], p.shape[2] * p.shape[3])
        output_dim_x, output_dim_y = p.shape[0], p.shape[1]
        return out, output_dim_x, output_dim_y

    def _neuron_locations(self):
        a, b = torch.meshgrid(torch.arange(self.grid_size[0]), torch.arange(self.grid_size[1]))
        return torch.transpose(torch.LongTensor(torch.stack([a.flatten(), b.flatten()])), 0, 1)

    def _pnorm(self, x1, x2, p=2):
        return torch.pow(torch.pow(x1 - x2.unsqueeze(dim=1), p).sum(2), p)

    def get_bmu_indices(self, x):
        dist = self._pnorm(self.w, x)
        _, bmu_index = torch.min(dist, 1)
        return bmu_index

    def adjust_synapses(self, x):
        x, output_dim_x, output_dim_y = self._2diImg2col(x)
        dist = self._pnorm(self.w, x)
        _, bmu_index = torch.min(dist, 1)
        bmu_loc = self.locations[bmu_index, :]

        learning_rate_op = 1.0 - self.it / self.niter
        alpha_op = self.alpha * learning_rate_op
        sigma_op = self.sigma * learning_rate_op

        bmu_distance_squares = torch.sum(torch.pow(self.locations.expand(bmu_loc.shape[0], 
                                                                         self.locations.shape[0],
                                                                         self.locations.shape[1]) - bmu_loc.unsqueeze(dim=1), 2), 2)
        neighbourhood_func = torch.exp(torch.neg(torch.div(bmu_distance_squares, sigma_op ** 2)))
        learning_rate = alpha_op * neighbourhood_func
        self.w.data += (learning_rate.unsqueeze(2) * (x.unsqueeze(dim=1) - self.w.unsqueeze(dim=0))).mean(dim=0)
        self.it += 1
        return bmu_loc.reshape(2, output_dim_x, output_dim_y)

    def set_mode(self, m):
        self.mode = m

    def forward(self, x) -> Tensor:
        x.to(self.device)
        output = self.adjust_synapses(x)
        return output


In [15]:
device = torch.device("cpu")
epochs = 10
NewSOTT(kernel_size = 3, 
       leaf_num= 100, 
       niter = epochs,
       stride = 1,
       device = device
                 )

NameError: name 'grid_size' is not defined