In [1]:
from __future__ import print_function
from IPython.display import clear_output
    
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

import numpy as np

from numpy.random import randint, rand

import copy

In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)
        
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

In [3]:
# Training settings
BATCH_SIZE = 64
TEST_BATCH_SIZE=1000
EPOCHS=1
LR=1.0
GAMMA=0.7
SEED=1
LOG_INTERVAL=10

In [4]:
use_cuda = torch.cuda.is_available()
torch.manual_seed(SEED)

device = torch.device("cuda" if use_cuda else "cpu")

train_kwargs = {'batch_size': BATCH_SIZE}
test_kwargs = {'batch_size': TEST_BATCH_SIZE}

if use_cuda:
    cuda_kwargs = {'num_workers': 1,
                   'pin_memory': True,
                   'shuffle': True}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)

In [5]:
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
    ])

dataset1 = datasets.MNIST('../data', train=True, transform=transform)

dataset2 = datasets.MNIST('../data', train=False, transform=transform)

train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

In [6]:
models = []
for i in range(20):
    models.append(Net().to(device))

In [7]:
def selection(models, losses, k=3, n_pop=20):
    parent_models = []
    parent_losses = []
    
    while(len(parent_models) < n_pop):
        # randomly select candidate
        selection_idx = randint(n_pop)
    
        # compare to k-1 random candidates
        for idx in randint(0, n_pop, k-1):
            # check if better (e.g. perform a tournament)
            if losses[idx] > losses[selection_idx]:
                selection_idx = idx
        parent_models.append(models[selection_idx])
        parent_losses.append(losses[selection_idx])
    return parent_models, parent_losses

In [8]:
def crossover(p1, p2, r_cross=0.8):
    """
    One point crossover 
    
    :param p1: list
    :param p2: list
    :param r_cross: float
    :return: list of lists
    """
    
    # children are copies of parents by default
    c1, c2 = Net(), Net()
    c1.load_state_dict(copy.deepcopy(p1.state_dict()))
    c2.load_state_dict(copy.deepcopy(p2.state_dict()))
    
    state_dict = model.state_dict()
    
    # check for recombination
    if rand() < r_cross:
        # generate a random number to use as index for crossover point that is before the end of the string
        pt = randint(1, len(state_dict)-2)
        
        k = list(state_dict.keys())
        
        sd1 = c1.state_dict()
        sd2 = c2.state_dict()
        
        # perform crossover
        for i in range(pt):
            sd1[k[i]] = sd2[k[i]]
        for i in range(pt, len(sd1)):
            sd2[k[i]] = sd1[k[i]]
        
        c1.load_state_dict(sd1)
        c2.load_state_dict(sd2)
        
    return [c1, c2]

In [9]:
def mutation(m, r_mut=0.3):
    """
    Mutation operator
    
    :param bitstring: list
    :param r_mut: float
    :return: list
    """
    sd = m.state_dict()
    k = list(sd.keys())
    
    for i in range(len(k)):
        sh = sd[k[i]].shape
        if 'bias' in k[i]:
            for a in range(sh[0]):
                if rand() < r_mut:
                    sd[k[i]][a] = randint(-1.0, 1.0)                
        elif 'conv' in k[i]:
            for a in range(sh[0]):
                for b in range(sh[1]):
                    for c in range(sh[2]):
                        for d in range(sh[3]):
                            if rand() < r_mut:
                                sd[k[i]][a][b][c][d] = randint(-1.0, 1.0)                        
        elif 'fc' in k[i]:
            for a in range(sh[0]):
                for b in range(sh[1]):
                    if rand() < r_mut:
                        sd[k[i]][a][b] = randint(-1.0, 1.0)
    m.load_state_dict(sd)
    return m

In [10]:
def genetic_algorithm(device, models, losses, n_pop=20):
    #selection
    parent_models, parent_losses = selection(models, losses)
    
    # create the next generation
    children = []
    
    for i in range(0, n_pop, 2):
        # get selected parents in pairs
        p1, p2 = parent_models[i], parent_models[i+1]
        l1, l2 = parent_losses[i], parent_losses[i+1]
        
        # crossover and mutation
        for c in crossover(p1, p2):
            # mutation
            c = mutation(c)
            # store for next generation
            children.append(c.to(device))
    # replace population
    models = children
        
    return models

In [None]:
for e in range(EPOCHS):
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        losses = []
        for i in range(len(models)):
            model = models[i].train()
            output = model(data)
            loss = F.nll_loss(output, target)
            losses.append(loss)
        max_loss = np.array(losses).max()
        
        models = genetic_algorithm(device, models, losses)
        losses = []
        
        if batch_idx % LOG_INTERVAL == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'
                  .format(
                      e, batch_idx * len(data),
                      len(train_loader.dataset),
                      100. * batch_idx / len(train_loader),
                      loss.item()
                  )
                 )

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)




In [10]:
state_dict = model.state_dict()

In [11]:
for s in state_dict:
    print(s)
    print(state_dict[s].shape)
    state_dict[s] = torch.zeros(state_dict[s].shape)

conv1.weight
torch.Size([32, 1, 3, 3])
conv1.bias
torch.Size([32])
conv2.weight
torch.Size([64, 32, 3, 3])
conv2.bias
torch.Size([64])
fc1.weight
torch.Size([128, 9216])
fc1.bias
torch.Size([128])
fc2.weight
torch.Size([10, 128])
fc2.bias
torch.Size([10])


In [29]:
model.load_state_dict(state_dict)

<All keys matched successfully>

In [31]:
for name, param in model.named_parameters():
    if "conv" in name:
        if "weight" in name:
            print(name)
            print(param)
        else:
            print(name)
            print(param)
    else:
        if "weight" in name:
            print(name)
            print(param)
        else:
            print(name)
            print(param)

conv1.weight
Parameter containing:
tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
     