In [1]:
import numpy as np
import torch
import torch.nn as nn
import torchvision
import matplotlib.pyplot as plt
from tqdm import tqdm
import plotly.io

import optuna

In [2]:
transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=(0.1307, ), std=(0.3081, ), inplace=True)  # mean and std are from train data
])

train_dataloader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(
        root='data', 
        train=True,
        transform=transforms, 
        download=True
    ),
    batch_size=100,
    shuffle=True,
    pin_memory=True,
    num_workers=1
)

val_dataloader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(
        root='data', 
        train=False,
        transform=transforms, 
        download=True
    ),
    batch_size=100,
    shuffle=False,
    pin_memory=True,
    num_workers=1
)

In [3]:
class CNN(nn.Module):
    def __init__(self, num_layers, num_filters, kernel_size):
        super().__init__()
        layers = []
        in_channels = 1
        self.flops = 0  # Initialize FLOP count

        H, W = 28, 28

        for _ in range(num_layers):
            layers.append(nn.Conv2d(in_channels, num_filters, kernel_size, padding=kernel_size // 2))
            layers.append(nn.ReLU())
            layers.append(nn.AvgPool2d(2))
            layers.append(nn.Dropout2d(p=0.2))
            
            # Compute FLOPs for this convolutional layer
            self.flops += 2 * in_channels * num_filters * (kernel_size ** 2) * H * W
            H, W = H // 2, W // 2  # Update size due to pooling
            in_channels = num_filters  # Update input channels for next layer
            
        #self.global_pooling = nn.AdaptiveAvgPool2d()
        
        self.conv = nn.Sequential(*layers)
        self.fc = nn.Linear(num_filters * H * W, 10)

        # Compute FLOPs for the fully connected layer
        self.flops += 2 * (num_filters * H * W) * 10  # Fully connected FLOPs

    def forward(self, x):
        x = self.conv(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

In [4]:
model = CNN(num_layers=4, num_filters=3, kernel_size=3)
pn = 0
for p in model.parameters():
    pn += p.numel()
pn

322

In [5]:
model

CNN(
  (conv): Sequential(
    (0): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (3): Dropout2d(p=0.2, inplace=False)
    (4): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ReLU()
    (6): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (7): Dropout2d(p=0.2, inplace=False)
    (8): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU()
    (10): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (11): Dropout2d(p=0.2, inplace=False)
    (12): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU()
    (14): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (15): Dropout2d(p=0.2, inplace=False)
  )
  (fc): Linear(in_features=3, out_features=10, bias=True)
)

In [6]:
def objective(trial):
    num_layers = trial.suggest_int("num_layers", 1, 4)
    num_filters = trial.suggest_int("num_filters", 1, 64)
    kernel_size = trial.suggest_categorical("kernel_size", (3, 5, 7))
    num_epochs = trial.suggest_int("num_epochs", 3, 10)
    
    model = CNN(num_layers=num_layers, num_filters=num_filters, kernel_size=kernel_size)
    
    flops = model.flops
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model.to(device)
    
    optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    
    model.train()
    
    for epoch in range(num_epochs):
        for X, y in train_dataloader:
            optimizer.zero_grad()
            y_pred = model(X.to(device))
            loss = criterion(y_pred, y.to(device))
            
            loss.backward()
            
            optimizer.step()
            
    model.eval()
    with torch.no_grad():
        total = 0
        correct = 0
        for X, y in val_dataloader:
            y_pred = model(X.to(device))
            label_pred = y_pred.argmax(-1)
            total += y_pred.shape[0]
            correct += (label_pred == y.to(device)).sum().item()
    
    accuracy = correct / total
    
    return accuracy, flops

In [7]:
sampler = optuna.samplers.TPESampler(
    multivariate=True, 
    group=True,
    n_startup_trials=50,
    seed=0
)



In [8]:
study = optuna.create_study(
    directions=['maximize', 'minimize'],
    sampler=sampler
)

[I 2025-03-07 00:19:16,257] A new study created in memory with name: no-name-bc210a34-4d12-4614-bbf9-285438ecf1a8


In [9]:
optuna.logging.set_verbosity(optuna.logging.DEBUG)

In [10]:
study.optimize(
    func=objective, 
    n_trials=1000, 
    n_jobs=10,
    gc_after_trial=True
)

[I 2025-03-07 00:21:23,892] Trial 2 finished with values: [0.9878, 11908960.0] and parameters: {'num_layers': 2, 'num_filters': 56, 'kernel_size': 3, 'num_epochs': 3}.
[I 2025-03-07 00:23:07,059] Trial 4 finished with values: [0.9776, 270480.0] and parameters: {'num_layers': 1, 'num_filters': 15, 'kernel_size': 3, 'num_epochs': 6}.
[I 2025-03-07 00:23:07,235] Trial 6 finished with values: [0.9908, 36497390.0] and parameters: {'num_layers': 3, 'num_filters': 53, 'kernel_size': 5, 'num_epochs': 6}.
[I 2025-03-07 00:23:42,059] Trial 7 finished with values: [0.9915, 9165456.0] and parameters: {'num_layers': 3, 'num_filters': 18, 'kernel_size': 7, 'num_epochs': 7}.
[I 2025-03-07 00:24:18,925] Trial 9 finished with values: [0.9643, 1240680.0] and parameters: {'num_layers': 4, 'num_filters': 15, 'kernel_size': 3, 'num_epochs': 8}.
[I 2025-03-07 00:24:21,120] Trial 3 finished with values: [0.9835, 1136016.0] and parameters: {'num_layers': 1, 'num_filters': 63, 'kernel_size': 3, 'num_epochs': 8

In [14]:
study.best_trials;

In [19]:
fig = optuna.visualization.plot_pareto_front(
    study
)

fig.update_layout(
    xaxis_type="log",
    xaxis_range=[np.log10(0.95), 0],
    yaxis_type="log",
    width=1000,
    height=600
#    yaxis_range=[]
);

In [20]:
fig.show()