#utiliez

In [None]:
import os
from enum import Enum

import torch
import torchvision


def mk_dir(export_dir):
    if not os.path.exists(export_dir):
            try:
                os.makedirs(export_dir)
                print('created dir: ', export_dir)
            except OSError as exc: # Guard against race condition
                 if exc.errno != exc.errno.EEXIST:
                    raise
            except Exception:
                pass
    else:
        print('dir already exists: ', export_dir)

In [None]:
class Dataset(Enum):
    MNIST = 0
    KMNIST = 1
    FASHION_MNIST = 2
    EMNIST = 3 # <-- Default URL for this dataset is currently offline (https://github.com/pytorch/vision/issues/1296).

In [None]:
def get_dataset_loaders(dataset=Dataset.MNIST,
                        train_batch=64,
                        test_batch=1000,
                        get_validation=False,
                        dir='./mnist_data/',
                        unroll_img=True,
                        max_value=1,
                        **kwargs):
    '''
    Generate the DataLoaders for various datasets.
    Args:
        dataset: The target dataset.
        train_batch: The training batch size.
        test_batch: The testing batch size.
        get_validation: Whether to return an additional third dataset for an unbiased test of the trained network.
        dir: Dataset directory.
        unroll_img: Whether the images should be unrolled into a vector (for MLP) or not (for Conv Net's).
        max_value: The maximum value to rescale the training data to.
        **kwargs: Other arguments passed to torch.utils.data.DataLoader constructor.
    Returns: A list of the training data loader, the test data loader and, optionally, the validation data loader.
    '''

    dataset = __find_dataset(dataset)

    transforms = [torchvision.transforms.ToTensor()]

    if unroll_img:

        class ReshapeTransform:
            def __init__(self, new_size):
                self.new_size = new_size

            def __call__(self, img):
                return img.view(img.size(0), -1)

        transforms.append(ReshapeTransform((-1,)))  # Reshape 28*28 array to vector.

    if max_value!=1:

        class RescaleTransform:
            def __init__(self, new_max):
                self.new_max = new_max

            def __call__(self, ft):
                return ft * self.new_max

        transforms.append(RescaleTransform(max_value))

    train_loader = torch.utils.data.DataLoader(
        dataset(dir, train=True, download=True,
                transform=torchvision.transforms.Compose(transforms)),
                batch_size=train_batch, shuffle=True, **kwargs)

    if not get_validation:

        test_loader = torch.utils.data.DataLoader(
            dataset(dir, train=False, download=True,
                    transform=torchvision.transforms.Compose(transforms)),
                    batch_size=test_batch, shuffle=True, **kwargs)

        loaders = [train_loader, test_loader]

    else:

        dataset = dataset(dir, train=False, download=True,
                          transform=torchvision.transforms.Compose(transforms))
        test_dataset, validation_dataset = torch.utils.data.random_split(dataset,[int(len(dataset)/2)]*2)

        test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch, shuffle=True, **kwargs)
        validation_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=test_batch, shuffle=True, **kwargs)

        loaders = [train_loader, test_loader, validation_loader]

    return loaders

#classifier

In [None]:
def __find_dataset(dataset="mnist"):
    if dataset==Dataset.MNIST:
        ret = torchvision.datasets.MNIST
    elif dataset==Dataset.FASHION_MNIST:
        ret = torchvision.datasets.FashionMNIST
    elif dataset==Dataset.KMNIST:
        ret = torchvision.datasets.KMNIST
    elif dataset==Dataset.EMNIST:
        # Manually overwrite URL for this dataset (https://github.com/pytorch/vision/issues/1296).
        torchvision.datasets.EMNIST.url = 'https://cloudstor.aarnet.edu.au/plus/s/ZNmuFiuQTqZlu9W/download'
        ret = lambda *args, **kwargs: torchvision.datasets.EMNIST(split="balanced", *args, **kwargs)
    else:
        raise ValueError("{} is not a recognised dataset.  Acceptable values are {}".format(
            dataset,
            [d for d in Dataset]
        ))
    return ret

In [None]:
import os

import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import torch.optim as optim
import pickle
import numpy as np
from enum import Enum

