## Dropout

Dropout is a regularization method to avoid overfitting.

![](./dropout.png)

$r_j^{(l)}\sim Bernoulli(p)$ where $p$ represents the probability of activating a certain neuron.

## Dropout during training

1. a mask $u$ is randomly generated to determine which neurons are activated while which are frozen. After applying the mask, we get a subnetwork;
2. use a minibatch to do feedforward propagation, calculate the loss and do backward propagation.

## Dropout during testing

1. Do feedforward propagation for the whole network;
2. for each neuron, the output need be rescaled,
$$a =a*p$$
to make the expected output unchanged.

## Perspectives on dropout

### Hinton's view on dropout

Hinton sees the final model as an ensemble of subnets.

All subnets shares weights as each subnet inherit a subset of the weights of the whole network. Most subnets are not trained explictly, but weight sharing makes the remianing subnets has a good weight setting.

Assume for each subnet, the objective function is $J(\theta, u)$, then the final objective for this whole ensemble model is  the expected loss 
$$E_{u\sim q(u)}J(u, \theta)$$


**Dropout is an approximation to geometric model averaing.**

### Dropout on feature selection

Dropout encourages the network to learn from features not as effective as the other neuronsâ€™ features.


In [7]:
import torch
from torch import nn
import torchvision
import sys
sys.path.append("../dlutils")
import importlib
import model
import dataset
import train
import loss
importlib.reload(model)
importlib.reload(dataset)
importlib.reload(train)
importlib.reload(loss)
from dataset import load_fashion_mnist_dataset
from train import train_3ch

## Load dataset

In [8]:
batch_size=256
train_loader,test_loader = load_fashion_mnist_dataset(batch_size=batch_size)

## The Model

In [9]:
from model import Dropout

p1_zeroed, p2_zeroed = 0.2, 0.5

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Flatten(), torch.nn.Linear(784, 256), torch.nn.ReLU(),
            # Add a dropout layer after the first fully connected layer
            Dropout(p1_zeroed),
            torch.nn.Linear(256, 256), torch.nn.ReLU(),
            # Add a dropout layer after the second fully connected layer
            Dropout(p2_zeroed), 
            torch.nn.Linear(256, 10))
        
    def forward(self, X):
        return self.net(X)
    
    def train(self):
        for subnet in self.modules():
            if isinstance(subnet, Dropout):
                subnet.train = True
            
    def eval(self):
        for subnet in self.modules():
            if isinstance(subnet, Dropout):
                subnet.train = False

net = Net()
net.train()

## Traning

In [11]:
num_epochs = 10
lr = 0.1
loss = torch.nn.CrossEntropyLoss()
trainer = torch.optim.SGD(net.parameters(), lr=lr)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from train import train_3ch
train_3ch(net, loss, num_epochs, train_loader, optimizer=trainer, test_loader=test_loader, device=device)


epoch 0, training loss 0.003167, training accuracy 0.695267, testing loss 0.003320, testing accuracy 0.689200
epoch 1, training loss 0.002960, training accuracy 0.719917, testing loss 0.003045, testing accuracy 0.712200
epoch 2, training loss 0.002778, training accuracy 0.737333, testing loss 0.002902, testing accuracy 0.730300
epoch 3, training loss 0.002679, training accuracy 0.750467, testing loss 0.002807, testing accuracy 0.743100
epoch 4, training loss 0.002594, training accuracy 0.760617, testing loss 0.002679, testing accuracy 0.757700
epoch 5, training loss 0.002465, training accuracy 0.774533, testing loss 0.002602, testing accuracy 0.756700
epoch 6, training loss 0.002407, training accuracy 0.779550, testing loss 0.002487, testing accuracy 0.778600
epoch 7, training loss 0.002366, training accuracy 0.782867, testing loss 0.002488, testing accuracy 0.776300
epoch 8, training loss 0.002284, training accuracy 0.792217, testing loss 0.002411, testing accuracy 0.780700
epoch 9, t