In [1]:
from run import *
from utils import *

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
%load_ext autoreload
%autoreload 2

In [15]:
# problem = problems.Quadratic(d=1, n=1024, seed=0)
# show_problem(problem)

In [4]:
# fig, axes = plt.subplots(1, 1)
# for sample in val_offline_dataset:
#     show_problem(sample['problem'], color='grey', ax=axes, x_min=False)
#     plt.scatter(sample['problem'].info["x_min"], sample['problem'].info["y_min"], s=10, c='red', zorder=2)

In [80]:
from torch.nn import functional as F

def run(model, sample):
    if "x" in sample and "y" in sample:
        outputs = model.model(
            x=sample["x"].unsqueeze(0).to(device),
            y=sample["y"].unsqueeze(0).to(device),
        ).squeeze(0).detach().cpu()
        return {
            "logits": outputs,
            "predictions": model.get_predictions(outputs),
            "targets": sample["x_min"]
        }
    
        x = sample["x"].cpu()
        y = sample["y"].cpu()
    else:
        outputs = model.run(
            problem=sample["problem"],
            n_steps=model.config["model_params"]["seq_len"]+1
        )
        logits = results['logits'].detach().cpu()
        x = results["x"].cpu()
        y = results["y"].cpu()
    # probs = F.softmax(logits, -1)
    probs = torch.round(logits[..., 0] * 1023)
    return {"logits": logits, "probs": probs, "x": x, "y": y, "target": sample["x_min"], "problem": sample["problem"]}

def show(results, title=""):
    probs = results["probs"]

    cmap = cm.get_cmap('jet')
    fig, axes = plt.subplots(1, 2, figsize=(16, 4), gridspec_kw=dict(wspace=0.125))
    plt.suptitle(title)
    axes[1].set_title("Predicted Distribution")
    show_problem(results["problem"], ax=axes[0], color="grey", linestyle='--')

    if probs.ndim == 2:
        indexes = range(len(probs)) #np.linspace(0, len(probs)-1, 10+1, dtype=np.int32)
        colors = [cmap(c) for c in np.linspace(0, 1, len(probs))]
        # indexes = range(len(probs))
        # colors = [cmap(c) for c in np.linspace(0, 1, len(indexes))]
        for i in indexes:
            axes[0].scatter(results["x"][:i], results["y"][:i], c=colors[1:i+1], zorder=2)
            # axes[1].plot(results["probs"][i], c=colors[i], label=i)
    else:
        color = cmap(1.0)
        colors = [cmap(c) for c in np.linspace(0, 1, len(probs))]
        print(results["x"].shape, probs.shape)
        if len(results["x"]) == len(probs):
            axes[0].scatter(results["x"][:-1], results["y"][:-1], color=colors[1:], zorder=2)
        else:
            axes[0].scatter(results["x"], results["y"], color=colors[1:], zorder=2)
        for i in range(len(probs)):
            axes[1].vlines(results["probs"][i], 0, 1, colors=colors[i], zorder=2)
        # axes[1].plot(results["probs"], c=color, label=len(results["x"]))

    ymin, ymax = axes[1].get_ylim()
    axes[1].vlines(x=results["target"], ymin=0, ymax=ymax, colors='white', label='target')
    axes[1].set_xlabel('x')

    # axes[1].legend(loc=4)
    axes[1].legend(loc='center left', title='# observations', bbox_to_anchor=(1, 0.55))

    # plt.savefig(f'{title}.png')
    plt.show()

def runnshow(model, train_idx=0, val_idx=0):
    model.train()
    with torch.no_grad():
        train_batch_train_mode = run(model, train_offline_dataset[train_idx])
        valof_batch_train_mode = run(model, val_offline_dataset[val_idx])
        show(train_batch_train_mode, title="Train problem: Offlain Training")
        show(valof_batch_train_mode, title="Validation problem: Offlain Training")

    model.eval()
    # train_batch_eval_mode = run(model, train_offline_dataset[idx])
    # valof_batch_eval_mode = run(model, val_offline_dataset[idx])
    valon_batch_eval_mode = run(model, val_online_dataset[val_idx])
    # show(train_batch_eval_mode, title="Train problem: Offlain Inference")
    # show(valof_batch_eval_mode, title="Validation problem: Offlain Inference")
    show(valon_batch_eval_mode, title="Validation problem: Onlain Inference")

