In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch
import torch.nn.functional as F
from torch import nn
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import matplotlib

import utils
from BayesModel import BFC, FC

import cProfile
import pstats
device = torch.device('cpu')

In [45]:
font = {'family' : 'normal',
        'weight' : 'normal',
        'size'   : 22}
matplotlib.rc('font', **font)
plt.rcParams["figure.figsize"] = (10,6)

In [60]:
def plot_model_pred(model, Y_test, X_test, enums=10, single=False):
    Y_sample = Y_test.detach().numpy()

    # sample test data enums times and average to make a prediction
    Y_pred = torch.zeros(enums, *Y_test.size())
    for j in range(enums):
        Y_pred[j] = model(X_test)

    if single:  # Plots all the predictions of a single data point
        y1 = Y_pred.detach().numpy()[:, 1]
        Y = Y_sample[1]
        pred_mean = y1.mean()
        x = np.linspace(0, 1, len(y1))
        plt.plot(x, y1, "bo", label="all preds for single datapoint")
        plt.plot(x, np.ones_like(x) * Y_sample[1], "r--", lw=5, label="true value")
        plt.plot(x, np.ones_like(x) * pred_mean, "k--", lw=4, label="prediction mean")
        plt.plot(x, pred_mean + np.ones_like(x) * y1.std(), "g--", lw=4, label="prediction std")
        plt.plot(x, pred_mean - np.ones_like(x) * y1.std(), "g--", lw=4)
        plt.xlabel("Prediction count")
        plt.ylabel("Value")
        plt.title("All predictions for a single datapoint after training")
    else:  # Plot mean of all predictions for 40 datapoints, with error bounds
        y1 = Y_pred.detach().numpy()
        pred_mean = y1.mean(0) 
        pred_std = y1.std(0)
        print("mean mean: ", pred_mean.mean())
        print("mean std:", pred_mean.std())
        print("std mean: ", pred_std.mean())
        print("std std:", pred_std.std())
        x = np.arange(len(Y_sample))
        xx = np.c_[x, x, x, x, x, x] if Y_test.size(1) == 3 else np.c_[x, x]
        interval = np.c_[pred_mean - pred_std, pred_mean + pred_std]
        R2 = 1 - ((Y_sample - pred_mean) ** 2).sum() / ((Y_sample - Y_sample.mean()) ** 2).sum()
    
        plt.plot(Y_sample[:40], "bo", ms=8, label="target")
        plt.plot(pred_mean[:40], "ro", ms=4, label="mean")
        plt.plot(xx[:40], interval[:40], "go", ms=4, label="1 sigma")
        plt.title(f"R2: {R2}")
        plt.xlabel("Datapoints")
        plt.ylabel("Value")
    plt.legend()
    plt.show()

def plot_model_pred_freq(model, Y_test, X_test):
    Y_sample = Y_test.detach().numpy()
    Y_pred_sample = model(X_test.to(torch.float)).detach().numpy()

    R2 = 1 - ((Y_sample - Y_pred_sample) ** 2).sum() / ((Y_sample - Y_sample.mean()) ** 2).sum()
    plt.plot(Y_sample[:100], "bo-", ms=8, label="target")
    plt.plot(Y_pred_sample[:100], "ro--", ms=4, label="prediction")
    plt.title(f"R2-score on test data: {R2}")
    plt.legend()
    plt.show()

In [61]:
def train_model(*, model, optimer, data, device, epochs, enums=10, kl_on = True):
    pbar = tqdm(range(epochs))
    losses = {"total": [], "recon": [], "kl": []}
    for epoch in pbar:
        for batch_idx, (x, y) in enumerate(data):
            optimer.zero_grad()
                # sample batch enums times during training and average prediction before calculating loss
            outs = torch.zeros(enums, *y.size())
            for j in range(enums):
                outs[j] = model(x, train=True)
            
            pred = outs.mean(0)  # take average of all predictions of each datapoint as final prediction of that datapoint 
            loss_recon = -torch.distributions.Normal(pred, 0.1).log_prob(y).mean()

            if kl_on:
                loss_kl = model.kl_reset() / (len(x) * enums)  # normalise by number of batches and enums
                loss =  loss_recon + loss_kl
            else:
                loss =  loss_recon

            loss.backward()

            optimer.step()

        if kl_on:    
            pbar.set_description(f"total loss: {loss:.4f}, recon. loss: {loss_recon:.4f}, kl_loss: {loss_kl:.4f}")
            losses["kl"].append(loss_kl.detach().numpy())
        else:
            pbar.set_description(f"total loss: {loss:.4f}, recon. loss: {loss_recon:.4f}")
        losses["total"].append(loss.detach().numpy())
        losses["recon"].append(loss_recon.detach().numpy())

    return losses

def train_model_freq(*, model, optimer, data, device, epochs):
    losses = []
    pbar = tqdm(range(epochs))
    loss_fn = torch.nn.MSELoss()
    for epoch in pbar:
        for bi, (x, y) in enumerate(data):
            x, y = x.to(device), y.to(device)
            optimer.zero_grad()

            # was simpler to copy code than generalize for non-kl_divergence
            pred = model(x.to(torch.float))
            loss = loss_fn(pred, y.to(torch.float))
            
            loss.backward()
            optimer.step()

        losses.append(loss.detach().numpy())
        pbar.set_description(f"total loss: {loss:.4f}")
    return losses