try:
    import seaborn as sns
    plt.style.use('seaborn-paper')
except ImportError:
    pass

In [None]:
class weight_norm(Enum):
    NONE = 0
    L1 = 1
    L2 = 2

class Classifier():
    '''Trains a network on image classification tasks.
   Args:
       network: The network to be trained.
       train_loader: A pytorch DataLoader for the training data.
       test_loader: A pytorch DataLoader for the test data.
       learning_rate: The learning rate for the network training.
       optimizer: "adam" or "sgd" for ADAM or SGD optimizer.
       loss: "nll" or "mse" for negative log-likelihood or mean-squared error loss.
       weight_range: A tuple of limits (lower, upper) to clip the weights in each linear layer to.
       weight_normalisation: A weight_norm enum for normalising the weights during training.
       init_weight_mean: Initial mean of the weights.
       init_weight_std: Initial standard deviation of linear layer weights.
       init_conv_weight_std: Initial standard deviation of convolutional layer weights.
       n_epochs: How many epochs to train for.
       n_test_per_epoch: How many times to test the network performance during a single epoch.
       log_interval: How often to log the loss during training.
       save_path: Where to save the network and training information.
   '''

    def __init__(self,
                 network,
                 train_loader,
                 test_loader,

                 learning_rate=0.001,
                 optimizer="adam",
                 loss = "nll",
                 weight_range=(0,1), # None denotes no limit.
                 weight_normalisation=weight_norm.NONE,

                 init_weight_mean=0.0,
                 init_weight_std=0.01,

                 init_conv_weight_std=0.01,

                 n_epochs=10,
                 n_test_per_epoch=0,
                 log_interval=25,

                 save_path="classifier"
                 ):

        self.network = network
        self.train_loader = train_loader
        self.test_loader = test_loader

        self.learning_rate = learning_rate
        self.optimizer = None
        if callable(optimizer):
            self.optimizer = optimizer(self.network.parameters(), lr=self.learning_rate)
        elif type(optimizer) is str:
            if optimizer.lower() == "adam":
                self.optimizer = optim.Adam(self.network.parameters(), lr=self.learning_rate)
            elif optimizer.lower() == "sgd":
                self.optimizer = optim.SGD(self.network.parameters(), lr=self.learning_rate)
        if self.optimizer is None:
            raise NotImplementedError("Unrecognised optimizer :", optimizer)

        if callable(loss):
            self.loss = loss
        elif loss == "nll":
            self.loss = F.nll_loss
        elif loss == "mse":
            self.loss = lambda probs, target, *args, **kwargs: F.mse_loss(probs, torch.zeros(probs.shape, device=target.device).scatter_(1, target.unsqueeze(-1), 1).to(target.device), *args, **kwargs)
        else:
            raise Exception()

        self.weight_range = weight_range

        if self.weight_range is not None:

            if len(self.weight_range)!=2:
                raise Exception("weight range must be of length 2.")

            self.clamp_weight_args = {}
            if self.weight_range[0] is not None:
                self.clamp_weight_args['min']=self.weight_range[0]
            if self.weight_range[1] is not None:
                self.clamp_weight_args['max']=self.weight_range[1]

        if (init_weight_mean is not None) and (init_weight_std is not None):
            def init_weights(m):
                # if type(m) in [torch.nn.Linear, torch.nn.Conv2d]:
                if type(m) in [torch.nn.Linear]:
                    print("Setting weights for", m)
                    m.weight.normal_(init_weight_mean, init_weight_std)
                    m.weight *= (2*torch.randint_like(m.weight,0,2)-1)
                    if self.weight_range is not None:
                        m.weight.clamp_(**self.clamp_weight_args)
                elif type(m) in [torch.nn.Conv2d]:
                    print("Setting weights for", m)
                    m.weight.normal_(init_weight_mean, init_conv_weight_std)
                    if self.weight_range is not None:
                        m.weight.clamp_(**self.clamp_weight_args)
            with torch.no_grad():
                self.network.apply(init_weights)

        self.weight_normalisation = weight_normalisation

        self.n_epochs = n_epochs
        self.n_test_per_epoch = n_test_per_epoch
        self.log_interval = log_interval

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.network.to(self.device)

        print("Prepared classifier with network:\n\n", self.network)

        self.train_losses = []
        self.train_counter = []
        self.test_losses = []
        self.test_correct = []

        self.save_path = save_path
        mk_dir(self.save_path)
        self.network_save_path = os.path.join(self.save_path, "network.pth")
        self.scores_save_path = os.path.join(self.save_path, "scores.pkl")
        self.loss_save_path = os.path.join(self.save_path, "loss.pkl")

    def train(self):

        n_test_in_epoch = max(self.n_test_per_epoch - 1,0)

        for i_epoch in range(1, self.n_epochs + 1):
            self.train_epoch(i_epoch, n_test_in_epoch)
            self.test(i_epoch)

        for data, f, lab in zip([self.train_losses, self.test_correct],
                                [self.loss_save_path, self.scores_save_path],
                                ["Losses", "Scores"]):
            with open(f, 'wb+') as output:
                pickle.dump(np.array(data), output, pickle.HIGHEST_PROTOCOL)
                print('{} saved to {}'.format(lab, f))

        return np.array(self.train_losses), np.array(self.test_correct)

    def train_epoch(self, i_epoch, n_test_in_epoch=0):
        len_loader = len(self.train_loader)
        len_loader_dataset = len(self.train_loader.dataset)

        if n_test_in_epoch>=1:
            test_interval = int(len_loader / n_test_in_epoch)
        else:
            test_interval = 2*len_loader

        self.network.train()
        for batch_idx, (data, target) in enumerate(self.train_loader):
            self.optimizer.zero_grad()
            probs = self.network(data.to(self.device)).squeeze(1)
            loss = self.loss(probs, target.to(self.device))
            loss.backward()
            self.optimizer.step()

            if self.weight_range is not None:
                for p in self.network.parameters():
                    p.data.clamp_(**self.clamp_weight_args)

            if self.weight_normalisation != weight_norm.NONE:
                def norm_weights(m):
                    if type(m) is torch.nn.Linear:
                        if self.weight_normalisation == weight_norm.L1:
                            norm = m.weight.abs().sum().item()
                            if norm > 1:
                                m.weight.div_(norm)
                        elif self.weight_normalisation == weight_norm.L2:
                            norm = torch.sqrt(torch.pow(m.weight.abs(), 2).sum()).item()
                            if norm > 1:
                                m.weight.div_(norm)

                with torch.no_grad():
                    self.network.apply(norm_weights)

            if batch_idx>0 and (batch_idx % test_interval == 0):
                self.test()

            if batch_idx % self.log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    i_epoch, batch_idx * len(data), len_loader_dataset,
                             100. * batch_idx / len_loader, loss.item()))
                self.train_losses.append([(i_epoch - 1) + batch_idx/len_loader, loss.item()])
                self.train_counter.append(
                    (batch_idx * 64) + ((i_epoch - 1) * len_loader_dataset))

    def test(self, i_epoch):

        test_loss, correct = self.validate(self.test_loader)

        if all([correct > score for _,score in self.test_correct]):
            self.save()

        self.test_losses.append([i_epoch, test_loss])
        self.test_correct.append([i_epoch, correct])

        print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
            test_loss, correct, len(self.test_loader.dataset),
            100. * correct / len(self.test_loader.dataset)))

    def validate(self, data_loader):
        self.network.eval()
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in data_loader:
                probs = self.network(data.to(self.device)).squeeze(1).to('cpu')
                test_loss += self.loss(probs, target, reduction='sum').item()
                pred = probs.max(-1)[-1]
                correct += pred.eq(target).sum().item()
        test_loss /= len(self.test_loader.dataset)

        return test_loss, correct

    def save(self):
        torch.save(self.network.state_dict(), self.network_save_path)

    def load(self, path):
        self.network.load_state_dict(torch.load(path, map_location=self.device))

