In [None]:
import math, torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
import numpy as np, random, os, time, tarfile, gc
import copy
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device", device)
torch.manual_seed(1); random.seed(1); np.random.seed(1)

device cuda


In [None]:
class PermutedMNIST:
    def __init__(self, num_tasks=10, seed=123):
        shuffler = torch.Generator().manual_seed(seed)
        self.perms = [torch.randperm(784, generator=shuffler) for _ in range(num_tasks)]
        tfm = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))])
        tr = datasets.MNIST("./data", True,  download=True, transform=tfm)
        te = datasets.MNIST("./data", False, download=True, transform=tfm)
        x_tr = tr.data.float().view(-1,784)/255.
        x_te = te.data.float().view(-1,784)/255.
        y_tr = F.one_hot(tr.targets, 10).float()
        y_te = F.one_hot(te.targets, 10).float()
        self.tasks = [(TensorDataset(x_tr[:, p], y_tr), TensorDataset(x_te[:, p], y_te))for p in self.perms]
        self.input_dim = 784
        self.n_classes = 10
        self.num_tasks = num_tasks
    def get_task(self, tid):
        return self.tasks[tid]

class SplitMNIST:
    pairs = [(0,1), (2,3), (4,5), (6,7), (8,9)]
    def __init__(self):
        tfm = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))])
        tr = datasets.MNIST("./data", True,  download=True, transform=tfm)
        te = datasets.MNIST("./data", False, download=True, transform=tfm)
        x_tr = tr.data.float().view(-1,784)/255.
        x_te = te.data.float().view(-1,784)/255.
        self.tasks=[]
        for a, b in self.pairs:
            msk_tr = (tr.targets==a)|(tr.targets==b)
            msk_te = (te.targets==a)|(te.targets==b)
            y_tr = F.one_hot((tr.targets[msk_tr]==b).long(), 2).float()
            y_te = F.one_hot((te.targets[msk_te]==b).long(), 2).float()
            self.tasks.append((TensorDataset(x_tr[msk_tr], y_tr), TensorDataset(x_te[msk_te], y_te)))
        self.input_dim = 784
        self.n_classes = 2
        self.num_tasks = 5
    def get_task(self, tid):
        return self.tasks[tid]

In [None]:
class PlainMLP(nn.Module):
    def __init__(self, dims):
        super().__init__()
        layers=[]
        for din,dout in zip(dims[:-1],dims[1:]):
            layers.append(nn.Linear(din,dout))
            layers.append(nn.ReLU())
        layers.pop() # last one no relu
        self.net = nn.Sequential(*layers)
    def forward(self,x): return self.net(x)

class BayesianLinear(nn.Module):
    def __init__(self, in_f, out_f, prior_var=1.0):
        super().__init__()
        self.w_mu     = nn.Parameter(torch.empty(out_f,in_f))
        self.w_logvar = nn.Parameter(torch.full((out_f,in_f), -6.0))
        self.b_mu     = nn.Parameter(torch.empty(out_f))
        self.b_logvar = nn.Parameter(torch.full((out_f,), -6.0))
        nn.init.normal_(self.w_mu,0,0.1); nn.init.normal_(self.b_mu,0,0.1)
        self.register_buffer("pw_mu", torch.zeros_like(self.w_mu))
        self.register_buffer("pw_logvar", torch.full_like(self.w_mu, math.log(prior_var)))
        self.register_buffer("pb_mu", torch.zeros_like(self.b_mu))
        self.register_buffer("pb_logvar", torch.full_like(self.b_mu, math.log(prior_var)))
    def _sample(self):
        ew = torch.randn_like(self.w_mu)
        eb = torch.randn_like(self.b_mu)
        w = self.w_mu + (0.5*self.w_logvar).exp()*ew
        b = self.b_mu + (0.5*self.b_logvar).exp()*eb
        return w,b
    def forward(self,x,sample=True):
        w,b = self._sample() if sample else (self.w_mu, self.b_mu)
        return F.linear(x,w,b)
    def helper_kl(self, m, lv, m0, lv0):
        v, v0 = lv.exp(), lv0.exp()
        return 0.5*((lv0-lv) + (v+(m-m0).pow(2))/v0 -1).sum()
    def kl(self):
        return self.helper_kl(self.w_mu,self.w_logvar,self.pw_mu,self.pw_logvar) + self.helper_kl(self.b_mu,self.b_logvar,self.pb_mu,self.pb_logvar)
    def update_prior(self):
        self.pw_mu.data.copy_(self.w_mu.data)
        self.pw_logvar.data.copy_(self.w_logvar.data)
        self.pb_mu.data.copy_(self.b_mu.data)
        self.pb_logvar.data.copy_(self.b_logvar.data)

