In [1]:
from constants import *
from solver import SolveScheduling
from get_data import *
from train import * 
from models import *
from projectnet import *

import warnings
warnings.filterwarnings('ignore')

%load_ext autoreload
%autoreload 2

In [2]:
params = {"n": 24, "c_ramp": 0.4, "gamma_under": 50, "gamma_over": 0.5}
scheduling_solver = SolvePointQP(params)
dist_solver = SolveScheduling(params)

X_train, Y_train, X_test, Y_test, X_train_pt, Y_train_pt, X_test_pt, Y_test_pt = get_data() 

In [3]:
G = scheduling_solver.G[24*2:24*2+23*2,:]
A = torch.cat((G, torch.eye(G.shape[0]).to(DEVICE)), dim=1).float().to(DEVICE)
b = params['c_ramp'] * torch.ones((24 - 1)*2, device=DEVICE).float()

In [4]:
print(A.shape, b.shape)

torch.Size([46, 118]) torch.Size([46])


In [None]:
rounds = 5
projectnet = ProjectNet(A, b, 24, rounds=rounds).to(DEVICE)
train_projectnet(projectnet, Y_train_pt, params, epochs=50, verbose=True) 
torch.save(projectnet.state_dict(), "projectnet.pt")

epoch  0  mean loss:  1.1017354394380863  median  1.0509483218193054 cur:  0.8622184991836548
epoch  1  mean loss:  0.9571446770429611  median  0.941331684589386 cur:  0.6837594509124756
epoch  2  mean loss:  0.7941359728574753  median  0.8484386801719666 cur:  0.7007178068161011
epoch  3  mean loss:  0.7236167746782303  median  0.8097432851791382 cur:  0.5807981491088867
epoch  4  mean loss:  0.6688397830724716  median  0.7655080854892731 cur:  0.5104993581771851
epoch  5  mean loss:  0.582694853246212  median  0.7079532742500305 cur:  0.4150388240814209
epoch  6  mean loss:  0.49595375657081603  median  0.684368908405304 cur:  0.44040292501449585
epoch  7  mean loss:  0.5047060745954514  median  0.662131667137146 cur:  0.44454240798950195
epoch  8  mean loss:  0.5292767044901848  median  0.6324037611484528 cur:  0.46450942754745483
epoch  9  mean loss:  0.4757129240036011  median  0.6030329763889313 cur:  0.36275070905685425
epoch  10  mean loss:  0.4041717004776001  median  0.580399

In [6]:
model_pnet = Net(X_train[:,:-1], Y_train, [200,200]).to(DEVICE)
train_with_pnet(model_pnet, projectnet, X_train_pt, Y_train_pt, params, rounds=rounds, epochs=100, lr=1e-4)
torch.save(model_pnet.state_dict(), "model_pnet.pt")

epoch  0  mean loss:  2.51083858196552 time 1.481426477432251
epoch  1  mean loss:  1.553789357153269 time 1.4794176816940308
epoch  2  mean loss:  1.1527113012778454 time 1.4907232125600178
epoch  3  mean loss:  0.8107200103998184 time 1.5050384402275085
epoch  4  mean loss:  0.3639902199804783 time 1.5169790744781495
epoch  5  mean loss:  0.29352716997265815 time 1.531642993291219
epoch  6  mean loss:  0.2747478397190571 time 1.5256685188838415
epoch  7  mean loss:  0.2625738888978958 time 1.5468855798244476
epoch  8  mean loss:  0.2579862396419048 time 1.5459132724338107
epoch  9  mean loss:  0.2558953915536404 time 1.5364607095718383
epoch  10  mean loss:  0.25212527602910995 time 1.5312221917239102
epoch  11  mean loss:  0.25008384227752684 time 1.5219195485115051
epoch  12  mean loss:  0.24860980480909348 time 1.5203453577481782
epoch  13  mean loss:  0.24431421250104904 time 1.5186446223940169
epoch  14  mean loss:  0.24303831920027733 time 1.5170502662658691
epoch  15  mean los

(Net(
   (lin): Linear(in_features=149, out_features=24, bias=True)
   (net): Sequential(
     (0): Linear(in_features=149, out_features=200, bias=True)
     (1): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (2): ReLU()
     (3): Dropout(p=0.2, inplace=False)
     (4): Linear(in_features=200, out_features=200, bias=True)
     (5): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (6): ReLU()
     (7): Dropout(p=0.2, inplace=False)
     (8): Linear(in_features=200, out_features=24, bias=True)
   )
 ),
 [4.566705226898193,
  4.294374465942383,
  4.353069305419922,
  3.4285054206848145,
  3.1843297481536865,
  2.7806813716888428,
  2.92234468460083,
  2.948564052581787,
  2.6324281692504883,
  2.0423240661621094,
  2.6829676628112793,
  2.4536995887756348,
  1.586218237876892,
  1.4523104429244995,
  2.149012804031372,
  1.3575828075408936,
  0.9071885943412781,
  0.5882694721221924,
  0.9317345023155212,
  1.476

In [8]:
def evaluate_(model):
    with torch.no_grad():
        p = model(X_test_pt)
        costs = []
        for i in range(p.shape[0]):
            d = projectnet(p[i:i+1,:], tol=1e-4).detach()
            c = task_loss(d, Y_test_pt[i:i+1,:], params).mean()
            costs.append(c.item())
    return costs

def evaluate_task(model):
    with torch.no_grad():
        m,s = model(X_test_pt)
        costs = []
        for i in range(m.shape[0]):
            d = dist_solver(m[i:i+1,:],s[i:i+1,:]).detach()
            c = task_loss(d, Y_test_pt[i:i+1,:], params).mean()
            costs.append(c.item())
    return costs

In [10]:
pnet_costs = evaluate_(model_pnet)
print("Test pnet costs:", np.mean(pnet_costs))

Test pnet costs: 0.7645002260634818


In [19]:
task_net = Net2(X_train[:,:-1], Y_train, [200,200]).to(DEVICE)
train_task_net(task_net, params, X_train_pt, Y_train_pt)

epoch: 0 loss: 0.8989258842034773 times: 4.283150911331177
epoch: 1 loss: 0.8982206691395153 times: 4.237900614738464
epoch: 2 loss: 0.8988134734558336 times: 4.14656925201416
epoch: 3 loss: 0.8979367180304094 times: 4.076341986656189
epoch: 4 loss: 0.8971390962600708 times: 4.060247945785522
epoch: 5 loss: 0.8972614335291313 times: 4.073833346366882
epoch: 6 loss: 0.8975307864028138 times: 4.084853546960013
epoch: 7 loss: 0.8978663357821378 times: 4.077466577291489
epoch: 8 loss: 0.8980425897270742 times: 4.07149420844184
epoch: 9 loss: 0.8981027281284333 times: 4.051555848121643
epoch: 10 loss: 0.8977667033672333 times: 4.048304232684049
epoch: 11 loss: 0.8980071008205414 times: 4.040776193141937
epoch: 12 loss: 0.8977414160966873 times: 4.023163006855891
epoch: 13 loss: 0.8981837499141693 times: 4.023526685578482
epoch: 14 loss: 0.897748441696167 times: 4.042127656936645


KeyboardInterrupt: 

In [15]:
task_costs = evaluate_task(task_net)

In [16]:
print("Test task costs:", np.mean(task_costs))

Test task costs: 0.8804126808834132
