In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import cvxpy as cp

import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torch.optim as optim
import os
import random

import sys
sys.path.insert(0, './mlopt-micp')
sys.path.insert(0, './mlopt-micp/cartpole')

import optimizer
from problem import Cartpole
from src.ae import Encoder, get_cartpole_encoder

In [3]:
def euclidean_dist(x,y):
    # x: NxD
    # y: MxD
    n = x.size(0)
    m = y.size(0)
    d = x.size(1)
    assert d == y.size(1)
    
    x = x.unsqueeze(1).expand(n, m, d)
    y = y.unsqueeze(0).expand(n, m, d)
    return torch.pow(x-y, 2).sum(2)

In [4]:
pp = Cartpole()

In [5]:
print('Total number of classes: {}'.format(pp.n_strategies))
print('Length of feature vector: {}'.format(pp.n_features))

Total number of classes: 581
Length of feature vector: 13


In [6]:
dim_in, dim_z = pp.n_features, 4

enc = get_cartpole_encoder(dim_in, dim_z).cuda()
enc(torch.from_numpy(pp.features[:2]).float().cuda())

# training parameters
TRAINING_ITERATIONS = int(5000)
BATCH_SIZE = int(64)
CHECKPOINT_AFTER = int(1250)
SAVEPOINT_AFTER = int(2500)

rand_idx = list(np.arange(0, pp.n_strategies-1))