class BayesianMLP(nn.Module):
    def __init__(self, in_dim, hidden, out_dim, heads=1, prior_var=1.0):
        super().__init__()
        self.hidden = nn.ModuleList()
        last = in_dim
        for h in hidden:
            self.hidden.append(BayesianLinear(last,h,prior_var))
            last = h
        self.heads = nn.ModuleList([BayesianLinear(last,out_dim,prior_var) for _ in range(heads)])
        self.out_dim = out_dim
    def add_head(self, out_dim):
        head = BayesianLinear(self.hidden[-1].w_mu.size(0), out_dim)
        head.to(next(self.parameters()).device); self.heads.append(head)
    def forward(self,x,head_id=0,sample=True):
        for l in self.hidden: x = torch.relu(l(x,sample))
        return self.heads[head_id](x,sample)
    def kl(self):
        return sum(l.kl() for l in self.hidden)+sum(h.kl() for h in self.heads)
    def update_prior(self):
        for l in self.hidden: l.update_prior()
        for h in self.heads:  h.update_prior()


In [None]:
class VCLGlobalSigma:
    def __init__(self, in_dim, hidden, n_classes,
                 single_head=True, lr=1e-4, mc=3,
                 init_log_sigma=-3.0,
                 beta=1, prior_var=1.0,
                 coreset_size=0, coreset_epochs=0):

        self.in_dim, self.hidden_sizes = in_dim, list(hidden)
        self.n_classes = n_classes
        self.single_head = single_head
        self.beta            = beta
        self.lr, self.mc = lr, mc
        self.prior_var = prior_var
        self.print_freq      = 5 # debug change
        self.coreset_size    = coreset_size
        self.coreset_epochs = coreset_epochs

        self.model = BayesianMLP(in_dim=self.in_dim, hidden=self.hidden_sizes, out_dim=self.n_classes, heads=1, prior_var=self.prior_var).to(device)

        self.log_sigma = nn.Parameter(torch.tensor(float(init_log_sigma)))
        self.model.register_parameter("log_sigma", self.log_sigma)

        self.rmse_hist = []
        self.core_x, self.core_y = [], []

        init_sig2 = torch.exp(2 * self.log_sigma).item()
        print(f"sigma init={init_sig2} " + f"beta = {self.beta}, lr={self.lr}, coreset size of {self.coreset_size}")

    # for LR just do random other method is pretty much identical in performance to permuted mnist
    def coreset_selection(self, x_full, y_full):
        n_samples = x_full.size(0)
        if self.coreset_size <= 0:
            return None, None, x_full, y_full # don't do coresets.
        if self.coreset_size >= n_samples:
             print("Coreset size > train size")
             return x_full, y_full, None, None

        perm = torch.randperm(n_samples, device=x_full.device)
        core_idx = perm[:self.coreset_size]
        non_core_mask = torch.ones(n_samples, dtype=torch.bool, device=x_full.device)
        non_core_mask[core_idx] = False

        core_x, core_y = x_full[core_idx], y_full[core_idx]
        non_core_x, non_core_y = x_full[non_core_mask], y_full[non_core_mask]

        return core_x, core_y, non_core_x, non_core_y

    @torch.no_grad()
    def evaluate(self, dset, t, model_to_eval):
        model_to_eval.eval()
        rmses = []
        overall_sq_sum = 0.0
        overall_n = 0
        num_eval_tasks = t + 1
        for task_idx in range(num_eval_tasks):
            _, test_ds = dset.get_task(task_idx)
            if len(test_ds) == 0:
                 rmses.append(float('nan'))
                 continue
            loader = DataLoader(test_ds, batch_size=1024, shuffle=False)
            sq_sum, n_elem = 0.0, 0
            head_id_eval = 0
            for xb_test, yb_test in loader:
                xb_test, yb_test = xb_test.to(device), yb_test.to(device)
                pred = model_to_eval(xb_test, head_id=head_id_eval, sample=False)
                sq_sum += (pred - yb_test).pow(2).sum().item()
                n_elem += yb_test.numel()
                overall_sq_sum += (pred - yb_test).pow(2).sum().item()
                overall_n += yb_test.numel()
            rmse = math.sqrt(sq_sum / n_elem) if n_elem > 0 else 0.0
            rmses.append(rmse)
        valid_rmses = [r for r in rmses if not math.isnan(r)]
        overall = math.sqrt(overall_sq_sum/overall_n) if valid_rmses else -1223334444
        self.rmse_hist.append((rmses, overall))
        tasks_str = ", ".join(f"T{i}={r}" for i, r in enumerate(rmses))
        print(f"After Task {t}: full RMSE={overall} | [{tasks_str}]")

    def fit(self, dset, epochs=50, batch_size=256):
        start_fit_time = time.time()
        head_id = 0

        opt = torch.optim.Adam(self.model.parameters(), lr=self.lr)

        for t in range(dset.num_tasks):
            task_start_time = time.time()
            tr_ds, _ = dset.get_task(t)

            x, y = tr_ds.tensors[0].to(device), tr_ds.tensors[1].to(device)

            cx, cy, x_nc, y_nc = self.coreset_selection(x, y)
            if cx is not None:
                self.core_x.append(cx.cpu())
                self.core_y.append(cy.cpu())

            if x_nc is not None and x_nc.size(0) > 0:
                num_train_samples_nc = x_nc.size(0)
                current_batch_size_nc = min(batch_size, num_train_samples_nc)
                loader_nc = DataLoader(TensorDataset(x_nc, y_nc), batch_size=current_batch_size_nc, shuffle=True)

                for ep in tqdm(range(epochs), desc=f"Task {t} learning", leave=False):
                    self.model.train()
                    for batch_idx, (xb, yb) in enumerate(loader_nc):
                        xb, yb = xb.to(device), yb.to(device)
                        opt.zero_grad()

                        batch_nll = 0.0; batch_mse = 0.0
                        inv_var = torch.exp(-2 * self.log_sigma)
                        out_dim = yb.size(1) if yb.ndim > 1 else 1
                        log_2pi = math.log(2 * math.pi)

                        for _ in range(self.mc):
                            out = self.model(xb, head_id=head_id, sample=True)
                            se_sum = (out - yb).pow(2).sum()
                            nll_mc_sample = (0.5 * xb.size(0) * out_dim * log_2pi) + (xb.size(0) * out_dim * self.log_sigma) + (0.5 * inv_var * se_sum)
                            batch_nll += nll_mc_sample
                            batch_mse += se_sum.item()
                        nll = batch_nll / self.mc
                        mse = batch_mse / (self.mc * xb.size(0) * out_dim)
                        kl = self.model.kl()

                        loss = nll / xb.size(0) + self.beta * kl / xb.size(0)

                        loss.backward()
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                        opt.step()

                        # debug
                        if (batch_idx == 0 and ep % self.print_freq == 0):
                            nll_s = nll.item() / xb.size(0)
                            kl_s = kl.item() / xb.size(0)
                            ratio = (self.beta * kl_s) / nll_s
                            sig2 = torch.exp(2 * self.log_sigma).item()
                            print(f"Task {t} Ep {ep} " + f"L={loss.item()} NLL={nll_s} + b*KL={self.beta * kl_s}) "
                                  + f" KL/NLL={ratio}  RMSE={math.sqrt(mse)}  var={sig2}")
                print(f"Task {t} non coreset done")
            else:
                print(f"Task {t}: error 123")

            # store so coreset doesn't override
            state_after_propagation = copy.deepcopy(self.model.state_dict())

            model_state_before_eval = state_after_propagation

            if self.coreset_epochs > 0 and self.core_x:
                cx_all = torch.cat(self.core_x).to(device); cy_all = torch.cat(self.core_y).to(device)
                num_train_samples_core = cx_all.size(0)
                if num_train_samples_core > 0:
                    current_batch_size_core = min(batch_size, num_train_samples_core)
                    loader_core = DataLoader(TensorDataset(cx_all, cy_all), batch_size=current_batch_size_core, shuffle=True)

                    self.model.update_prior()

                    for ep in tqdm(range(self.coreset_epochs), desc=f"Task {t} rem", leave=False):
                        self.model.train()
                        for batch_idx, (xb, yb) in enumerate(loader_core):
                            xb, yb = xb.to(device), yb.to(device)
                            opt.zero_grad()

                            batch_nll = 0.0; batch_mse = 0.0
                            inv_var = torch.exp(-2 * self.log_sigma)
                            out_dim = yb.size(1) if yb.ndim > 1 else 1
                            log_2pi = math.log(2 * math.pi)

                            for _ in range(self.mc):
                                out = self.model(xb, head_id=head_id, sample=True)
                                se_sum = (out - yb).pow(2).sum()
                                nll_mc_sample = (0.5 * xb.size(0) * out_dim * log_2pi) + (xb.size(0) * out_dim * self.log_sigma) + (0.5 * inv_var * se_sum)
                                batch_nll += nll_mc_sample
                                batch_mse += se_sum.item()

                            nll = batch_nll / self.mc
                            mse = batch_mse / (self.mc * xb.size(0) * out_dim)
                            kl = self.model.kl()
                            loss = nll / xb.size(0) + self.beta * kl / xb.size(0)

                            loss.backward()
                            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                            opt.step()

                    print(f"Task {t} coreset done.")
                    model_state_before_eval = copy.deepcopy(self.model.state_dict())
                else:
                     print(f"Task {t} no coreset error tr43")
            else:
                print(f"Task {t} no coreset")

            eval_model = BayesianMLP(in_dim=self.in_dim, hidden=self.hidden_sizes, out_dim=self.n_classes, heads=1, prior_var=self.prior_var).to(device)
            eval_model.register_parameter("log_sigma", self.log_sigma)
            eval_model.load_state_dict(model_state_before_eval)
            self.evaluate(dset, t, eval_model)

            self.model.load_state_dict(state_after_propagation)

            del eval_model

            self.model.update_prior()

            task_end_time = time.time()

        end_fit_time = time.time()
        print(f"took {end_fit_time - start_fit_time} s")

