<a href="https://colab.research.google.com/github/astopchatyy/GenAlgoTest/blob/main/test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import random
import copy
import matplotlib.pyplot as plt
from typing import Dict, List, Type, Iterator
from tqdm import tqdm
from abc import ABC, abstractmethod
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
import os
num_workers = os.cpu_count()
print("Number of workers:", num_workers)

Device: cuda
Number of workers: 12


In [None]:
from torch import Tensor


class Metric(ABC):
    def __init__(self):
        pass

    @abstractmethod
    def eval(self, predicted_values: torch.Tensor, true_values: torch.Tensor) -> float:
        return 0.0

    @abstractmethod
    def aggregate(self, values: List) -> float:
        return 0.0


class Accuracy(Metric):
    def eval(self, predicted_values: Tensor, true_values: Tensor) -> float:
        _, predicted = torch.max(predicted_values, 1)
        return (predicted == true_values).float().mean().item()

    def aggregate(self, values: List) -> float:
        return sum(values) / len(values)

class MSE(Metric):
    def eval(self, predicted_values: Tensor, true_values: Tensor) -> float:
        return torch.mean((predicted_values - true_values) ** 2).item()

    def aggregate(self, values: List) -> float:
        return sum(values)

In [None]:
class GeneticTrainer:
    def __init__(self,
                 genetic_population_size: int,
                 superviser_population_size: int,
                 model_class: Type[nn.Module],
                 optimizer_class: Type[optim.Optimizer],
                 criterion: nn.modules.loss._Loss,
                 metric: Metric,
                 model_params: Dict={},
                 optimizer_params: Dict={}) -> None:

        self._model_class = model_class
        self._optimizer_class = optimizer_class
        self._genetic_population_size = genetic_population_size
        self._superviser_population_size = superviser_population_size
        self._total_population_size = self._superviser_population_size + self._genetic_population_size
        self._metric = metric
        self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self._model_params = model_params
        self._optimizer_params = optimizer_params

        self._population = [self._model_class(**self._model_params).to(device) for _ in range(self._total_population_size)]
        self._criterion = criterion
        self._optimizers = [self._optimizer_class(self._population[i].parameters(), **self._optimizer_params) for i in range(self._total_population_size)]

        self._training_history = None

    def train(self,
              cycles: int,
              train_dataloader: DataLoader,
              test_dataloader: DataLoader,
              validation_dataloader: DataLoader|None=None,
              survivor_fraction: float = 0.5,
              mutation_chance: float = 0.1,
              epoches_per_cycle=1,
              last_cycle_evolution=True,
              verbose=2) -> Dict:

        if not validation_dataloader:
            validation_dataloader = train_dataloader

        self._training_history = {
            "crossover_history": [[] for _ in range(self._genetic_population_size + self._superviser_population_size)],
            "train_losses": [[] for _ in range(self._genetic_population_size + self._superviser_population_size)],
            "test_metric": [[] for _ in range(self._genetic_population_size + self._superviser_population_size)],
            "val_metric": [[] for _ in range(self._genetic_population_size + self._superviser_population_size)]
        }
        if verbose == 2:
            total_len = cycles * (self._total_population_size) * epoches_per_cycle * (len(train_dataloader) + len(validation_dataloader) + len(test_dataloader))
        elif verbose == 1:
            total_len = cycles * (self._total_population_size) * epoches_per_cycle
        else:
            total_len = cycles
        with tqdm(total=total_len, leave=False) if verbose > 0 else None as pbar:
            for cycle in range(cycles):
                if verbose not in [1,2]:
                    pbar.set_description(f"Cycle: {cycle+1}/{cycles}")
                    pbar.update(1)
                fitness_scores = []
                for i in range(self._total_population_size):
                    fitness_list = []
                    test_fitness_list = []
                    train_losses_list = []

                    fitness = 0

                    for epoch in range(epoches_per_cycle):
                        message = f"Cycle: {cycle+1}/{cycles}, entity: {i+1}/{self._total_population_size}, epoch:{epoch+1}/{epoches_per_cycle}"
                        if verbose == 2:
                            progr_bar = pbar
                        else:
                            progr_bar = None
                            if verbose == 1:
                                pbar.set_description(message)
                                pbar.update(1)
                        losses = self._train_one_model_one_epoch(self._population[i],
                                                                self._optimizers[i],
                                                                self._criterion,
                                                                train_dataloader,
                                                                progr_bar,
                                                                f"{message}, training: ")
                        train_losses_list.append(losses)

                        fitness = self._evaluate_model(self._population[i], validation_dataloader, progr_bar, f"{message}, evaluating validation: ")
                        fitness_list.append(fitness)

                        test_fitness = self._evaluate_model(self._population[i], test_dataloader, progr_bar, f"{message}, evaluating test: ")
                        test_fitness_list.append(test_fitness)

                    fitness_scores.append(fitness)
                    self._training_history["train_losses"][i].append(train_losses_list)
                    self._training_history["val_metric"][i].append(fitness_list)
                    self._training_history["test_metric"][i].append(test_fitness_list)

                if cycle < cycles - 1 or last_cycle_evolution:
                  ranked = sorted(list(zip(fitness_scores[:self._genetic_population_size], list(range(self._genetic_population_size)))), key=lambda x: x[0], reverse=True)

                  survivors_count = math.ceil(self._genetic_population_size * survivor_fraction)

                  for _, index in ranked[survivors_count:]:
                      index_a, index_b = random.sample(range(survivors_count), 2)
                      self._training_history["crossover_history"][index].append((cycle, index_a, index_b))

                      self._population[index] = self._crossover(self._population[index_a], self._population[index_b], mutation_chance)
                      self._optimizers[index] = self._reset_optimizer(self._optimizers[index], self._population[index])

        return self._training_history

    def _crossover(self, model_a: nn.Module, model_b: nn.Module, mutation_chance: float) -> nn.Module:
        child = self._model_class(**self._model_params).to(device)
        with torch.no_grad():
            mods_a = dict(model_a.named_modules())
            mods_b = dict(model_b.named_modules())
            if self._superviser_population_size and mutation_chance:
                mutation_index = random.sample(range(self._superviser_population_size), 1)[0]
                mods_mut = dict(self._population[mutation_index].named_modules())

            for name, module_child in child.named_modules():

                if len(list(module_child.children())) != 0:
                    continue

                if self._superviser_population_size and  mutation_chance and random.random() < mutation_chance:
                    src = mods_mut[name] # type: ignore
                elif random.random() < 0.5:
                    src = mods_a[name]
                else:
                    src = mods_b[name]

                for param_child, param_src in zip(module_child.parameters(), src.parameters()):
                    param_child.data.copy_(param_src.data)

                for buf_name, buf_child in module_child.named_buffers():
                    buf_src = getattr(src, buf_name)
                    buf_child.copy_(buf_src)
                    return child
        return child

    def _train_one_model_one_epoch(self,
                                   model: nn.Module,
                                   optimizer: optim.Optimizer,
                                   criterion: nn.modules.loss._Loss,
                                   dataloader: DataLoader,
                                   pbar: tqdm|None=None,
                                   pbar_message: str|None=None) -> List:

        model.train()
        train_losses = []
        for i, (inputs, labels) in enumerate(dataloader):
            inputs, labels = inputs.to(self._device), labels.to(self._device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())
            if pbar:
                if pbar_message:
                    pbar.set_description(f"{pbar_message}{i+1}/{len(dataloader)}")
                pbar.update(1)

        return train_losses

    def _evaluate_model(self, model: nn.Module, dataloader: DataLoader,  pbar: tqdm|None=None, pbar_message: str|None=None) -> float:
        model.eval()
        scores = []
        with torch.no_grad():
                for i, (inputs, labels) in enumerate(dataloader):
                    inputs, labels = inputs.to(device), labels.to(device)
                    outputs = model(inputs)
                    scores.append(self._metric.eval(outputs, labels))
                    if pbar:
                        if pbar_message:
                            pbar.set_description(f"{pbar_message}{i+1}/{len(dataloader)}")
                        pbar.update(1)
        return self._metric.aggregate(scores)

    def _reset_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer:
        param_groups = optimizer.param_groups

        base_hyperparams = {k: v for k, v in param_groups[0].items() if k != 'params'}

        return self._optimizer_class(model.parameters(), **base_hyperparams)

    def _layerwise_mse(self, model_a: nn.Module, model_b: nn.Module):
        layers_a = [m for m in model_a.modules() if len(list(m.parameters())) > 0]
        layers_b = [m for m in model_b.modules() if len(list(m.parameters())) > 0]

        mses = []
        for la, lb in zip(layers_a, layers_b):
            params_a = torch.cat([p.view(-1) for p in la.parameters()])
            params_b = torch.cat([p.view(-1) for p in lb.parameters()])
            mse = nn.functional.mse_loss(params_a, params_b).item()
            mses.append(mse)

        return torch.mean(Tensor(mses)).item()

    def extract_model(self, smilarity_fraction: float=0.8) -> nn.Module:
        mses = [[] for i in range(pop_size)]
        for i in range(self._genetic_population_size):
          for j in range(i + 1, self._genetic_population_size):
            mses[i].append(self._layerwise_mse(self._population[i], self._population[j]))
            mses[j].append(self._layerwise_mse(self._population[i], self._population[j]))

        row_means = np.mean(mses, axis=1)
        selector = row_means <= np.quantile(row_means, smilarity_fraction)

        avg_model = self._model_class(**self._model_params).to(device)
        avg_state = avg_model.state_dict()

        keys = avg_state.keys()
        for key in keys:
            stacked = torch.stack([self._population[i].state_dict()[key] for i in range(self._genetic_population_size) if selector[i]], dim=0)
            avg_state[key] = stacked.mean(dim=0)

        avg_model.load_state_dict(avg_state)
        return avg_model


