In [28]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [29]:
import os
import re
from pathlib import Path

import torch
from tqdm import tqdm

from model import ReachabilityClassifierLSTM as Model
from dataset import SingleDataset
import plotly.graph_objects as go

#val in name
dataset = SingleDataset(Path(os.getcwd()).parent / 'data' / 'test', 100000, False, True)

points = dataset.data[:, 1:4].cpu()
true_labels = dataset.data[:, 0].cpu() != -1

In [32]:
model_id = 633
model_dir = Path(os.getcwd()).parent / "trained_models"
pattern = rf"{re.escape("reachability_classifier")}_[a-z]+-[a-z]+-{model_id}"
folder = next((f for f in model_dir.iterdir() if re.match(pattern, f.name)), None)

In [33]:
#model = Model(160, {"nhead": 8, "dim_feedforward": 640, "dropout": 0.1}, 4, 160,
#              {"n_heads": 8, "mlp_dim": 640, "dropout": 0.1}, 4).to("cuda")
model = Model(
         encoder_config={'width': 512, 'depth': 1},
         decoder_config={'width': 128, 'depth': 4},
         output_func=torch.nn.Sigmoid()
    ).to("cuda")
model.load_state_dict(torch.load(next(Path(model_dir / folder).glob('*.pth')), map_location="cuda"))

<All keys matched successfully>

In [34]:
model.eval()
pred_labels = torch.empty((0,), dtype=torch.long).to("cuda")
true_positives = 0
true_negatives = 0
false_positives = 0
false_negatives = 0
for i, (morphs, poses, labels) in enumerate(tqdm(dataset, desc=f"Evaluation")):
    morphs = morphs.to("cuda", non_blocking=True)
    poses = poses.to("cuda", non_blocking=True)
    labels = labels.to("cuda", non_blocking=True)
    with torch.no_grad():
        pred = model(poses, morphs)
    pred_label = pred > 0.5
    true_label = labels != -1
    true_positives += (pred_label & true_label).sum().item()
    true_negatives += (~pred_label & ~true_label).sum().item()
    false_positives += (pred_label & ~true_label).sum().item()
    false_negatives += (~pred_label & true_label).sum().item()
    pred_labels = torch.cat([pred_labels, pred_label], dim=0)
pred_labels = pred_labels.cpu().squeeze(dim=1)
true_positives /= dataset.data.shape[0]
true_negatives /= dataset.data.shape[0]
false_positives /= dataset.data.shape[0]
false_negatives /= dataset.data.shape[0]
true_positives, true_negatives, false_positives, false_negatives

Evaluation: 101it [00:12,  8.29it/s]                         


(0.47898140213907137,
 0.45470758354755786,
 0.04529241645244216,
 0.021018597860928624)

In [35]:
unique_points, inverse = torch.unique(points, dim=0, return_inverse=True)

unique_pred = torch.zeros(unique_points.shape[0], dtype=torch.bool)
unique_labels = torch.zeros(unique_points.shape[0], dtype=torch.bool)
for i in range(len(unique_points)):
    mask = (inverse == i)
    unique_pred[i] = pred_labels[mask].any()
    unique_labels[i] = true_labels[mask].any()

In [36]:
true_positive = unique_labels & unique_pred
true_negative = ~unique_labels & ~unique_pred
false_positive = unique_labels & ~unique_pred
false_negative = ~unique_labels & unique_pred

In [37]:
(true_positive.sum().item() / unique_points.shape[0],
true_negative.sum().item() / unique_points.shape[0],
false_positive.sum().item() / unique_points.shape[0],
false_negative.sum().item() / unique_points.shape[0])

(0.16998714652956298,
 0.6638817480719794,
 0.002892030848329049,
 0.16323907455012854)

In [38]:
# Create scatter traces
traces = []

tp_points = unique_points[true_positive]
traces.append(go.Scatter3d(
    x=tp_points[:, 0],
    y=tp_points[:, 1],
    z=tp_points[:, 2],
    mode='markers',
    marker=dict(size=3, color='white', line=dict(color='gray', width=1)),
    name='True Positives',
    opacity=0.25
))

tn_points = unique_points[true_negative]
traces.append(go.Scatter3d(
    x=tn_points[:, 0],
    y=tn_points[:, 1],
    z=tn_points[:, 2],
    mode='markers',
    marker=dict(size=3, color='black'),
    name='True negatives',
    opacity=0.25
))

fp_points = unique_points[false_positive]
traces.append(go.Scatter3d(
    x=fp_points[:, 0],
    y=fp_points[:, 1],
    z=fp_points[:, 2],
    mode='markers',
    marker=dict(size=3, color='red'),
    name='False Positive',
    opacity=0.9
))

fn_points = unique_points[false_negative]
traces.append(go.Scatter3d(
    x=fn_points[:, 0],
    y=fn_points[:, 1],
    z=fn_points[:, 2],
    mode='markers',
    marker=dict(size=3, color='blue'),
    name='False Negative',
    opacity=0.9
))

# Combine and render figure
fig = go.Figure(data=traces)
fig.update_layout(
    scene=dict(
        xaxis=dict(
            title='X',
            showgrid=False,
            showbackground=False,
            zeroline=False,
            visible=False
        ),
        yaxis=dict(
            title='Y',
            showgrid=False,
            showbackground=False,
            zeroline=False,
            visible=False
        ),
        zaxis=dict(
            title='Z',
            showgrid=False,
            showbackground=False,
            zeroline=False,
            visible=False
        ),
        aspectmode='cube'
    ),
    legend=dict(itemsizing='constant'), width=1000, height=1000,
    paper_bgcolor='white'
)
fig.show()