In [None]:
trainer = VCLGlobalSigma(in_dim=784, hidden=[100, 100],
n_classes=10, single_head=True, init_log_sigma=-1.5, beta=1.0, lr=1e-4, mc=3,
prior_var=1.0, coreset_size=200, coreset_epochs=0)

permuted_mnist_data = PermutedMNIST(num_tasks=10)
trainer.fit(permuted_mnist_data, epochs=50, batch_size=256)

sigma init=0.049787066876888275 beta = 1.0, lr=0.0001, coreset size of 200


100%|██████████| 9.91M/9.91M [00:00<00:00, 41.3MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.19MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 9.45MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 11.1MB/s]


Task 0 learning:   0%|          | 0/50 [00:00<?, ?it/s]

Task 0 Ep 0 L=908.9370727539062 NLL=31.655075073242188 + b*KL=877.281982421875)  KL/NLL=27.71378619043097  RMSE=0.6107874940618245  var=0.04979702830314636
Task 0 Ep 5 L=848.68798828125 NLL=-1.1240925788879395 + b*KL=849.8120727539062)  KL/NLL=-755.9982947264202  RMSE=0.21503497543400812  var=0.05390974506735802
Task 0 Ep 10 L=825.8154296875 NLL=-2.8678908348083496 + b*KL=828.683349609375)  KL/NLL=-288.9521942576844  RMSE=0.17734232830536453  var=0.04357746243476868
Task 0 Ep 15 L=806.4302978515625 NLL=-2.759669303894043 + b*KL=809.18994140625)  KL/NLL=-293.2198942331384  RMSE=0.18334911632513878  var=0.03631551191210747
Task 0 Ep 20 L=787.6663818359375 NLL=-2.623272657394409 + b*KL=790.2896728515625)  KL/NLL=-301.2609728630059  RMSE=0.18576388334311916  var=0.03158487379550934
Task 0 Ep 25 L=768.573974609375 NLL=-2.9561386108398438 + b*KL=771.5300903320312)  KL/NLL=-260.99252839596664  RMSE=0.17940126924522945  var=0.028622472658753395
Task 0 Ep 30 L=748.5680541992188 NLL=-4.075110435

