In [1]:
import torch
from torchvision.transforms import ToTensor, Normalize, Compose
from torchvision import datasets, transforms
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [13]:
import torch.nn as nn
class Flatten(nn.Module):
    def forward(self, x):
        x = x.view(x.size()[0], -1)
        return x

In [2]:
class MLPModel(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(MLPModel, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 10)
        )
    
    def forward(self, input):
        input = input.view(input.size(0), -1)
        return self.layers(input)

In [3]:
class PR_CNN(nn.Module):
    def __init__(self, **kwargs):
        super(PR_CNN, self).__init__()
        self.expected_input_size = kwargs.get('input_dim', None)

        # First layer
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=kwargs.get('input_channels', None),
                      out_channels=kwargs.get('input_channels', None)*5,
                      kernel_size=3,
                      stride=3),
            nn.LeakyReLU()
        )
        # Second layer
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=kwargs.get('input_channels', None)*5,
                      out_channels=kwargs.get('input_channels', None)*25,
                      kernel_size=3,
                      stride=3),
            nn.LeakyReLU()
        )
        # Third layer
        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels=kwargs.get('input_channels', None)*25,
                      out_channels=1536,
                      kernel_size=3,
                      stride=3),
            nn.LeakyReLU()
        )
        

        # Classification layer
        self.fc = nn.Sequential(
            Flatten(),
            nn.Linear(1536, kwargs.get('output_channels', None))
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.fc(x)
        return x

In [4]:
import numpy as np
import torch

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

def train(model, train_loader, optimizer, loss_fn, print_every=100):
    '''
    Trains the model for one epoch
    '''
    model.train()
    losses = []
    n_correct = 0
    for iteration, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        output = model(images)
        optimizer.zero_grad()
        loss = loss_fn(output, labels)
        loss.backward()
        optimizer.step()
#         if iteration % print_every == 0:
#             print('Training iteration {}: loss {:.4f}'.format(iteration, loss.item()))
        losses.append(loss.item())
        n_correct += torch.sum(output.argmax(1) == labels).item()
    accuracy = 100.0 * n_correct / len(train_loader.dataset)
    return np.mean(np.array(losses)), accuracy
            
def test(model, test_loader, loss_fn):
    '''
    Tests the model on data from test_loader
    '''
    model.eval()
    test_loss = 0
    n_correct = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            output = model(images)
            loss = loss_fn(output, labels)
            test_loss += loss.item()
            n_correct += torch.sum(output.argmax(1) == labels).item()

    average_loss = test_loss / len(test_loader)
    accuracy = 100.0 * n_correct / len(test_loader.dataset)
#     print('Test average loss: {:.4f}, accuracy: {:.3f}'.format(average_loss, accuracy))
    return average_loss, accuracy


def fit(train_dataloader, val_dataloader, model, optimizer, loss_fn, n_epochs, scheduler=None):
    train_losses, train_accuracies = [], []
    val_losses, val_accuracies = [], []

    for epoch in range(n_epochs):
        train_loss, train_accuracy = train(model, train_dataloader, optimizer, loss_fn)
        val_loss, val_accuracy = test(model, val_dataloader, loss_fn)
        train_losses.append(train_loss)
        train_accuracies.append(train_accuracy)
        val_losses.append(val_loss)
        val_accuracies.append(val_accuracy)
        if scheduler:
            scheduler.step() # argument only needed for ReduceLROnPlateau
        print('Epoch {}/{}: train_loss: {:.4f}, train_accuracy: {:.4f}, val_loss: {:.4f}, val_accuracy: {:.4f}'.format(epoch+1, n_epochs,
                                                                                                          train_losses[-1],
                                                                                                          train_accuracies[-1],
                                                                                                          val_losses[-1],
                                                                                                          val_accuracies[-1]))
    
    return val_accuracies

In [5]:
transform_dataset = transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])

In [6]:
train_dataset = datasets.ImageFolder('mnist/train',transform=transform_dataset)
test_dataset = datasets.ImageFolder('mnist/test',transform=transform_dataset)

In [7]:
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=100, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=100, shuffle=True)

In [8]:
train_perm_dataset = datasets.ImageFolder('mnist-permutated/train',transform=transform_dataset)
test_perm_dataset = datasets.ImageFolder('mnist-permutated/test',transform=transform_dataset)

In [9]:
train_perm_loader = torch.utils.data.DataLoader(train_perm_dataset,batch_size=100, shuffle=True)
test_perm_loader = torch.utils.data.DataLoader(test_perm_dataset,batch_size=100, shuffle=True)