# Optical SA

In [None]:
from abc import abstractmethod
from enum import Enum

import matplotlib.pyplot as plt
import torch

try:
    import seaborn as sns
    plt.style.use('seaborn')
except ImportError:
    pass

import numpy as np

class Encoding(Enum):
    INTENSITY = 1
    AMPLITUDE = 2

class Gradient(Enum):
    EXACT = 1
    APPROXIMATE = 2
    ZERO = 3
    POSITIVE = 4
    NEGATIVE = 5

In [None]:
class SatAbsNL(torch.nn.Module):

    def __init__(self, encoding=Encoding.INTENSITY, gradient=Gradient.APPROXIMATE, OD=5, I_sat=1):
        super().__init__()

        self.sat_abs_nl_func = self.__get_sat_abs_nl_func(encoding, gradient).apply
        self.encoding = encoding
        self.gradient = gradient
        self.OD = OD
        self.I_sat = I_sat

    def forward(self, input):
        return self.sat_abs_nl_func(input, self.OD, self.I_sat)

    def __get_sat_abs_nl_func(self, encoding, gradient):
        sat_abs_nl_func = None

        if encoding == Encoding.INTENSITY:
            if gradient == gradient.APPROXIMATE:
                sat_abs_nl_func = SatAbsNL_I_approxGrad
            elif gradient == gradient.EXACT:
                sat_abs_nl_func = SatAbsNL_I_exactGrad
            elif gradient == gradient.ZERO:
                sat_abs_nl_func = SatAbsNL_I_zeroGrad
            elif gradient == gradient.POSITIVE:
                sat_abs_nl_func = SatAbsNL_I_positiveGrad
            elif gradient == gradient.NEGATIVE:
                sat_abs_nl_func = SatAbsNL_I_negativeGrad

        elif encoding == Encoding.AMPLITUDE:
            if gradient == gradient.APPROXIMATE:
                sat_abs_nl_func = SatAbsNL_E_approxGrad
            elif gradient == gradient.EXACT:
                sat_abs_nl_func = SatAbsNL_E_exactGrad
            elif gradient == gradient.ZERO:
                sat_abs_nl_func = SatAbsNL_E_zeroGrad
            elif gradient == gradient.POSITIVE:
                sat_abs_nl_func = SatAbsNL_E_positiveGrad
            elif gradient == gradient.NEGATIVE:
                sat_abs_nl_func = SatAbsNL_E_negativeGrad

        if sat_abs_nl_func is None:
            print("Unrecognised options for saturated absorption non-linearity:\n\tencoding={}\n\tgradient={}".format(
                encoding, gradient
            ))

        return sat_abs_nl_func

    def extra_repr(self):
        return 'encoding={}, gradient={}, OD={}, I_sat={}'.format(
            self.encoding, self.gradient, self.OD, self.I_sat
        )

