In [1]:
%load_ext autoreload
%autoreload 2

import time
import random
import pickle, os
import numpy as np
import cvxpy as cp

from copy import deepcopy

import matplotlib.pyplot as plt
%matplotlib inline

import torch
from torch.autograd import Variable
from torch.utils.tensorboard import SummaryWriter

from free_flyer.free_flyer import FreeFlyer
from free_flyer.utils import *
from solvers.mlopt_ff import MLOPT_FF

In [2]:
#load train/test data
prob = FreeFlyer() #use default config, pass different config file oth.
config_fn = './free_flyer/config/default.p'

config_file = open(config_fn,'rb')
dataset_name, _, _ = pickle.load(config_file); config_file.close()

relative_path = os.getcwd()
dataset_fn = relative_path + '/free_flyer/data/' + dataset_name

train_file = open(dataset_fn+'/train.p','rb')
# p_train, x_train, u_train, y_train, c_train, times_train = pickle.load(train_file)
train_data = pickle.load(train_file)
train_file.close()
x_train = train_data[1]
y_train = train_data[3]

test_file = open(dataset_fn+'/test.p','rb')
# p_test, x_test, u_test, y_test, c_test, times_test = pickle.load(test_file)
test_data = pickle.load(test_file)
p_test, x_test, u_test, y_test, c_test, times_test = test_data
test_file.close()

n_test = x_test.shape[0]

### Hacky utility functions for passing network weights around

In [3]:
# Copies all network weights from target NN to source NN except for last layer
def copy_shared_params(target, source):
    ff_depth = len(target.ff_layers)-1
    last_layer_names = ['ff_layers.{}.weight'.format(ff_depth), 'ff_layers.{}.bias'.format(ff_depth)]
    source_params, target_params = source.named_parameters(), target.named_parameters()
    target_params_dict = dict(target_params)
    for name, param in source_params:
        if name in last_layer_names:
            continue
        if name in target_params_dict:
            target_params_dict[name].data.copy_(param.data)

In [4]:
# Updates target NN parameters using list of weights in source_data except for last layer
def copy_all_but_last(target, source_data):
    target_params = target.named_parameters()
    target_params_dict = dict(target_params)
    
    idx = 0
    for ii in range(len(target.conv_layers)):
        target_params_dict['conv_layers.{}.weight'.format(ii)].data.copy_(source_data[idx].data)
        idx+=1
        target_params_dict['conv_layers.{}.bias'.format(ii)].data.copy_(source_data[idx].data)
        idx+=1

    for ii in range(len(target.ff_layers)-1):
        target_params_dict['ff_layers.{}.weight'.format(ii)].data.copy_(source_data[idx].data)
        idx+=1
        target_params_dict['ff_layers.{}.bias'.format(ii)].data.copy_(source_data[idx].data)
        idx+=1

        # Updates target NN parameters using list of weights in source_data for last layer
def copy_last(target, source_data):
    target_params = target.named_parameters()
    target_params_dict = dict(target_params)

    ff_depth = len(target.ff_layers)-1
    target_params_dict['ff_layers.{}.weight'.format(ff_depth)].data.copy_(source_data[-2].data)
    target_params_dict['ff_layers.{}.bias'.format(ff_depth)].data.copy_(source_data[-1].data)

### Load MLOPT network with pre-trained weights

In [5]:
system = 'free_flyer'
prob_features = ['x0', 'obstacles_map']

mlopt_cnn = MLOPT_FF(system, prob, prob_features)

n_features = 4
mlopt_cnn.construct_strategies(n_features, train_data)

device_id = 1 # use -1 for CPU
mlopt_cnn.setup_network(device_id=device_id)

fn_saved = 'models/mloptff_free_flyer_20200716_0708.pt'   # New spaced out dataset
mlopt_cnn.load_network(fn_saved)

mlopt_cnn.model

Loading presaved classifier model from models/mloptff_free_flyer_20200716_0708.pt


CNNet(
  (conv_activation): ReLU()
  (ff_activation): ReLU()
  (conv_layers): ModuleList(
    (0): Conv2d(3, 16, kernel_size=(2, 2), stride=(2, 2))
    (1): Conv2d(16, 16, kernel_size=(2, 2), stride=(2, 2))
    (2): Conv2d(16, 16, kernel_size=(2, 2), stride=(2, 2))
  )
  (ff_layers): ModuleList(
    (0): Linear(in_features=260, out_features=128, bias=True)
    (1): Linear(in_features=128, out_features=128, bias=True)
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): Linear(in_features=128, out_features=458, bias=True)
  )
)

