In [None]:
import torch 
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import transforms
import torchvision.datasets as datasets
from tqdm import tqdm, trange
import numpy as np
import matplotlib.pyplot as plt


In [None]:
def create_datasets(batch_size):
    validation_size = 0.2
    transform = transforms.ToTensor()
    train_data = datasets.MNIST(root='data/',
                                train=True,
                                download=True,
                                transform=transform)
    
    test_data  = datasets.MNIST(root='data/',
                                train=False,
                                download=True,
                                transform=transform)
    
    # specify validation indices
    num_train = len(train_data)
    indices = list(range(num_train))
    np.random.shuffle(indices)
    split = int(np.floor(validation_size*num_train))
    train_idx, valid_idx = indices[split:], indices[:split]

    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    train_loader = DataLoader(train_data,
                              batch_size=batch_size,
                              sampler=train_sampler)
    
    valid_loader = DataLoader(train_data,
                              batch_size=batch_size,
                              sampler=valid_sampler)
    
    test_loader = DataLoader(test_data,
                             batch_size=batch_size)
    
    return train_loader, valid_loader, test_loader
    

## IOCMLP:

- 3 hidden layers(weights of 2:4 layers are made to be positive) , 800 nodes each
- batchnorm between every layer
- activation: ELU



In [None]:
class IOCMLP(nn.Module):
    def __init__(self, input_size, num_classes):
        super(IOCMLP, self).__init__()
        self.input_size = input_size
        self.fc1 = nn.Linear(in_features=input_size,out_features=800)
        self.fc2 = nn.Linear(in_features=800,out_features=800)
        self.fc3 = nn.Linear(in_features=800,out_features=num_classes)
        self.elu = nn.ELU()
        self.bn = nn.BatchNorm1d(800)

    def forward(self, x):
        x = x.view(-1, self.input_size) #flatten the image input
        x = self.elu(self.fc1(x))
        x = self.bn(x)
        x = self.elu(self.fc2(x))
        x = self.bn(x)
        x = self.fc3(x)
        return x

    

In [None]:
# PARAMETERS (as specified in the paper)
num_classes = 10
lr = 0.0001

In [None]:
# MNIST
input_size = 28*28
model = IOCMLP(input_size=input_size, num_classes=num_classes)
print(model)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr)

IOCMLP(
  (fc1): Linear(in_features=784, out_features=800, bias=True)
  (fc2): Linear(in_features=800, out_features=800, bias=True)
  (fc3): Linear(in_features=800, out_features=10, bias=True)
  (elu): ELU(alpha=1.0)
  (bn): BatchNorm1d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)


In [None]:
def fix_neg_weights(t):
    mask = (t<0)
    maskf = mask.float()
    x = t*maskf
    x = torch.exp(x-5)*mask
    xinv = t * ~mask
    return xinv+x

t = torch.tensor([-1,-1])
type(fix_neg_weights(t))

torch.Tensor

In [None]:
def train_IOCMLP(model, batch_size, patience, n_epochs):
    train_losses = []
    valid_losses = []
    test_losses = []
    avg_train_losses = []
    avg_valid_losses = []

    early_stopping = EarlyStopping(patience=patience, verbose=True)
    for epoch in trange(1, n_epochs+1):
        ##########
        #TRAINING#
        ##########
        model.train()
        for batch, (data,target) in enumerate(train_loader, 1):
            optimizer.zero_grad()
            # FORWARD PASS
            out = model(data)
            loss = criterion(out, target)
            # BACKWARD PASS
            loss.backward()
            # OPTIMIZER STEP
            optimizer.step()

            #FIXING NON-ZERO WEIGHTS
            with torch.no_grad():
                model.fc2.weight = nn.Parameter(fix_neg_weights(model.fc2.weight))
                model.fc3.weight = nn.Parameter(fix_neg_weights(model.fc3.weight))
            # RECORD LOSS
            train_losses.append(loss.item())
    
        ############
        #VALIDATION#
        ############
        model.eval()
        for data,target in valid_loader:
            out = model(data)
            loss = criterion(out, target)
            valid_losses.append(loss.item())

        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)
        
        msg = (f'[{epoch}/{n_epochs}] ' +
                     f'train_loss: {train_loss:.5f} ' +
                     f'valid_loss: {valid_loss:.5f}')
        print(msg)

        #clear loss_arrays
        train_losses.clear()
        valid_losses.clear()

        early_stopping(valid_loss, model)

        if early_stopping.early_stop:
            print("early stopping")
            break

    model.load_state_dict(torch.load('checkpoint.pt'))
    return  model, avg_train_losses, avg_valid_losses

In [None]:
epsilon = 5
batch_size = 32
n_epochs = 100
patience = 50

train_loader, test_loader, valid_loader = create_datasets(batch_size)
model, train_loss, validation_loss = train_IOCMLP(model, batch_size, patience, n_epochs)







  0%|          | 0/100 [00:00<?, ?it/s][A[A[A[A[A[A





  1%|          | 1/100 [01:09<1:54:30, 69.40s/it][A[A[A[A[A[A

[1/100] train_loss: 0.75836 valid_loss: 3.73056
Validation loss decreased (inf --> 3.730562).  Saving model ...








  2%|▏         | 2/100 [02:24<1:56:11, 71.14s/it][A[A[A[A[A[A

[2/100] train_loss: 0.52515 valid_loss: 5.24029
EarlyStopping counter: 1 out of 50








  3%|▎         | 3/100 [03:40<1:57:18, 72.57s/it][A[A[A[A[A[A

[3/100] train_loss: 0.47245 valid_loss: 7.52064
EarlyStopping counter: 2 out of 50








  4%|▍         | 4/100 [04:55<1:57:25, 73.39s/it][A[A[A[A[A[A

[4/100] train_loss: 0.44082 valid_loss: 7.51067
EarlyStopping counter: 3 out of 50








  5%|▌         | 5/100 [06:11<1:57:07, 73.97s/it][A[A[A[A[A[A

[5/100] train_loss: 0.41855 valid_loss: 11.83380
EarlyStopping counter: 4 out of 50








  6%|▌         | 6/100 [07:26<1:56:22, 74.28s/it][A[A[A[A[A[A

[6/100] train_loss: 0.40215 valid_loss: 12.04852
EarlyStopping counter: 5 out of 50


KeyboardInterrupt: ignored