In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from operations import OPS

In [3]:
train = datasets.MNIST('./data', train=True, transform=transforms.ToTensor(), download=True)
valid = datasets.MNIST('./data', train=False)

train_loader = DataLoader(train, batch_size=10)

In [4]:
x, y = next(iter(train_loader))

In [5]:
x.shape

torch.Size([10, 1, 28, 28])

In [24]:
mlp = OPS['mlp'](1, 100, False)
conv = OPS['conv'](1, 32, True)

In [7]:
conv

ConvBlock(
  (op): Sequential(
    (0): Conv2d(1, 1, kernel_size=(3, 3), stride=(100, 100), padding=(1, 1), bias=False)
    (1): BatchNorm1d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

In [8]:
dataset = RandomData(100, 2, {'task': 2})

In [9]:
x.shape

torch.Size([10, 1, 28, 28])

In [10]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=10, shuffle=True)

In [11]:
x, y = next(iter(train_loader))

In [12]:
x.shape

torch.Size([10, 1, 28, 28])

In [57]:
class StemNet(nn.Module):
    """ Network stem

    This will always be the beginning of the network.
    DARTS will only recompose modules after the stem.
    For this reason, we define this separate from the
    other modules in the network.

    Args:
        input_dim: the input dimension for your data

        cell_dim: the intermediate dimension size for
                  the remaining modules of the network.
    """
    def __init__(self, in_channels: int=1, cell_dim: int=100, kernel_size=3):
        super(StemNet, self).__init__()
        self.stem = nn.Conv2d(in_channels, cell_dim, kernel_size)

    def forward(self, x):
        return self.stem(x)


In [58]:
class ConvBlock(nn.Module):
    """ ReLu -> Conv1d -> BatchNorm """

    def __init__(self, c_in, c_out, kernel_size, stride, affine=True):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(c_in, c_out, kernel_size=kernel_size, stride=stride)

    def forward(self, x):
        return self.conv(x)

In [100]:
class DilConv(nn.Module):
    """ ReLU Dilated Convolution """

    def __init__(self, c_in, c_out, kernel_size, 
                 stride, padding, dilation, affine=True):
        super(DilConv, self).__init__()

        self.op = nn.Sequential(
            nn.ReLU(inplace=False),

            nn.Conv2d(
                c_in,
                c_in,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                dilation=dilation,
                groups=c_in,
                bias=False
            ),

            nn.Conv2d(
                c_in,
                c_out,
                kernel_size=1,
                padding=0,
                bias=False
            ),

            nn.BatchNorm2d(c_out, affine=affine),
        )

    def forward(self, x):
        return self.op(x)

In [101]:
stem = StemNet()

In [102]:
OPS = {
    'dil_conv': lambda c, stride, affine: DilConv(c, c, 3, stride, 2, 2, affine=affine),
    'conv' : lambda c, stride, affine: ConvBlock(c, c, 3, stride, affine=affine),
}

In [103]:
conv = OPS['conv'](100, 1, True)

In [104]:
dill = OPS['dil_conv'](100, 1, True)

In [105]:
out = stem(x)

In [106]:
out.shape

torch.Size([10, 100, 26, 26])

In [107]:
conv(out).shape

torch.Size([10, 100, 24, 24])

In [108]:
dill(out).shape

torch.Size([10, 100, 26, 26])