In [None]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv4 = nn.Conv2d(128, 256, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(256 * 4 * 4, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = self.pool(x)
        x = self.dropout1(x)

        x = torch.relu(self.conv2(x))
        x = self.pool(x)
        x = self.dropout1(x)

        x = torch.relu(self.conv3(x))
        x = self.pool(x)
        x = self.dropout1(x)

        x = torch.relu(self.conv4(x))
        x = self.dropout1(x)

        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = self.dropout2(x)
        x = self.fc2(x)
        return x



In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

base_train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)


testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)

100%|██████████| 170M/170M [00:14<00:00, 11.9MB/s]


In [None]:
validation_fraction = 0.2
validation_size = int(len(testset) * validation_fraction)
seed = 42
g = torch.Generator().manual_seed(seed)
valset, trainset = torch.utils.data.random_split(base_train_set, [validation_size, len(base_train_set) - validation_size], generator=g)

val_loader = torch.utils.data.DataLoader(valset, batch_size=128, shuffle=False, num_workers=num_workers)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=num_workers)

test_loader = torch.utils.data.DataLoader(testset, batch_size=128,
                                          shuffle=False, num_workers=num_workers)

In [None]:
pop_size = 12
superviser_pop_size = 4
trainer = GeneticTrainer(pop_size, superviser_pop_size, SimpleCNN, optim.Adam, nn.CrossEntropyLoss(), Accuracy(), optimizer_params={"lr": 0.001})
history = []

