In [5]:
import torch
import torch.nn as nn
import torchvision.models as models
from lib.utils import to_grayscale

%run NGLCM.ipynb

In [6]:
class HEX(nn.Module):
    def __init__(self, dim, num_classes=1000, colors=16, alex_pretrained=True):
        super(HEX, self).__init__()
        
        self.nglcm = nn.Sequential(
            NGLCM(dim, colors),
            nn.Linear(in_features=colors, out_features=dim, bias=True),
            nn.ReLU(inplace=True)
        )
        
        self.cnn = models.alexnet(pretrained=alex_pretrained)
        self.cnn.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(in_features=9216, out_features=4096, bias=True),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(in_features=4096, out_features=4096, bias=True),
            nn.ReLU(inplace=True)
        )
        
        self.classifier = nn.Linear(in_features=4096 + dim * colors, out_features=num_classes, bias=True)
        
    def forward(self, x):
        with torch.no_grad():
            x_gray = to_grayscale(x)
            
        nglcm = self.nglcm.forward(x_gray)
        nglcm = nglcm.view(x.shape[0], -1)
        cnn = self.cnn.forward(x)
        
        nglcm /= nglcm.max(1, keepdim=True)[0] + 10e-8
        cnn /= nglcm.max(1, keepdim=True)[0] + 10e-8

        FA = self.classifier(torch.cat((cnn, nglcm), 1))
        FG = self.classifier(torch.cat((torch.zeros(*cnn.shape), nglcm), 1))
        FGT = FG.t()
        
        return (torch.eye(x.shape[0]) - FG.mm(FGT.mm(FG).inverse()).mm(FGT)).mm(FA)

In [7]:
def test():
    Hex = HEX(224, num_classes=8)
    from lib.datasets import datasets

    data = datasets()
    data.create_dataset('pacs', img_size=224, data_aug=True)
    batch_loader = data.batch_loader(64)

    img = None
    label = None

    for a, b in batch_loader['train']:
        # a - batch_size x channels x img_size x img_size - data
        # b - batch_size - labels
        img = a
        label = b
        break
    
    print(img.shape)
    print(Hex(img).shape)