In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import pickle

In [None]:
with open('../data/HeadsOrTails_data.pkl', 'rb') as f:
    data = pickle.load(f)

data_tr = data['data_tr']
data_te = data['data_te']
x_m = data['x_m']
x_std = data['x_std']
y_m = data['y_m']
y_std = data['y_std']

In [None]:
train_loader = torch.utils.data.DataLoader(data['data_tr'], batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(data['data_te'], batch_size=100000, shuffle=True)

# Define loss functions

In [None]:
def KL_loss(logsig_i, mu_i, logsig_f=torch.tensor(0.0), mu_f=torch.tensor(0.0)):
    temp = 2*(logsig_f-logsig_i) - 1 + torch.exp(logsig_i)**2/torch.exp(logsig_f)**2 + (mu_f - mu_i)**2/torch.exp(logsig_f)**2
    temp = 0.5*torch.sum(temp, axis=-1)
    return torch.mean(temp)

def MSELoss_allTargets(pred, target, return_indices=False):
    """MSE loss applied on all possible targets, then taking the minimum per data point, i.e., using the closest target.

    Parameters
    ----------
    pred : torch tensor, [N, f]
        prediction per node, N = batch size, f = nr of predicted features
    target : torch tensor, [N, f, a]
        all possible targets, N = batch size, f = nr of predicted features, a = nr of possible alternatives
    return_indices : bool, optional
        whether to return the indices of the alternatives giving the minimum loss, by default False

    Returns
    -------
    MSE : torch tensor, [1,]
        MSE loss, using the closest alternative for each graph
    indices : torch tensor, [N,]
        index indicating which of the alternatives was closest
    """
    # sum square error of both nodes per alternative (resulting shape: [N, a])
    SE = torch.sum((pred.unsqueeze(-1) - target)**2, dim=1)
    # take the minimum square error per graph
    SE_min, inds = torch.min(SE, axis=-1)

    # take mean over entire batch
    MSE = torch.sum(SE_min)/(pred.numel())

    if return_indices:
        return MSE, inds
    else:
        return MSE

# Choose device

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    print('cuda available')
    # mlflow.log_param('device', torch.cuda.get_device_name(device))
else:
    device = torch.device('cpu')
    print('cuda not available')
    # mlflow.log_param('device', 'cpu')

# Define model

In [None]:
class MyModel(torch.nn.Module):
    def __init__(self, latent_dim):
        super(MyModel, self).__init__()
        self.latent_dim = latent_dim

        self.mlp1 = torch.nn.Sequential(
                        torch.nn.Linear(1, 32),
                        torch.nn.LeakyReLU(),
                        torch.nn.Linear(32, 32),
                        torch.nn.LeakyReLU(),
                        torch.nn.Linear(32, latent_dim*2)
                        )
        self.mlp2 = torch.nn.Sequential(
                        torch.nn.Linear(2, 32),
                        torch.nn.LeakyReLU(),
                        torch.nn.Linear(32, 32),
                        torch.nn.LeakyReLU(),
                        torch.nn.Linear(32, latent_dim*2)
                        )
        self.mlp3 = torch.nn.Sequential(
                        torch.nn.Linear(self.latent_dim+1, 32),
                        torch.nn.LeakyReLU(),
                        torch.nn.Linear(32, 32),
                        torch.nn.LeakyReLU(),
                        torch.nn.Linear(32, 1)
                        )

    def forward(self, x, y, cond_on_final=False):
        logsig_mu_i = self.mlp1(x)
        logsig_i, mu_i = logsig_mu_i[..., :self.latent_dim], logsig_mu_i[..., self.latent_dim:]
        if cond_on_final:
            # also predict f, mu based on
            logsig_mu_f = self.mlp2(torch.cat((x, y), dim=-1))
            logsig_f, mu_f = logsig_mu_f[..., :self.latent_dim], logsig_mu_f[..., self.latent_dim:]

        # Sample
        eps = torch.randn_like(logsig_i)
        z_i = eps*torch.exp(logsig_i) + mu_i
        if cond_on_final:
            eps = torch.randn_like(logsig_f)
            z_f = eps*torch.exp(logsig_f) + mu_f
            x = torch.cat((x, z_f), dim=-1)
        else:
            x = torch.cat((x, z_i), dim=-1)

        x = self.mlp3(x)

        if cond_on_final:
            return logsig_i, mu_i, logsig_f, mu_f, x
        else:
            return (x, )

In [None]:
# using latent_dim=2 here to be able to visualize the latent space, but a higher value would be more useful in practice
model = MyModel(latent_dim=2).to(device=device)
print(model)
n_params = sum(p.numel() for p in model.parameters())
print('Total nr of parameters:', n_params)

# Evaluate model before training

In [None]:
model.eval()
for batch in test_loader:
    batch = batch.clone()
    batch.to(device=device)
    x = batch[:, 0].unsqueeze(-1).to(device=device)
    print(x.shape)
    y = batch[:, 1].unsqueeze(-1).to(device=device)
    pred = model(x, y)[-1]
    print('Shapes:', pred.shape, y.shape)

plt.scatter(x.cpu().detach().numpy(), y.cpu().detach().numpy(), s=1, label='Real')
plt.scatter(x.cpu().detach().numpy(), pred.cpu().detach().numpy(), s=1, label='Predicted')

for batch in test_loader:
    batch = batch.clone()
    batch.to(device=device)
    x = batch[:, 0].unsqueeze(-1).to(device=device)
    print(x.shape)
    y = batch[:, 1].unsqueeze(-1).to(device=device)
    pred = model(x, y, cond_on_final=True)[-1]
    print('Shapes:', pred.shape, y.shape)

plt.scatter(x.cpu().detach().numpy(), pred.cpu().detach().numpy(), s=1, label='Predicted, cond. on output')
plt.xlabel('Bet (€)')
plt.ylabel('Winnings (€)')
plt.title('Real and predicted before training\n(to check initialization of modelwork)')
plt.legend()
plt.gca().set_aspect('equal')

# Train

In [None]:
lr=1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

losses = {
            # on training data, while conditioning on output:
            'MSE_train': [],
            'KL_train': [],
            'loss_train': [],

            # on validation data, without conditioning on output:
            'MSE_val': [],
            'MSE_val_allTargets': [],  # MSE compared to closest target

            # on validation data, but still conditioned on output:
            'KL_val2': [],
            'MSE_val2': []
        }

KL_increase_range = 50  # nr of epochs over which we increase the weight of the KL loss
for epoch in range(1000):
    print(f'epoch {epoch}')

    # ====================== 1) TRAIN ======================
    KL_weight = np.clip(epoch/KL_increase_range, a_min=0.0, a_max=1.0)/2
    print(' \t KL_weight:', KL_weight)
    losses_temp = {key: [] for key in losses}

    model.train()
    for batch in train_loader:
        batch = batch.clone().to(device=device)
        optimizer.zero_grad()

        out = model(batch[..., [0]], batch[..., [1]], cond_on_final=True)
        pred = out[-1]
        if torch.isnan(pred).any():
            raise ValueError('nan value in train prediction')

        MSE_loss = torch.nn.MSELoss()(pred, batch[..., [1]])
        KL_loss1 = KL_loss(*out[:-1])

        loss = (1.0-KL_weight)*MSE_loss + KL_weight*KL_loss1
        loss.backward()

        optimizer.step()

        losses_temp['MSE_train'].append(MSE_loss.item())
        losses_temp['KL_train'].append(KL_loss1.item())
        losses_temp['loss_train'].append(loss.item())


    for loss in ['MSE_train', 'KL_train', 'loss_train']:
        losses[loss].append(np.mean(losses_temp[loss]))
        print(f'\t{loss:10} {losses[loss][-1]:8.4}')

    # =============== 2) EVALUATE WITHOUT CONDITIONING ON OUTPUT ===============
    # In evaluation mode, so not conditioned on output. Calculating KL loss is therefore not possible.
    model.eval()
    for batch in test_loader:
        batch = batch.clone().to(device=device)
        out = model(batch[..., [0]], batch[..., [1]], cond_on_final=False)
        pred = out[-1]
        if torch.isnan(pred).any():
            raise ValueError('nan value in train prediction')

        MSE_loss = torch.nn.MSELoss()(pred, batch[..., [1]])
        y_all = batch[..., 1:]
        MSE_loss_allTargets = MSELoss_allTargets(pred, y_all.unsqueeze(1))

        losses_temp['MSE_val'].append(MSE_loss.item())
        losses_temp['MSE_val_allTargets'].append(MSE_loss_allTargets.item())

    for loss in ['MSE_val', 'MSE_val_allTargets']:
        losses[loss].append(np.mean(losses_temp[loss]))
        print(f'\t{loss:10} {losses[loss][-1]:8.4}')

    # ================ 3) EVALUATE WITH CONDITIONING ON OUTPUT ================
    model.eval()
    for batch in test_loader:
        batch = batch.clone().to(device=device)
        out = model(batch[..., [0]],
                    batch[..., [1]],
                    cond_on_final=True)
        pred = out[-1]
        if torch.isnan(pred).any():
            raise ValueError('nan value in train prediction')

        MSE_loss = torch.nn.MSELoss()(pred, batch[..., [1]])
        KL_loss1 = KL_loss(*out[:-1])

        losses_temp['MSE_val2'].append(MSE_loss.item())
        losses_temp['KL_val2'].append(KL_loss1.item())

    for loss in ['MSE_val2', 'KL_val2']:
        losses[loss].append(np.mean(losses_temp[loss]))
        print(f'\t{loss:10} {losses[loss][-1]:8.4}')

    # ======== Early stopping =========
    # Early stopping if both MSE and KL loss did not decrease enough on validation data
    if (
        epoch > KL_increase_range
        and (np.mean(losses["MSE_val_allTargets"][-5:])
             > 0.995*np.mean(losses["MSE_val_allTargets"][-10:-5]))
        and (np.mean(losses["KL_val2"][-5:])
             > 0.995*np.mean(losses["KL_val2"][-10:-5]))
    ):

        print('======= !! Validation loss did not decrease enough =======')
        print('Stopping training')
        break



# Plot losses

In [None]:
plt.figure()

for key, value in losses.items():
    plt.plot(value, label=key)
plt.legend()
plt.yscale('log')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()


# Compare input, output, predicted output

In [None]:
# print input
print('============== INPUT ==============')
temp_x = (batch[:10, [0]]*x_std + x_m).cpu().detach().numpy().astype(int)
print(temp_x)

print('============== OUTPUTS ==============')
temp_y = (batch[:10, [1]]*y_std + y_m).cpu().detach().numpy()
print(temp_y)

print('============== PREDICTIONS ==============')
model.eval()
pred = model(batch[:10, [0]].to(device), batch[:10, [1]].to(device), cond_on_final=False)[-1]*y_std + y_m
temp_y = pred.cpu().detach().numpy()
print(temp_y)

In [None]:
fig, ax = plt.subplots(figsize=(4,3), dpi=200)

model.eval()
for batch in test_loader:
    x = (batch[:, [0]]*x_std + x_m).cpu().detach().numpy()
    for i in range(10):  # 10 predictions per data point
        batch_temp = batch.clone().to(device=device)
        pred = model(batch_temp[..., [0]].to(device), batch_temp[..., [1]].to(device), cond_on_final=False)[-1]*y_std + y_m
        pred = pred.cpu().detach().numpy()

        if i == 0:
            plt.scatter(x, pred, s=1, c='tab:orange', alpha=0.5, label='predicted')
        else:
            plt.scatter(x, pred, s=1, c='tab:orange', alpha=0.5)


    x = (batch[:, [0]]*x_std + x_m).cpu().detach().numpy()
    y = (batch[:, [1]]*y_std + y_m).cpu().detach().numpy()
    plt.scatter(x.flatten(), y.flatten(), s=1, label='true')
# plt.title('Not conditioned on output')
plt.xlabel(f'bet (€)')
plt.ylabel(f'winnings (€)')
plt.gca().set_aspect('equal')
plt.legend()
fig.subplots_adjust(left=0.2, bottom=0.2)

In [None]:
plt.figure()

model.eval()
for batch in test_loader:
    x = (batch[:, [0]]*x_std + x_m).cpu().detach().numpy()
    for i in range(10):
        pred = model(batch[..., [0]].to(device), batch[..., [1]].to(device), cond_on_final=True)[-1]*y_std + y_m
        pred = pred.cpu().detach().numpy()

        if i == 0:
            plt.scatter(x, pred, s=1, c='tab:orange', alpha=0.5, label='predicted')
        else:
            plt.scatter(x, pred, s=1, c='tab:orange', alpha=0.5)

    x = (batch[:, [0]]*x_std + x_m).cpu().detach().numpy()
    y = (batch[:, [1]]*y_std + y_m).cpu().detach().numpy()
    plt.scatter(x, y, s=1, label='true')
plt.title('Conditioned on output')
plt.xlabel(f'bet (€)')
plt.ylabel(f'winnings (€)')
plt.gca().set_aspect('equal')
plt.legend()

# 1 input, 200 outputs

In [None]:
%matplotlib inline
N = 2000
bet = 50.0
x = ((torch.Tensor([[bet]]*N)-x_m)/x_std).to(device)
y = ((torch.Tensor([[bet]]*N)-y_m)/y_std).to(device)
model.eval()
pred = model(x, y)[-1]*y_std + y_m
pred = pred.cpu().detach().numpy()

plt.figure()
asdf = plt.hist(pred, bins=np.linspace(-100, 100, 50), label='predicted winnings')
plt.xlabel('predicted winnings (€)')
plt.ylabel('nr of predictions')
plt.title(f'Predicted winnings for a bet of €{bet}\n(latent space dim = {model.latent_dim})')
plt.vlines([bet, -bet], 0, asdf[0].max(), colors='r', linestyles='dashed', label='possible ground truth winnings')
plt.legend()
plt.xlim(-110, 110)
plt.show()

# Plot latent space
Needs a latent_dim of 2, so the latent space can be plotted in 2D.

In [None]:
if not model.latent_dim == 2:
    raise ValueError('Plotting latent space only works for latent_dim=2')

In [None]:
import matplotlib.patches as mpatch

In [None]:
fig = plt.figure(figsize=(10,10))
N = 300
bet = 50.0
x = ((torch.Tensor([[bet]]*N)-x_m)/x_std).to(device)
y = ((torch.Tensor([[bet]]*N)-y_m)/y_std).to(device)

with torch.no_grad():
    logsig_mu_i = model.mlp1(x)
    logsig_i, mu_i = logsig_mu_i[..., :model.latent_dim], logsig_mu_i[..., model.latent_dim:]

    logsig_mu_f = model.mlp2(torch.cat((x, y), dim=-1))
    logsig_f, mu_f = logsig_mu_f[..., :model.latent_dim], logsig_mu_f[..., model.latent_dim:]
    logsig_mu_f2 = model.mlp2(torch.cat((x, -y), dim=-1))
    logsig_f2, mu_f2 = logsig_mu_f2[..., :model.latent_dim], logsig_mu_f2[..., model.latent_dim:]

    # Sample
    eps = torch.randn_like(logsig_i)
    z_i = eps*torch.exp(logsig_i) + mu_i

    eps = torch.randn_like(logsig_f)
    z_f = eps*torch.exp(logsig_f) + mu_f
    eps = torch.randn_like(logsig_f2)
    z_f2 = eps*torch.exp(logsig_f2) + mu_f2

    # Turn latent samples into prediction
    x1 = torch.cat((x, z_i), dim=-1)
    x2 = torch.cat((x, z_f), dim=-1)
    x3 = torch.cat((x, z_f2), dim=-1)
    pred1 = model.mlp3(x1)*y_std + y_m
    pred2 = model.mlp3(x2)*y_std + y_m
    pred3 = model.mlp3(x3)*y_std + y_m

    # draw ellipse patch centered at mu_i with axes given by np.exp(logsig_i)
    ellipse = mpatch.Ellipse(mu_i[0],
                            width=2*np.exp(logsig_i.cpu().detach().numpy()[0,0]),
                            height=2*np.exp(logsig_i.cpu().detach().numpy()[0,1]),
                            fc='None', edgecolor='black')
    plt.gca().add_patch(ellipse)
    ellipse = mpatch.Ellipse(mu_f[0],
                            width=2*np.exp(logsig_f.cpu().detach().numpy()[0,0]),
                            height=2*np.exp(logsig_f.cpu().detach().numpy()[0,1]), fc='None', edgecolor='tab:blue',
                            )
    plt.gca().add_patch(ellipse)
    ellipse = mpatch.Ellipse(mu_f2[0],
                            width=2*np.exp(logsig_f2.cpu().detach().numpy()[0,0]),
                            height=2*np.exp(logsig_f2.cpu().detach().numpy()[0,1]), fc='None', edgecolor='tab:orange',
                            )
    plt.gca().add_patch(ellipse)

    # plot samples
    plt.scatter(z_i[:, 0].cpu().detach().numpy(), z_i[:, 1].cpu().detach().numpy(), s=1, label='latent space samples', c='black')
    plt.scatter(z_f[:, 0].cpu().detach().numpy(), z_f[:, 1].cpu().detach().numpy(), s=1, label=f'latent space samples, conditioned on {bet}', c='tab:blue')
    plt.scatter(z_f2[:, 0].cpu().detach().numpy(), z_f2[:, 1].cpu().detach().numpy(), s=1, label=f'latent space samples, conditioned on {-bet}', c='tab:orange')

    # plot means
    plt.scatter(mu_i[:, 0].cpu().detach().numpy(), mu_i[:, 1].cpu().detach().numpy(), s=50, label='latent space means', c='black', marker='x')
    plt.scatter(mu_f[:, 0].cpu().detach().numpy(), mu_f[:, 1].cpu().detach().numpy(), s=50, label=f'latent space means, conditioned on {bet}', c='tab:blue', marker='x')
    plt.scatter(mu_f2[:, 0].cpu().detach().numpy(), mu_f2[:, 1].cpu().detach().numpy(), s=50, label=f'latent space means, conditioned on {-bet}', c='tab:orange', marker='x')

    # contourplot of z1, z2 vs pred
    z_temp = z_i.cpu().detach().numpy()
    Z1, Z2 = np.meshgrid(
        np.linspace(-3, 3, 100),
        np.linspace(-3, 3, 100))
    Z1 = torch.tensor(Z1).float().to(device).reshape(-1, 1)
    Z2 = torch.tensor(Z2).float().to(device).reshape(-1, 1)
    print(Z1.shape)
    x_temp = ((torch.Tensor([[bet]]*len(Z1))-x_m)/x_std).to(device)
    pred = model.mlp3(torch.cat((x_temp, Z1, Z2), dim=-1))*y_std + y_m
    pred = pred.cpu().detach().numpy()

    plt.contourf(Z1.cpu().detach().numpy().reshape(100, 100),
                Z2.cpu().detach().numpy().reshape(100, 100),
                pred.reshape(100, 100),
                levels=[-100,-75,-50,-25,0,25,50,75,100], cmap='viridis', alpha=0.5, zorder=-10)
    plt.colorbar()
    plt.legend()
    plt.gca().set_aspect('equal')
    plt.xlabel('$z_1$')
    plt.ylabel('$z_2$')
plt.show()

# Calculate Wasserstein distance

In [None]:
from scipy.stats import wasserstein_distance

model.eval()
for batch in test_loader:
    batch = batch.clone().to(device)

    preds = np.zeros((len(batch), 100))  # create 100 predictions per data point
    print(preds.shape)

    for i in range(100):  # make 100 predictions for each data point
        pred = model(batch[:, [0]], batch[:, [1]], cond_on_final=False)[-1]
        preds[:, i] = pred.cpu().detach().numpy()[:, 0]*y_std+y_m

    print(preds.shape)

    reals = batch[:, 1:].cpu().detach().numpy()*y_std + y_m

    print(reals.shape)

    wd = []
    for pred, real in zip(preds, reals):  # iterate over all data points (as far as I know, this cannot be batched)
        wd.append(wasserstein_distance(pred, real))

print('Mean Wasserstein distance between real and predicted:', np.mean(wd))