In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import torch
from torch import nn
import torchvision
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import os
import shutil

from torchinfo import summary
import torch.utils.tensorboard as tb

import models
import mnist

torch.manual_seed(0);

In [3]:
# import tempfile
# tb_log_dir = tempfile.mkdtemp()
user = os.getlogin()
# tb_log_dir = os.path.join(cp.data_dir, 'tensorboard/', user)
tb_log_dir = f'/tmp/tensorboard/{user}'
# print(tb_log_dir)
if os.path.exists(tb_log_dir):
    shutil.rmtree(tb_log_dir)
    

In [4]:
logger = tb.SummaryWriter(tb_log_dir)

In [5]:

batch_idx, (X_batch, Y_batch) = next(enumerate(mnist.train_loader))
def calc_pheo_fitness(net):
    Y_batch_pred = net(X_batch)
    n_correct = (Y_batch_pred.argmax(dim=-1)==Y_batch).sum().item()
    loss = mnist.loss_func(Y_batch_pred.log(), Y_batch).item()
    accuracy = n_correct/len(Y_batch)
    return loss, accuracy

In [6]:
model = models.SmallNet

def to_np_array(population):
    pop = np.empty(len(population), dtype=object)
    for i in range(len(pop)):
        pop[i] = population[i]
    return pop

def get_init_population():
    pop = [nn.utils.parameters_to_vector(model().parameters()).detach() for _ in range(100)]
    pop = to_np_array(pop)
    return pop

def geno2pheno():
    pass
def pheno2geno():
    pass

def calc_fitnesses(population):
    agent = model()
    fitnesses = []
    fitnesses_acc = []
    for dna in population:
        nn.utils.vector_to_parameters(dna, agent.parameters())
#         fitness = -perform_stats(agent, train_loader, n_batches=1, show_stats=False, tqdm=None)[0]
        nll, acc = calc_pheo_fitness(agent)
        fitnesses.append(-nll)
        fitnesses_acc.append(acc)
    return fitnesses, fitnesses_acc

def calc_mutate(pop):
    mutant = [d+1e-2*torch.randn_like(d) for d in pop]
    return mutant

def calc_crossover(pop1, pop2):
    cross = [torch.stack([d1, d2], dim=0).mean(dim=0) for d1, d2 in zip(pop1, pop2)]
    return cross
    
def calc_next_population(population, fitnesses):
    N = len(population)
    npop = []
    
    n_keep_best = int(.1*N)
    npop.extend(population[np.argsort(fitnesses)[-n_keep_best:]])
    
    prob = torch.tensor(fitnesses).softmax(dim=-1).numpy()
    
    
    n_parents = int(.0*N)
    n_mutants = N-n_keep_best-n_parents
    
    mutants = np.random.choice(population, size=n_mutants, p=prob)
    mutants = calc_mutate(mutants)
    parents1 = np.random.choice(population, size=n_parents, p=prob)
    parents2 = np.random.choice(population, size=n_parents, p=prob)
    children = calc_crossover(parents1, parents2)
    
    npop.extend(mutants)
    npop.extend(children)
    
    npop = to_np_array(npop)
    return npop
    
    

population = get_init_population()
for gen_idx in tqdm(range(10000)):
    fitnesses, fitnesses_acc = calc_fitnesses(population)
    population = calc_next_population(population, fitnesses)
    
    logger.add_histogram('fitness histogram', np.array(fitnesses), global_step=gen_idx)
    logger.add_histogram('accuracies histogram', np.array(fitnesses_acc), global_step=gen_idx)


  0%|          | 0/10000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
# Now, we are going to learn the crossover function

In [24]:
net = model()
v = torch.nn.utils.parameters_to_vector(net.parameters())
torch.nn.utils.vector_to_parameters(v, net.parameters())
dna_len = len(v)

In [41]:
class CONet(nn.Module):
    def __init__(self):
        super().__init__()
        n = dna_len
#         self.lin1 = nn.Linear(n, n//2)
#         self.lin2 = nn.Linear(n//2, n//4)
#         self.lin3 = nn.Linear(64, n//2)
#         self.lin4 = nn.Linear(n//2, n)

        self.lin1 = nn.Linear(n, n)
        self.lin2 = nn.Linear(n, n)
        self.lin3 = nn.Linear(n, n)

    def forward(self, x):
        x = torch.sigmoid(self.lin1(x))
        x = torch.sigmoid(self.lin2(x))
        x = self.lin3(x)
        
#         x = torch.sigmoid(self.lin1(x))
#         x = torch.sigmoid(self.lin2(x))
#         x = x.reshape(len(x), -1)
#         x = torch.sigmoid(self.lin3(x))
#         x = self.lin4(x)
        return x

