In [1]:
import os, sys
import torch
from pathlib import Path
import numpy as np
import matplotlib
from matplotlib import cm
matplotlib.use("agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker


__file__ = os.path.dirname(os.path.realpath("__file__"))
root_dir = (Path(__file__).parent / "..").resolve()
lib_dir = (root_dir / "lib").resolve()
print("The root path: {:}".format(root_dir))
print("The library path: {:}".format(lib_dir))
assert lib_dir.exists(), "{:} does not exist".format(lib_dir)
if str(lib_dir) not in sys.path:
    sys.path.insert(0, str(lib_dir))

from datasets import SynAdaptiveEnv
from xlayers.super_core import SuperSequential, SuperMLPv1

The root path: /Users/xuanyidong/Desktop/AutoDL-Projects
The library path: /Users/xuanyidong/Desktop/AutoDL-Projects/lib


In [2]:
def optimize_fn(xs, ys, test_sets):
    xs = torch.FloatTensor(xs).view(-1, 1)
    ys = torch.FloatTensor(ys).view(-1, 1)
    
    model = SuperSequential(
        SuperMLPv1(1, 10, 20, torch.nn.ReLU),
        SuperMLPv1(20, 10, 1, torch.nn.ReLU)
    )
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=0.01, weight_decay=1e-4, amsgrad=True
    )
    for _iter in range(100):
        preds = model(ys)

        optimizer.zero_grad()
        loss = torch.nn.functional.mse_loss(preds, ys)
        loss.backward()
        optimizer.step()
        
    with torch.no_grad():
        answers = []
        for test_set in test_sets:
            test_set = torch.FloatTensor(test_set).view(-1, 1)
            preds = model(test_set).view(-1).numpy()
            answers.append(preds.tolist())
    return answers

def f(x):
    return np.cos( 0.5 * x + x * x)

def get_data(mode):
    dataset = SynAdaptiveEnv(mode=mode)
    times, xs, ys = [], [], []
    for i, (_, t, x) in enumerate(dataset):
        times.append(t)
        xs.append(x)
    dataset.set_transform(f)
    for i, (_, _, y) in enumerate(dataset):
        ys.append(y)
    return times, xs, ys

def visualize_syn(save_path):
    save_dir = (save_path / '..').resolve()
    save_dir.mkdir(parents=True, exist_ok=True)
    
    dpi, width, height = 40, 2000, 900
    figsize = width / float(dpi), height / float(dpi)
    LabelSize, LegendFontsize, font_gap = 40, 40, 5
    
    fig = plt.figure(figsize=figsize)
    
    times, xs, ys = get_data(None)
    
    def draw_ax(cur_ax, xaxis, yaxis, xlabel, ylabel,
                alpha=0.1, color='k', linestyle='-', legend=None, plot_only=False):
        if legend is not None:
            cur_ax.plot(xaxis[:1], yaxis[:1], color=color, label=legend)
        cur_ax.plot(xaxis, yaxis, color=color, linestyle=linestyle, alpha=alpha, label=None)
        if not plot_only:
            cur_ax.set_xlabel(xlabel, fontsize=LabelSize)
            cur_ax.set_ylabel(ylabel, rotation=0, fontsize=LabelSize)
            for tick in cur_ax.xaxis.get_major_ticks():
                tick.label.set_fontsize(LabelSize - font_gap)
                tick.label.set_rotation(10)
            for tick in cur_ax.yaxis.get_major_ticks():
                tick.label.set_fontsize(LabelSize - font_gap)
    
    cur_ax = fig.add_subplot(2, 1, 1)
    draw_ax(cur_ax, times, xs, "time", "x", alpha=1.0, legend=None)

    cur_ax = fig.add_subplot(2, 1, 2)
    draw_ax(cur_ax, times, ys, "time", "y", alpha=0.1, legend="ground truth")
    
    train_times, train_xs, train_ys = get_data("train")
    draw_ax(cur_ax, train_times, train_ys, None, None, alpha=1.0, color='r', legend=None, plot_only=True)
    
    valid_times, valid_xs, valid_ys = get_data("valid")
    draw_ax(cur_ax, valid_times, valid_ys, None, None, alpha=1.0, color='g', legend=None, plot_only=True)
    
    test_times, test_xs, test_ys = get_data("test")
    draw_ax(cur_ax, test_times, test_ys, None, None, alpha=1.0, color='b', legend=None, plot_only=True)
    
    # optimize MLP models
    [train_preds, valid_preds, test_preds] = optimize_fn(train_xs, train_ys, [train_xs, valid_xs, test_xs])
    draw_ax(cur_ax, train_times, train_preds, None, None,
            alpha=1.0, linestyle='--', color='r', legend="MLP", plot_only=True)
    import pdb; pdb.set_trace()
    draw_ax(cur_ax, valid_times, valid_preds, None, None,
            alpha=1.0, linestyle='--', color='g', legend=None, plot_only=True)
    draw_ax(cur_ax, test_times, test_preds, None, None,
            alpha=1.0, linestyle='--', color='b', legend=None, plot_only=True)

    plt.legend(loc=1, fontsize=LegendFontsize)

    fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf")
    plt.close("all")
    # plt.show()

