In [None]:
import torch
import numpy as np
from tqdm import tqdm
from model import SetTransformer
from input_pipeline import get_random_datasets
from train import get_parameters, compute_groundtruth, LogLikelihood

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
DEVICE = torch.device('cuda:0')
CHECKPOINT = 'models/run00_iteration_20000.pth'
BATCH_SIZE = 1024
NUM_BATCHES = 10000
K = 4
MIN_SIZE = 100
MAX_SIZE = 500

# Load the model

In [None]:
model = SetTransformer(in_dimension=2, out_dimension=5 * K)
model = model.eval().to(DEVICE)
model.load_state_dict(torch.load(CHECKPOINT, map_location=DEVICE))

# Evaluate the model

In [None]:
criterion = LogLikelihood()

loss = 0.0
true_loss = 0.0

for _ in tqdm(range(NUM_BATCHES)):

    x, params = get_random_datasets(BATCH_SIZE, K, MIN_SIZE, MAX_SIZE)
    x = x.to(DEVICE)

    with torch.no_grad():

        y = model(x)
        means, variances, pis = get_parameters(y)
        z = criterion(x, means, variances, pis)
        loss += z.item()

        z = compute_groundtruth(x, params, criterion)
        true_loss += z.item()


loss /= NUM_BATCHES
true_loss /= NUM_BATCHES

In [None]:
print(f'{round(loss, 5)}, {round(true_loss, 5)}')
# my output is 1.47324, 1.47486

# Visualize predictions

In [None]:
data, params = get_random_datasets(1, K, MIN_SIZE, MAX_SIZE)
params = {k: v[0].numpy() for k, v in params.items()}

with torch.no_grad():
    x = data.to(DEVICE)
    y = model(x).cpu()
    y = get_parameters(y)
    y = [x[0].numpy() for x in y]
    means, variances, pis = y
    data = data[0].numpy()

print('true probabilities of belonging to different clusters:')
print(params['pis'], '\n')

print('predicted probabilities of belonging to different clusters:')
print(pis, '\n')

print('true std of gaussians:')
print(np.sqrt(params['variances']), '\n')

print('predicted std of gaussians:')
print(np.sqrt(variances))

x = data[:, 0]
y = data[:, 1]
plt.scatter(x, y, alpha=0.3)

# true means
x = params['means'][:, 0]
y = params['means'][:, 1]
plt.scatter(x, y, c='r');

# predicted means
x = means[:, 0]
y = means[:, 1]
plt.scatter(x, y, c='b');