In [23]:
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import copy
from networks import DQN
from scipy import stats
import itertools
import time

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device: ", device)
print("Torch Version: ", torch.__version__)

Device:  cuda
Torch Version:  1.0.1.post2


In [153]:
def increase_capacity_keep_lr(network, capacity, optimizer, device):
    # Store old ids
    old_ids = [id(p) for p in network.parameters()]
    old_param_sizes = [p.size() for p in network.parameters()]

    network.increase_capacity(capacity)

    # Store new ids
    new_ids = [id(p) for p in network.parameters()]
    new_param_sizes = [p.size() for p in network.parameters()]

    # Store old state 
    opt_state_dict = optimizer.state_dict()
    for old_id, new_id, new_param_size, old_param_size in zip(old_ids, new_ids, new_param_sizes, old_param_sizes):
        # Store step, and exp_avgs
        step = opt_state_dict['state'][old_id]['step']
        old_exp_avg = opt_state_dict['state'][old_id]['exp_avg']
        old_exp_avg_sq = opt_state_dict['state'][old_id]['exp_avg_sq']
        old_max_exp_avg_sq = opt_state_dict['state'][old_id]['max_exp_avg_sq']

        exp_avg = torch.zeros(new_param_size)
        exp_avg_sq = torch.zeros(new_param_size)
        max_exp_avg_sq =  torch.zeros(new_param_size)
        # Extend exp_avgs to new shape depending on wether param is bias or weight
        if exp_avg.dim()>1:
            # Weights
            exp_avg[0:old_param_size[0],0:old_param_size[1]] = old_exp_avg
            exp_avg_sq[0:old_param_size[0],0:old_param_size[1]] = old_exp_avg_sq
            max_exp_avg_sq[0:old_param_size[0],0:old_param_size[1]] = old_max_exp_avg_sq
        else:
            # Biases/last layer
            exp_avg[0:old_param_size[0]] = old_exp_avg
            exp_avg_sq[0:old_param_size[0]] = old_exp_avg_sq
            max_exp_avg_sq[0:old_param_size[0]] = old_max_exp_avg_sq
        
        # Delete old id from state_dict and update new params and new id
        del opt_state_dict['state'][old_id]
        opt_state_dict['state'][new_id] = {
            'step': step,
            'exp_avg': exp_avg,
            'exp_avg_sq': exp_avg_sq.to(device),
            'max_exp_avg_sq' : max_exp_avg_sq.to(device)
        }
        opt_state_dict['param_groups'][0]['params'].remove(old_id)
        opt_state_dict['param_groups'][0]['params'].append(new_id)

    network.to(device)
    optimizer = optim.Adam(network.parameters(), amsgrad=True)
    optimizer.load_state_dict(opt_state_dict)
    
    return network, optimizer

In [173]:
# def generate_n_XOR(batch_size, n_inputs, n, p):
#     Xs =  torch.zeros(batch_size,n_inputs, dtype=torch.float)
#     Ys = torch.zeros(batch_size,1, dtype=torch.float)
#     for i in range(batch_size):
#         Xs[i] = torch.randint(0,2,(1,n_inputs))
#         if random.random() < p:
#             Ys[i] = torch.sum(Xs[i][:random.randint(1,n_inputs)]) == 1
#         else:
#             Ys[i] = torch.sum(Xs[i][:n]) == 1
#     return Xs.to(device), Ys.to(device)

# def generate_n_XOR_uniform(batch_size, n_inputs, n, p, delta):
#     Xs =  torch.zeros(batch_size,n_inputs, dtype=torch.float)
#     Ys = torch.zeros(batch_size,1, dtype=torch.float)
#     for i in range(batch_size):
#         Xs[i] = torch.randint(0,2,(1,n_inputs))
#         if random.random() < p:
#             Ys[i] = torch.sum(Xs[i][:random.randint(1,n_inputs)]) == 1
#         else:
#             Ys[i] = torch.sum(Xs[i][:n]) == 1
#         Xs[i] += torch.FloatTensor(Xs[i].size()).uniform_(-delta, delta)
#     return Xs.to(device), Ys.to(device)

# def generate_n_XOR_float(batch_size, n_inputs, n, p):
#     Xs =  torch.zeros((batch_size,n_inputs), dtype=torch.float)
#     Ys = torch.zeros((batch_size,1), dtype=torch.float)
#     for i in range(batch_size):
#         Xs[i] = torch.rand((1,n_inputs))
#         if random.random() < p:
#             Ys[i] = torch.sum(Xs[i][:random.randint(1,n_inputs)]>0.5) == 1
#         else:
#             Ys[i] = torch.sum(Xs[i][:n]>0.5) == 1
#     return Xs, Ys

