In [1]:
import sys
sys.path.append("..")
from utils.dataset import FerDataset

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader


import matplotlib.pyplot as plt
%matplotlib inline

In [41]:
class GatingUnit(nn.Module):
    """Gating unit described in 'Convolutional Networks with Adaptive Inference Graphs'."""
    
    def __init__(self, in_channels, hidden_layer_dim):
        super(GatingUnit, self).__init__()
        
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.estimate_relevance = nn.Sequential(
            nn.Linear(in_channels, hidden_layer_dim),
            nn.ReLU(),
            nn.Linear(hidden_layer_dim, 2))
        
    def forward(self, x):
        x = self.global_avg_pool(x)
        x = x.view(x.size(0), -1) 
        x = self.estimate_relevance(x)
        x = F.gumbel_softmax(x, tau=1, hard=True)
        return x
    
    
class AdaptiveConv2d(nn.Module):
    """Adaptive Conv2d layer described in 'Convolutional Networks with Adaptive Inference Graphs'."""
    
    def __init__(self,in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        super(GatingUnit, self).__init__()
        
        self.gate = GatingUnit()
        self.weights = torch
        
    def forward(self, x):
        x = self.global_avg_pool(x)
        x = x.view(x.size(0), -1) 
        x = self.estimate_relevance(x)
        x = F.gumbel_softmax(x, tau=1, hard=True)
        return x

In [58]:
x = torch.rand(12, 3, 5, 5)
gate = GatingUnit(in_channels=3, hidden_layer_dim=16)
cond = gate(x)
batch_slice = torch.where(cond == )

In [59]:
batch_slice[]

tensor([[ 0.,  1.],
        [ 1.,  0.],
        [ 0.,  1.],
        [ 0.,  1.],
        [ 1.,  0.],
        [ 0.,  1.],
        [ 0.,  1.],
        [ 1.,  0.],
        [ 0.,  1.],
        [ 1.,  0.],
        [ 0.,  1.],
        [ 0.,  1.]])

In [29]:
x = torch.rand(2, 3, 5, 5)

In [32]:
x = torch.rand(2, 3, 5, 5)
[]
out = x[0, 1].pow(2).sum()
out.backward()
x.grad

tensor([[ 0., -2.],
        [ 0.,  0.]])

In [61]:
torch.where?

In [None]:
c = np.arange(0, 16).reshape(4, 4)
x = np.array([c, c, c])[np.newaxis, ...]
x = torch.from_numpy(x).float()
y = F.adaptive_avg_pool2d(x, 1)
print(y.shape)

In [39]:
batch_slice

array([0, 0, 0, 1, 1])