conet = CONet()
summary(conet, input_size=(1, 2, dna_len))

Layer (type:depth-idx)                   Output Shape              Param #
CONet                                    --                        --
├─Linear: 1-1                            [1, 2, 130]               17,030
├─Linear: 1-2                            [1, 2, 130]               17,030
├─Linear: 1-3                            [1, 2, 130]               17,030
Total params: 51,090
Trainable params: 51,090
Non-trainable params: 0
Total mult-adds (M): 0.05
Input size (MB): 0.00
Forward/backward pass size (MB): 0.01
Params size (MB): 0.20
Estimated Total Size (MB): 0.21

In [12]:
conet = nn.Sequential(
    nn.Linear(len(v)*2, 100),
    nn.ReLU(),
    nn.Linear(100, 100),
    nn.ReLU(),
    nn.Linear(100, 100),
    nn.ReLU(),
    nn.Linear(100, 100),
    nn.ReLU(),
    nn.Linear(100, 100),
    nn.ReLU(),
    nn.Linear(100, 100),
    nn.ReLU(),
    nn.Linear(100, 100),
    nn.ReLU(),
    nn.Linear(100, len(v)*2),
)
munet = nn.Sequential(
    nn.Linear(len(v), 100),
    nn.ReLU(),
    nn.Linear(100, 50),
    nn.ReLU(),
    nn.Linear(50, 100),
    nn.ReLU(),
    nn.Linear(100, len(v)),
    nn.Tanh(),
)
summary(munet)

Layer (type:depth-idx)                   Param #
Sequential                               --
├─Linear: 1-1                            68,600
├─ReLU: 1-2                              --
├─Linear: 1-3                            5,050
├─ReLU: 1-4                              --
├─Linear: 1-5                            5,100
├─ReLU: 1-6                              --
├─Linear: 1-7                            69,185
├─Tanh: 1-8                              --
Total params: 147,935
Trainable params: 147,935
Non-trainable params: 0

In [48]:
torch.manual_seed(0)
rand_idx1 = torch.arange(1000)
rand_idx2 = torch.arange(1000)
rand_idx3 = torch.arange(1000)
class MuNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin1 = nn.Linear(50, 50)
        self.lin2 = nn.Linear(50, 50)
        self.lin3 = nn.Linear(50, 50)
        self.lin4 = nn.Linear(50, 50)

    def forward(self, x):
        x = torch.cat([x, torch.zeros(len(x), 1000-len(v))], dim=-1)
        x = x.reshape(-1, 20, 50)
        x = torch.relu(self.lin1(x))
        x = x.reshape(-1, 1000)[:, rand_idx1]
        x = x.reshape(-1, 20, 50)
        x = torch.relu(self.lin2(x))
        x = x.reshape(-1, 1000)[:, rand_idx2]
        x = x.reshape(-1, 20, 50)
        x = torch.relu(self.lin3(x))
        x = x.reshape(-1, 1000)[:, rand_idx3]
        x = x.reshape(-1, 20, 50)
        x = torch.relu(self.lin4(x))
        x = x.reshape(-1, 1000)[:, :len(v)]
        return x
    
class CONet(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin1 = nn.Linear(50, 50)
        self.lin2 = nn.Linear(50, 50)
        self.lin3 = nn.Linear(50, 50)
        self.lin4 = nn.Linear(50, 50)

    def forward(self, x):
        x = torch.cat([x, torch.zeros(len(x), 1000-len(v))], dim=-1)
        x = x.reshape(-1, 20, 50)
        x = torch.relu(self.lin1(x))
        x = x.reshape(-1, 1000)[:, rand_idx1]
        x = x.reshape(-1, 20, 50)
        x = torch.relu(self.lin2(x))
        x = x.reshape(-1, 1000)[:, rand_idx2]
        x = x.reshape(-1, 20, 50)
        x = torch.relu(self.lin3(x))
        x = x.reshape(-1, 1000)[:, rand_idx3]
        x = x.reshape(-1, 20, 50)
        x = torch.relu(self.lin4(x))
        x = x.reshape(-1, 1000)[:, :len(v)]
        return x
    
munet = MuNet()
print(munet(v[None]).shape)
summary(munet)

torch.Size([1, 685])


Layer (type:depth-idx)                   Param #
MuNet                                    --
├─Linear: 1-1                            2,550
├─Linear: 1-2                            2,550
├─Linear: 1-3                            2,550
├─Linear: 1-4                            2,550
Total params: 10,200
Trainable params: 10,200
Non-trainable params: 0

In [7]:
685*685

469225