In [None]:
# %% Deep learning - Section 9.71
#    Dropout regularisation

# This code pertains a deep learning course provided by Mike X. Cohen on Udemy:
#   > https://www.udemy.com/course/deeplearning_x
# The "base" code in this repository is adapted (with very minor modifications)
# from code developed by the course instructor (Mike X. Cohen), while the
# "exercises" and the "code challenges" contain more original solutions and
# creative input from my side. If you are interested in DL (and if you are
# reading this statement, chances are that you are), go check out the course, it
# is singularly good.


In [None]:
# %% Libraries and modules
import numpy               as np
import matplotlib.pyplot   as plt
import torch
import torch.nn            as nn
import seaborn             as sns
import copy
import torch.nn.functional as F

from torch.utils.data                 import DataLoader,TensorDataset
from sklearn.model_selection          import train_test_split
from google.colab                     import files
from torchsummary                     import summary
from matplotlib_inline.backend_inline import set_matplotlib_formats
set_matplotlib_formats('svg')


In [None]:
# %% Define dropout probability and some synthetic data

prob    = 0.5
dropout = nn.Dropout(p=prob)
x       = torch.ones(10)

# Check output (nn.dropout() scales up)
y = dropout(x)

print(x)
print(y)
print(torch.mean(y))


In [None]:
# %% Drop out is turned of when evaluating the model

# Switch dropout to evaluation mode
dropout.eval()
y = dropout(x)

print(x)
print(y)
print(torch.mean(y))


In [None]:
# %% Alternate way: use F.dropout()

# Notice how F.dropout() is not automatically switched off ...
dropout.eval()
y = F.dropout(x)

print(y)
print(torch.mean(y))

# ... but you can do that manually
dropout.eval()
y = F.dropout(x,training=False)

print(y)
print(torch.mean(y))


In [None]:
# %% The model needs to be reset after toggling into .eval() mode

# The model stays in .eval() mode until you switch it back
dropout.train()
y = dropout(x)
print(y)

dropout.eval()
y = dropout(x)
print(y)

dropout.train() # comment out to check
y = dropout(x)
print(y)