def generate_test_set(n):
    combinations = list(itertools.product([0,1], repeat=n))
    combinations = np.array(combinations)
    combinations = np.flip(combinations,1)
    combinations = combinations.tolist()
    Xs_test = torch.tensor(combinations, dtype=torch.float)
    Ys_test = (torch.sum(Xs_test, dim=1)==1).view(-1,1).float()
    return Xs_test.to(device),Ys_test.to(device)

def generate_train_sample(n, level, batch_size):
    combinations = list(itertools.product([0,1], repeat=n))
    combinations = np.array(combinations)
    combinations = np.flip(combinations,1)
    combinations = combinations.tolist()
    Xs_test = torch.tensor(combinations, dtype=torch.float)[:2**level]
    Ys_test = (torch.sum(Xs_test, dim=1)==1).view(-1,1).float()[:2**level]
    r = torch.randint(0,Xs_test.size(0),(batch_size,))
    Xs_test = Xs_test[r]
    Ys_test = Ys_test[r]
    return Xs_test.to(device),Ys_test.to(device)

In [177]:
generate_train_sample(3, 2, 9)

(tensor([[1., 1., 0.],
         [0., 1., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [1., 1., 0.],
         [0., 1., 0.],
         [1., 1., 0.],
         [0., 1., 0.]]), tensor([[0.],
         [1.],
         [0.],
         [0.],
         [0.],
         [0.],
         [1.],
         [0.],
         [1.]]))

In [175]:
def train(iterations, network, criterion, optimizer, batch_size, non_linearity, n_inputs, level, p, Xs_test, Ys_test):
    eval_loss = float("inf")
    losses = []
    i = 0
    # Train untill max iterations reached or loss threshold passed
    while i<iterations and eval_loss>0.05:
        optimizer.zero_grad()

        # Uniform 0.5, Naive 1.0, Zaremba 0.9
#         Xs, Ys = generate_n_XOR_uniform(batch_size, n_inputs, level, p, 0.1) 
        generate_train_sample(n_inputs, level, batch_size)

        prediction = non_linearity(network(Xs))
        loss = criterion(prediction, Ys)

        loss.backward()
        optimizer.step()
        i += 1
        
        with torch.no_grad():
            prediction = non_linearity(network(Xs_test[:2**level]))
            eval_loss = criterion(prediction, Ys_test[:2**level])
            losses.append(eval_loss)
            
    return network, optimizer, eval_loss, i, losses

def n_way_xor_experiment(batch_size, initial_capacity, capacity, non_linearity, n_inputs, p, iterations, seeds, keep_lr):
    # Generate test set of size n_inputs
    Xs_test, Ys_test = generate_test_set(n_inputs)
    
    # Initialise arrays. First 2 levels are skipped
    loss_per_level = np.zeros(n_inputs-1)
    duration_per_level = np.zeros(n_inputs-1)
    total_loss = np.zeros((n_inputs-1)*iterations)
    
    for seed in range(seeds):
        # Set seeds
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)

        # Initialisation network, optimiser and loss
        network = DQN(n_inputs, initial_capacity.copy(), 1, non_linearity).to(device)
        optimizer = optim.Adam(network.parameters(),lr=0.01, amsgrad=True)
        criterion = nn.MSELoss()
        
        global_i = 0
        for level in range(2,n_inputs + 1):
            # Train network on level
            network, optimizer, loss, duration, losses = train(iterations, network, criterion, optimizer, batch_size, non_linearity, n_inputs, level, p, Xs_test, Ys_test)
            
            total_loss[global_i:global_i+duration] = total_loss[global_i:global_i+duration] + losses
            global_i = global_i+duration
            # Increase capacity
            if capacity is not None:
                network, optimizer = increase_capacity_keep_lr(network, capacity, optimizer, device)
                if not keep_lr:
                    optimizer = optim.Adam(network.parameters(), amsgrad=True)
                network.to(device)
            
            # Gather stats
            duration_per_level[level-2] += duration
            loss_per_level[level-2] += loss 
    
    for level in range(n_inputs-1):
        print('Number of inputs: ', level + 2)
        print(loss_per_level[level]/seeds)
        print(duration_per_level[level]/seeds)
        
    print('Total Duration:', np.sum(duration_per_level)/seeds)
    print(network)
    # Indicate last index of 0 with vertical line
    plt.plot(total_loss/seeds)

