In [1]:
import numpy as np
import torch
import torch.nn as nn
import torchvision
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from torchinfo import summary

In [2]:
n_epochs = 3
batch_size_train = 64
batch_size_test = 5000
momentum = 0.5
log_interval = 10

random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

<torch._C.Generator at 0x1097da690>

In [3]:
transform = torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                                 ])
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./mnist/', train=True, download=True,
                             transform=transform),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./', train=False, download=True,
                             transform=transform),
    batch_size=batch_size_test, shuffle=True)

In [4]:
loss_func = nn.NLLLoss()
def perform_stats(net, loader=test_loader, show_stats=True, n_batches=-1, tqdm=tqdm):
    n_correct, total = 0, 0
    loss_total = 0
    n_examples = 0
    loop = enumerate(loader)
    if tqdm is not None:
        loop = tqdm(loop, leave=False, total=max(n_batches, len(loader)))
    for batch_idx, (X_batch, Y_batch) in loop:
        if batch_idx == n_batches:
            break
        Y_batch_pred = net(X_batch)
        n_correct += (Y_batch_pred.argmax(dim=-1)==Y_batch).sum().item()
        loss = loss_func(Y_batch_pred.log(), Y_batch).item()
        loss_total += loss * len(X_batch)
        n_examples += len(X_batch)
        total += len(Y_batch)
    loss_total /= n_examples
    accuracy = n_correct/total
    if show_stats:
        print(f'Average Loss: {loss_total:.03f}, Accuracy: {accuracy*100:.03f}%')
    return loss_total, accuracy

In [5]:
class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 2, kernel_size=3)
        self.conv2 = nn.Conv2d(2, 5, kernel_size=3)
        self.conv3 = nn.Conv2d(5, 10, kernel_size=3)
        self.conv4 = nn.Conv2d(10, 10, kernel_size=1)
        self.fc1 = nn.Linear(10, 10)

    def forward(self, x):
        x = torch.relu(torch.max_pool2d(self.conv1(x), 2))
        x = torch.relu(torch.max_pool2d(self.conv2(x), 2))
        x = torch.relu(torch.max_pool2d(self.conv3(x), 2))
        x = torch.relu(self.conv4(x))
        x = x.view(-1, 10)
        x = self.fc1(x)
        return x.softmax(dim=-1)
    
    
net = ConvNet()
summary_kwargs = {'input_size': (batch_size_train, 1, 28, 28), 
                  'col_names': ["input_size", "output_size", "num_params", "kernel_size"]}
summary(net, **summary_kwargs)

Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Kernel Shape
ConvNet                                  --                        --                        --                        --
├─Conv2d: 1-1                            [64, 1, 28, 28]           [64, 2, 26, 26]           20                        [1, 2, 3, 3]
├─Conv2d: 1-2                            [64, 2, 13, 13]           [64, 5, 11, 11]           95                        [2, 5, 3, 3]
├─Conv2d: 1-3                            [64, 5, 5, 5]             [64, 10, 3, 3]            460                       [5, 10, 3, 3]
├─Conv2d: 1-4                            [64, 10, 1, 1]            [64, 10, 1, 1]            110                       [10, 10, 1, 1]
├─Linear: 1-5                            [64, 10]                  [64, 10]                  110                       [10, 10]
Total params: 795
Trainable params: 795
Non-trainable params: 0
Total mult-adds (M): 1.

In [6]:
perform_stats(net)

  0%|          | 0/2 [00:00<?, ?it/s]

Average Loss: 2.340, Accuracy: 10.100%


(2.340292453765869, 0.101)

In [8]:
# opt = torch.optim.SGD(net.parameters(), lr=1e-3)
opt = torch.optim.Adam(net.parameters(), lr=1e-2)

In [10]:
for epoch in tqdm(range(n_epochs)):
    for batch_idx, (X_batch, Y_batch) in tqdm(enumerate(train_loader), 
                                              leave=False, total=len(train_loader)):
        Y_batch_pred = net(X_batch)
        loss = loss_func(Y_batch_pred.log(), Y_batch)
        opt.zero_grad()
        loss.backward()
        opt.step()
    perform_stats(net)
    
    

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/938 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [145]:
v = torch.nn.utils.parameters_to_vector(net.parameters())
torch.nn.utils.vector_to_parameters(v, net.parameters())

In [148]:
conet = nn.Sequential(
    nn.Linear(len(v)*2, 100),
    nn.Sigmoid(),
    nn.Linear(100, 100),
    nn.Sigmoid(),
    nn.Linear(100, 100),
    nn.Sigmoid(),
    nn.Linear(100, 100),
    nn.Sigmoid(),
    nn.Linear(100, 100),
    nn.Sigmoid(),
    nn.Linear(100, 100),
    nn.Sigmoid(),
    nn.Linear(100, 100),
    nn.Sigmoid(),
    nn.Linear(100, len(v)*2),
)
munet = nn.Sequential(
    nn.Linear(len(v), 100),
    nn.Sigmoid(),
    nn.Linear(100, 50),
    nn.Sigmoid(),
    nn.Linear(50, 100),
    nn.Sigmoid(),
    nn.Linear(100, len(v)),
)
summary(conet)

