In [1]:
# import libraries
import torch
import torch.nn as nn
import torch.nn.functional as F

# Using dropout

In [7]:
# define a dropout instance and make some data
prob = .5

dropout = nn.Dropout(p=prob)
x = torch.ones(10)

# let's see what dropout returns
y = dropout(x)
print(x, "\n")
print(y, "\n") # PyTorch uses the scale up method. That's why some data scaled up to 2 from 1
print(y*(1-prob), "\n") # the equation is for scaling down. For scaling up, "Wq" equ is used 
print(torch.mean(y))

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]) 

tensor([0., 0., 0., 2., 0., 0., 2., 2., 2., 0.]) 

tensor([0., 0., 0., 1., 0., 0., 1., 1., 1., 0.]) 

tensor(0.8000)


In [3]:
# dropout is turned off when evaluating the model
dropout.eval() # switing the model to evaluation mode
y = dropout(x)
print(y, "\n") # though dropout is applied, because of evaluation mode, dropout layer is ignored!
print(torch.mean(y))

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]) 

tensor(1.)


In [9]:
# annoyingly, F.dropout() is not deactivated in eval mode:

dropout.eval()
y = F.dropout(x)
print(y) # NOTICE: though evaluation mode is activated, dropout is applied
print(torch.mean(y))

tensor([0., 0., 0., 0., 2., 0., 0., 2., 0., 0.])
tensor(0.4000)


In [5]:
# but you can manually switch it off
# dropout.eval()
y = F.dropout(x, training = False)

print(y) # NOTICE: See the difference between this cell and above cell
print(torch.mean(y))

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
tensor(1.)


In [11]:
# the model needs to be reset after toggling into eval mode

dropout.train() # switch to the training mode
y = dropout(x)
print(y, "\n") # with dropout


dropout.eval() # switch to the evaluation mode
y = dropout(x)
print(y, "\n") # without dropout


# dropout.train()
y = dropout(x) # because the above line commented out, evaluation mode is still activcated
print(y) # still w/o dropout ;)

tensor([0., 2., 2., 0., 2., 2., 0., 2., 2., 2.]) 

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]) 

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
