In [25]:
import os
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.functional import normalize
from math import sqrt
from detorch import DE, Policy, Strategy
from detorch.config import default_config, Config
from typing import Type
import numpy as np
import sys
import matplotlib.pyplot as plt
sys.path.append('..')
sys.path.append('viz')
from optimneuralts import NetworkDropout, DENeuralTSDiag, LenientDENeuralTSDiag
import viz_config
import logging

logging.basicConfig(level=logging.INFO)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = Config(default_config)
bounds = [-2.5, 1.5]
theta = torch.Tensor([0, 3, -2, -4, 1, 1]).to(device)
d = 1
torch.set_default_tensor_type("torch.cuda.FloatTensor")

%load_ext autoreload
%autoreload 2

# Classes

In [26]:
class PullPolicy(Policy):
    def __init__(self, eval_fn):
        super().__init__()
        self.point = torch.FloatTensor(1).uniform_(*bounds).to(device)
        self.params = nn.Parameter(
            self.point, requires_grad=False
        )
        self.eval_fn = eval_fn
        self.ucb = None

    def evaluate(self):
        self.transform()
        ucb, activation_grad, _, _ = self.eval_fn(self.point)
        ucb = ucb.detach().item()
        self.activation_grad = activation_grad
        self.ucb = ucb
        # logging.info(self.point)
        # logging.info(ucb)
        return ucb

    def transform(self):
        self.point = torch.clip(self.params, *bounds).to(device)
        self.params = nn.Parameter(
            self.point, requires_grad=False
        )
        # return generate_feature_vector_from_point(self.point)


class DEConfig:
    n_step: int = 3
    population_size: int = 60
    differential_weight: float = 0.8
    crossover_probability: float = 0.9
    strategy: Strategy = Strategy.best1bin
    seed: int = "does not matter"


# Utility

In [118]:
def compute_jaccard(found_solution, true_solution):
    found_sol_list = found_solution.tolist()
    true_sol_list = true_solution.tolist()

    n_in_inter = 0

    for vec in found_sol_list:
        n_in_inter += vec in true_sol_list

    return n_in_inter / (len(found_solution) + len(true_solution) - n_in_inter), n_in_inter

def make_deterministic(seed):
    # PyTorch
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # Numpy
    np.random.seed(seed)

    # Built-in Python
    random.seed(seed)

def project_point(point):
    return torch.tensor([1, point, point ** 2, point ** 3, point ** 4, point ** 5]).to(device)

def reward_fn(point, add_noise=True):
    noise = torch.normal(0, 1, (1,)).item()

    vec = project_point(point)
    value = theta @ vec.T
    
    return value + add_noise * noise


def gen_warmup_vecs_and_rewards(n_warmup):
    vecs = torch.tensor([])
    rewards = torch.tensor([])
    for i in range(n_warmup):
        point = torch.FloatTensor(1).uniform_(bounds[0], bounds[1]).to(device)
        reward = torch.tensor([reward_fn(point)])

        vecs = torch.cat((vecs, point))
        rewards = torch.cat((rewards, reward))
    vecs = vecs.view((n_warmup, -1))
    rewards = rewards.view((n_warmup, -1))
    return vecs, rewards


def plot_estimate(agent, trial, fn=None, title=""):
    n_points = 1000
    x = np.linspace(-2.5, 1.5, n_points)
    x_vec = []
    y = []
    for x_point in x:
        y.append(reward_fn(float(x_point), add_noise=False).cpu().numpy())

    y_pred = []
    cbs = []
    ucbs = []
    x_tns = torch.from_numpy(x)
    x_tns = x_tns.view(n_points, 1).float()

    for point in x_tns:
        sample, _, activ, cb = agent.get_sample(point.to(device))
        y_pred.append(activ)
        cbs.append(3 * cb)
        ucbs.append(sample)

    y_pred = np.array(y_pred)
    cbs = np.array(cbs)

    point_played = agent.dataset.hist_features.squeeze(0).cpu().numpy()
    rewards_rec = agent.dataset.hist_rewards.squeeze(0).cpu().numpy()
    n_played = point_played.shape[0]

    # plt.ylim(-20, 10)
    # plt.ylim(-25, 10)
    plt.plot(x, y, color="tab:blue", label="Vraie fonction")
    plt.plot(x, y_pred, color="tab:orange", label="Estimation de la fonction")
    # plt.fill_between(x, y_pred, ucbs, color='tab:red', alpha=0.3)
    plt.fill_between(
        x, y_pred + cbs, y_pred - cbs, alpha=0.3, color="tab:orange", zorder=-1, label="Intervalle de confiance",
    )
    plt.plot(
        x,
        [0] * n_points,
        color="black",
        linestyle="dashed",
        label="Seuil bonne/mauvaise action",
    )
    plt.scatter(
        point_played[:n_played],
        rewards_rec[:n_played],
        color="black",
        alpha=0.5,
        label="Points joués précédemment",
    )
    plt.scatter(
        point_played[-1], rewards_rec[-1], color="green", label="Dernier point joué"
    )

    plt.title(title)
    
    plt.xlabel("$x$")
    plt.ylabel("$y$")


    plt.legend()
    if fn is None:
        filename = f"viz/images/exp_poly/regTS_{n_trials}_trials_expl_{exploration_mult}_trial_{trial}.png"
    else:
        filename = f"viz/images/exp_poly/{fn}.png"
    plt.savefig(filename)

    plt.clf()


