In [38]:
import math
import sys
sys.path.append("..")
from utils.dataset import FerDataset

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torch.nn.parameter import Parameter


import matplotlib.pyplot as plt
%matplotlib inline

In [12]:
class GatingUnit(nn.Module):
    """Gating unit described in 'Convolutional Networks with Adaptive Inference Graphs'."""
    
    def __init__(self, in_channels, gate_dim):
        super(GatingUnit, self).__init__()
        
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.estimate_relevance = nn.Sequential(
            nn.Linear(in_channels, gate_dim),
            nn.ReLU(),
            nn.Linear(gate_dim, 2))
        
    def forward(self, x):
        x = self.global_avg_pool(x)
        x = x.view(x.size(0), -1) 
        x = self.estimate_relevance(x)
        print(x)
        x = F.gumbel_softmax(x, tau=1, hard=True)
        return x
    
    
class AdaptiveConv2d(nn.Module):
    """Adaptive Conv2d layer described in 'Convolutional Networks with Adaptive Inference Graphs'."""
    
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True, gate_dim=16):
        super(AdaptiveConv2d, self).__init__()

        self.gate = GatingUnit(in_channels, gate_dim)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride,
                              padding, dilation, groups, bias)
        
    def forward(self, x):
        decision = self.gate(x)
        conv_index = decision[:, 1].nonzero().view(-1)
        conv_in = x.index_select(0, conv_index)
        conv_out = self.conv(x)
        
        if conv_in.size() != conv_out.size():
            # Do Math
        else:
            # Also do math
        return x

In [35]:
x = torch.rand(12, 3, 5, 5)
gate = GatingUnit(in_channels=3, hidden_layer_dim=16)
cond = gate(x)
conv_index = cond[:, 1].nonzero().view(-1)
conv_input = x.index_select(0, conv_index)
conv_weights = weights.

tensor([[ 0.0290, -0.3717],
        [ 0.0129, -0.3197],
        [ 0.0089, -0.3038],
        [ 0.0343, -0.3490],
        [ 0.0230, -0.3272],
        [ 0.0037, -0.3145],
        [ 0.0210, -0.2988],
        [ 0.0235, -0.3167],
        [ 0.0168, -0.3138],
        [ 0.0071, -0.3151],
        [-0.0012, -0.3231],
        [ 0.0021, -0.3088]], grad_fn=<ThAddmmBackward>)
tensor([[0., 1.],
        [1., 0.],
        [1., 0.],
        [0., 1.],
        [1., 0.],
        [0., 1.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [0., 1.],
        [1., 0.],
        [0., 1.]], grad_fn=<ThAddBackward>)


tensor([[[[0.5495, 0.2985, 0.1253, 0.8119, 0.4992],
          [0.9188, 0.4821, 0.6444, 0.8395, 0.3465],
          [0.8262, 0.9112, 0.6921, 0.8597, 0.3539],
          [0.7665, 0.6900, 0.5419, 0.8898, 0.8607],
          [0.9310, 0.6509, 0.8718, 0.6782, 0.8211]],

         [[0.8046, 0.9768, 0.0070, 0.3013, 0.0972],
          [0.9024, 0.3047, 0.8814, 0.1755, 0.0179],
          [0.7028, 0.7762, 0.0333, 0.6288, 0.3000],
          [0.4325, 0.9721, 0.8618, 0.5498, 0.3575],
          [0.0274, 0.1275, 0.6650, 0.7355, 0.1149]],

         [[0.7766, 0.6726, 0.3634, 0.6606, 0.2743],
          [0.8063, 0.4997, 0.9470, 0.7933, 0.0951],
          [0.7158, 0.8008, 0.3565, 0.0465, 0.0358],
          [0.9518, 0.8723, 0.0257, 0.7804, 0.4957],
          [0.6515, 0.2795, 0.9468, 0.9447, 0.2961]]],


        [[[0.4921, 0.6068, 0.5353, 0.9972, 0.2090],
          [0.2539, 0.7632, 0.5563, 0.7647, 0.6703],
          [0.7683, 0.8440, 0.4695, 0.5384, 0.6641],
          [0.4563, 0.9586, 0.3373, 0.8670, 0.9790],
    

In [46]:
x = torch.rand(5)
print(x)
y = x.index_select(0, torch.LongTensor([0, 1]))
print(y)
y = y**2


tensor([0.0100, 0.7428, 0.1073, 0.0196, 0.1097])
tensor([0.0100, 0.7428])
tensor([0.0100, 0.7428, 0.1073, 0.0196, 0.1097])
tensor([0.0001, 0.5517])


In [25]:
cond.nonzero?

In [32]:
x = torch.rand(2, 3, 5, 5)
[]
out = x[0, 1].pow(2).sum()
out.backward()
x.grad

tensor([[ 0., -2.],
        [ 0.,  0.]])

In [61]:
torch.where?

In [None]:
c = np.arange(0, 16).reshape(4, 4)
x = np.array([c, c, c])[np.newaxis, ...]
x = torch.from_numpy(x).float()
y = F.adaptive_avg_pool2d(x, 1)
print(y.shape)

In [39]:
batch_slice

array([0, 0, 0, 1, 1])

In [48]:
torch.index_put_?