class SatAbsNL_I(torch.autograd.Function):
    """
    We can implement our own custom autograd Functions by subclassing
    torch.autograd.Function and implementing the forward and backward passes
    which operate on Tensors.
    """

    @staticmethod
    def forward(ctx, I_in, OD, I_sat):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the ctx.save_for_backward method.
        """
        ctx.save_for_backward(I_in)
        ctx.OD = OD
        ctx.I_sat = I_sat

        return I_in * torch.exp(-OD / (1 + I_in/I_sat))

    @staticmethod
    @abstractmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
        raise NotImplementedError()

class SatAbsNL_I_exactGrad(SatAbsNL_I):
    @staticmethod
    def backward(ctx, grad_output):
        I_in, = ctx.saved_tensors
        OD, I_sat = ctx.OD, ctx.I_sat
        # Return gradients.  Note as forward takes 3 arguments, we must return 3
        # gradients, however the "gradient"'s of OD and I_sat are None.
        return grad_output * (1 + (I_in/I_sat)*( OD/ (1+I_in/I_sat)**2 )) * torch.exp(-OD / (1 + I_in/I_sat)), None, None

class SatAbsNL_I_approxGrad(SatAbsNL_I):
    @staticmethod
    def backward(ctx, grad_output):
        I_in, = ctx.saved_tensors
        OD, I_sat = ctx.OD, ctx.I_sat
        # Return gradients.  Note as forward takes 3 arguments, we must return 3
        # gradients, however the "gradient"'s of OD and I_sat are None.
        return grad_output * torch.exp(-OD / (1 + I_in / I_sat)), None, None

class SatAbsNL_I_zeroGrad(SatAbsNL_I):
    @staticmethod
    def backward(ctx, grad_output):
        return torch.zeros(grad_output.shape).to(grad_output.device), None, None

class SatAbsNL_I_positiveGrad(SatAbsNL_I):
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None, None

class SatAbsNL_I_negativeGrad(SatAbsNL_I):
    @staticmethod
    def backward(ctx, grad_output):
        return -1*grad_output, None, None

class SatAbsNL_E(torch.autograd.Function):
    """
    We can implement our own custom autograd Functions by subclassing
    torch.autograd.Function and implementing the forward and backward passes
    which operate on Tensors.
    """

    @staticmethod
    def forward(ctx, E_in, OD, I_sat):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the ctx.save_for_backward method.
        """
        ctx.save_for_backward(E_in)
        ctx.OD = OD # Note this OD is still defined w.r.t the transmitted intensity.
        ctx.I_sat = I_sat

        return E_in * torch.exp(- (OD/2) / (1 + (E_in**2)/I_sat))

    @staticmethod
    @abstractmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
        raise NotImplementedError()