In [3]:
# Visualization
home_dir = Path.home()
desktop_dir = home_dir / 'Desktop'
print('The Desktop is at: {:}'.format(desktop_dir))
visualize_syn(desktop_dir / 'tot-synthetic-v0.pdf')

The Desktop is at: /Users/xuanyidong/Desktop
> [0;32m<ipython-input-2-dec7d637caaa>[0m(89)[0;36mvisualize_syn[0;34m()[0m
[0;32m     87 [0;31m            alpha=1.0, linestyle='--', color='r', legend="MLP", plot_only=True)
[0m[0;32m     88 [0;31m    [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 89 [0;31m    draw_ax(cur_ax, valid_times, valid_preds, None, None,
[0m[0;32m     90 [0;31m            alpha=1.0, linestyle='--', color='g', legend=None, plot_only=True)
[0m[0;32m     91 [0;31m    draw_ax(cur_ax, test_times, test_preds, None, None,
[0m
ipdb> train_times
[0.0, 0.1, 0.2, 0.30000000000000004, 0.4, 0.5, 0.6000000000000001, 0.7000000000000001, 0.8, 0.9, 1.0, 1.1, 1.2000000000000002, 1.3, 1.4000000000000001, 1.5, 1.6, 1.7000000000000002, 1.8, 1.9000000000000001, 2.0, 2.1, 2.2, 2.3000000000000003, 2.4000000000000004, 2.5, 2.6, 2.7, 2.8000000000000003, 2.9000000000000004

ipdb> train_preds
[-0.04611632227897644, -0.045859843492507935, -0.045347750186920166, -0.04458075761795044, -0.04355984926223755, -0.04228568077087402, -0.04075917601585388, -0.03898113965988159, -0.036952465772628784, -0.03467392921447754, -0.03214627504348755, -0.029370546340942383, -0.026347368955612183, -0.023629456758499146, -0.021652281284332275, -0.019537389278411865, -0.01728537678718567, -0.014701485633850098, -0.011017769575119019, -0.007136136293411255, -0.0030573904514312744, 0.0012176334857940674, 0.00568816065788269, 0.01035335659980774, 0.015212282538414001, 0.024441495537757874, 0.034274160861968994, 0.04434235394001007, 0.05476266145706177, 0.06553322076797485, 0.07665219902992249, 0.08811751008033752, 0.09992708265781403, 0.11207878589630127, 0.12457036972045898, 0.1348687708377838, 0.14432348310947418, 0.15401709079742432, 0.1639476716518402, 0.1741131991147995, 0.1845117211341858, 0.1951409876346588, 0.20599885284900665, 0.2170828878879547, 0.22839070856571198, 0.2

ipdb> train_ys
[1.0, 0.9999999910945311, 0.9999999198522608, 0.9999996790771866, 0.9999991066924582, 0.9999979837910936, 0.9999960318551232, 0.9999929091393012, 0.9999882062104364, 0.9999814406286233, 0.9999720507522349, 0.99995938864448, 0.9999427120556592, 0.9999211754520029, 0.9998938200591698, 0.9998595628861638, 0.9998171846936232, 0.9997653168692014, 0.9997024271721292, 0.9996268043090873, 0.9995365413042816, 0.9994295176281586, 0.9993033800516088, 0.9991555221958434, 0.9989830627524896, 0.9987828223539091, 0.9985512990804113, 0.9982846425989912, 0.9979786269375996, 0.9976286219098381, 0.9972295632174979, 0.9967759212726425, 0.9962616687970814, 0.9956802472752349, 0.9950245323566472, 0.994286798326904, 0.9934586817905299, 0.9925311447367143, 0.9914944371884851, 0.9903380596683096, 0.9890507257480735, 0.9876203249889757, 0.9860338866170509, 0.9842775443227177, 0.9823365026177825, 0.9801950052305569, 0.9778363060688365, 0.9752426433311261, 0.9723952173981886, 0.9692741731891905, 0.

--KeyboardInterrupt--

KeyboardInterrupt: Interrupted by user
