# GoldenDropout Demo

This notebook demonstrates GoldenDropout vs standard Dropout on a simple MNIST classifier.

Run the cells to train and compare.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

class GoldenDropout(nn.Module):
    def __init__(self):
        super().__init__()
        phi = (1 + torch.sqrt(torch.tensor(5.0))) / 2
        self.keep_prob = 1.0 / phi
        self.scale = phi
        self.register_buffer('keep_prob_buffer', torch.tensor(self.keep_prob))
        self.register_buffer('scale_buffer', torch.tensor(self.scale))

    def forward(self, x):
        if self.training:
            mask = (torch.rand_like(x) < self.keep_prob_buffer).to(x.dtype)
            return x * mask * self.scale_buffer
        else:
            return x

# Simple NN
class Net(nn.Module):
    def __init__(self, dropout_type='standard'):
        super().__init__()
        self.fc1 = nn.Linear(784, 128)
        self.dropout = nn.Dropout(0.5) if dropout_type == 'standard' else GoldenDropout()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# Data
transform = transforms.ToTensor()
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False)

In [None]:
def train(net, epochs=5):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters())
    losses = []
    for epoch in range(epochs):
        running_loss = 0.0
        for data in trainloader:
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        losses.append(running_loss / len(trainloader))
    return losses

# Train standard
net_standard = Net('standard')
losses_standard = train(net_standard)

# Train golden
net_golden = Net('golden')
losses_golden = train(net_golden)

# Plot
plt.plot(losses_standard, label='Standard Dropout')
plt.plot(losses_golden, label='Golden Dropout')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
def accuracy(net, loader):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in loader:
            inputs, labels = data
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

print('Standard Accuracy:', accuracy(net_standard, testloader))
print('Golden Accuracy:', accuracy(net_golden, testloader))