In [None]:
history.append(trainer.train(16, train_loader, test_loader, val_loader, verbose=1))
history.append(trainer.train(16, train_loader, test_loader, val_loader, epoches_per_cycle=2, last_cycle_evolution=False, verbose=1))



In [None]:
max_vals = [0 for _ in range(pop_size)]
test_vals = [0 for _ in range(pop_size)]
for i in range(pop_size):
    for el in history[-1]["val_metric"][i]:
        for val in el:
            if val > max_vals[i]:
                max_vals[i] = val
                test_vals[i] = history[-1]["test_metric"][i][-1][-1]

print(max_vals)
print(max(max_vals))

print("====")
print(test_vals)
print(test_vals[np.argmax(max_vals)])

[0.8273437507450581, 0.8349609375, 0.84033203125, 0.8353515639901161, 0.8313476555049419, 0.833984375, 0.8312500007450581, 0.8319335952401161, 0.8389648459851742, 0.8357421867549419, 0.8348632827401161, 0.8352539055049419]
0.84033203125
====
[0.8121044303797469, 0.8102254746835443, 0.807753164556962, 0.8079509493670886, 0.8095332278481012, 0.8119066455696202, 0.8137856012658228, 0.8097310126582279, 0.8071598101265823, 0.8126977848101266, 0.8122033227848101, 0.8072587025316456]
0.807753164556962