In [3]:
import re
from collections import defaultdict


def get_best_checkpoint(log_dir, key='epoch') -> str:
    checkpoints = defaultdict(dict)
    checkpoints_dir = os.path.join(log_dir, 'checkpoints')

    pattern = r'epoch=(\d+)-step=(\d+)'

    for filename in [file for file in os.listdir(checkpoints_dir)
                 if file.endswith('.ckpt') and file != 'last.ckpt']:

        epoch, step = re.findall(pattern, filename)[0]
        checkpoints[filename] = {'epoch': epoch, 'step': step}  # Create a dictionary for the checkpoint

    # Find the checkpoint with the maximum value for the given key
    cpkt = max(checkpoints.keys(), key=lambda cp: float(checkpoints[cp][key]))
    return os.path.join(checkpoints_dir, cpkt)

In [14]:
# checkpoint_file = '../results/DPT/gdbzctgp/checkpoints/epoch=499-step=1000.ckpt'
run_name = "cdv3c4id"
root_dir = os.path.join("..", "results", "DPT_2", run_name)
checkpoint_file = get_best_checkpoint(root_dir, 'epoch')

model = DPTSolver.load_from_checkpoint(checkpoint_file).to(device)

In [None]:
# config = load_config("config.yaml")
config = model.config

dl = get_dataloaders(config)

train_offline_dataset = dl["train_dataloaders"].dataset
val_offline_dataset = dl["val_dataloaders"][0].dataset
val_online_dataset = dl["val_dataloaders"][1].dataset

In [12]:
import torch.nn.functional as F

In [None]:
sample = train_offline_dataset[0]
outputs = model.model(
    x=sample["x"].unsqueeze(0).to(device),
    y=sample["y"].unsqueeze(0).to(device),
).squeeze(0).detach().cpu()
F.sigmoid(outputs[-1][0]), sample

In [None]:
ax = plt.gca()
ax.scatter(sample["x"], sample["y"])
# show_problem(sample["problem"], ax=ax)

In [7]:
# sample = train_offline_dataset[501]

# fig = plt.figure(figsize=(6, 3))
# ax = plt.gca()
# show_problem(sample["problem"], ax=ax, color="grey", linestyle='--')
# cmap = cm.get_cmap('jet')
# colors = [cmap(c) for c in np.linspace(0, 1, config["model_params"]["seq_len"])]
# ax.scatter(sample["x"], sample["y"], c=colors, zorder=2)
# plt.show()

In [7]:
# fig = plt.figure(figsize=(3, 8))
# ax = plt.gca()
# for i in range(10):
#     sample = val_offline_dataset[i]
#     show_problem(sample["problem"], ax=ax, color="grey", linestyle='--')

In [9]:
# train_targets = torch.tensor([sample["x_min"] for sample in train_offline_dataset])
# val_targets = torch.tensor([sample["x_min"] for sample in val_offline_dataset])

In [None]:
runnshow(model, train_idx=100, val_idx=0)

In [None]:
model.eval()
valon_batch_eval_mode = run(model, val_online_dataset[2])
show(valon_batch_eval_mode, title="Validation problem: Online Inference")

In [31]:
# model.eval()

# targets = torch.tensor([val_online_dataset[i]["x_min"] for i in range(len(val_online_dataset))])
# sort_indexes = torch.argsort(targets)

# for idx in sort_indexes:
#     results = run(model, val_online_dataset[idx])
#     show(results, title=f'train/{idx}')

In [None]:
dataset = train_offline_dataset
# dataset = val_offline_dataset
model.train()

# dataset = val_online_dataset
# model.eval()

targets = torch.tensor([dataset[i]["x_min"] for i in range(len(dataset))])
sort_indexes = torch.argsort(targets)[::100]

fig, axes = plt.subplots(1, 3, figsize=(16, 4), gridspec_kw=dict(wspace=0.125), sharey=True)
axes[0].set_title('Zero step')
axes[1].set_title('First step')
axes[2].set_title('Last step')