Task 1 learning:   0%|          | 0/50 [00:00<?, ?it/s]

Task 1 Ep 0 L=11.775821685791016 NLL=11.775821685791016 + b*KL=0.0)  KL/NLL=0.0  RMSE=0.32051856383044086  var=0.024251272901892662
Task 1 Ep 5 L=2.008873462677002 NLL=0.15118011832237244 + b*KL=1.8576934337615967)  KL/NLL=12.28794800782138  RMSE=0.23327020716522803  var=0.035657092928886414
Task 1 Ep 10 L=0.8966151475906372 NLL=-0.9890345335006714 + b*KL=1.8856496810913086)  KL/NLL=-1.9065559565620855  RMSE=0.21859577964220378  var=0.043158240616321564
Task 1 Ep 15 L=1.3785109519958496 NLL=-0.4148448407649994 + b*KL=1.7933558225631714)  KL/NLL=-4.322955588060618  RMSE=0.2315611880527385  var=0.04861212149262428
Task 1 Ep 20 L=1.1098082065582275 NLL=-0.6244744062423706 + b*KL=1.7342826128005981)  KL/NLL=-2.7771876564745703  RMSE=0.22732113073503077  var=0.051430895924568176
Task 1 Ep 25 L=0.9120521545410156 NLL=-0.8121347427368164 + b*KL=1.724186897277832)  KL/NLL=-2.1230305841460337  RMSE=0.22301774154817233  var=0.051652003079652786
Task 1 Ep 30 L=0.36835813522338867 NLL=-1.371280908