In [10]:
feas_model = deepcopy(mlopt_cnn.model)
mlopt_model = mlopt_cnn.model
copy_shared_params(feas_model, mlopt_model)

In [7]:
meta_lr = 1e-3
update_lr = 1e-3
num_train = train_data[0]['x0'].shape[0]
writer = SummaryWriter("runs/warm_start_3")

In [None]:
# grab training params
BATCH_SIZE = mlopt_cnn.training_params['BATCH_SIZE']
TRAINING_ITERATIONS = mlopt_cnn.training_params['TRAINING_ITERATIONS']
TRAINING_ITERATIONS = 5
BATCH_SIZE = mlopt_cnn.training_params['BATCH_SIZE']
CHECKPOINT_AFTER = mlopt_cnn.training_params['CHECKPOINT_AFTER']
CCHECKPOINT_AFTER = int(1000)
SAVEPOINT_AFTER = mlopt_cnn.training_params['SAVEPOINT_AFTER']
SAVEPOINT_AFTER = int(200)
TEST_BATCH_SIZE = mlopt_cnn.training_params['TEST_BATCH_SIZE']

NUM_META_PROBLEMS = np.maximum(int(np.ceil(BATCH_SIZE / mlopt_cnn.problem.n_obs)), 10)
UPDATE_STEP = 1

model = mlopt_model
model.to(device=mlopt_cnn.device)

params = train_data[0]
X = mlopt_cnn.features[:mlopt_cnn.problem.n_obs*mlopt_cnn.num_train]
X_cnn = np.zeros((BATCH_SIZE, 3,mlopt_cnn.problem.H,mlopt_cnn.problem.W))
Y = mlopt_cnn.labels[:mlopt_cnn.problem.n_obs*mlopt_cnn.num_train,0]

training_loss = torch.nn.CrossEntropyLoss()
meta_opt = torch.optim.Adam(model.parameters(), lr=meta_lr, weight_decay=0.00001)

