In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
from lib.utils import to_grayscale, load_model

%run NGLCM.ipynb

In [4]:
class HEX(nn.Module):
    def __init__(self, dim, num_classes=1000, colors=16, alex_pretrained=True, alex_params=None):
        super(HEX, self).__init__()
        
        self.nglcm = nn.Sequential(
            NGLCM(dim, colors),
            nn.Dropout(p=0.5),
            nn.Linear(in_features=colors, out_features=dim, bias=True),
            nn.ReLU(inplace=True),
        )
        
        self.cnn = models.alexnet(pretrained=alex_pretrained)
        self.cnn.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(in_features=9216, out_features=4096, bias=True),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(in_features=4096, out_features=4096, bias=True),
            nn.ReLU(inplace=True)
        )
        
        if alex_params is not None:
            self.load_my_state_dict(alex_params)
            
        self.classifier = nn.Linear(in_features=4096 + colors * dim, out_features=num_classes, bias=True)

    def load_my_state_dict(self, state_dict):
        own_state = self.cnn.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                 continue
            print('loading param: ', name)
            own_state[name].copy_(param)
        
    def forward(self, x):
        with torch.no_grad():
            x_gray = to_grayscale(x)
            
        nglcm = self.nglcm.forward(x_gray)
        nglcm = nglcm.view(x.shape[0], -1)
        cnn = self.cnn.forward(x)
        
        nglcm = nglcm / (nglcm.max(1, keepdim=True)[0] + 10e-8)
        cnn = cnn / (cnn.max(1, keepdim=True)[0] + 10e-8)
            
        FA = self.classifier(torch.cat((cnn, nglcm), 1))
        FG = self.classifier(torch.cat((torch.zeros(*cnn.shape, device=x.device), nglcm), 1))
        FGT = FG.t()
        
        if self.training is False:
            FP = self.classifier(torch.cat((cnn, torch.zeros(*nglcm.shape, device=x.device)), 1))
            return FP     
        
        return (torch.eye(x.shape[0], device=x.device) - FG.mm(FGT.mm(FG).inverse()).mm(FGT)).mm(FA)

In [5]:
def test():
    Hex = HEX(224, num_classes=8)
    from lib.datasets import datasets

    data = datasets()
    data.create_dataset('pacs', img_size=224, data_aug=True)
    batch_loader = data.batch_loader(64)

    img = None
    label = None

    for a, b in batch_loader['train']:
        # a - batch_size x channels x img_size x img_size - data
        # b - batch_size - labels
        img = a
        label = b
        break
    
    print(img.shape)
    print(Hex(img).shape)
    Hex.eval()
    print(Hex(img).shape)
    Hex.train()