class SatAbsNL_E_exactGrad(SatAbsNL_E):
    @staticmethod
    def backward(ctx, grad_output):
        E_in, = ctx.saved_tensors
        OD, I_sat = ctx.OD, ctx.I_sat

        # Return gradients.  Note as forward takes 3 arguments, we must return 3
        # gradients, however the "gradient"'s of OD and I_sat are None.
        return grad_output * (1 + ((E_in**2)/I_sat)*( OD/ (1+(E_in**2)/I_sat)**2 )) * torch.exp(-(OD/2) / (1 + (E_in**2) / I_sat)), None, None

class SatAbsNL_E_approxGrad(SatAbsNL_E):
    @staticmethod
    def backward(ctx, grad_output):
        E_in, = ctx.saved_tensors
        OD, I_sat = ctx.OD, ctx.I_sat
        # Return gradients.  Note as forward takes 3 arguments, we must return 3
        # gradients, however the "gradient"'s of OD and I_sat are None.
        return grad_output * torch.exp(-(OD/2) / (1 + (E_in**2)/I_sat)), None, None

class SatAbsNL_E_zeroGrad(SatAbsNL_E):
    @staticmethod
    def backward(ctx, grad_output):
        return torch.zeros(grad_output.shape).to(grad_output.device), None, None

class SatAbsNL_E_positiveGrad(SatAbsNL_E):
    @staticmethod
    def backward(ctx, grad_output):
        return np.exp(-5)*grad_output, None, None

class SatAbsNL_E_negativeGrad(SatAbsNL_E):
    @staticmethod
    def backward(ctx, grad_output):
        return -1*grad_output, None, None

#Network

In [None]:
import torch.nn as nn

class LinNet(nn.Module):
    """
    A simple MLP with configurable activation functions.
    """

    def __init__(self,
                 n_hid=[200],
                 n_in=784,
                 n_out=10,
                 activation=nn.ReLU,
                 output=lambda: nn.LogSoftmax(-1)):
        super().__init__()

        if type(n_hid) != list:
            n_hid = [n_hid]
        n_layers = [n_in] + n_hid + [n_out]

        self.layers = []
        for i_layer, (n1, n2) in enumerate(zip(n_layers, n_layers[1:])):
            mods = [nn.Linear(n1, n2, bias=False)]
            act_fn = activation if i_layer < len(n_layers) - 2 else output
            if act_fn is not None:
                mods.append(act_fn())
            layer = nn.Sequential(*mods)
            self.layers.append(layer)

        self.layers = nn.ModuleList(self.layers)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