In [176]:
device = 'cpu'
t = time.time()
plt.figure()
n_way_xor_experiment(4, [5,1], [1,1], torch.sigmoid ,5, 0.2, 10000, 10,True)
plt.show()
# print(time.time()-t)
# plt.figure()
# n_way_xor_experiment(1, [5], [1], torch.sigmoid ,5, 0.2, 10000, 10,False)
# plt.show()
print('-----------')
t = time.time()
plt.figure()
n_way_xor_experiment(4, [5,5], None, torch.sigmoid ,5, 0.2, 10000, 10, True)
plt.show()
print(time.time()-t)

KeyboardInterrupt: 

<Figure size 432x288 with 0 Axes>

DQN(
  (layers): ModuleList(
    (0): Linear(in_features=3, out_features=1, bias=True)
    (1): Linear(in_features=1, out_features=1, bias=True)
    (2): Linear(in_features=1, out_features=1, bias=True)
  )
)
DQN(
  (layers): ModuleList(
    (0): Linear(in_features=3, out_features=1, bias=True)
    (1): Linear(in_features=1, out_features=2, bias=True)
    (2): Linear(in_features=2, out_features=1, bias=True)
  )
)


In [24]:

plt.figure()
n_way_xor_experiment(1, [2], [1], torch.relu ,5, 0.2, 10000, 5)
plt.show()

print('-----------')
plt.figure()
n_way_xor_experiment(100, [5], None, torch.relu ,5, 0.2, 10000, 5)
plt.show()

TypeError: n_way_xor_experiment() missing 1 required positional argument: 'keep_lr'

<Figure size 432x288 with 0 Axes>

# Hirose

In [None]:
def train_hirose(network, criterion, optimizer, batch_size, non_linearity, n_inputs, n, p, Xs_test, Ys_test):
    before_loss = float("inf")
    
    with torch.no_grad():
        prediction = non_linearity(network(Xs_test[:2**n]))
        before_loss = criterion(prediction, Ys_test[:2**n])
    
    eval_loss = float("inf")
    i = 0
    # Train untill max iterations reached or loss threshold passed
    while eval_loss>0.05:        
        optimizer.zero_grad()

        # Uniform 0.5, Naive 1.0, Zaremba 0.9
        Xs, Ys = generate_n_XOR_float(batch_size, n_inputs, n, p) 

        prediction = non_linearity(network(Xs))
        loss = criterion(prediction, Ys)

        loss.backward()
        optimizer.step()
        i += 1
        
        with torch.no_grad():
            prediction = non_linearity(network(Xs_test[:2**n]))
            eval_loss = criterion(prediction, Ys_test[:2**n])
        
        if i%1000==0:
            if before_loss*0.99<=eval_loss:
                break
            else:
                before_loss = eval_loss
            
    return network, optimizer, eval_loss, i

def n_way_xor_experiment_hirose(batch_size, initial_capacity, capacity, non_linearity, n_inputs, p, seeds, keep_lr):
    # Generate test set of size n_inputs
    Xs_test,Ys_test = generate_test_set(n_inputs)
    
    # Initialise arrays. First 2 levels are skipped
    loss_per_level = np.zeros(n_inputs-2)
    duration_per_level = np.zeros(n_inputs-2)
    
    for seed in range(seeds):
        # Set seeds
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        
        # Initialisation network, optimiser and loss
        network = DQN(n_inputs, initial_capacity.copy(), 1, non_linearity).to(device)
        optimizer = optim.Adam(network.parameters(), amsgrad=True)
        criterion = nn.MSELoss()
        
        for level in range(2,n_inputs+1):
            loss = float("inf")
            while loss>0.01:
                # Train network on level
                network, optimizer, loss, duration = train_hirose(network, criterion, optimizer, batch_size, non_linearity, n_inputs, level, p, Xs_test, Ys_test)
            
                if loss>0.01:
                    
                    # Increase capacity
                    if capacity is not None:
                        network, optimizer = increase_capacity_keep_lr(network, capacity, optimizer, device)
                        if not keep_lr:
                            optimizer = optim.Adam(network.parameters(), amsgrad=True)
                        network.to(device)
                        
        
                # Gather stats
                duration_per_level[level-2] += duration
            loss_per_level[level-2] += loss
            
    for level in range(n_inputs-2):
        print('Level: ', level)
        print(loss_per_level[level]/seeds)
        print(duration_per_level[level]/seeds)
        
    print('Total Duration:', np.sum(duration_per_level)/seeds)

In [None]:
t = time.time()
n_way_xor_experiment_hirose(100, [2], [1], torch.sigmoid ,5, 0.2, 7, True)
print(time.time()-t)
t = time.time()
n_way_xor_experiment_hirose(100, [2], [1], torch.sigmoid ,5, 0.2, 7, False)
print(time.time()-t)

In [78]:
0.29*0.99

0.28709999999999997

In [109]:
[i for i in range(2,3)]

[2]