Task 2 learning:   0%|          | 0/50 [00:00<?, ?it/s]

Task 2 Ep 0 L=5.545581340789795 NLL=5.545581340789795 + b*KL=0.0)  KL/NLL=0.0  RMSE=0.33612513070690925  var=0.049719277769327164
Task 2 Ep 5 L=1.5133916139602661 NLL=0.24000994861125946 + b*KL=1.2733817100524902)  KL/NLL=5.305537197189136  RMSE=0.2477994831542156  var=0.05970671772956848
Task 2 Ep 10 L=0.9197371006011963 NLL=-0.5321578979492188 + b*KL=1.451894998550415)  KL/NLL=-2.7283161710943213  RMSE=0.22926167267161207  var=0.05551838502287865
Task 2 Ep 15 L=1.3095134496688843 NLL=-0.2568244934082031 + b*KL=1.5663379430770874)  KL/NLL=-6.098865113256591  RMSE=0.2357010184651078  var=0.05298571661114693
Task 2 Ep 20 L=1.464585542678833 NLL=-0.18230199813842773 + b*KL=1.6468875408172607)  KL/NLL=-9.033842512064659  RMSE=0.23715662896482964  var=0.051650967448949814
Task 2 Ep 25 L=0.9360312223434448 NLL=-0.7461028099060059 + b*KL=1.6821340322494507)  KL/NLL=-2.254560644881322  RMSE=0.22457219355125996  var=0.05073997750878334
Task 2 Ep 30 L=0.7675714492797852 NLL=-0.9587448835372925 