class ConvNet(nn.Module):
    """
    A simple CNN+MLP stacked network, with configurable activation functions.
    """

    def __init__(self,
                 n_ch_conv = [32, 64],
                 kernel_size_conv = [5, 5],
                 conv_args = {'stride':1, 'padding':0, 'bias':False},
                 activation_conv=nn.ReLU,
                 pool_conv=lambda: nn.MaxPool2d(kernel_size=2, stride=2),
                 dropout=False,
                 n_in_fc = 1024,
                 n_hid_fc=[128],
                 n_out=10,
                 activation_fc=nn.ReLU,
                 bias_fc=False,
                 output=lambda: nn.LogSoftmax(-1)):
        super().__init__()

        if type(n_ch_conv) != list:
            n_ch_conv = [n_ch_conv]
        n_ch_conv = [1] + n_ch_conv

        if type(activation_conv) != list:
            activation_conv = [activation_conv]*len(n_ch_conv)

        self.layers_conv = []
        for n_ch_in, n_ch_out, k_size, act_fn in zip(n_ch_conv, n_ch_conv[1:], kernel_size_conv, activation_conv):
            mods = [nn.Conv2d(n_ch_in, n_ch_out, k_size, **conv_args)]
            if act_fn is not None:
                mods.append(act_fn())
            if pool_conv is not None:
                mods.append(pool_conv())
            layer = nn.Sequential(*mods)
            self.layers_conv.append(layer)

        self.layers_conv = nn.ModuleList(self.layers_conv)

        if type(n_hid_fc) != list:
            n_hid_fc = [n_hid_fc]
        n_layers = [n_in_fc] + n_hid_fc + [n_out]

        if callable(dropout):
            self.dropout = dropout()
        else:
            if dropout:
                self.dropout = nn.Dropout()
            else:
                self.dropout = None

        self.layers_fc = []
        for i_layer, (n_in, n_out) in enumerate(zip(n_layers, n_layers[1:])):
            mods = [nn.Linear(n_in, n_out, bias=bias_fc)]
            act_fn = activation_fc if i_layer < len(n_layers) - 2 else output
            if act_fn is not None:
                mods.append(act_fn())
            layer = nn.Sequential(*mods)
            self.layers_fc.append(layer)

        self.layers_fc = nn.ModuleList(self.layers_fc)

        self.printed_size = False

    def forward(self, x):

        for layer in self.layers_conv:
            x = layer(x)
        x = x.reshape(x.size(0), -1)

        if self.dropout is not None:
            x = self.dropout(x)

        if not self.printed_size:
            print("Size of input to first linear layer is", x.shape)
            self.printed_size = True

        for layer in self.layers_fc:
            x = layer(x)

        return x

#LOSS

In [None]:
from enum import Enum

class Loss(Enum):
    MSE = 0 # Mean-squared-error : corresponds to optically obtainable loss.
    CCE = 1 # Categorical cross-entropy : standard computational loss for classification problems.

#TRAIN ANN

In [None]:
import os
import pickle

import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn


try:
    import seaborn as sns
    plt.style.use('seaborn-paper')
except ImportError:
    pass
save_loc="cnn/ANN"
loss=Loss.CCE
activation=nn.ReLU



    ####################################################
    # Configure datasets.
    ####################################################

dataset = Dataset.MNIST

if dataset != Dataset.MNIST:
    save_loc += "_{}".format(str(dataset).split(".")[-1])

batch_size_train = 64
batch_size_test = 1000

  ####################################################
  # Configure Networks.
  ####################################################

if loss==Loss.MSE:
    output = None
    loss_str = "mse"
elif loss==Loss.CCE:
    output = lambda: nn.LogSoftmax(-1)
    loss_str = "nll"
else:
    raise ValueError("Unrecognised loss :", loss)

net_args = {
    'n_ch_conv': [32, 64],
    'kernel_size_conv': [5, 5],
    'n_in_fc': 1024,
    'n_hid_fc': [128],
    'activation_conv': [activation, activation],
    'activation_fc': activation,
    'dropout': lambda: nn.Dropout(0.4),
    'conv_args': {'stride': 1, 'padding': 0, 'bias': False},
    'pool_conv': lambda: nn.AvgPool2d(kernel_size=2, stride=2),
    'n_out': 10 if dataset != Dataset.EMNIST else 47,
    'bias_fc': False,
    'output': output
}

  ####################################################
  # Train classifiers
  ####################################################

n_seeds = 5

losses = {}
corrects = {}
valid_scores = {}