def find_best_member(eval_fn, de_config, seed):
    de_config.seed = seed
    config = Config(default_config)

    @config("policy")
    class PolicyConfig:
        policy: Type[Policy] = PullPolicy
        eval_fn: object = agent.get_sample

    config("de")(de_config)

    de = DE(config)
    de.train()

    return de.population[de.current_best]


def gen_mask(example_vec, agent, account=False):
    # Mask for 1 vec
    eval_mask = (
        torch.distributions.Bernoulli(
            torch.full_like(example_vec, (1 - dropout_rate))
        ).sample()
        / (1 - dropout_rate)
    ).to(device)

    n_obs = list(agent.dataset.hist_features.shape)[0]

    if account:
        n_obs + 1

    # Mask for multiple vecs
    training_mask = (
        torch.distributions.Bernoulli(
            torch.full((n_obs, agent.net.hidden_size), (1 - dropout_rate))
        ).sample()
        / (1 - dropout_rate)
    ).to(device)

    return eval_mask, training_mask



# Visualize 1 run

In [139]:
run = 42
make_deterministic(run)

n_trials = 500
dropout_rate = 0.1
width = 100
net = NetworkDropout(d, width,dropout=dropout_rate).to(device)
reg = 1
exploration_mult = 1
delay = 0
reward_fn = reward_fn
de_config = DEConfig
de_policy = PullPolicy
style="ts"
sampletype="f"
max_n_steps = 100 
agent = DENeuralTSDiag(net, nu=exploration_mult, lambda_=reg, style=style, sampletype=sampletype)
# agent = LenientDENeuralTSDiag(net, nu=exploration_mult, lambda_=reg, reward_sample_thresholds=[float('-inf'), 0])


In [140]:
vecs, rewards = gen_warmup_vecs_and_rewards(10)
vecs, rewards = vecs.to(device), rewards.to(device)

example_vec = vecs[0]

In [141]:
agent.dataset.set_hists(vecs, rewards)

for vec in vecs:
    activ, grad = agent.compute_activation_and_grad(vec)
    agent.U += grad * grad

agent.net.eval()
agent.train(max_n_steps)

1.516566276550293

In [142]:
plot_estimate(agent, 0,  f'{style}_expl_{exploration_mult}_{sampletype}_run{run}_trial_0')
agent.net.train()


