# Mutant UCB

In [3]:
import warnings
warnings.filterwarnings('ignore')
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import openml
from sklearn.model_selection import train_test_split
import numpy as np
import os
from dragon.search_space.bricks_variables import mlp_var, identity_var, operations_var, mlp_const_var, dag_var, node_var
from dragon.search_space.cells import AdjMatrix, Node
from dragon.search_space.zellij_variables import ArrayVar
from dragon.search_algorithm.zellij_neighborhoods import ArrayInterval


dataset = openml.datasets.get_dataset(32)
data, _, numerical, names = dataset.get_data()
X = data.drop('class', axis=1)
y = data[["class"]].astype(int)
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.3, random_state=0)
class CustomDataset(Dataset):
    def __init__(self, X, y):
        super().__init__()
        self.X = torch.FloatTensor(X.values)
        self.y = torch.LongTensor(y.values)
    def __len__(self):
        return self.X.shape[0]
    def __getitem__(self, index):
        return self.X[index], self.y[index]
train_set = CustomDataset(X_train, y_train)
val_set = CustomDataset(X_val, y_val)
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
val_loader = DataLoader(val_set, batch_size=1, shuffle=False)

class MetaArchi(nn.Module):
    def __init__(self, args, input_shape):
        super().__init__()
        # Number of features, here equals to 16
        self.input_shape = input_shape

        # We create the DAG using the WeightsAdjCell module
        assert isinstance(args['Dag'], AdjMatrix), f"The 'Dag' argument should be an 'AdjMatrix'. Got {type(args['Dag'])} instead."
        self.dag = args['Dag']
        self.dag.set(input_shape)

        # We set the final layer
        assert isinstance(args['Out'], Node), f"The 'Out' argument should be a 'Node'. Got {type(args['Node'])} instead."
        self.output = args["Out"]
        self.output.set(self.dag.output_shape)

    def forward(self, X):
        out = self.dag(X)
        return self.output(out)
    
    def save(self, path):
        if not os.path.exists(path):
            os.makedirs(path)
        full_path = os.path.join(path, "best_model.pth")
        torch.save(self.state_dict(), full_path)


candidate_operations = operations_var("Candidate operations", size=10, candidates=[mlp_var("MLP"), identity_var("Identity")])
dag = dag_var("Dag", candidate_operations)
out = node_var("Out", operation=mlp_const_var('Operation', 10), activation_function=nn.Softmax())


search_space = ArrayVar(dag, out, label="Search Space", neighbor=ArrayInterval())

def train_model(model, data_loader, n_epochs=2):
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.05)
    model.train()
    for _ in range(n_epochs):
        for X,y in data_loader:
            optimizer.zero_grad()
            y = y.squeeze()
            pred = model(X)
            loss = loss_fn(pred,y)
            loss.backward()
            optimizer.step()
    return model

def test_model(model, data_loader):
    loss_fn = nn.CrossEntropyLoss()
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
      for X, y in data_loader:
          y = y.squeeze(1)
          pred = model(X)
          loss = loss_fn(pred, y).item()
          test_loss += loss
          prediction = pred.argmax(axis=1)
          correct += (prediction == y).sum().item()
    accuracy = correct/ len(data_loader.dataset)
    return accuracy

def loss_function(args, idx, *kwargs):
    labels = [e.label for e in search_space]
    args = dict(zip(labels, args))
    model = MetaArchi(args, input_shape=(16,))
    model = train_model(model, train_loader)
    accuracy = test_model(model, val_loader)
    print(f'Idx = {idx}, accuracy = {accuracy}')
    return 1 - accuracy, model

2024-09-27 16:47:02,234 | INFO | pickle write pendigits


In [4]:
from dragon.search_algorithm.mutant_ucb import Mutant_UCB


search_algorithm = Mutant_UCB(search_space, "test_mutant", T=500, N=5, K=10, E=0.01, evaluation=loss_function)
min_loss = search_algorithm.run()

Idx = 0, accuracy = 0.09702850212249849
Idx = 1, accuracy = 0.5576106731352335
Idx = 2, accuracy = 0.10430563978168587
Idx = 3, accuracy = 0.0585203153426319
Idx = 4, accuracy = 0.9035779260157671
Idx = 5, accuracy = 0.30442692540933897
Idx = 6, accuracy = 0.09945421467556094
Idx = 7, accuracy = 0.10430563978168587
Idx = 8, accuracy = 0.0979381443298969
Idx = 9, accuracy = 0.09702850212249849
2024-09-27 16:47:55,412 | INFO | With p = 0.2 = 1 / 5, training 4 instead
Idx = 4, accuracy = 0.8583990297149787
2024-09-27 16:48:00,996 | INFO | Best found! 0.14160097028502128 < inf
2024-09-27 16:48:00,999 | INFO | With p = 0.4 = 2 / 5, mutating 4 to 10
Idx = 10, accuracy = 0.9423893268647665
2024-09-27 16:48:06,641 | INFO | Best found! 0.05761067313523349 < 0.14160097028502128
2024-09-27 16:48:06,644 | INFO | With p = 0.2 = 1 / 5, training 10 instead
Idx = 10, accuracy = 0.9460278956943602
2024-09-27 16:48:12,228 | INFO | Best found! 0.05397210430563981 < 0.05761067313523349
2024-09-27 16:48:12

In [5]:
from dragon.search_algorithm.ssea import SteadyStateEA

def loss_function(args, idx, *kwargs):
    labels = [e.label for e in search_space]
    args = dict(zip(labels, args))
    model = MetaArchi(args, input_shape=(16,))
    model = train_model(model, train_loader, n_epochs=10)
    accuracy = test_model(model, val_loader)
    print(f'Idx = {idx}, accuracy = {accuracy}')
    return 1 - accuracy, model

search_algorithm = SteadyStateEA(search_space, n_iterations=100, population_size=20, selection_size=5, evaluation=loss_function, save_dir="test_ssea")
min_loss = search_algorithm.run()

2024-09-27 17:26:52,414 | INFO | The whole population has been created (size = 20), 20 have been randomy initialized.
2024-09-27 17:26:52,415 | INFO | We start by evaluating the whole population (size=20)
Idx = 0, accuracy = 0.25651910248635534
2024-09-27 17:26:55,654 | INFO | Best found ! 0.7434808975136447 < inf
Idx = 1, accuracy = 0.36901152213462707
2024-09-27 17:26:58,378 | INFO | Best found ! 0.6309884778653729 < 0.7434808975136447
Idx = 2, accuracy = 0.19860521528198907
Idx = 3, accuracy = 0.09187386294724076
Idx = 4, accuracy = 0.09702850212249849
Idx = 5, accuracy = 0.08793208004851426
Idx = 6, accuracy = 0.8244390539721043
2024-09-27 17:27:36,361 | INFO | Best found ! 0.17556094602789574 < 0.6309884778653729
Idx = 7, accuracy = 0.08459672528805337
Idx = 8, accuracy = 0.10794420861127957
Idx = 9, accuracy = 0.947847180109157
2024-09-27 17:28:08,397 | INFO | Best found ! 0.05215281989084297 < 0.17556094602789574
Idx = 10, accuracy = 0.16252274105518497
Idx = 11, accuracy = 0.20