for i in range(n_seeds):
    lab = 'seed{}'.format(i)

    network = ConvNet(**net_args)

    train_loader, test_loader, validation_loader = get_dataset_loaders(
        dataset=dataset,
        train_batch=batch_size_train,
        test_batch=batch_size_test,
        unroll_img=False,
        max_value=1,
        get_validation=True)

    classifier = Classifier(network, train_loader, test_loader,
                            n_epochs=30 if dataset == Dataset.MNIST else 40,
                            learning_rate=5e-4,
                            init_weight_mean=0., init_weight_std=0.01, init_conv_weight_std=0.1,
                            loss=loss_str,
                            weight_range=None,
                            weight_normalisation=weight_norm.NONE,
                            log_interval=25, n_test_per_epoch=0,
                            save_path=os.path.join(save_loc, lab))

    train_losses, test_correct = classifier.train()

    losses[lab] = train_losses
    corrects[lab] = test_correct

    ####################################################
    # Validation
    ####################################################

    classifier.load(classifier.network_save_path)

    valid_loss, valid_correct = classifier.validate(validation_loader)

    print("Validation accuracy : {:.2f}%".format(100. * valid_correct / len(validation_loader.dataset)))
    valid_scores[lab] = 100. * valid_correct / len(validation_loader.dataset)

    validation_save_path = os.path.join(classifier.save_path, "validation_score.pkl")
    with open(validation_save_path, 'wb+') as output:
        pickle.dump(np.array([valid_loss, valid_correct]), output, pickle.HIGHEST_PROTOCOL)
        print('Validation scores saved to {}'.format(validation_save_path))

print("Validation scores are:")
for lab, score in valid_scores.items():
    print("\t{} : {:.2f}%".format(lab, score))

  ####################################################
  # Plot results
  ####################################################

fig_fname = os.path.join(save_loc, "training_performance")

with plt.style.context('seaborn-paper', after_reset=True):

  fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(7, 2.5), gridspec_kw={'wspace': 0.3})

  window = 25
  avg_mask = np.ones(window) / window

  for lab, data in losses.items():
      ax1.plot(np.convolve(data[:, 0], avg_mask, 'valid'),
                np.convolve(data[:, 1], avg_mask, 'valid'),
                label=lab, linewidth=0.75, alpha=0.8)
  ax1.legend()
  ax1.set_xlabel("Epoch")
  ax1.set_ylabel("Losses")

  for lab, data in corrects.items():
      ax2.plot(data[:, 0], data[:, 1] / len(test_loader.dataset), label=lab)
      print("{}: Best score {}/{}".format(lab, np.max(data), len(test_loader)))
  ax2.legend()
  ax2.set_xlabel("Epoch")
  ax2.set_ylabel("Accuracy")

  plt.savefig(fig_fname + ".png", bbox_inches='tight')
  plt.savefig(fig_fname + ".pdf", bbox_inches='tight')


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw

