In [None]:
!pip install -q pyswarms

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/104.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m [32m102.4/104.1 kB[0m [31m4.1 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m104.1/104.1 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
import pyswarms as ps

# Set device
# Get gpu, mps or cpu device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


In [None]:
# Define the model
class MyModel(nn.Module):
    def __init__(self, num_classes=10, num_layers=5, neurons_per_layer=256):
        super(MyModel, self).__init__()
        layers = []
        layers.append(nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=0))
        layers.append(nn.BatchNorm2d(96))
        layers.append(nn.ReLU())
        layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

        for _ in range(num_layers):
            layers.append(nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1))
            layers.append(nn.BatchNorm2d(96))
            layers.append(nn.ReLU())
            layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

        # Adjust input size for the linear layer based on the output size of preceding layers
        dummy_input = torch.randn(1, 3, 128, 128)
        self.layers = nn.Sequential(*layers)
        self._to_linear = None
        self.convs(dummy_input)

        self.fc = nn.Sequential(
            nn.Linear(self._to_linear, neurons_per_layer),
            nn.ReLU(),
            nn.Linear(neurons_per_layer, num_classes)
        )

    def convs(self, x):
        x = self.layers(x)
        if self._to_linear is None:
            self._to_linear = x[0].shape[0] * x[0].shape[1] * x[0].shape[2]
        return x

    def forward(self, x):
        x = self.layers(x)
        x = x.view(-1, self._to_linear)
        x = self.fc(x)
        return x

# Function to evaluate the model
def evaluate_model(params):
    num_layers = int(params[0][0])  # Corrected indexing
    neurons_per_layer = int(params[0][1])  # Corrected indexing

    # Define hyperparameters
    num_classes = 10
    learning_rate = 0.001
    num_epochs = 20
    batch_size = 72

    # Define model
    model = MyModel(num_classes, num_layers, neurons_per_layer).to(device)

    # Define loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Load CIFAR10 data
    transform = transforms.Compose([
        transforms.Resize((128,128)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    train_data = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_data = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

    # Define data loaders
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)

    # Training the model
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        # Validation
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        # Print progress
        print('Epoch [%d/%d], Loss: %.4f, Accuracy: %.2f %%' % (epoch+1, num_epochs, running_loss / len(train_loader), 100 * correct / total))

    # Return validation accuracy
    return correct / total

# Define the search space
search_space = {
    'num_layers': (1, 5),  # Range for number of layers
    'neurons_per_layer': (16, 256),  # Range for neurons per layer
}

# Define the bounds for the parameters
lb = [search_space['num_layers'][0], search_space['neurons_per_layer'][0]]
ub = [search_space['num_layers'][1], search_space['neurons_per_layer'][1]]

# Initialize the swarm
num_particles = 10
dim = 2  # Dimensionality of the search space
options = {'c1': 0.5, 'c2': 0.3, 'w': 0.9}  # PSO hyperparameters
optimizer = ps.single.GlobalBestPSO(n_particles=num_particles, dimensions=dim, options=options, bounds=(lb, ub))

# Perform optimization
best_params, _ = optimizer.optimize(evaluate_model, iters=10)
print("Best hyperparameters:", best_params)

2024-04-03 18:54:54,035 - pyswarms.single.global_best - INFO - Optimize for 10 iters with {'c1': 0.5, 'c2': 0.3, 'w': 0.9}
pyswarms.single.global_best:   0%|          |0/10

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz



  0%|          | 0/170498071 [00:00<?, ?it/s][A
  0%|          | 65536/170498071 [00:00<08:08, 348547.19it/s][A
  0%|          | 196608/170498071 [00:00<03:43, 763469.39it/s][A
  0%|          | 458752/170498071 [00:00<02:27, 1155736.04it/s][A
  1%|          | 1048576/170498071 [00:00<01:05, 2568455.00it/s][A
  1%|▏         | 2293760/170498071 [00:00<00:30, 5516727.26it/s][A
  3%|▎         | 4521984/170498071 [00:00<00:15, 10517721.14it/s][A
  4%|▍         | 7536640/170498071 [00:00<00:09, 16359675.99it/s][A
  6%|▌         | 10452992/170498071 [00:00<00:08, 19513657.74it/s][A
  8%|▊         | 13598720/170498071 [00:01<00:06, 22842537.43it/s][A
 10%|▉         | 16842752/170498071 [00:01<00:06, 25606296.76it/s][A
 12%|█▏        | 20086784/170498071 [00:01<00:05, 27510297.63it/s][A
 14%|█▎        | 23232512/170498071 [00:01<00:05, 28630207.65it/s][A
 15%|█▌        | 26279936/170498071 [00:01<00:04, 29166772.56it/s][A
 17%|█▋        | 29458432/170498071 [00:01<00:04, 29916977

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Epoch [1/20], Loss: 1.3248, Accuracy: 63.32 %
Epoch [2/20], Loss: 0.9840, Accuracy: 65.09 %
Epoch [3/20], Loss: 0.8646, Accuracy: 68.65 %
Epoch [4/20], Loss: 0.7871, Accuracy: 71.74 %
Epoch [5/20], Loss: 0.7262, Accuracy: 69.71 %
Epoch [6/20], Loss: 0.6730, Accuracy: 73.61 %
Epoch [7/20], Loss: 0.6303, Accuracy: 72.86 %
Epoch [8/20], Loss: 0.5926, Accuracy: 72.58 %
Epoch [9/20], Loss: 0.5538, Accuracy: 73.95 %
Epoch [10/20], Loss: 0.5139, Accuracy: 75.30 %
Epoch [11/20], Loss: 0.4853, Accuracy: 75.31 %
Epoch [12/20], Loss: 0.4458, Accuracy: 76.20 %
Epoch [13/20], Loss: 0.4119, Accuracy: 76.48 %
Epoch [14/20], Loss: 0.3829, Accuracy: 74.30 %
Epoch [15/20], Loss: 0.3570, Accuracy: 74.67 %
Epoch [16/20], Loss: 0.3367, Accuracy: 75.39 %
Epoch [17/20], Loss: 0.3000, Accuracy: 75.28 %
Epoch [18/20], Loss: 0.2737, Accuracy: 76.66 %
Epoch [19/20], Loss: 0.2532, Accuracy: 76.17 %


pyswarms.single.global_best:  10%|█         |1/10, best_cost=0.746

Epoch [20/20], Loss: 0.2386, Accuracy: 74.55 %
Files already downloaded and verified
Files already downloaded and verified
Epoch [1/20], Loss: 1.2552, Accuracy: 63.89 %
Epoch [2/20], Loss: 0.8626, Accuracy: 71.26 %
Epoch [3/20], Loss: 0.7159, Accuracy: 70.37 %
Epoch [4/20], Loss: 0.6256, Accuracy: 71.89 %
Epoch [5/20], Loss: 0.5546, Accuracy: 76.82 %
Epoch [6/20], Loss: 0.4914, Accuracy: 77.04 %
Epoch [7/20], Loss: 0.4488, Accuracy: 78.58 %
Epoch [8/20], Loss: 0.3972, Accuracy: 78.66 %
Epoch [9/20], Loss: 0.3623, Accuracy: 79.47 %
Epoch [10/20], Loss: 0.3152, Accuracy: 78.92 %
Epoch [11/20], Loss: 0.2870, Accuracy: 78.84 %
Epoch [12/20], Loss: 0.2528, Accuracy: 80.86 %
Epoch [13/20], Loss: 0.2186, Accuracy: 79.37 %
Epoch [14/20], Loss: 0.2015, Accuracy: 79.95 %
Epoch [15/20], Loss: 0.1837, Accuracy: 79.75 %
Epoch [16/20], Loss: 0.1515, Accuracy: 80.10 %
Epoch [17/20], Loss: 0.1471, Accuracy: 78.84 %
Epoch [18/20], Loss: 0.1328, Accuracy: 80.36 %
Epoch [19/20], Loss: 0.1152, Accuracy: 8

pyswarms.single.global_best:  20%|██        |2/10, best_cost=0.746

Epoch [20/20], Loss: 0.1027, Accuracy: 79.56 %
Files already downloaded and verified
Files already downloaded and verified
Epoch [1/20], Loss: 1.2710, Accuracy: 62.95 %
Epoch [2/20], Loss: 0.8803, Accuracy: 69.51 %
Epoch [3/20], Loss: 0.7363, Accuracy: 75.01 %
Epoch [4/20], Loss: 0.6432, Accuracy: 72.80 %
Epoch [5/20], Loss: 0.5702, Accuracy: 74.60 %
Epoch [6/20], Loss: 0.5093, Accuracy: 75.66 %
Epoch [7/20], Loss: 0.4707, Accuracy: 77.15 %
Epoch [8/20], Loss: 0.4135, Accuracy: 79.09 %
Epoch [9/20], Loss: 0.3717, Accuracy: 78.15 %
Epoch [10/20], Loss: 0.3378, Accuracy: 77.99 %
Epoch [11/20], Loss: 0.3022, Accuracy: 79.17 %
Epoch [12/20], Loss: 0.2724, Accuracy: 79.67 %
Epoch [13/20], Loss: 0.2407, Accuracy: 78.86 %
Epoch [14/20], Loss: 0.2218, Accuracy: 78.38 %
Epoch [15/20], Loss: 0.1969, Accuracy: 77.45 %
Epoch [16/20], Loss: 0.1736, Accuracy: 78.45 %
Epoch [17/20], Loss: 0.1597, Accuracy: 78.84 %
Epoch [18/20], Loss: 0.1414, Accuracy: 79.56 %
Epoch [19/20], Loss: 0.1345, Accuracy: 7

pyswarms.single.global_best:  30%|███       |3/10, best_cost=0.746

Epoch [20/20], Loss: 0.1189, Accuracy: 79.07 %
Files already downloaded and verified
Files already downloaded and verified
Epoch [1/20], Loss: 1.2377, Accuracy: 63.04 %
Epoch [2/20], Loss: 0.8637, Accuracy: 71.89 %
Epoch [3/20], Loss: 0.7211, Accuracy: 72.60 %
Epoch [4/20], Loss: 0.6271, Accuracy: 75.22 %
Epoch [5/20], Loss: 0.5561, Accuracy: 77.87 %
Epoch [6/20], Loss: 0.4922, Accuracy: 76.65 %
Epoch [7/20], Loss: 0.4392, Accuracy: 78.06 %
Epoch [8/20], Loss: 0.3936, Accuracy: 78.31 %
Epoch [9/20], Loss: 0.3581, Accuracy: 77.75 %
Epoch [10/20], Loss: 0.3186, Accuracy: 79.50 %
Epoch [11/20], Loss: 0.2817, Accuracy: 79.14 %
Epoch [12/20], Loss: 0.2452, Accuracy: 79.25 %
Epoch [13/20], Loss: 0.2247, Accuracy: 80.10 %
Epoch [14/20], Loss: 0.1968, Accuracy: 79.78 %
Epoch [15/20], Loss: 0.1763, Accuracy: 80.66 %
Epoch [16/20], Loss: 0.1572, Accuracy: 79.49 %
Epoch [17/20], Loss: 0.1437, Accuracy: 80.15 %
Epoch [18/20], Loss: 0.1254, Accuracy: 80.33 %
Epoch [19/20], Loss: 0.1215, Accuracy: 7

pyswarms.single.global_best:  40%|████      |4/10, best_cost=0.746

Epoch [20/20], Loss: 0.1069, Accuracy: 80.14 %
Files already downloaded and verified
Files already downloaded and verified
Epoch [1/20], Loss: 1.3589, Accuracy: 62.21 %
Epoch [2/20], Loss: 1.0063, Accuracy: 64.94 %
Epoch [3/20], Loss: 0.8722, Accuracy: 69.29 %
Epoch [4/20], Loss: 0.7943, Accuracy: 71.58 %
Epoch [5/20], Loss: 0.7257, Accuracy: 72.57 %
Epoch [6/20], Loss: 0.6685, Accuracy: 75.04 %
Epoch [7/20], Loss: 0.6205, Accuracy: 70.77 %
Epoch [8/20], Loss: 0.5765, Accuracy: 73.93 %
Epoch [9/20], Loss: 0.5393, Accuracy: 74.81 %
Epoch [10/20], Loss: 0.4995, Accuracy: 74.44 %
Epoch [11/20], Loss: 0.4595, Accuracy: 74.74 %
Epoch [12/20], Loss: 0.4334, Accuracy: 74.10 %
Epoch [13/20], Loss: 0.3954, Accuracy: 75.64 %
Epoch [14/20], Loss: 0.3695, Accuracy: 74.57 %
Epoch [15/20], Loss: 0.3445, Accuracy: 76.14 %
Epoch [16/20], Loss: 0.3121, Accuracy: 76.06 %
Epoch [17/20], Loss: 0.2966, Accuracy: 76.13 %
Epoch [18/20], Loss: 0.2642, Accuracy: 76.19 %
Epoch [19/20], Loss: 0.2432, Accuracy: 7

pyswarms.single.global_best:  50%|█████     |5/10, best_cost=0.746

Epoch [20/20], Loss: 0.2207, Accuracy: 76.82 %
Files already downloaded and verified
Files already downloaded and verified
Epoch [1/20], Loss: 1.3618, Accuracy: 57.81 %
Epoch [2/20], Loss: 0.9917, Accuracy: 65.26 %
Epoch [3/20], Loss: 0.8592, Accuracy: 68.21 %
Epoch [4/20], Loss: 0.7857, Accuracy: 68.26 %
Epoch [5/20], Loss: 0.7250, Accuracy: 72.20 %
Epoch [6/20], Loss: 0.6746, Accuracy: 71.85 %
Epoch [7/20], Loss: 0.6318, Accuracy: 74.21 %
Epoch [8/20], Loss: 0.5915, Accuracy: 73.74 %
Epoch [9/20], Loss: 0.5459, Accuracy: 74.77 %
Epoch [10/20], Loss: 0.5151, Accuracy: 73.66 %
Epoch [11/20], Loss: 0.4824, Accuracy: 74.89 %
Epoch [12/20], Loss: 0.4510, Accuracy: 74.96 %
Epoch [13/20], Loss: 0.4153, Accuracy: 76.17 %
Epoch [14/20], Loss: 0.3850, Accuracy: 74.85 %
Epoch [15/20], Loss: 0.3606, Accuracy: 75.49 %
Epoch [16/20], Loss: 0.3367, Accuracy: 75.07 %
Epoch [17/20], Loss: 0.3105, Accuracy: 76.01 %
Epoch [18/20], Loss: 0.2823, Accuracy: 74.07 %
Epoch [19/20], Loss: 0.2682, Accuracy: 7

pyswarms.single.global_best:  60%|██████    |6/10, best_cost=0.746

Epoch [20/20], Loss: 0.2527, Accuracy: 74.69 %
Files already downloaded and verified
Files already downloaded and verified
Epoch [1/20], Loss: 1.2695, Accuracy: 62.80 %
Epoch [2/20], Loss: 0.8590, Accuracy: 71.12 %
Epoch [3/20], Loss: 0.7049, Accuracy: 71.89 %
Epoch [4/20], Loss: 0.5997, Accuracy: 74.87 %
Epoch [5/20], Loss: 0.5280, Accuracy: 76.95 %
Epoch [6/20], Loss: 0.4600, Accuracy: 76.61 %
Epoch [7/20], Loss: 0.4077, Accuracy: 79.89 %
Epoch [8/20], Loss: 0.3589, Accuracy: 77.04 %
Epoch [9/20], Loss: 0.3150, Accuracy: 78.53 %
Epoch [10/20], Loss: 0.2800, Accuracy: 77.31 %
Epoch [11/20], Loss: 0.2489, Accuracy: 80.25 %
Epoch [12/20], Loss: 0.2174, Accuracy: 78.85 %
Epoch [13/20], Loss: 0.1987, Accuracy: 80.32 %
Epoch [14/20], Loss: 0.1736, Accuracy: 78.70 %
Epoch [15/20], Loss: 0.1545, Accuracy: 80.21 %
Epoch [16/20], Loss: 0.1442, Accuracy: 78.50 %
Epoch [17/20], Loss: 0.1286, Accuracy: 80.51 %
Epoch [18/20], Loss: 0.1181, Accuracy: 80.84 %
Epoch [19/20], Loss: 0.1119, Accuracy: 7

pyswarms.single.global_best:  70%|███████   |7/10, best_cost=0.746

Epoch [20/20], Loss: 0.1014, Accuracy: 79.08 %
Files already downloaded and verified
Files already downloaded and verified
Epoch [1/20], Loss: 1.2271, Accuracy: 64.42 %
Epoch [2/20], Loss: 0.8341, Accuracy: 71.70 %
Epoch [3/20], Loss: 0.6903, Accuracy: 76.03 %
Epoch [4/20], Loss: 0.5907, Accuracy: 74.35 %
Epoch [5/20], Loss: 0.5157, Accuracy: 76.95 %
Epoch [6/20], Loss: 0.4503, Accuracy: 78.51 %
Epoch [7/20], Loss: 0.4004, Accuracy: 78.36 %
Epoch [8/20], Loss: 0.3517, Accuracy: 79.69 %
Epoch [9/20], Loss: 0.3107, Accuracy: 79.83 %
Epoch [10/20], Loss: 0.2656, Accuracy: 80.15 %
Epoch [11/20], Loss: 0.2350, Accuracy: 79.78 %
Epoch [12/20], Loss: 0.2084, Accuracy: 79.83 %
Epoch [13/20], Loss: 0.1837, Accuracy: 80.62 %
Epoch [14/20], Loss: 0.1595, Accuracy: 80.37 %
Epoch [15/20], Loss: 0.1443, Accuracy: 80.63 %
Epoch [16/20], Loss: 0.1337, Accuracy: 79.57 %
Epoch [17/20], Loss: 0.1162, Accuracy: 80.27 %
Epoch [18/20], Loss: 0.1113, Accuracy: 80.86 %
Epoch [19/20], Loss: 0.0999, Accuracy: 8

pyswarms.single.global_best:  80%|████████  |8/10, best_cost=0.746

Epoch [20/20], Loss: 0.0887, Accuracy: 81.20 %
Files already downloaded and verified
Files already downloaded and verified
Epoch [1/20], Loss: 1.2362, Accuracy: 65.50 %
Epoch [2/20], Loss: 0.8312, Accuracy: 68.18 %
Epoch [3/20], Loss: 0.6874, Accuracy: 74.01 %
Epoch [4/20], Loss: 0.5823, Accuracy: 76.20 %
Epoch [5/20], Loss: 0.5143, Accuracy: 77.91 %
Epoch [6/20], Loss: 0.4556, Accuracy: 79.21 %
Epoch [7/20], Loss: 0.3971, Accuracy: 79.16 %
Epoch [8/20], Loss: 0.3524, Accuracy: 78.14 %
Epoch [9/20], Loss: 0.3173, Accuracy: 78.51 %
Epoch [10/20], Loss: 0.2787, Accuracy: 80.27 %
Epoch [11/20], Loss: 0.2442, Accuracy: 81.24 %
Epoch [12/20], Loss: 0.2149, Accuracy: 80.88 %
Epoch [13/20], Loss: 0.1916, Accuracy: 78.56 %
Epoch [14/20], Loss: 0.1726, Accuracy: 80.00 %
Epoch [15/20], Loss: 0.1541, Accuracy: 79.61 %
Epoch [16/20], Loss: 0.1379, Accuracy: 80.47 %
Epoch [17/20], Loss: 0.1266, Accuracy: 80.38 %
Epoch [18/20], Loss: 0.1149, Accuracy: 79.90 %
Epoch [19/20], Loss: 0.1100, Accuracy: 8

pyswarms.single.global_best:  90%|█████████ |9/10, best_cost=0.746

Epoch [20/20], Loss: 0.0961, Accuracy: 79.73 %
Files already downloaded and verified
Files already downloaded and verified
Epoch [1/20], Loss: 1.2648, Accuracy: 66.25 %
Epoch [2/20], Loss: 0.8689, Accuracy: 68.13 %
Epoch [3/20], Loss: 0.7006, Accuracy: 71.95 %
Epoch [4/20], Loss: 0.6043, Accuracy: 76.23 %
Epoch [5/20], Loss: 0.5242, Accuracy: 77.43 %
Epoch [6/20], Loss: 0.4657, Accuracy: 78.18 %
Epoch [7/20], Loss: 0.4095, Accuracy: 77.20 %
Epoch [8/20], Loss: 0.3627, Accuracy: 78.77 %
Epoch [9/20], Loss: 0.3187, Accuracy: 77.55 %
Epoch [10/20], Loss: 0.2841, Accuracy: 77.75 %
Epoch [11/20], Loss: 0.2496, Accuracy: 79.24 %
Epoch [12/20], Loss: 0.2202, Accuracy: 79.92 %
Epoch [13/20], Loss: 0.1992, Accuracy: 79.93 %
Epoch [14/20], Loss: 0.1740, Accuracy: 81.09 %
Epoch [15/20], Loss: 0.1526, Accuracy: 79.80 %
Epoch [16/20], Loss: 0.1443, Accuracy: 80.32 %
Epoch [17/20], Loss: 0.1295, Accuracy: 80.19 %
Epoch [18/20], Loss: 0.1180, Accuracy: 79.44 %
Epoch [19/20], Loss: 0.1089, Accuracy: 8

pyswarms.single.global_best: 100%|██████████|10/10, best_cost=0.746
2024-04-03 22:02:30,874 - pyswarms.single.global_best - INFO - Optimization finished | best cost: 0.7455, best pos: [  1.63923135 166.55545564]


Epoch [20/20], Loss: 0.1024, Accuracy: 79.47 %
Best hyperparameters: 0.7455


### Best position is [1.63923135 166.55545564], therefore, 2 layers, 167 neurons