Task 3 learning:   0%|          | 0/50 [00:00<?, ?it/s]

Task 3 Ep 0 L=4.961034297943115 NLL=4.961034297943115 + b*KL=0.0)  KL/NLL=0.0  RMSE=0.32475067380777006  var=0.04826143756508827
Task 3 Ep 5 L=1.0317474603652954 NLL=-0.384521484375 + b*KL=1.4162689447402954)  KL/NLL=-3.683198474702381  RMSE=0.23272210612873054  var=0.05670076981186867
Task 3 Ep 10 L=1.1313657760620117 NLL=-0.4497801661491394 + b*KL=1.581146001815796)  KL/NLL=-3.5153751116973773  RMSE=0.23132845346826217  var=0.053557734936475754
Task 3 Ep 15 L=0.7154818177223206 NLL=-0.9626151919364929 + b*KL=1.6780970096588135)  KL/NLL=-1.743268778340165  RMSE=0.21951207521804905  var=0.05160393565893173
Task 3 Ep 20 L=0.8939873576164246 NLL=-0.8484441637992859 + b*KL=1.7424315214157104)  KL/NLL=-2.0536784808716213  RMSE=0.22228235451076947  var=0.04989532381296158
Task 3 Ep 25 L=0.3464404344558716 NLL=-1.4548801183700562 + b*KL=1.8013205528259277)  KL/NLL=-1.2381230110175667  RMSE=0.20848081772504584  var=0.04902258887887001
Task 3 Ep 30 L=0.4817357063293457 NLL=-1.3549723625183105 

Task 4 learning:   0%|          | 0/50 [00:00<?, ?it/s]

Task 4 Ep 0 L=5.037041187286377 NLL=5.037041187286377 + b*KL=0.0)  KL/NLL=0.0  RMSE=0.32191129683491654  var=0.046153485774993896
Task 4 Ep 5 L=0.3995068073272705 NLL=-0.948577880859375 + b*KL=1.3480846881866455)  KL/NLL=-1.4211639501495994  RMSE=0.219211831086742  var=0.05460754409432411
Task 4 Ep 10 L=1.1788958311080933 NLL=-0.3115638196468353 + b*KL=1.490459680557251)  KL/NLL=-4.78380218295796  RMSE=0.2343034550293078  var=0.05149635672569275
Task 4 Ep 15 L=0.8364073634147644 NLL=-0.7510859370231628 + b*KL=1.5874933004379272)  KL/NLL=-2.113597422326615  RMSE=0.2244289702003206  var=0.04915712773799896
Task 4 Ep 20 L=0.7548840045928955 NLL=-0.9023079872131348 + b*KL=1.6571919918060303)  KL/NLL=-1.8366145654151058  RMSE=0.22108235144466032  var=0.04818981513381004
Task 4 Ep 25 L=0.9596253633499146 NLL=-0.7558416128158569 + b*KL=1.7154669761657715)  KL/NLL=-2.269611711076438  RMSE=0.2241421868974673  var=0.0472627654671669
Task 4 Ep 30 L=0.503569483757019 NLL=-1.2457808256149292 + b*KL

