In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchsummaryX import summary
import numpy as np
import time
from tqdm.notebook import trange, tqdm

In [2]:
device = torch.device('cuda:0' \
                      if torch.cuda.is_available() \
                      else 'cpu')
device

from torchvision import transforms

transform = transforms.Compose([
    transforms.ToTensor(),
])

dataset = torchvision.datasets.mnist.MNIST('/workspace/datasets/', download=True, transform=transform)

# create a random split of the dataset, 80% for training and 20% for validation
mnist, val = random_split(dataset, [int(len(dataset)*0.8), int(len(dataset)*0.2)])

In [3]:
class Residual_MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )
    def forward(self, x):
        return self.mlp(x) + x

In [4]:
sample = torch.randn(1, 10)
L = Residual_MLP(10, 100)
print("x.shape", sample.shape)
print("L(x).shape", L(sample).shape)

x.shape torch.Size([1, 10])
L(x).shape torch.Size([1, 10])


In [5]:
N1 = nn.Sequential(
    nn.Flatten(),
    nn.LazyLinear(10),
)



In [6]:
def train(model, dataset, epochs):
    optimizer = torch.optim.Adam(model.parameters())
    loss = nn.CrossEntropyLoss()
    dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
    model = model.to(device)
    for epoch in trange(epochs):
        start = time.time()
        for (xs, targets) in tqdm(dataloader):
            xs, targets = xs.to(device), targets.to(device)
            ys = model(xs)
            optimizer.zero_grad()
            l = loss(ys, targets)
            l.backward()
            optimizer.step()
            with torch.no_grad():
                acc = (ys.argmax(axis=1) == targets).sum() / xs.shape[0]
        duration = time.time() - start
        print("[%d] acc = %.2f loss = %.4f in %.2f seconds." % (epoch, acc.item(), l.item(), duration))

In [7]:
train(N1, mnist, 1)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/375 [00:00<?, ?it/s]

[0] acc = 0.88 loss = 0.4785 in 13.26 seconds.


In [8]:
N2 = nn.Sequential(
    N1,
    Residual_MLP(10, 5),
)

train(N2, mnist, 1)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/375 [00:00<?, ?it/s]

[0] acc = 0.94 loss = 0.2226 in 6.48 seconds.


In [13]:
N3 = nn.Sequential(
    N2,
    Residual_MLP(10, 5),
    
)
train(N3, mnist, 1)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/375 [00:00<?, ?it/s]

[0] acc = 0.95 loss = 0.1391 in 6.75 seconds.


In [10]:
N4 = nn.Sequential(
    N3,
    Residual_MLP(10, 5),
)
train(N4, mnist, 1)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/375 [00:00<?, ?it/s]

[0] acc = 0.90 loss = 0.3728 in 7.02 seconds.


In [11]:
print(N2)

Sequential(
  (0): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=784, out_features=10, bias=True)
  )
  (1): Residual_MLP(
    (mlp): Sequential(
      (0): Linear(in_features=10, out_features=5, bias=True)
      (1): ReLU()
      (2): Linear(in_features=5, out_features=10, bias=True)
    )
  )
)


In [14]:
print(N3)

Sequential(
  (0): Sequential(
    (0): Sequential(
      (0): Flatten(start_dim=1, end_dim=-1)
      (1): Linear(in_features=784, out_features=10, bias=True)
    )
    (1): Residual_MLP(
      (mlp): Sequential(
        (0): Linear(in_features=10, out_features=5, bias=True)
        (1): ReLU()
        (2): Linear(in_features=5, out_features=10, bias=True)
      )
    )
  )
  (1): Residual_MLP(
    (mlp): Sequential(
      (0): Linear(in_features=10, out_features=5, bias=True)
      (1): ReLU()
      (2): Linear(in_features=5, out_features=10, bias=True)
    )
  )
)