In [39]:
from scipy.interpolate import RegularGridInterpolator

# Create a 3D grid
x_unique = torch.unique(unique_points[:, 0]).sort()[0]
y_unique = torch.unique(unique_points[:, 1]).sort()[0]
z_unique = torch.unique(unique_points[:, 2]).sort()[0]
nx, ny, nz = len(x_unique), len(y_unique), len(z_unique)
# Create separate grids for true and false samples
tp_grid = torch.zeros((nx, ny, nz), device=unique_points.device)
tn_grid = torch.zeros((nx, ny, nz), device=unique_points.device)
fp_grid = torch.zeros((nx, ny, nz), device=unique_points.device)
fn_grid = torch.zeros((nx, ny, nz), device=unique_points.device)

for i in range(len(unique_points)):
    point = unique_points[i]
    ix = torch.searchsorted(x_unique, point[0])
    iy = torch.searchsorted(y_unique, point[1])
    iz = torch.searchsorted(z_unique, point[2])

    if true_positive[i]:
        tp_grid[ix, iy, iz] = 1
    elif true_negative[i]:
        tn_grid[ix, iy, iz] = 1
    elif false_positive[i]:
        fp_grid[ix, iy, iz] = 1
    elif false_negative[i]:
        fn_grid[ix, iy, iz] = 1

# Create figure
fig = go.Figure()

# Create finer grid for interpolation
interp_factor = 1
x_fine = torch.linspace(x_unique.min(), x_unique.max(), nx * interp_factor)
y_fine = torch.linspace(y_unique.min(), y_unique.max(), ny * interp_factor)
z_fine = torch.linspace(z_unique.min(), z_unique.max(), nz * interp_factor)

for grid, name, color in zip([tp_grid, tn_grid, fp_grid, fn_grid], ["True Positive", "True Negative", "False Positive", "False Negative"], ["white", "black", "red", "blue"]):
    # Interpolate true samples
    interpolator = RegularGridInterpolator((x_unique, y_unique, z_unique), grid.numpy(), method='linear', bounds_error=False, fill_value=0)
    X_fine, Y_fine, Z_fine = torch.meshgrid(x_fine, y_fine, z_fine, indexing='ij')
    points_fine = torch.stack([X_fine.ravel(), Y_fine.ravel(), Z_fine.ravel()], dim=-1)
    values = interpolator(points_fine).reshape(X_fine.shape)

    # Add dummy traces for legend
    fig.add_trace(go.Scatter3d(
        x=[None], y=[None], z=[None],
        mode='markers',
        marker=dict(size=10, color=color, line=dict(color='gray', width=1)) if color=="white" else dict(size=10, color=color),
        name=name,
        showlegend=True
    ))

    # Add true samples with low opacity
    fig.add_trace(go.Volume(
        x=X_fine.flatten(),
        y=Y_fine.flatten(),
        z=Z_fine.flatten(),
        value=values.flatten(),
        isomin=0.1,
        isomax=1,
        opacity=0.025 if "True" in name else 0.2,
        surface_count=20,
        colorscale=[
            [0, 'rgba(0,0,0,0)'],
            [1, color]
        ],
        showlegend=False,
        showscale=False,
        caps=dict(x_show=False, y_show=False, z_show=False),
    ))

fig.update_layout(
    scene=dict(
        xaxis=dict(
            title='X',
            showgrid=False,
            showbackground=False,
            zeroline=False,
            visible=False
        ),
        yaxis=dict(
            title='Y',
            showgrid=False,
            showbackground=False,
            zeroline=False,
            visible=False
        ),
        zaxis=dict(
            title='Z',
            showgrid=False,
            showbackground=False,
            zeroline=False,
            visible=False
        ),
        aspectmode='cube'
    ),
    legend=dict(
        yanchor="top",
        y=0.99,
        xanchor="left",
        x=0.01
    ),
    width=1000,
    height=1000
)

fig.show()

In [40]:
# Create scatter traces
traces = []

tp_points = unique_points[unique_labels]
traces.append(go.Scatter3d(
    x=tp_points[:, 0],
    y=tp_points[:, 1],
    z=tp_points[:, 2],
    mode='markers',
    marker=dict(size=3, color='red', line=dict(color='gray', width=1)),
    name='Positives',
    opacity=0.6
))

tn_points = unique_points[~unique_labels]
traces.append(go.Scatter3d(
    x=tn_points[:, 0],
    y=tn_points[:, 1],
    z=tn_points[:, 2],
    mode='markers',
    marker=dict(size=3, color='blue'),
    name='negatives',
    opacity=0.6
))


# Combine and render figure
fig = go.Figure(data=traces)
fig.update_layout(
    scene=dict(
        xaxis=dict(
            title='X',
            showgrid=False,
            showbackground=False,
            zeroline=False,
            visible=True
        ),
        yaxis=dict(
            title='Y',
            showgrid=False,
            showbackground=False,
            zeroline=False,
            visible=True
        ),
        zaxis=dict(
            title='Z',
            showgrid=False,
            showbackground=False,
            zeroline=False,
            visible=True
        ),
        aspectmode='cube'
    ),
    legend=dict(itemsizing='constant'), width=1000, height=1000,
    paper_bgcolor='white'
)
fig.show()