Task 5 learning:   0%|          | 0/50 [00:00<?, ?it/s]

Task 5 Ep 0 L=6.5211076736450195 NLL=6.5211076736450195 + b*KL=0.0)  KL/NLL=0.0  RMSE=0.34029720644510925  var=0.04517822340130806
Task 5 Ep 5 L=0.584222674369812 NLL=-1.034487247467041 + b*KL=1.618709921836853)  KL/NLL=-1.5647461346674798  RMSE=0.21763012651003202  var=0.05250629782676697
Task 5 Ep 10 L=0.7078524827957153 NLL=-1.0946104526519775 + b*KL=1.8024629354476929)  KL/NLL=-1.6466706772996358  RMSE=0.2168151756100267  var=0.04870166629552841
Task 5 Ep 15 L=1.1231489181518555 NLL=-0.7975530624389648 + b*KL=1.9207019805908203)  KL/NLL=-2.4082435025917888  RMSE=0.22311486200689573  var=0.04626486822962761
Task 5 Ep 20 L=0.9001696109771729 NLL=-1.0887397527694702 + b*KL=1.988909363746643)  KL/NLL=-1.8267996173441596  RMSE=0.21691915139621612  var=0.04518156498670578
Task 5 Ep 25 L=0.43094682693481445 NLL=-1.5837805271148682 + b*KL=2.0147273540496826)  KL/NLL=-1.2721000918731202  RMSE=0.20642820857041758  var=0.044549789279699326
Task 5 Ep 30 L=0.7264542579650879 NLL=-1.331638574600

Task 6 learning:   0%|          | 0/50 [00:00<?, ?it/s]

Task 6 Ep 0 L=6.581894874572754 NLL=6.581894874572754 + b*KL=0.0)  KL/NLL=0.0  RMSE=0.33529726276457944  var=0.04272717982530594
Task 6 Ep 5 L=1.0254392623901367 NLL=-0.6334799528121948 + b*KL=1.6589192152023315)  KL/NLL=-2.6187398793567573  RMSE=0.22710994482871089  var=0.050977371633052826
Task 6 Ep 10 L=0.6484266519546509 NLL=-1.1330074071884155 + b*KL=1.7814340591430664)  KL/NLL=-1.5723057482596137  RMSE=0.21602662180938642  var=0.047677651047706604
Task 6 Ep 15 L=0.6327153444290161 NLL=-1.2311619520187378 + b*KL=1.863877296447754)  KL/NLL=-1.5139172335463682  RMSE=0.2139390241260494  var=0.04548844322562218
Task 6 Ep 20 L=0.17220795154571533 NLL=-1.7493975162506104 + b*KL=1.9216054677963257)  KL/NLL=-1.098438433772788  RMSE=0.20284195521979978  var=0.044445086270570755
Task 6 Ep 25 L=0.6204533576965332 NLL=-1.3320157527923584 + b*KL=1.9524691104888916)  KL/NLL=-1.4658003153460097  RMSE=0.2117521402790546  var=0.04359719529747963
Task 6 Ep 30 L=0.09010601043701172 NLL=-1.8947628736

Task 7 learning:   0%|          | 0/50 [00:00<?, ?it/s]