In [62]:
def plot_losses(losses, model="FNN", save=False, show=False):
    plt.plot(losses["total"][0:], label="total loss")
    plt.plot(losses["recon"][0:], label="recon loss")
    plt.plot(losses["kl"][0:], label="kl loss")
    title = f'{model} Losses'
    plt.title(title)
    plt.legend()
    if show: plt.show()
    if save: plt.savefig(f'figures/carbon_{title.replace(" ", "_")}', bbox_inches='tight', dpi=300)
    plt.clf()

In [114]:
def plot_pred_uncertainty(x, y, gnet, train_size, net_type, enums=500, save=False, show=False):
    y_pred = torch.zeros(enums, *y.size())

    for i in tqdm(range(enums)):
        y_pred[i] = gnet(x)
    y_pred = y_pred.detach().numpy()
    y_mean = y_pred.mean(axis=0)
    y_sigma = y_pred.std(axis=0)
    
    a = 50
    aaaa = np.arange(a)

    R2 = 1 - ((y - y_mean) ** 2).sum() / ((y - y.mean()) ** 2).sum()
    plt.plot(aaaa, y_mean[:a], 'ro', lw=1, label='Predictive mean')
    plt.plot(aaaa, y[:a], "bo", ms=4, lw=1, label='Target value')
    plt.fill_between(aaaa.ravel(), 
                    (y_mean + 2 * y_sigma)[:a, 0], 
                    (y_mean - 2 * y_sigma)[:a, 0], 
                    alpha=0.5, label='Epistemic uncertainty')
    plt.title(f'Prediction with R2-score on test data: {R2:.4f}')
    plt.ylim([0, 1])
    plt.xlabel("data-points")
    plt.ylabel("coordinate value")
    # plt.legend()
    if show: plt.show()
    if save:
        filepath = f'figures/carbon_{net_type}_{train_size}'.replace(".", "")
        plt.savefig(filepath, bbox_inches='tight', dpi=300)
    plt.clf()

def plot_pred_freq(x, y, fnet, train_size, params, save=False, show=False):
    y_pred = fnet(x.to(torch.float)).detach().numpy()

    R2 = 1 - ((y - y_pred) ** 2).sum() / ((y - y.mean()) ** 2).sum()

    plt.plot(y_pred[:50], 'ro', lw=1, label='Predicted value')
    plt.plot(y[:50], "bo", ms=4, lw=1, label='Target value')
    plt.title(f'Prediction with R2-score on test data: {R2:.4f}')
    plt.ylim([0, 1])
    plt.legend(loc="lower right")
    plt.xlabel("data-points")
    plt.ylabel("coordinate value")
    if show: plt.show()
    if save:
        filepath = f'figures/carbon_FCNET_ts{train_size}_l{len(params[0])}'.replace(".", "")
        plt.savefig(filepath, bbox_inches='tight', dpi=300)
    plt.clf()

In [121]:
data_params = [(0.2, 200), (0.99, 10)]
fc_params = [([20, 20], 1000), ([50, 50, 50, 50, 50, 50], 5000)]

for test_s, batch_s in data_params:
    train_loader, test_data = utils.get_nanotube_data(test_size=test_s, batch_size=batch_s)
    X_test, Y_test = test_data

    # for params in fc_params:
    #     print("FC", test_s, params)
    #     fnet = FC(features=7, classes=1, hiddens=params[0])
    #     foptimizer = torch.optim.AdamW(fnet.parameters(), lr=0.001)
    #     losses = train_model_freq(model=fnet, optimer=foptimizer, data=train_loader, device=device, epochs=15)
    #     plot_pred_freq(X_test, Y_test, fnet, test_s, params, save=True, show=False)
    
    # print("VMF", test_s)
    # vnet = BFC(features=7, classes=1, hiddens=[20,20], prior={"dist": "vmf", "loc": 1, "scale": .1, "record": False, "dist_kwargs": {"k": 100}})
    # voptimizer = torch.optim.AdamW(vnet.parameters(), lr=.001)
    # losses = train_model(model=vnet, optimer=voptimizer, data=train_loader, device=device, epochs=15, enums=10, kl_on=True)
    # plot_pred_uncertainty(X_test, Y_test, vnet, test_s, 'VmfNET_highlr', enums=500, save=True, show=False)
        
    print('GAUSS', test_s)
    gnet = BFC(features=7, classes=1, hiddens=[20, 20], prior={"dist": "normal", "loc": 0, "scale": .1, "record": False})
    goptimizer = torch.optim.AdamW(gnet.parameters(), lr=0.001)
    losses = train_model(model=gnet, optimer=goptimizer, data=train_loader, device=device, epochs=15, enums=10, kl_on=True)
    plot_pred_uncertainty(X_test, Y_test, gnet, test_s, 'GaussNET_redo', enums=500, save=True, show=False)


GAUSS 0.2


  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

GAUSS 0.99


  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

<Figure size 720x432 with 0 Axes>