indices = [rand_idx[ii * BATCH_SIZE:(ii + 1) * BATCH_SIZE] for ii in range((len(rand_idx) + BATCH_SIZE - 1) // BATCH_SIZE)]
random.shuffle(indices)

enc_dict = {}
str_dict = {}
for ii in range(len(pp.features)):
    str_idx = int(pp.labels[ii,0])
    str_dict[ii] = str_idx
    if str_idx in enc_dict.keys():
        enc_dict[str_idx] += [ii]
    else:
        enc_dict[str_idx] = [ii]
        
feats = torch.from_numpy(pp.features).float().cuda()

In [7]:
optimizer = optim.Adam(enc.parameters(),lr=3e-4)
N = pp.n_strategies # number of classes in training set
Nc = 100 # number of classes per episode
Ns = 20  # number of support examples per class
Nq = 20  # number of query examples per class
training_iters = 10000
for tt in range(training_iters):
    optimizer.zero_grad()

    #sample classes for this iter
    V = np.random.randint(0, pp.n_strategies, Nc)
    Sk = {}  # support examples
    Qk = {}  # query examples
    ck = torch.zeros((Nc, dim_z))

    for ii, v in enumerate(V):
        if len(enc_dict[v]) <= Ns: #if not enough examples for support
            Sk[v] = enc_dict[v]
            Qk[v] = enc_dict[v]
        else:
            Sk[v] = random.sample(enc_dict[v], Ns)
            Qk[v] = [kk for kk in enc_dict[v] if kk not in Sk[v]]
            if len(Qk[v]) > Nq: #if not enough examples for query
                Qk[v] = random.sample(Qk[v], Nq)
        enc_support = enc(feats[Sk[v],:])
        ck[ii,:] = torch.mean(enc_support, axis=0).float().cuda()
        
    losses = torch.zeros(len(V))
    correct = torch.zeros(len(V))
    total = torch.zeros(len(V))
    for ii, v in enumerate(V):
        fx = enc(feats[Qk[v],:]) #current features
        dists = euclidean_dist(fx.cuda(),ck.cuda()) #compute distance between centroid & query embeds
        log_p_y = dists[:,ii] + torch.log(torch.sum(torch.exp(-dists)+1e-6, axis=1))
        losses[ii] = log_p_y.mean()
        #compute accuracy
        correct[ii] = torch.sum(torch.argmin(dists,axis=1)==ii)
        total[ii] = len(Qk[v])
        
    acc = torch.sum(correct)/torch.sum(total)
    
    if tt % 50 == 0: #print for debug
        print(acc, torch.mean(losses))
    
    torch.mean(losses).backward()
    torch.nn.utils.clip_grad_norm_(enc.parameters(), 1.)
    total_norm = 0.
    optimizer.step()

tensor(0.1619) tensor(4.6039, grad_fn=<MeanBackward0>)
tensor(0.1455) tensor(4.5528, grad_fn=<MeanBackward0>)
tensor(0.1178) tensor(4.1660, grad_fn=<MeanBackward0>)
tensor(0.1367) tensor(3.6817, grad_fn=<MeanBackward0>)
tensor(0.1187) tensor(3.5258, grad_fn=<MeanBackward0>)
tensor(0.1740) tensor(3.2067, grad_fn=<MeanBackward0>)
tensor(0.2357) tensor(2.9375, grad_fn=<MeanBackward0>)
tensor(0.2140) tensor(2.5882, grad_fn=<MeanBackward0>)
tensor(0.1843) tensor(2.5683, grad_fn=<MeanBackward0>)
tensor(0.2529) tensor(2.4319, grad_fn=<MeanBackward0>)
tensor(0.2296) tensor(2.4675, grad_fn=<MeanBackward0>)
tensor(0.1947) tensor(2.4063, grad_fn=<MeanBackward0>)
tensor(0.2221) tensor(2.4792, grad_fn=<MeanBackward0>)
tensor(0.2652) tensor(2.4749, grad_fn=<MeanBackward0>)
tensor(0.2371) tensor(2.4477, grad_fn=<MeanBackward0>)
tensor(0.2482) tensor(2.2410, grad_fn=<MeanBackward0>)
tensor(0.2605) tensor(2.3936, grad_fn=<MeanBackward0>)
tensor(0.2709) tensor(2.3220, grad_fn=<MeanBackward0>)
tensor(0.3

KeyboardInterrupt: 

In [8]:
#test script
n_train_strategies = pp.n_strategies #store how many strats in train set
c_k = torch.zeros((n_train_strategies,4)) 
embeddings = enc(feats) #embed training points
for ii in range(n_train_strategies): #compute train centroids
    inds = enc_dict[ii]
    c_k[ii,:] = torch.mean(embeddings[inds,:],axis=0).cuda()

#compute strategy dictionary for all problems
pp.training_batch_percentage = 1.
pp.construct_strategies()
strat_lookup = {}
for k, v in pp.strategy_dict.items():
    strat_lookup[v[0]] = v[1:]

#setup for test
test_feats = torch.from_numpy(pp.features[int(0.9*pp.n_probs):,:]).float().cuda()
test_enc = enc(test_feats).cuda()
test_dists = torch.cdist(test_enc,c_k.cuda()).detach().cpu().numpy()
test_start = int(0.9*pp.n_probs)
n_test = int(0.1*pp.n_probs)
ind_max = np.argsort(test_dists)[:,:pp.n_evals]
feasible = np.zeros(n_test)
costs = np.zeros(n_test)

In [9]:
prob_success = False

for ii in range(n_test):
    for jj in range(pp.n_evals):
        y_guess = strat_lookup[ind_max[ii,jj]]
        try:
            prob_success, cost, solve_time = pp.solve_mlopt_prob_with_idx(ii+test_start, y_guess)
            if prob_success:
                feasible[ii] = 1.
                costs[ii] = cost
                print('Succeded at {} with {} tries'.format(ii,jj+1))
                break
        except:
            print('mosek failed at '.format(ii))

Succeded at 1 with 4 tries
Succeded at 2 with 0 tries
Succeded at 4 with 7 tries
Succeded at 6 with 1 tries
Succeded at 7 with 1 tries
Succeded at 8 with 0 tries
Succeded at 10 with 3 tries
Succeded at 11 with 0 tries
Succeded at 12 with 1 tries
Succeded at 14 with 9 tries
Succeded at 17 with 3 tries
Succeded at 18 with 7 tries
Succeded at 21 with 9 tries
Succeded at 24 with 0 tries
Succeded at 26 with 0 tries
Succeded at 27 with 3 tries
Succeded at 28 with 2 tries
Succeded at 29 with 7 tries
Succeded at 31 with 1 tries
Succeded at 32 with 0 tries
Succeded at 34 with 8 tries
Succeded at 36 with 0 tries
Succeded at 37 with 0 tries
Succeded at 38 with 0 tries
Succeded at 39 with 0 tries
Succeded at 41 with 4 tries
Succeded at 43 with 3 tries
Succeded at 44 with 3 tries
Succeded at 45 with 0 tries
Succeded at 46 with 4 tries
Succeded at 47 with 6 tries
Succeded at 48 with 1 tries
Succeded at 49 with 2 tries
Succeded at 50 with 2 tries
Succeded at 52 with 4 tries
Succeded at 54 with 0 trie

In [None]:
global_acc = sum(sum(np.equal(ind_max,pp.labels[test_start:,0][:,None])))/(0.1*pp.n_probs)
global_acc

In [12]:
np.mean(feasible[:ii])

0.6464646464646465