In [None]:
model = trainer.extract_model(0.8)
trainer._evaluate_model(model, test_loader)

0.821993670886076

In [None]:
for i in range(trainer._genetic_population_size):
  torch.save(trainer._population[i], f"/content/drive/MyDrive/gen_algo/4l/gen/model_4l_gen_{i}")

In [None]:
for i in range(trainer._genetic_population_size):
  trainer._population[i] = torch.load(f"/content/drive/MyDrive/gen_algo/4l/gen/model_4l_gen_{i}", weights_only=False)

for i in range(trainer._superviser_population_size):
  trainer._population[trainer._genetic_population_size + i] = torch.load(f"/content/drive/MyDrive/gen_algo/4l/sep/model_4l_sep_{i}", weights_only=False)

In [None]:
pops = []
for i in range(trainer._genetic_population_size):
  model = trainer._population[i]
  pops.append(model)

In [None]:
for i in range(trainer._genetic_population_size):
  index_a, index_b = random.sample(range(len(pops)), 2)
  trainer._population[i] = trainer._crossover(pops[index_a], pops[index_b], mutation_chance=0.3)

In [None]:
h = [[],[]]
h[0] = trainer.train(4, train_loader, test_loader, val_loader, epoches_per_cycle=1, last_cycle_evolution=False, mutation_chance=0.4)
h[1] = trainer.train(4, train_loader, test_loader, val_loader, epoches_per_cycle=1, last_cycle_evolution=False, mutation_chance=0)



In [None]:
model = trainer.extract_model(1)
trainer._evaluate_model(model, test_loader)

0.8269382911392406

In [None]:
model = trainer.extract_model(1)
trainer._evaluate_model(model, test_loader)

0.8267405063291139

In [None]:
trainer2 = GeneticTrainer(1, 0, SimpleCNN, optim.Adam, nn.CrossEntropyLoss(), Accuracy(), optimizer_params={"lr": 0.001})
hh = trainer2.train(1, train_loader, test_loader, val_loader, epoches_per_cycle=200, last_cycle_evolution=False, mutation_chance=0)




In [None]:
print(max(hh["val_metric"][0][0]), np.argmax(hh["val_metric"][0][0]), hh["test_metric"][0][0][np.argmax(hh["val_metric"][0][0])])

0.8433593772351742 195 0.8214003164556962


In [None]:
[h[1]["test_metric"][i][-1] for i in range(len(h[1]["test_metric"]))]

[[0.8163568037974683],
 [0.8149723101265823],
 [0.8134889240506329],
 [0.811807753164557],
 [0.8158623417721519],
 [0.8145767405063291],
 [0.8084454113924051],
 [0.8166534810126582],
 [0.8089398734177216],
 [0.8143789556962026],
 [0.8137856012658228],
 [0.8194224683544303],
 [0.8082476265822784],
 [0.8157634493670886],
 [0.807060917721519],
 [0.806368670886076]]

In [None]:
for i in range(trainer2._genetic_population_size):
  torch.save(trainer2._population[i], f"/content/drive/MyDrive/gen_algo/4l/sep/model_4l_sep_{i}")

In [None]:
for i in range(trainer2._genetic_population_size):
  trainer2._population[i] = torch.load(f"/content/drive/MyDrive/gen_algo/4l/sep/model_4l_sep_{i}", weights_only=False)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
os.makedirs("/content/drive/MyDrive/gen_algo/4l/gen", exist_ok=True)
os.makedirs("/content/drive/MyDrive/gen_algo/4l/sep", exist_ok=True)