In [10]:
nb_epochs = 25
hidden_layers = 512
learning_rate = 0.001
model_mlp = MLPModel(3*28*28, hidden_layers)
model_mlp = model_mlp.to(device)
optimizer = torch.optim.Adam(model_mlp.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()
accuracies = fit(train_loader, test_loader, model_mlp, optimizer, loss_fn, nb_epochs)
model_mlp = MLPModel(3*28*28, hidden_layers)
model_mlp = model_mlp.to(device)
optimizer = torch.optim.Adam(model_mlp.parameters(), lr=learning_rate)
accuracies_perm = fit(train_perm_loader, test_perm_loader, model_mlp, optimizer, loss_fn, nb_epochs)
print("Accuracy of the MLP with the normal dataset: ", accuracies[-1])
print("Accuracy of the MLP with the permuted dataset: ", accuracies_perm[-1])      

Epoch 1/25: train_loss: 0.5482, train_accuracy: 81.9005, val_loss: 0.4184, val_accuracy: 86.1400
Epoch 2/25: train_loss: 0.2712, train_accuracy: 91.3955, val_loss: 0.2483, val_accuracy: 92.0000
Epoch 3/25: train_loss: 0.2098, train_accuracy: 93.3856, val_loss: 0.1964, val_accuracy: 93.7700
Epoch 4/25: train_loss: 0.1792, train_accuracy: 94.3831, val_loss: 0.2024, val_accuracy: 93.4400
Epoch 5/25: train_loss: 0.1591, train_accuracy: 95.0075, val_loss: 0.1820, val_accuracy: 94.1800
Epoch 6/25: train_loss: 0.1432, train_accuracy: 95.3731, val_loss: 0.1356, val_accuracy: 95.4900
Epoch 7/25: train_loss: 0.1272, train_accuracy: 95.9726, val_loss: 0.1583, val_accuracy: 95.1400
Epoch 8/25: train_loss: 0.1199, train_accuracy: 96.2289, val_loss: 0.1874, val_accuracy: 94.3800
Epoch 9/25: train_loss: 0.1072, train_accuracy: 96.6020, val_loss: 0.1387, val_accuracy: 95.7900
Epoch 10/25: train_loss: 0.1040, train_accuracy: 96.5398, val_loss: 0.1054, val_accuracy: 96.7900
Epoch 11/25: train_loss: 0.09

In [14]:
nb_epochs = 25
learning_rate = 0.001
pre_defined_kwargs = {'input_dim': (28,28), 'input_channels': 3, 'output_channels': 10}
model_cnn = PR_CNN(**pre_defined_kwargs)
model_cnn = model_cnn.to(device)
optimizer = torch.optim.Adam(model_cnn.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()
accuracies = fit(train_loader, test_loader, model_cnn, optimizer, loss_fn, nb_epochs)
model_cnn = PR_CNN(**pre_defined_kwargs)
model_cnn = model_cnn.to(device)
optimizer = torch.optim.Adam(model_cnn.parameters(), lr=learning_rate)
accuracies_perm = fit(train_perm_loader, test_perm_loader, model_cnn, optimizer, loss_fn, nb_epochs)
print("Accuracy of the CNN with the normal dataset: ", accuracies[-1])
print("Accuracy of the CNN with the permuted dataset: ", accuracies_perm[-1])  

Epoch 1/25: train_loss: 0.3030, train_accuracy: 90.8308, val_loss: 0.1242, val_accuracy: 96.0000
Epoch 2/25: train_loss: 0.0940, train_accuracy: 97.0299, val_loss: 0.0728, val_accuracy: 97.6900
Epoch 3/25: train_loss: 0.0599, train_accuracy: 98.0920, val_loss: 0.0702, val_accuracy: 97.6600
Epoch 4/25: train_loss: 0.0418, train_accuracy: 98.6343, val_loss: 0.0682, val_accuracy: 98.0100
Epoch 5/25: train_loss: 0.0284, train_accuracy: 99.1169, val_loss: 0.0726, val_accuracy: 97.8000
Epoch 6/25: train_loss: 0.0231, train_accuracy: 99.2413, val_loss: 0.0709, val_accuracy: 98.0700
Epoch 7/25: train_loss: 0.0158, train_accuracy: 99.4801, val_loss: 0.0727, val_accuracy: 97.9700
Epoch 8/25: train_loss: 0.0157, train_accuracy: 99.4900, val_loss: 0.0642, val_accuracy: 98.2300
Epoch 9/25: train_loss: 0.0148, train_accuracy: 99.4677, val_loss: 0.0724, val_accuracy: 98.0000
Epoch 10/25: train_loss: 0.0151, train_accuracy: 99.4701, val_loss: 0.0695, val_accuracy: 98.0900
Epoch 11/25: train_loss: 0.00

We can observe that for the MLP and CNN networks, accuracies obtained with the normal dataset and the permuted dataset are quite similar. It is not surprising as classes (0 to 9) are separated by similar patterns (in each class), which could be the representation that humans naturaly see (8 for example). But the permutation used for this exercise enables the network to still recognize a pattern for each class. Thus, it means that it transforms each class in a "unique" and separable way. We could have obtained a worse score with a different kind of permutation.