cmap = cm.get_cmap('jet')
colors = [cmap(c) for c in np.linspace(0, 1, len(sort_indexes))]
for i, c in zip(sort_indexes, colors):
    results = run(model, dataset[i])
    target = results['target'].item()
    axes[0].plot(results['probs'][0], c=c, label=target)
    axes[1].plot(results['probs'][1], c=c, label=target)
    # axes[2].plot(results['probs'][-1], c=c, label=target)
axes[0].legend()
plt.show()

### loss

In [19]:
from dpt.train import Loss

loss = Loss(
    num_classes=config["model_params"]["num_actions"],
    seq_len=config["model_params"]["seq_len"]+1,
    eps=config["label_smoothing"],
    mode=config["loss"]
)

In [20]:
def runnshow2(results, loss, u):
    # show loss
    plt.figure(figsize=(16, 3))
    plt.title('Loss')
    plt.plot(loss)
    plt.xlabel('Number of ')
    plt.show()

    # compare distributions 1
    n = 5 #len(u)
    fig, axes = plt.subplots(1, n, figsize=(16, 3), gridspec_kw=dict(wspace=0.05))
    for i in range(n):
        axes[i].set_title(f'{i} observations')
        axes[i].plot(results["probs"][i], c='red', label='predicted')
        axes[i].plot(u[i], c='blue', label='ground truth')

    axes[0].legend()
    plt.show()

    # # compare distributions 2
    # cmap = cm.get_cmap('jet')
    # colors = [cmap(c) for c in np.linspace(0, 1, len(u))]
    # _, axes = plt.subplots(2, 1, figsize=(8, 6))
    # step = 10
    # for i in np.linspace(0, len(u)-1, step+1, dtype=np.int32):
    #     axes[0].plot(results["probs"][i], c=colors[i])
    #     axes[1].plot(u[i], c=colors[i], label=i)
    # axes[1].legend(prop={'size': 8})
    # plt.show()

In [None]:
model.train()
results = run(model, train_offline_dataset[1])
# results = run(model, val_offline_dataset[1])

predictions = results["logits"].unsqueeze(0).to(device)
targets = results["target"].unsqueeze(0).repeat(1, config["model_params"]["seq_len"]+1).to(device)
l = loss(predictions, targets, reduction='none').squeeze(0).cpu()
u = loss.u.squeeze(0).cpu().T

runnshow2(results, l, u)

### gif

In [25]:
from matplotlib.animation import FuncAnimation, PillowWriter

def gif(results, title='result'):
    probs = results["probs"]

    cmap = cm.get_cmap('jet')
    colors = [cmap(c) for c in np.linspace(0, 1, len(probs))]
    fig, axes = plt.subplots(2, 1, figsize=(5, 5), sharex=True, gridspec_kw=dict(hspace=0))

    def init():
        problem = results["problem"]
        i = get_xaxis(problem.d, problem.n)
        y = problem.target(i)
        axes[0].set_title("Target Function")
        axes[0].plot(y, '--', c='grey', markersize=1)
        ymin, ymax = axes[0].get_ylim()
        axes[0].vlines(x=results["target"], ymin=ymin, ymax=y.min(), colors='white')
        axes[0].set_ylim(ymin, ymax)

    def update(i):
        # axes[0].clear()
        axes[1].clear()
        axes[0].scatter(results["x"][:i], results["y"][:i], c=colors[:i], zorder=2)
        axes[1].plot(results["probs"][i], c=colors[i], label=i)

        ymin, ymax = axes[1].get_ylim()
        axes[1].vlines(x=results["target"], ymin=0, ymax=ymax, colors='white')
        axes[1].vlines(x=results["x"][i], ymin=0, ymax=ymax, colors='white')

    frames = np.linspace(0, len(probs)-1, 10+1, dtype=np.int32)
    frames = range(len(probs))
    ani = FuncAnimation(fig, update, init_func=init, frames=frames)
    ani.save(f'{title}.gif', writer=PillowWriter(fps=1))

In [None]:
model.eval()
results = run(model, val_online_dataset[10])
gif(results)