# CNN with Optuna
Hyperparameter tuning for a simple CNN on MNIST.

In [None]:
import torch
from torch import nn, optim
from torchvision import datasets, transforms
import optuna

In [None]:
train_ds = datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)

In [None]:
class Net(nn.Module):
    def __init__(self, dropout=0.2):
        super().__init__()
        self.conv = nn.Conv2d(1, 32, 3, 1)
        self.fc1 = nn.Linear(5408, 10)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        x = torch.relu(self.conv(x))
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        return self.fc1(x)

In [None]:
def objective(trial):
    dropout = trial.suggest_float('dropout', 0.1, 0.5)
    model = Net(dropout)
    optimizer = optim.Adam(model.parameters(), lr=trial.suggest_float('lr', 1e-4, 1e-2, log=True))
    loss_fn = nn.CrossEntropyLoss()
    for epoch in range(3):
        for data, target in train_loader:
            optimizer.zero_grad()
            out = model(data)
            loss = loss_fn(out, target)
            loss.backward()
            optimizer.step()
    return loss.item()

In [None]:
study = optuna.create_study(direction='minimize')
study.optimize(objective, n_trials=10)
best_params = study.best_params
print('Best params:', best_params)

In [None]:
torch.save(Net(**best_params).state_dict(), 'best_cnn.pt')