Layer (type:depth-idx)                   Param #
Sequential                               --
├─Linear: 1-1                            69,700
├─Sigmoid: 1-2                           --
├─Linear: 1-3                            10,100
├─Sigmoid: 1-4                           --
├─Linear: 1-5                            10,100
├─Sigmoid: 1-6                           --
├─Linear: 1-7                            10,100
├─Sigmoid: 1-8                           --
├─Linear: 1-9                            10,100
├─Sigmoid: 1-10                          --
├─Linear: 1-11                           10,100
├─Sigmoid: 1-12                          --
├─Linear: 1-13                           10,100
├─Sigmoid: 1-14                          --
├─Linear: 1-15                           70,296
Total params: 200,596
Trainable params: 200,596
Non-trainable params: 0

In [6]:

batch_idx, (X_batch, Y_batch) = next(enumerate(train_loader))
def calc_pheo_fitness(net):
    Y_batch_pred = net(X_batch)
    n_correct = (Y_batch_pred.argmax(dim=-1)==Y_batch).sum().item()
    loss = loss_func(Y_batch_pred.log(), Y_batch).item()
    accuracy = n_correct/len(Y_batch)
    return loss, accuracy

In [7]:
import torch.utils.tensorboard as tb
import os
# import tempfile
# tb_log_dir = tempfile.mkdtemp()
user = os.getlogin()
# tb_log_dir = os.path.join(cp.data_dir, 'tensorboard/', user)
tb_log_dir = f'/tmp/tensorboard/{user}'
# print(tb_log_dir)
import shutil
if os.path.exists(tb_log_dir):
    shutil.rmtree(tb_log_dir)

In [8]:
from torch.utils.tensorboard import SummaryWriter

logger = SummaryWriter(tb_log_dir)

In [9]:
agent = ConvNet()

def to_np_array(population):
    pop = np.empty(len(population), dtype=object)
    for i in range(len(pop)):
        pop[i] = population[i]
    return pop

def get_init_population():
    pop = [nn.utils.parameters_to_vector(ConvNet().parameters()).detach() for _ in range(100)]
    pop = to_np_array(pop)
    return pop

def geno2pheno():
    pass
def pheno2geno():
    pass

def calc_fitnesses(population):
    fitnesses = []
    fitnesses_acc = []
    for dna in population:
        nn.utils.vector_to_parameters(dna, agent.parameters())
#         fitness = -perform_stats(agent, train_loader, n_batches=1, show_stats=False, tqdm=None)[0]
        nll, acc = calc_pheo_fitness(agent)
        fitnesses.append(-nll)
        fitnesses_acc.append(acc)
    return fitnesses, fitnesses_acc

def calc_mutate(pop):
    mutant = [d+1e-2*torch.randn_like(d) for d in pop]
    return mutant

def calc_crossover(pop1, pop2):
    cross = [torch.stack([d1, d2], dim=0).mean(dim=0) for d1, d2 in zip(pop1, pop2)]
    return cross
    
def calc_next_population(population, fitnesses):
    npop = []
    npop.append(population[np.argmax(fitnesses)])
    prob = torch.tensor(fitnesses).softmax(dim=-1).numpy()
    
    n_parents = 29
    n_mutants = 70
    
    mutants = np.random.choice(population, size=n_mutants, p=prob)
    mutants = calc_mutate(mutants)
    parents1 = np.random.choice(population, size=n_parents, p=prob)
    parents2 = np.random.choice(population, size=n_parents, p=prob)
    children = calc_crossover(parents1, parents2)
    
    npop.extend(mutants)
    npop.extend(children)
    
    npop = to_np_array(npop)
    return npop
    
    

population = get_init_population()
for gen_idx in tqdm(range(10000)):
    fitnesses, fitnesses_acc = calc_fitnesses(population)
    data =  {'fitness max': np.max(fitnesses),
             'fitness mean': np.mean(fitnesses),
             'fitness med': np.median(fitnesses)}
    logger.add_scalars('fitness', data, global_step=gen_idx)
    data =  {'fitness max': np.max(fitnesses_acc),
             'fitness mean': np.mean(fitnesses_acc),
             'fitness med': np.median(fitnesses_acc)}
    logger.add_scalars('accuracies', data, global_step=gen_idx)
    logger.add_histogram('fitness histogram', np.array(fitnesses), global_step=gen_idx)
    logger.add_histogram('accuracies histogram', np.array(fitnesses_acc), global_step=gen_idx)
    population = calc_next_population(population, fitnesses)


  0%|          | 0/10000 [00:00<?, ?it/s]

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.


KeyboardInterrupt



In [11]:
population[0].shape

torch.Size([795])

In [None]:

logger.close()

In [169]:
import sys
from types import ModuleType, FunctionType
from gc import get_referents

# Custom objects know their class.
# Function objects seem to know way too much, including modules.
# Exclude modules as well.
BLACKLIST = type, ModuleType, FunctionType
def getsize(obj):
    """sum size of object & members."""
    if isinstance(obj, BLACKLIST):
        raise TypeError('getsize() does not take argument of type: '+ str(type(obj)))
    seen_ids = set()
    size = 0
    objects = [obj]
    while objects:
        need_referents = []
        for obj in objects:
            if not isinstance(obj, BLACKLIST) and id(obj) not in seen_ids:
                seen_ids.add(id(obj))
                size += sys.getsizeof(obj)
                need_referents.append(obj)
        objects = get_referents(*need_referents)
    return size

In [171]:
getsize(ConvNet())

21402