<a href="https://colab.research.google.com/github/Bitdribble/dlwpt-code/blob/master/colab/PyTorchCh8_Width.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Experimenting with model width - [Deep Learning with PyTorch](https://pytorch.org/assets/deep-learning/Deep-Learning-with-PyTorch.pdf), Chap. 8.

In [None]:
import datetime
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

torch.set_printoptions(edgeitems=2, linewidth=75)
torch.manual_seed(123)

<torch._C.Generator at 0x7fde4c3baa90>

In [None]:
# Data preparation
data_path = '.'
cifar10 = datasets.CIFAR10(data_path, train=True, download=True)
cifar10_val = datasets.CIFAR10(data_path, train=False, download=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

In [None]:
# Normalize data
transformed_cifar10 = datasets.CIFAR10(
    data_path, train=True, download=False,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4915, 0.4823, 0.4468),
                             (0.2470, 0.2435, 0.2616))
    ]))
transformed_cifar10_val = datasets.CIFAR10(
    data_path, train=False, download=False,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4915, 0.4823, 0.4468),
                             (0.2470, 0.2435, 0.2616))
    ]))

In [None]:
# Restrict data to airplanes and birds
label_map = {0: 0, 2: 1}
class_names = ['airplane', 'bird']

cifar2 = [(img, label_map[label]) for img, label in transformed_cifar10 if label in [0, 2]]
cifar2_val = [(img, label_map[label]) for img, label in transformed_cifar10_val if label in [0, 2]]

In [None]:
device = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
print(f"Training on device {device}.")

In [None]:
def training_loop(n_epochs, device, optimizer, model, loss_fn, train_loader, log_epochs=0):
  for epoch in range(1, n_epochs + 1):
    loss_train = 0.0

    for imgs, labels in train_loader:
      imgs = imgs.to(device=device)
      labels = labels.to(device=device)

      outputs = model(imgs)
      loss = loss_fn(outputs, labels)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      loss_train += loss.item()

    if log_epochs is not 0 and ((epoch+1) % log_epochs == 0 or (epoch+1) == n_epochs):
      print(f"{datetime.datetime.now()} Epoch {epoch+1}, "
            f"Training loss {loss_train / len(train_loader):.3f}")

def validate(model, device, train_loader, val_loader):
  for name, loader in [("train", train_loader), ("val", val_loader)]:
    correct = 0
    total = 0
    with torch.no_grad(): 
      for imgs, labels in loader:
        imgs = imgs.to(device=device)
        labels = labels.to(device=device)

        outputs = model(imgs)
        _, predicted = torch.max(outputs, dim=1) 

        total += labels.shape[0]
        correct += int((predicted == labels).sum())


    print(f"Accuracy {name}: {correct / total:.2f}")

In [None]:
# Adding memory capacity: Width
class NetWidth(nn.Module):
  def __init__(self, n_chans1):
    super().__init__()
    self.n_chans1 = n_chans1
    self.conv1 = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)
    self.act1 = nn.Tanh()

    self.pool1 = nn.MaxPool2d(2)
    self.conv2 = nn.Conv2d(n_chans1, n_chans1//2, kernel_size=3, padding=1)
    self.act2 = nn.Tanh()

    self.pool2 = nn.MaxPool2d(2)
    self.fc1 = nn.Linear(8*8*(n_chans1//2), 32)
    self.act3 = nn.Tanh()

    self.fc2 = nn.Linear(32, 2)

  def forward(self, x):
    out = self.pool1(self.act1(self.conv1(x)))
    out = self.pool2(self.act2(self.conv2(out)))
    out = out.view(-1, 8*8*(self.n_chans1//2)) # In place of nn.Flatten()
    out = self.act3(self.fc1(out))
    out = self.fc2(out)
    return out

In [None]:
model = NetWidth(n_chans1=32).to(device=device)
optimizer = optim.SGD(model.parameters(), lr=1e-2)
train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,
                                           shuffle=True)
loss_fn = nn.CrossEntropyLoss()

training_loop(
    n_epochs = 100,
    device=device,
    optimizer = optimizer,
    model = model,

    loss_fn = loss_fn,
    train_loader = train_loader,
    log_epochs = 10
)

In [None]:
train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=False)
val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64, shuffle=False)

validate(model, train_loader, val_loader)