Setting weights for Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), bias=False)
Setting weights for Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), bias=False)
Setting weights for Linear(in_features=1024, out_features=128, bias=False)
Setting weights for Linear(in_features=128, out_features=10, bias=False)
Prepared classifier with network:

 ConvNet(
  (layers_conv): ModuleList(
    (0): Sequential(
      (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), bias=False)
      (1): ReLU()
      (2): AvgPool2d(kernel_size=2, stride=2, padding=0)
    )
    (1): Sequential(
      (0): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), bias=False)
      (1): ReLU()
      (2): AvgPool2d(kernel_size=2, stride=2, padding=0)
    )
  )
  (dropout): Dropout(p=0.4, inplace=False)
  (layers_fc): ModuleList(
    (0): Sequential(
      (0): Linear(in_features=1024, out_features=128, bias=False)
      (1): ReLU()


#TRAIN ONN

In [None]:
import os
import pickle

import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn

In [None]:
save_loc="cnn/ONN"
loss=Loss.MSE
OD=10
gradient=Gradient.APPROXIMATE

In [None]:
dataset = Dataset.MNIST

if dataset != Dataset.MNIST:
    save_loc += "_{}".format(str(dataset).split(".")[-1])

batch_size_train = 64
batch_size_test = 1000

    ####################################################
    # Configure Networks.
    ####################################################

sat_abs_nl_args = {'I_sat': 1,
                    'OD': OD,
                    'encoding': Encoding.AMPLITUDE,
                    'gradient': gradient}

SANL = lambda: SatAbsNL(**sat_abs_nl_args)

if loss==Loss.MSE:
    output = None
    loss_str = "mse"
elif loss==Loss.CCE:
    output = lambda: nn.LogSoftmax(-1)
    loss_str = "nll"
else:
    raise ValueError("Unrecognised loss :", loss)

net_args = {
    'n_ch_conv': [32, 64],
    'kernel_size_conv': [5, 5],
    'n_in_fc': 1024,
    'n_hid_fc': [128],
    'activation_conv': [SANL, SANL],
    'activation_fc': SANL,
    'dropout': lambda: nn.Dropout(0.4),
    'conv_args': {'stride': 1, 'padding': 0, 'bias': False},
    'pool_conv': lambda: nn.AvgPool2d(kernel_size=2, stride=2),
    'n_out': 10 if dataset != Dataset.EMNIST else 47,
    'bias_fc': False,
    'output': output
}

    ####################################################
    # Train classifiers
    ####################################################

n_seeds = 5

losses = {}
corrects = {}
valid_scores = {}

for i in range(n_seeds):
    lab = 'seed{}'.format(i)

    network = ConvNet(**net_args)

    train_loader, test_loader, validation_loader = get_dataset_loaders(
        dataset=dataset,
        train_batch=batch_size_train,
        test_batch=batch_size_test,
        unroll_img=False,
        max_value=15 if OD > 10 else 5,
        get_validation=True)

    classifier = Classifier(network, train_loader, test_loader,
                            n_epochs=30 if dataset == Dataset.MNIST else 40,
                            learning_rate=5e-4,
                            init_weight_mean=0., init_weight_std=0.01, init_conv_weight_std=0.1,
                            loss=loss_str,
                            weight_range=None,
                            weight_normalisation=weight_norm.NONE,
                            log_interval=25, n_test_per_epoch=0,
                            save_path=os.path.join(save_loc, lab))

    train_losses, test_correct = classifier.train()

    losses[lab] = train_losses
    corrects[lab] = test_correct

    ####################################################
    # Validation
    ####################################################

    classifier.load(classifier.network_save_path)

    valid_loss, valid_correct = classifier.validate(validation_loader)

    print("Validation accuracy : {:.2f}%".format(100. * valid_correct / len(validation_loader.dataset)))
    valid_scores[lab] = 100. * valid_correct / len(validation_loader.dataset)

    validation_save_path = os.path.join(classifier.save_path, "validation_score.pkl")
    with open(validation_save_path, 'wb+') as output:
        pickle.dump(np.array([valid_loss, valid_correct]), output, pickle.HIGHEST_PROTOCOL)
        print('Validation scores saved to {}'.format(validation_save_path))

print("Validation scores are:")
for lab, score in valid_scores.items():
    print("\t{} : {:.2f}%".format(lab, score))

    ####################################################
    # Plot results
    ####################################################

fig_fname = os.path.join(save_loc, "training_performance")

with plt.style.context('seaborn-paper', after_reset=True):

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(7, 2.5), gridspec_kw={'wspace': 0.3})

    window = 25
    avg_mask = np.ones(window) / window

    for lab, data in losses.items():
        ax1.plot(np.convolve(data[:, 0], avg_mask, 'valid'),
                  np.convolve(data[:, 1], avg_mask, 'valid'),
                  label=lab, linewidth=0.75, alpha=0.8)
    ax1.legend()
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Losses")

    for lab, data in corrects.items():
        ax2.plot(data[:, 0], data[:, 1] / len(test_loader.dataset), label=lab)
        print("{}: Best score {}/{}".format(lab, np.max(data), len(test_loader)))
    ax2.legend()
    ax2.set_xlabel("Epoch")
    ax2.set_ylabel("Accuracy")

    plt.savefig(fig_fname + ".png", bbox_inches='tight')
    plt.savefig(fig_fname + ".pdf", bbox_inches='tight')