<a href="https://colab.research.google.com/github/ayulockin/debugNNwithWandB/blob/master/MNIST_pytorch_wandb_LRFinder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Imports and Setups

In [1]:
!pip install wandb -q

[K     |████████████████████████████████| 1.4MB 3.5MB/s 
[K     |████████████████████████████████| 92kB 15.7MB/s 
[K     |████████████████████████████████| 102kB 15.9MB/s 
[K     |████████████████████████████████| 102kB 16.9MB/s 
[K     |████████████████████████████████| 460kB 52.4MB/s 
[K     |████████████████████████████████| 71kB 11.9MB/s 
[K     |████████████████████████████████| 71kB 9.1MB/s 
[?25h  Building wheel for shortuuid (setup.py) ... [?25l[?25hdone
  Building wheel for watchdog (setup.py) ... [?25l[?25hdone
  Building wheel for subprocess32 (setup.py) ... [?25l[?25hdone
  Building wheel for gql (setup.py) ... [?25l[?25hdone
  Building wheel for pathtools (setup.py) ... [?25l[?25hdone
  Building wheel for graphql-core (setup.py) ... [?25l[?25hdone


In [0]:
import wandb

In [3]:
!wandb login

[34m[1mwandb[0m: You can find your API key in your browser here: https://app.wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter: 69f60a7711ce6b8bbae91ac6d15e45d6b1f1430e
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[32mSuccessfully logged in to Weights & Biases![0m


In [0]:
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
import numpy as np

#### For GPU

In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


## MNIST Hand written Dataset

In [6]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.1307,), (0.3081,))])

trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

classes = ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')

  0%|          | 16384/9912422 [00:00<01:30, 109230.24it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


9920512it [00:00, 22846423.70it/s]                           


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw


32768it [00:00, 316632.92it/s]                           
0it [00:00, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


1654784it [00:00, 5356440.85it/s]                           
8192it [00:00, 127880.17it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
Processing...
Done!


## Model



In [0]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=False)
        self.conv2 = nn.Conv2d(32, 64, 3, 1, bias=False)

        self.fc1 = nn.Linear(9216, 128, bias=False)
        self.fc2 = nn.Linear(128, 10, bias=False)

    def forward(self, x):
        ## Conv 1st Block
        x = self.conv1(x)
        x = F.relu(x) 
        x = self.conv2(x)
        x = F.relu(x) 
        x = F.max_pool2d(x, 2)

        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

## Learning Rate Finder with W&B

In [0]:
from torch.optim.lr_scheduler import _LRScheduler

In [0]:
## Reference: https://github.com/davidtvs/pytorch-lr-finder/blob/14abc0b8c3edd95eefa385c2619028e73831622a/torch_lr_finder/lr_finder.py

class LRFinder(object):
    def __init__(self,model,optimizer,device=None,memory_cache=True,cache_dir=None):
        # Check if the optimizer is already attached to a scheduler
        self.optimizer = optimizer
        self.model = model
        self.history = {"lr": [], "loss": []}
        self.best_loss = None
        self.device = device    
    
    def range_test(self,
        train_loader,
        val_loader=None,
        start_lr=None,
        end_lr=10,
        num_iter=100,
        smooth_f=0.05,
        diverge_th=8,
        accumulation_steps=1,
        logwandb=False
    ):
        # Reset test results
        self.history = {"lr": [], "loss": []}
        self.best_loss = None

        # Move the model to the proper device
        self.model.to(self.device)

        # Set the starting learning rate
        if start_lr:
            self._set_learning_rate(start_lr)

        # Initialize the proper learning rate policy
        lr_schedule = ExponentialLR(self.optimizer, end_lr, num_iter)
        
        if smooth_f < 0 or smooth_f >= 1:
            raise ValueError("smooth_f is outside the range [0, 1]")

        # Create an iterator to get data batch by batch
        iter_wrapper = DataLoaderIterWrapper(train_loader)
        
        for iteration in range(num_iter):
            # Train on batch and retrieve loss
            loss = self._train_on_batch(iter_wrapper, accumulation_steps)
    
            # Update the learning rate
            lr_schedule.step()
            self.history["lr"].append(lr_schedule.get_lr()[0])

            # Track the best loss and smooth it if smooth_f is specified
            if iteration == 0:
                self.best_loss = loss
            else:
                if smooth_f > 0:
                    loss = smooth_f * loss + (1 - smooth_f) * self.history["loss"][-1]
                if loss < self.best_loss:
                    self.best_loss = loss

            # Check if the loss has diverged; if it has, stop the test
            self.history["loss"].append(loss)
            
            if logwandb:
              wandb.log({'lr': lr_schedule.get_lr()[0], 'loss': loss})

            if loss > diverge_th * self.best_loss:
                print("Stopping early, the loss has diverged")
                break

        print("Learning rate search finished")

    def _set_learning_rate(self, new_lrs):
        if not isinstance(new_lrs, list):
            new_lrs = [new_lrs] * len(self.optimizer.param_groups)
        if len(new_lrs) != len(self.optimizer.param_groups):
            raise ValueError(
                "Length of `new_lrs` is not equal to the number of parameter groups "
                + "in the given optimizer"
            )

        for param_group, new_lr in zip(self.optimizer.param_groups, new_lrs):
            param_group["lr"] = new_lr

    def _train_on_batch(self, iter_wrapper, accumulation_steps):
        self.model.train()
        total_loss = None  # for late initialization

        self.optimizer.zero_grad()
        for i in range(accumulation_steps):
            inputs, labels = iter_wrapper.get_batch()
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            outputs = self.model(inputs)
            loss = F.nll_loss(outputs, labels)

            # Loss should be averaged in each step
            loss /= accumulation_steps

            loss.backward()

            if total_loss is None:
                total_loss = loss.item()
            else:
                total_loss += loss.item()

        self.optimizer.step()

        return total_loss

    def plot(self, skip_start=10, skip_end=5, log_lr=True, show_lr=None):
        if skip_start < 0:
            raise ValueError("skip_start cannot be negative")
        if skip_end < 0:
            raise ValueError("skip_end cannot be negative")
        if show_lr is not None and not isinstance(show_lr, float):
            raise ValueError("show_lr must be float")

        # Get the data to plot from the history dictionary. Also, handle skip_end=0
        # properly so the behaviour is the expected
        lrs = self.history["lr"]
        losses = self.history["loss"]
        if skip_end == 0:
            lrs = lrs[skip_start:]
            losses = losses[skip_start:]
        else:
            lrs = lrs[skip_start:-skip_end]
            losses = losses[skip_start:-skip_end]

        # Plot loss as a function of the learning rate
        plt.plot(lrs, losses)
        if log_lr:
            plt.xscale("log")
        plt.xlabel("Learning rate")
        plt.ylabel("Loss")

        if show_lr is not None:
            plt.axvline(x=show_lr, color="red")
        plt.show()

    def get_best_lr(self):
      lrs = self.history['lr']
      losses = self.history['loss']
      return lrs[losses.index(min(losses))]

class ExponentialLR(_LRScheduler):
    def __init__(self, optimizer, end_lr, num_iter, last_epoch=-1):
        self.end_lr = end_lr
        self.num_iter = num_iter
        super(ExponentialLR, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        curr_iter = self.last_epoch + 1
        r = curr_iter / self.num_iter
        return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs]

class DataLoaderIterWrapper(object):
    def __init__(self, data_loader, auto_reset=True):
        self.data_loader = data_loader
        self.auto_reset = auto_reset
        self._iterator = iter(data_loader)

    def __next__(self):
        # Get a new set of inputs and labels
        try:
            inputs, labels = next(self._iterator)
        except StopIteration:
            if not self.auto_reset:
                raise
            self._iterator = iter(self.data_loader)
            inputs, labels = next(self._iterator)

        return inputs, labels

    def get_batch(self):
        return next(self)

## Train-Test loop

In [0]:
def train(model, device, train_loader, optimizer, epoch, steps_per_epoch=20):
  # Switch model to training mode. This is necessary for layers like dropout, batchnorm etc which behave differently in training and evaluation mode
  model.train()
  train_total = 0
  train_correct = 0

  # We loop over the data iterator, and feed the inputs to the network and adjust the weights.
  for batch_idx, (data, target) in enumerate(train_loader, start=0):
    if batch_idx > steps_per_epoch:
      break
    # Load the input features and labels from the training dataset
    data, target = data.to(device), target.to(device)
    
    # Reset the gradients to 0 for all learnable weight parameters
    optimizer.zero_grad()
    
    # Forward pass: Pass image data from training dataset, make predictions about class image belongs to (0-9 in this case)
    output = model(data)
    
    # Define our loss function, and compute the loss
    loss = F.nll_loss(output, target)

    scores, predictions = torch.max(output.data, 1)
    train_total += target.size(0)
    train_correct += int(sum(predictions == target))
            
    # Backward pass: compute the gradients of the loss w.r.t. the model's parameters
    loss.backward()
    
    # Update the neural network weights
    optimizer.step()

  acc = round((train_correct / train_total) * 100, 2)
  print('Epoch [{}], Loss: {}, Accuracy: {}, '.format(epoch, loss.item(), acc), end='')
  wandb.log({'Train Loss': loss.item(), 'Train Accuracy': acc})


In [0]:
def test(model, device, test_loader, classes):
  # Switch model to evaluation mode. This is necessary for layers like dropout, batchnorm etc which behave differently in training and evaluation mode
  model.eval()
  
  test_loss = 0
  test_total = 0
  test_correct = 0

  with torch.no_grad():
      for data, target in test_loader:
          # Load the input features and labels from the test dataset
          data, target = data.to(device), target.to(device)
          
          # Make predictions: Pass image data from test dataset, make predictions about class image belongs to (0-9 in this case)
          output = model(data)
          
          # Compute the loss sum up batch loss
          test_loss += F.nll_loss(output, target, reduction='sum').item()
          
          scores, predictions = torch.max(output.data, 1)
          test_total += target.size(0)
          test_correct += int(sum(predictions == target))
          
  acc = round((test_correct / test_total) * 100, 2)
  print(' Test_loss: {}, Test_accuracy: {}'.format(test_loss/test_total, acc))
  wandb.log({'Test Loss': test_loss/test_total, 'Test Accuracy': acc})


## W&B and Model Init

In [28]:
wandb.init(project='lrfinder')

W&B Run: https://app.wandb.ai/ayush-thakur/lrfinder/runs/y4rs46wh

In [27]:
net = Net().to(device)
print(net)

optimizer = optim.Adam(net.parameters(), lr=1e-9)

Net(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), bias=False)
  (fc1): Linear(in_features=9216, out_features=128, bias=False)
  (fc2): Linear(in_features=128, out_features=10, bias=False)
)


## LR Finder



In [29]:
lr_finder = LRFinder(net, optimizer, device)
lr_finder.range_test(trainloader, end_lr=10, num_iter=100, logwandb=True)

Learning rate search finished


## Train the model

In [31]:
del net
net = Net().to(device)
optimizer = optim.Adam(net.parameters())

wandb.watch(net, log='all')

for epoch in range(10):
  train(net, device, trainloader, optimizer, epoch)
  test(net, device, testloader, classes)

print('Finished Training')

Epoch [0], Loss: 0.23745277523994446, Accuracy: 68.15,  Test_loss: 0.4461846143245697, Test_accuracy: 86.56
Epoch [1], Loss: 0.4042225480079651, Accuracy: 87.13,  Test_loss: 0.34654149870872497, Test_accuracy: 90.29
Epoch [2], Loss: 0.27657175064086914, Accuracy: 90.77,  Test_loss: 0.2819734499692917, Test_accuracy: 91.15
Epoch [3], Loss: 0.30954664945602417, Accuracy: 92.93,  Test_loss: 0.18147338354587556, Test_accuracy: 94.77
Epoch [4], Loss: 0.33351895213127136, Accuracy: 94.64,  Test_loss: 0.20199446659088136, Test_accuracy: 93.5
Epoch [5], Loss: 0.07027415931224823, Accuracy: 95.16,  Test_loss: 0.13558567674160005, Test_accuracy: 96.06
Epoch [6], Loss: 0.13497908413410187, Accuracy: 95.91,  Test_loss: 0.13018470935821533, Test_accuracy: 95.85
Epoch [7], Loss: 0.40916383266448975, Accuracy: 96.13,  Test_loss: 0.11119639990329742, Test_accuracy: 96.59
Epoch [8], Loss: 0.08424792438745499, Accuracy: 96.95,  Test_loss: 0.10213404145240784, Test_accuracy: 97.02
Epoch [9], Loss: 0.0587