itr = 1
for epoch in range(TRAINING_ITERATIONS):  # loop over the dataset multiple times
    t0 = time.time()
    running_loss = 0.0
    rand_idx = list(np.arange(0,X.shape[0]-1))
    random.shuffle(rand_idx)

    # Sample all data points
    indices = [rand_idx[ii * BATCH_SIZE:(ii + 1) * BATCH_SIZE] for ii in range((len(rand_idx) + BATCH_SIZE - 1) // BATCH_SIZE)]

    for ii,idx in enumerate(indices):

        # fast_weights are network weights for feas_model with descent steps taken
        fast_weights = list(feas_model.parameters())
        for _ in range(UPDATE_STEP):
            inner_loss = torch.zeros(1)

            # Update feas_model to be using fast_weights
            copy_last(feas_model, fast_weights)
            copy_all_but_last(feas_model, fast_weights)

            for __ in range(NUM_META_PROBLEMS):
                idx_val = np.random.randint(0,num_train)

                params = train_data[0]

                prob_params = {}
                for k in params:
                    prob_params[k] = params[k][mlopt_cnn.cnn_features_idx[idx_val][0]]

                ff_inputs_inner = torch.from_numpy(mlopt_cnn.problem.construct_features(prob_params, mlopt_cnn.prob_features))
                ff_inputs_inner = Variable(ff_inputs_inner.repeat(mlopt_cnn.problem.n_obs,1)).float().to(device=mlopt_cnn.device)

                X_cnn_inner = np.zeros((mlopt_cnn.problem.n_obs, 3, mlopt_cnn.problem.H, mlopt_cnn.problem.W))
                for ii_obs in range(mlopt_cnn.problem.n_obs):
                    X_cnn_inner[ii_obs] = mlopt_cnn.problem.construct_cnn_features(prob_params, \
                                    mlopt_cnn.prob_features, \
                                    ii_obs=ii_obs)
                cnn_inputs_inner = Variable(torch.from_numpy(X_cnn_inner)).float().to(device=mlopt_cnn.device)

                scores = mlopt_model_copy(cnn_inputs_inner, ff_inputs_inner).detach().cpu().numpy()
                class_labels = np.argmax(scores, axis=1)

                y_guess = np.zeros((4*mlopt_cnn.problem.n_obs, mlopt_cnn.problem.N-1))
                for cl_ii,cl in enumerate(class_labels):
                    cl_idx = np.where(mlopt_cnn.labels[:,0] == cl)[0][0]
                    y_obs = mlopt_cnn.labels[cl_idx,1:]
                    y_guess[4*cl_ii:4*(cl_ii+1)] = np.reshape(y_obs, (4, mlopt_cnn.problem.N-1))

                # Sometimes Mosek fails, so try again with Gurobi
                try:
                    prob_success = mlopt_cnn.problem.solve_pinned(prob_params, y_guess, solver=cp.MOSEK)[0]
                except:
                    prob_success = mlopt_cnn.problem.solve_pinned(prob_params, y_guess, solver=cp.GUROBI)[0]

                losses = torch.zeros(8,1)

                feas_scores = feas_model(cnn_inputs_inner, ff_inputs_inner)
                for ii_obs in range(mlopt_cnn.problem.n_obs):
                    losses[ii_obs] = feas_scores[ii_obs,class_labels[ii_obs]]

                margin = 10.
                if prob_success:
                    # If problem feasible, push scores to positive value
                    inner_loss += torch.relu(margin - torch.sum(losses))
                else:
                    # If problem infeasible, push scores to negative value 
                    inner_loss += torch.relu(margin + torch.sum(losses))

            # Descent step on feas_model network weights
            inner_loss /= float(NUM_META_PROBLEMS)
            grad = torch.autograd.grad(inner_loss, fast_weights)
            fast_weights = list(map(lambda p: p[1] - update_lr * p[0], zip(grad, fast_weights)))

        # Pass inner loop weights to MLOPT classifier (except last layer)
        copy_all_but_last(mlopt_model_copy, fast_weights)

        ff_inputs = Variable(torch.from_numpy(X[idx,:])).float().to(device=mlopt_cnn.device)
        labels = Variable(torch.from_numpy(Y[idx])).long().to(device=mlopt_cnn.device)

        # forward + backward + optimize
        X_cnn = np.zeros((len(idx), 3,mlopt_cnn.problem.H,mlopt_cnn.problem.W))
        for idx_ii, idx_val in enumerate(idx):
            prob_params = {}
            for k in params:
                prob_params[k] = params[k][mlopt_cnn.cnn_features_idx[idx_val][0]]
            X_cnn[idx_ii] = mlopt_cnn.problem.construct_cnn_features(prob_params, mlopt_cnn.prob_features, ii_obs=mlopt_cnn.cnn_features_idx[idx_val][1])
        cnn_inputs = Variable(torch.from_numpy(X_cnn)).float().to(device=mlopt_cnn.device)
        outputs = model(cnn_inputs, ff_inputs)

        loss = training_loss(outputs, labels).float().to(device=mlopt_cnn.device)
        running_loss += loss.item()

        class_guesses = torch.argmax(outputs,1)
        accuracy = torch.mean(torch.eq(class_guesses,labels).float())
        loss.backward()
        meta_opt.step()
        meta_opt.zero_grad() # zero the parameter gradients

        # Update feas_model weights
        copy_last(feas_model, [fw.detach() for fw in fast_weights])
        copy_shared_params(feas_model, model)

        if itr % CHECKPOINT_AFTER == 0:
            rand_idx = list(np.arange(0,X.shape[0]-1))
            random.shuffle(rand_idx)
            test_inds = rand_idx[:TEST_BATCH_SIZE]
            ff_inputs = Variable(torch.from_numpy(X[test_inds,:])).float().to(device=mlopt_cnn.device)
            labels = Variable(torch.from_numpy(Y[test_inds])).long().to(device=mlopt_cnn.device)

            # forward + backward + optimize
            X_cnn = np.zeros((len(test_inds), 3,mlopt_cnn.problem.H,mlopt_cnn.problem.W))
            for idx_ii, idx_val in enumerate(test_inds):
                prob_params = {}
                for k in params:
                    prob_params[k] = params[k][mlopt_cnn.cnn_features_idx[idx_val][0]]
                X_cnn[idx_ii] = mlopt_cnn.problem.construct_cnn_features(prob_params, mlopt_cnn.prob_features, ii_obs=mlopt_cnn.cnn_features_idx[idx_val][1])
            cnn_inputs = Variable(torch.from_numpy(X_cnn)).float().to(device=mlopt_cnn.device)
            outputs = model(cnn_inputs, ff_inputs)

            loss = training_loss(outputs, labels).float().to(device=mlopt_cnn.device)
            class_guesses = torch.argmax(outputs,1)
            accuracy = torch.mean(torch.eq(class_guesses,labels).float())
            print("loss:   "+str(loss.item())+",   acc:  "+str(accuracy.item()))

        if itr % SAVEPOINT_AFTER == 0:
            writer.add_scalar('Loss/train', running_loss / float(SAVEPOINT_AFTER), itr)
            running_loss = 0.

        itr += 1