Task 7 Ep 0 L=6.518470764160156 NLL=6.518470764160156 + b*KL=0.0)  KL/NLL=0.0  RMSE=0.3324276461550422  var=0.04188117757439613
Task 7 Ep 5 L=0.04267013072967529 NLL=-1.5039385557174683 + b*KL=1.5466086864471436)  KL/NLL=-1.028372256677281  RMSE=0.2075315703845403  var=0.048275452107191086
Task 7 Ep 10 L=0.37517476081848145 NLL=-1.3087584972381592 + b*KL=1.6839332580566406)  KL/NLL=-1.2866646227017464  RMSE=0.2122776481022966  var=0.045677173882722855
Task 7 Ep 15 L=0.42648768424987793 NLL=-1.3448235988616943 + b*KL=1.7713112831115723)  KL/NLL=-1.3171328080581512  RMSE=0.21151168522896246  var=0.04408686235547066
Task 7 Ep 20 L=0.4025275707244873 NLL=-1.433563232421875 + b*KL=1.8360908031463623)  KL/NLL=-1.2807881519425226  RMSE=0.2096200088304095  var=0.04283541068434715
Task 7 Ep 25 L=-0.44530296325683594 NLL=-2.3136563301086426 + b*KL=1.8683533668518066)  KL/NLL=-0.8075327966984942  RMSE=0.19094901194313904  var=0.04241352155804634
Task 7 Ep 30 L=0.026658296585083008 NLL=-1.87276768

Task 8 learning:   0%|          | 0/50 [00:00<?, ?it/s]

Task 8 Ep 0 L=6.749584197998047 NLL=6.749584197998047 + b*KL=0.0)  KL/NLL=0.0  RMSE=0.3325406404387076  var=0.040786515921354294
Task 8 Ep 5 L=0.37208282947540283 NLL=-1.2757091522216797 + b*KL=1.6477919816970825)  KL/NLL=-1.2916674453792318  RMSE=0.2128297771325038  var=0.04787399247288704
Task 8 Ep 10 L=0.7706296443939209 NLL=-1.0561573505401611 + b*KL=1.826786994934082)  KL/NLL=-1.729654197833107  RMSE=0.2175185235492165  var=0.04455837234854698
Task 8 Ep 15 L=0.6009624004364014 NLL=-1.3329700231552124 + b*KL=1.9339324235916138)  KL/NLL=-1.450844647664237  RMSE=0.2116567851811301  var=0.042753349989652634
Task 8 Ep 20 L=0.5095998048782349 NLL=-1.4795074462890625 + b*KL=1.9891072511672974)  KL/NLL=-1.3444388239859324  RMSE=0.20861411167923907  var=0.04186514392495155
Task 8 Ep 25 L=0.6037907600402832 NLL=-1.4121372699737549 + b*KL=2.015928030014038)  KL/NLL=-1.4275722855551465  RMSE=0.2098656205865987  var=0.04120565950870514
Task 8 Ep 30 L=0.4483994245529175 NLL=-1.5966895818710327 

Task 9 learning:   0%|          | 0/50 [00:00<?, ?it/s]

Task 9 Ep 0 L=6.841977119445801 NLL=6.841977119445801 + b*KL=0.0)  KL/NLL=0.0  RMSE=0.33086240992980037  var=0.0397176519036293
Task 9 Ep 5 L=0.4975717067718506 NLL=-1.0448216199874878 + b*KL=1.5423933267593384)  KL/NLL=-1.4762264651240746  RMSE=0.2179483689016608  var=0.048338983207941055
Task 9 Ep 10 L=0.6079267263412476 NLL=-1.0729718208312988 + b*KL=1.6808985471725464)  KL/NLL=-1.5665821921308691  RMSE=0.21733416172564002  var=0.04637938365340233
Task 9 Ep 15 L=0.07655990123748779 NLL=-1.7135663032531738 + b*KL=1.7901262044906616)  KL/NLL=-1.0446786920892062  RMSE=0.20366245227980878  var=0.04420487955212593
Task 9 Ep 20 L=0.07308101654052734 NLL=-1.7879360914230347 + b*KL=1.861017107963562)  KL/NLL=-1.040874512736281  RMSE=0.20224469238267287  var=0.04287329688668251
Task 9 Ep 25 L=0.08788645267486572 NLL=-1.8152642250061035 + b*KL=1.9031506776809692)  KL/NLL=-1.0484152397563888  RMSE=0.2017489381791369  var=0.04205869883298874
Task 9 Ep 30 L=0.3837013244628906 NLL=-1.545927882194