In [3]:
import torch
import torch.nn as nn
import torchvision.models as models
import numpy as np

In [4]:
class NGLCM(nn.Module):
    def __init__(self, dim, colors=16):
        """NGLCM layer implementation
        
        Parameters:
            dim : int
                Dimensions of image (dim * dim)
            colors : int
                Size of colors space
        """
        
        super(NGLCM, self).__init__()
    
        self.colors = colors - 1
        
        self.a = nn.Parameter(torch.zeros([colors, 1]))
        nn.init.xavier_uniform_(self.a)
        
        self.b = nn.Parameter(torch.zeros([colors, 1]))
        nn.init.xavier_uniform_(self.b)
        
    def forward(self, x):
        a = x.view(x.shape[0], 1, -1) * self.colors
        b = a - torch.cat((torch.zeros(x.shape[0], 1, 1, device=x.device), a[:, :, :-1]), 2)
        return (torch.clamp(a - self.a, 0, 1)
                    .matmul(torch.clamp(b - self.b, 0, 1).transpose(1, 2)))

In [2]:
def test():

    size = 224 # default for alexnet
    from lib import datasets

    nglcm = NGLCM(dim=224)

    data = datasets.datasets()


    data.create_dataset('mnist', img_size=size, data_aug=True)

    batch_size = 10

    batch_loader = data.batch_loader(batch_size)

    img = None
    label = None

    for a, b in batch_loader['train']:
        # a - batch_size x channels x img_size x img_size - data
        # b - batch_size - labels
        img = a
        label = b
        break
    print(img.shape)
    print(nglcm(img).shape)