NetworkDropout(
  (fc1): Linear(in_features=1, out_features=100, bias=True)
  (activate): ReLU()
  (fc2): Linear(in_features=100, out_features=1, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

<Figure size 980x700 with 0 Axes>

### Train 1 run

In [143]:
for i in range(n_trials):
    best_member = find_best_member(agent.get_sample, de_config, i)
    best_member_grad = best_member.activation_grad
    a_t = best_member.point.unsqueeze(0)
    r_t = reward_fn(a_t).unsqueeze(0).unsqueeze(0)

    logging.info(f"trial: {i}")

    agent.U += best_member_grad * best_member_grad

    agent.dataset.add(a_t, r_t)

    agent.net.eval()
    agent.train(max_n_steps)
    if (i + 1) % 100 == 0:
        plot_estimate(
            agent,
            0,
            f"{style}_expl_{exploration_mult}_{sampletype}_run{run}_trial_{i+1}"
        )
    agent.net.train()


INFO:root:trial: 0
INFO:root:trial: 1
INFO:root:trial: 2
INFO:root:trial: 3
INFO:root:trial: 4
INFO:root:trial: 5
INFO:root:trial: 6
INFO:root:trial: 7
INFO:root:trial: 8
INFO:root:trial: 9
INFO:root:trial: 10
INFO:root:trial: 11
INFO:root:trial: 12
INFO:root:trial: 13
INFO:root:trial: 14
INFO:root:trial: 15
INFO:root:trial: 16
INFO:root:trial: 17
INFO:root:trial: 18
INFO:root:trial: 19
INFO:root:trial: 20
INFO:root:trial: 21
INFO:root:trial: 22
INFO:root:trial: 23
INFO:root:trial: 24
INFO:root:trial: 25
INFO:root:trial: 26
INFO:root:trial: 27
INFO:root:trial: 28
INFO:root:trial: 29
INFO:root:trial: 30
INFO:root:trial: 31
INFO:root:trial: 32
INFO:root:trial: 33
INFO:root:trial: 34
INFO:root:trial: 35
INFO:root:trial: 36
INFO:root:trial: 37
INFO:root:trial: 38
INFO:root:trial: 39
INFO:root:trial: 40
INFO:root:trial: 41
INFO:root:trial: 42
INFO:root:trial: 43
INFO:root:trial: 44
INFO:root:trial: 45
INFO:root:trial: 46
INFO:root:trial: 47
INFO:root:trial: 48
INFO:root:trial: 49
INFO:root:

<Figure size 980x700 with 0 Axes>

# Train all runs, for all algos and every exploration

In [33]:
x_arr = np.linspace(-2.5, 1.5, 1000)
x_arr = x_arr.reshape(1000, 1)
x = torch.from_numpy(x_arr).to(device).float()
y = []
for point in x:
    y.append(reward_fn(point, add_noise=False).cpu().numpy())

y = np.array(y)
true_sol = x[np.where(y >= 0)[0]].to(device).float()
n_true_sol = len(true_sol)
logging.info(n_true_sol)
n_sigmas = 3

INFO:root:419


In [34]:
metrics_dict = {}
algos = ["UCB", "regTS", "lenientTS"]
exploration_mults = [1, 10]
logging.info(metrics_dict)
max_n_steps = 10
for algo in algos:
    metrics_dict[algo] = {}
    for exploration_mult in exploration_mults:
        metrics_dict[algo][str(exploration_mult)] = {}
        metrics_dict[algo][str(exploration_mult)]["jaccards"] = []
        metrics_dict[algo][str(exploration_mult)]["percent_inter"] = []
        metrics_dict[algo][str(exploration_mult)]["percent_found"] = []
        metrics_dict[algo][str(exploration_mult)]["fails"] = 0

        b = 0

        for i in range(b, 50):
            logging.info(f"at algo: {algo}, expl_mult: {exploration_mult}, run: {i}")
            make_deterministic(seed=i)
            vecs, rewards = gen_warmup_vecs_and_rewards(10)
            vecs, rewards = vecs.to(device), rewards.to(device)
            n_trials = 500
            width = 100
            net = NetworkDropout(d, width).to(device)
            reg = 1
            delay = 0
            reward_fn = reward_fn
            de_config = DEConfig
            de_policy = PullPolicy
            lr = 1e-2

            if algo == "UCB":
                agent = DENeuralTSDiag(
                    net, nu=exploration_mult, lambda_=reg, style="ucb"
                )
            elif algo == "regTS":
                agent = DENeuralTSDiag(net, nu=exploration_mult, lambda_=reg, style="ts")
            elif algo == "lenientTS":
                agent = LenientDENeuralTSDiag(
                    [float("-inf"), 0],
                    net=net,
                    nu=exploration_mult,
                    lambda_=reg,
                )

            vecs, rewards = gen_warmup_vecs_and_rewards(10)
            vecs, rewards = vecs.to(device), rewards.to(device)
            # Warmup
            for vec in vecs:
                activ, grad = agent.compute_activation_and_grad(vec)
                agent.U += grad * grad
            agent.dataset.set_hists(vecs, rewards)
            agent.train(max_n_steps)


            # Playing
            for j in range(n_trials):
                best_member = find_best_member(agent.get_sample, de_config=de_config, seed=j)
                best_member_grad = best_member.activation_grad
                a_t = best_member.point.unsqueeze(0)
                r_t = reward_fn(a_t).unsqueeze(0).unsqueeze(0)

                agent.U += best_member_grad * best_member_grad

                agent.dataset.add(a_t, r_t)

                agent.train(max_n_steps)

            # Stop using mask in evaluation
            agent.net.set_mask(None)
            sol, _, _ = agent.find_solution_in_vecs(x, 0, n_sigmas=n_sigmas)
            sol = sol.to(device)
            n_sol = len(sol)

            jaccard, n_inter = compute_jaccard(sol, true_sol)
            percent_found = n_inter / n_true_sol
            if n_sol == 0:
                percent_inter = 0
            else:
                percent_inter = n_inter / n_sol

            metrics_dict[algo][str(exploration_mult)]["jaccards"].append(jaccard)
            metrics_dict[algo][str(exploration_mult)]["percent_inter"].append(
                percent_inter
            )
            metrics_dict[algo][str(exploration_mult)]["percent_found"].append(
                percent_found
            )

            if n_sol == 0:
                logging.info(f"Found no solution for run {i}")
                metrics_dict[algo][str(exploration_mult)]["fails"] += 1

                plot_estimate(
                    agent,
                    n_trials,
                    fn=f"no_sol_{algo}_expl_{exploration_mult}_100_trials_seed_{i}",
                )

            logging.info(
                f"jaccard: {jaccard}, percent_inter: {percent_inter}, percent_found: {percent_found}"
            )

torch.save(metrics_dict, f"metrics_dict_exp_synth_{max_n_steps}_basic.pth")


INFO:root:{}
INFO:root:at algo: UCB, expl_mult: 1, run: 0
INFO:root:jaccard: 0.2935560859188544, percent_inter: 1.0, percent_found: 0.2935560859188544
INFO:root:at algo: UCB, expl_mult: 1, run: 1
INFO:root:Found no solution for run 1
INFO:root:jaccard: 0.0, percent_inter: 0, percent_found: 0.0
INFO:root:at algo: UCB, expl_mult: 1, run: 2
INFO:root:Found no solution for run 2
INFO:root:jaccard: 0.0, percent_inter: 0, percent_found: 0.0
INFO:root:at algo: UCB, expl_mult: 1, run: 3
INFO:root:jaccard: 0.3198090692124105, percent_inter: 1.0, percent_found: 0.3198090692124105
INFO:root:at algo: UCB, expl_mult: 1, run: 4
INFO:root:jaccard: 0.16945107398568018, percent_inter: 1.0, percent_found: 0.16945107398568018
INFO:root:at algo: UCB, expl_mult: 1, run: 5
INFO:root:Found no solution for run 5
INFO:root:jaccard: 0.0, percent_inter: 0, percent_found: 0.0
INFO:root:at algo: UCB, expl_mult: 1, run: 6
INFO:root:jaccard: 0.13365155131264916, percent_inter: 1.0, percent_found: 0.13365155131264916

<Figure size 980x700 with 0 Axes>

In [35]:
for algo in metrics_dict.keys():
    logging.info(f"Algo: {algo}")
    for expl_mult in metrics_dict[algo].keys():
        logging.info(f'Expl mult: {expl_mult}')
        for metric in  metrics_dict[algo][expl_mult].keys():
            logging.info(f'Metric: {metric}')
            logging.info(f"mean: {np.mean(metrics_dict[algo][expl_mult][metric])} +- {np.std(metrics_dict[algo][expl_mult][metric])} ")
            logging.info(f"interval:  [{np.min(metrics_dict[algo][expl_mult][metric])}, {np.max(metrics_dict[algo][expl_mult][metric])}]")
            logging.info("============================================")
        # if 0 in metrics_dict[algo][expl_mult]:
        #     logging.info(f'rerun {metrics_dict[algo][expl_mult][metric].index(0)} with plotting')
    
    logging.info("============================================")
    

INFO:root:Algo: UCB
INFO:root:Expl mult: 1
INFO:root:Metric: jaccards
INFO:root:mean: 0.14186434060030653 +- 0.11181670212544564 
INFO:root:interval:  [0.0, 0.38028169014084506]
INFO:root:Metric: percent_inter
INFO:root:mean: 0.7370404500921525 +- 0.4371655031317151 
INFO:root:interval:  [0.0, 1.0]
INFO:root:Metric: percent_found
INFO:root:mean: 0.14214797136038188 +- 0.11225017322597196 
INFO:root:interval:  [0.0, 0.38663484486873506]
INFO:root:Metric: fails
INFO:root:mean: 13.0 +- 0.0 
INFO:root:interval:  [13, 13]
INFO:root:Expl mult: 10
INFO:root:Metric: jaccards
INFO:root:mean: 0.1349208915200734 +- 0.10946343537566157 
INFO:root:interval:  [0.0, 0.3794749403341289]
INFO:root:Metric: percent_inter
INFO:root:mean: 0.73712 +- 0.43738864365687413 
INFO:root:interval:  [0.0, 1.0]
INFO:root:Metric: percent_found
INFO:root:mean: 0.1351312649164678 +- 0.1096843887020497 
INFO:root:interval:  [0.0, 0.3794749403341289]
INFO:root:Metric: fails
INFO:root:mean: 13.0 +- 0.0 
INFO:root:interval

## Test avec weight decay

In [36]:
metrics_dict = {}
algos = ["UCB", "regTS", "lenientTS"]
exploration_mults = [1, 10]
logging.info(metrics_dict)
max_n_steps = 10
for algo in algos:
    metrics_dict[algo] = {}
    for exploration_mult in exploration_mults:
        metrics_dict[algo][str(exploration_mult)] = {}
        metrics_dict[algo][str(exploration_mult)]["jaccards"] = []
        metrics_dict[algo][str(exploration_mult)]["percent_inter"] = []
        metrics_dict[algo][str(exploration_mult)]["percent_found"] = []
        metrics_dict[algo][str(exploration_mult)]["fails"] = 0

        b = 0

        for i in range(b, 50):
            logging.info(f"at algo: {algo}, expl_mult: {exploration_mult}, run: {i}")
            make_deterministic(seed=i)
            vecs, rewards = gen_warmup_vecs_and_rewards(10)
            vecs, rewards = vecs.to(device), rewards.to(device)
            n_trials = 500
            width = 100
            net = NetworkDropout(d, width).to(device)
            reg = 1
            delay = 0
            reward_fn = reward_fn
            de_config = DEConfig
            de_policy = PullPolicy
            lr = 1e-2

            if algo == "UCB":
                agent = DENeuralTSDiag(
                    net, nu=exploration_mult, lambda_=reg, style="ucb", decay=True
                )
            elif algo == "regTS":
                agent = DENeuralTSDiag(net, nu=exploration_mult, lambda_=reg, style="ts", decay=True)
            elif algo == "lenientTS":
                agent = LenientDENeuralTSDiag(
                    [float("-inf"), 0],
                    net=net,
                    nu=exploration_mult,
                    lambda_=reg,
                    decay=True
                )

            vecs, rewards = gen_warmup_vecs_and_rewards(10)
            vecs, rewards = vecs.to(device), rewards.to(device)
            agent.dataset.set_hists(vecs, rewards)

            # Warmup
            for vec in vecs:
                activ, grad = agent.compute_activation_and_grad(vec)
                agent.U += grad * grad
            agent.train(max_n_steps)

            # Playing
            for j in range(n_trials):
                best_member = find_best_member(agent.get_sample, de_config=de_config, seed=j)
                best_member_grad = best_member.activation_grad
                a_t = best_member.point.unsqueeze(0)
                r_t = reward_fn(a_t).unsqueeze(0).unsqueeze(0)

                agent.U += best_member_grad * best_member_grad

                agent.dataset.add(a_t, r_t)

                agent.train(max_n_steps)


            # Stop using mask in evaluation
            agent.net.set_mask(None)
            sol, _, _ = agent.find_solution_in_vecs(x, 0, n_sigmas=n_sigmas)
            sol = sol.to(device)
            n_sol = len(sol)

            jaccard, n_inter = compute_jaccard(sol, true_sol)
            percent_found = n_inter / n_true_sol
            if n_sol == 0:
                percent_inter = 0
            else:
                percent_inter = n_inter / n_sol

            metrics_dict[algo][str(exploration_mult)]["jaccards"].append(jaccard)
            metrics_dict[algo][str(exploration_mult)]["percent_inter"].append(
                percent_inter
            )
            metrics_dict[algo][str(exploration_mult)]["percent_found"].append(
                percent_found
            )

            if n_sol == 0:
                logging.info(f"Found no solution for run {i}")
                plot_estimate(
                    agent,
                    n_trials,
                    fn=f"no_sol_{algo}_decay_expl_{exploration_mult}_100_trials_seed_{i}",
                )
                metrics_dict[algo][str(exploration_mult)]["fails"] += 1


            logging.info(
                f"jaccard: {jaccard}, percent_inter: {percent_inter}, percent_found: {percent_found}"
            )

torch.save(metrics_dict, f"metrics_dict_exp_synth_{max_n_steps}_decay.pth")


INFO:root:{}
INFO:root:at algo: UCB, expl_mult: 1, run: 0
INFO:root:jaccard: 0.3341288782816229, percent_inter: 1.0, percent_found: 0.3341288782816229
INFO:root:at algo: UCB, expl_mult: 1, run: 1
INFO:root:Found no solution for run 1
INFO:root:jaccard: 0.0, percent_inter: 0, percent_found: 0.0
INFO:root:at algo: UCB, expl_mult: 1, run: 2
INFO:root:Found no solution for run 2
INFO:root:jaccard: 0.0, percent_inter: 0, percent_found: 0.0
INFO:root:at algo: UCB, expl_mult: 1, run: 3
INFO:root:jaccard: 0.30787589498806683, percent_inter: 1.0, percent_found: 0.30787589498806683
INFO:root:at algo: UCB, expl_mult: 1, run: 4
INFO:root:jaccard: 0.24105011933174225, percent_inter: 1.0, percent_found: 0.24105011933174225
INFO:root:at algo: UCB, expl_mult: 1, run: 5
INFO:root:Found no solution for run 5
INFO:root:jaccard: 0.0, percent_inter: 0, percent_found: 0.0
INFO:root:at algo: UCB, expl_mult: 1, run: 6
INFO:root:jaccard: 0.2304147465437788, percent_inter: 0.8695652173913043, percent_found: 0.2

<Figure size 980x700 with 0 Axes>

In [37]:
for algo in metrics_dict.keys():
    logging.info(f"Algo: {algo}")
    for expl_mult in metrics_dict[algo].keys():
        logging.info(f'Expl mult: {expl_mult}')
        for metric in  metrics_dict[algo][expl_mult].keys():
            logging.info(f'Metric: {metric}')
            logging.info(f"mean: {np.mean(metrics_dict[algo][expl_mult][metric])} +- {np.std(metrics_dict[algo][expl_mult][metric])} ")
            logging.info(f"interval:  [{np.min(metrics_dict[algo][expl_mult][metric])}, {np.max(metrics_dict[algo][expl_mult][metric])}]")
            logging.info("============================================")
        # if 0 in metrics_dict[algo][expl_mult]:
        #     logging.info(f'rerun {metrics_dict[algo][expl_mult][metric].index(0)} with plotting')
    
    logging.info("============================================")
    

INFO:root:Algo: UCB
INFO:root:Expl mult: 1
INFO:root:Metric: jaccards
INFO:root:mean: 0.14292063871228178 +- 0.14066535876671035 
INFO:root:interval:  [0.0, 0.448]
INFO:root:Metric: percent_inter
INFO:root:mean: 0.5501525341371906 +- 0.47752363900503275 
INFO:root:interval:  [0.0, 1.0]
INFO:root:Metric: percent_found
INFO:root:mean: 0.14821002386634843 +- 0.1487493980568872 
INFO:root:interval:  [0.0, 0.5346062052505967]
INFO:root:Metric: fails
INFO:root:mean: 21.0 +- 0.0 
INFO:root:interval:  [21, 21]
INFO:root:Expl mult: 10
INFO:root:Metric: jaccards
INFO:root:mean: 0.14164765363082893 +- 0.13510168005961679 
INFO:root:interval:  [0.0, 0.4491017964071856]
INFO:root:Metric: percent_inter
INFO:root:mean: 0.6465621732311254 +- 0.4552683035624665 
INFO:root:interval:  [0.0, 1.0]
INFO:root:Metric: percent_found
INFO:root:mean: 0.1472076372315036 +- 0.14343690507856838 
INFO:root:interval:  [0.0, 0.5369928400954654]
INFO:root:Metric: fails
INFO:root:mean: 16.0 +- 0.0 
INFO:root:interval:  

## Test avec dropout plutot que bonus d'exploration

In [150]:
metrics_dict = {}
algos = ["regTS"]
logging.info(metrics_dict)
max_n_steps = 10
for algo in algos:
    metrics_dict[algo] = {}
    for dropout_rate in [0.2, 0.5, 0.8]:
        metrics_dict[algo][str(dropout_rate)] = {}
        metrics_dict[algo][str(dropout_rate)]["jaccards"] = []
        metrics_dict[algo][str(dropout_rate)]["percent_inter"] = []
        metrics_dict[algo][str(dropout_rate)]["percent_found"] = []
        metrics_dict[algo][str(dropout_rate)]["fails"] = 0
        b = 0

        for i in range(b, 50):
            logging.info(f"at algo: {algo}, dropout: {dropout_rate}, run: {i}")
            make_deterministic(seed=i)
            vecs, rewards = gen_warmup_vecs_and_rewards(10)
            vecs, rewards = vecs.to(device), rewards.to(device)
            n_trials = 500
            width = 100
            net = NetworkDropout(d, width, dropout=dropout_rate).to(device)
            reg = 1
            sampletype = "f"
            reward_fn = reward_fn
            de_config = DEConfig
            de_policy = PullPolicy
            bern_p = 1 - dropout_rate
            p_vec = torch.tensor([bern_p] * width)

            agent = DENeuralTSDiag(
                net, lambda_=reg, style="ts", sampletype=sampletype
            )

            vecs, rewards = gen_warmup_vecs_and_rewards(10)
            vecs, rewards = vecs.to(device), rewards.to(device)
            agent.dataset.set_hists(vecs, rewards)

            # Warmup
            for vec in vecs:
                activ, grad = agent.compute_activation_and_grad(vec)
                agent.U += grad * grad

                agent.net.eval()
                agent.train(max_n_steps)
                agent.net.train()

            # Train
            for j in range(n_trials):

                best_member = find_best_member(agent.get_sample, de_config=de_config, seed=j)
                best_member_grad = best_member.activation_grad
                a_t = best_member.point.unsqueeze(0)
                r_t = reward_fn(a_t).unsqueeze(0).unsqueeze(0)

                if best_member_grad is None:
                    break
                agent.U += best_member_grad * best_member_grad

                agent.dataset.add(a_t, r_t)


                agent.net.eval()
                agent.train(max_n_steps)
                agent.net.train()

            if best_member_grad is None:
                metrics_dict[algo][str(dropout_rate)]["fails"] += 1
                logging.info(
                    f"Encountered a fail in {algo} {dropout_rate} because of nans"
                )
                continue
            agent.net.eval()
            sol, _, _ = agent.find_solution_in_vecs(x, 0, n_sigmas=n_sigmas)
            sol = sol.to(device)
            n_sol = len(sol)

            jaccard, n_inter = compute_jaccard(sol, true_sol)
            percent_found = n_inter / n_true_sol
            if n_sol == 0:
                percent_inter = 0
            else:
                percent_inter = n_inter / n_sol

            metrics_dict[algo][str(dropout_rate)]["jaccards"].append(jaccard)
            metrics_dict[algo][str(dropout_rate)]["percent_inter"].append(percent_inter)
            metrics_dict[algo][str(dropout_rate)]["percent_found"].append(percent_found)

            if n_sol == 0:
                logging.info(f"Found no solution for run {i}")
                metrics_dict[algo][str(dropout_rate)]["fails"] += 1

                plot_estimate(
                    agent,
                    n_trials,
                    fn=f"no_sol_{algo}_drop_{dropout_rate}_100_trials_seed_{i}.png",
                )
            agent.net.train()

            logging.info(
                f"jaccard: {jaccard}, percent_inter: {percent_inter}, percent_found: {percent_found}"
            )

torch.save(metrics_dict, f"metrics_dict_exp_synth_{max_n_steps}_dropout.pth")


INFO:root:{}
INFO:root:at algo: regTS, dropout: 0.2, run: 0
INFO:root:jaccard: 0.14558472553699284, percent_inter: 1.0, percent_found: 0.14558472553699284
INFO:root:at algo: regTS, dropout: 0.2, run: 1
INFO:root:Found no solution for run 1
INFO:root:jaccard: 0.0, percent_inter: 0, percent_found: 0.0
INFO:root:at algo: regTS, dropout: 0.2, run: 2
INFO:root:jaccard: 0.11933174224343675, percent_inter: 1.0, percent_found: 0.11933174224343675
INFO:root:at algo: regTS, dropout: 0.2, run: 3
INFO:root:Found no solution for run 3
INFO:root:jaccard: 0.0, percent_inter: 0, percent_found: 0.0
INFO:root:at algo: regTS, dropout: 0.2, run: 4
INFO:root:jaccard: 0.19331742243436753, percent_inter: 1.0, percent_found: 0.19331742243436753
INFO:root:at algo: regTS, dropout: 0.2, run: 5
INFO:root:Found no solution for run 5
INFO:root:jaccard: 0.0, percent_inter: 0, percent_found: 0.0
INFO:root:at algo: regTS, dropout: 0.2, run: 6
INFO:root:jaccard: 0.2458233890214797, percent_inter: 1.0, percent_found: 0.

<Figure size 980x700 with 0 Axes>

In [151]:
for algo in metrics_dict.keys():
    logging.info(f"Algo: {algo}")
    for dropout_rate in metrics_dict[algo].keys():
        logging.info(f'dropout: {dropout_rate}')
        for metric in  metrics_dict[algo][dropout_rate].keys():
            logging.info(f'Metric: {metric}')
            logging.info(f"mean: {np.mean(metrics_dict[algo][dropout_rate][metric])} +- {np.std(metrics_dict[algo][dropout_rate][metric])} ")
            logging.info(f"interval:  [{np.min(metrics_dict[algo][dropout_rate][metric])}, {np.max(metrics_dict[algo][dropout_rate][metric])}]")
            logging.info("============================================")
        # if 0 in metrics_dict[algo][expl_mult]:
        #     logging.info(f'rerun {metrics_dict[algo][expl_mult][metric].index(0)} with plotting')
    
    logging.info("============================================")
    

INFO:root:Algo: regTS
INFO:root:dropout: 0.2
INFO:root:Metric: jaccards
INFO:root:mean: 0.07050119331742243 +- 0.09472410283632834 
INFO:root:interval:  [0.0, 0.2744630071599045]
INFO:root:Metric: percent_inter
INFO:root:mean: 0.44 +- 0.4963869458396343 
INFO:root:interval:  [0.0, 1.0]
INFO:root:Metric: percent_found
INFO:root:mean: 0.07050119331742243 +- 0.09472410283632834 
INFO:root:interval:  [0.0, 0.2744630071599045]
INFO:root:Metric: fails
INFO:root:mean: 28.0 +- 0.0 
INFO:root:interval:  [28, 28]
INFO:root:dropout: 0.5
INFO:root:Metric: jaccards
INFO:root:mean: 0.05427207637231504 +- 0.07976904878667916 
INFO:root:interval:  [0.0, 0.2386634844868735]
INFO:root:Metric: percent_inter
INFO:root:mean: 0.4 +- 0.48989794855663565 
INFO:root:interval:  [0.0, 1.0]
INFO:root:Metric: percent_found
INFO:root:mean: 0.05427207637231504 +- 0.07976904878667916 
INFO:root:interval:  [0.0, 0.2386634844868735]
INFO:root:Metric: fails
INFO:root:mean: 30.0 +- 0.0 
INFO:root:interval:  [30, 30]
INFO

# TEST: Keep dropout active all the time instead of turning it off during training

In [None]:
metrics_dict = {}
algos = ["regTS"]
logging.info(metrics_dict)
max_n_steps = 10
for algo in algos:
    metrics_dict[algo] = {}
    for dropout_rate in [0.2, 0.5, 0.8]:
        metrics_dict[algo][str(dropout_rate)] = {}
        metrics_dict[algo][str(dropout_rate)]["jaccards"] = []
        metrics_dict[algo][str(dropout_rate)]["percent_inter"] = []
        metrics_dict[algo][str(dropout_rate)]["percent_found"] = []
        metrics_dict[algo][str(dropout_rate)]["fails"] = 0
        b = 0

        for i in range(b, 50):
            logging.info(f"at algo: {algo}, dropout: {dropout_rate}, run: {i}")
            make_deterministic(seed=i)
            vecs, rewards = gen_warmup_vecs_and_rewards(10)
            vecs, rewards = vecs.to(device), rewards.to(device)
            n_trials = 500
            width = 100
            net = NetworkDropout(d, width, dropout=dropout_rate).to(device)
            reg = 1
            sampletype = "f"
            reward_fn = reward_fn
            de_config = DEConfig
            de_policy = PullPolicy
            bern_p = 1 - dropout_rate
            p_vec = torch.tensor([bern_p] * width)

            agent = DENeuralTSDiag(
                net, lambda_=reg, style="ts", sampletype=sampletype
            )

            vecs, rewards = gen_warmup_vecs_and_rewards(10)
            vecs, rewards = vecs.to(device), rewards.to(device)
            agent.dataset.set_hists(vecs, rewards)

            # Warmup
            for vec in vecs:
                activ, grad = agent.compute_activation_and_grad(vec)
                agent.U += grad * grad

                # agent.net.eval()
                agent.train(max_n_steps)
                # agent.net.train()

            # Train
            for j in range(n_trials):

                best_member = find_best_member(agent.get_sample, de_config=de_config, seed=j)
                best_member_grad = best_member.activation_grad
                a_t = best_member.point.unsqueeze(0)
                r_t = reward_fn(a_t).unsqueeze(0).unsqueeze(0)

                if best_member_grad is None:
                    break
                agent.U += best_member_grad * best_member_grad

                agent.dataset.add(a_t, r_t)


                # agent.net.eval()
                agent.train(max_n_steps)
                # agent.net.train()

            if best_member_grad is None:
                metrics_dict[algo][str(dropout_rate)]["fails"] += 1
                logging.info(
                    f"Encountered a fail in {algo} {dropout_rate} because of nans"
                )
                continue
            # agent.net.eval()
            sol, _, _ = agent.find_solution_in_vecs(x, 0, n_sigmas=n_sigmas)
            sol = sol.to(device)
            n_sol = len(sol)

            jaccard, n_inter = compute_jaccard(sol, true_sol)
            percent_found = n_inter / n_true_sol
            if n_sol == 0:
                percent_inter = 0
            else:
                percent_inter = n_inter / n_sol

            metrics_dict[algo][str(dropout_rate)]["jaccards"].append(jaccard)
            metrics_dict[algo][str(dropout_rate)]["percent_inter"].append(percent_inter)
            metrics_dict[algo][str(dropout_rate)]["percent_found"].append(percent_found)

            if n_sol == 0:
                logging.info(f"Found no solution for run {i}")
                metrics_dict[algo][str(dropout_rate)]["fails"] += 1

                plot_estimate(
                    agent,
                    n_trials,
                    fn=f"no_sol_{algo}_drop_{dropout_rate}_100_trials_seed_{i}.png",
                )
            agent.net.train()

            logging.info(
                f"jaccard: {jaccard}, percent_inter: {percent_inter}, percent_found: {percent_found}"
            )

torch.save(metrics_dict, f"metrics_dict_exp_synth_{max_n_steps}_dropout.pth")


In [None]:
for algo in metrics_dict.keys():
    logging.info(f"Algo: {algo}")
    for dropout_rate in metrics_dict[algo].keys():
        logging.info(f'dropout: {dropout_rate}')
        for metric in  metrics_dict[algo][dropout_rate].keys():
            logging.info(f'Metric: {metric}')
            logging.info(f"mean: {np.mean(metrics_dict[algo][dropout_rate][metric])} +- {np.std(metrics_dict[algo][dropout_rate][metric])} ")
            logging.info(f"interval:  [{np.min(metrics_dict[algo][dropout_rate][metric])}, {np.max(metrics_dict[algo][dropout_rate][metric])}]")
            logging.info("============================================")
        # if 0 in metrics_dict[algo][expl_mult]:
        #     logging.info(f'rerun {metrics_dict[algo][expl_mult][metric].index(0)} with plotting')
    
    logging.info("============================================")
    