In [None]:
import torch
from model import EdgePredictorGNN
from prepare_data import train_loader, val_loader

In [None]:
import numpy as np

for batch in train_loader:
    data = batch
    break
#torch.cat((val_preds[0:100], val_targets[0:100]), dim=1)
row, col = data.edge_index
test_x = np.linspace(0, 1001, 1000)
test_x[row]

In [None]:
best_model = 4  # Change this to load a different epoch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
in_channels = 2
hidden_channels = 64
edge_out_channels = 1

# Load model
model = EdgePredictorGNN(in_channels, hidden_channels, edge_out_channels).to(device)
model.load_state_dict(torch.load(f"model_epoch_{best_model}.pt", map_location=device))
model.eval()

def predict(loader, name):
	all_preds = []
	all_targets = []
	with torch.no_grad():
		for batch in loader:
			batch = batch.to(device)
			pred = model(batch)
			all_preds.append(pred.cpu())
			all_targets.append(batch.edge_attr.cpu())
	preds = torch.cat(all_preds, dim=0)
	targets = torch.cat(all_targets, dim=0)
	mse = torch.mean((preds - targets) ** 2).item()
	print(f"{name} set: MSE = {mse:.6f}")
	return preds, targets

print("Evaluating model...")
train_preds, train_targets = predict(train_loader, "Train")
val_preds, val_targets = predict(val_loader, "Validation")


In [None]:
import matplotlib.pyplot as plt

plt.scatter(train_targets[2::9], train_preds[2::9], s=0.005)
plt.show()

plt.scatter(val_targets[2::9], val_preds[2::9], s=0.005)
plt.show()

In [None]:
# 9 subplots: scatter for each edge type (assuming 9 edge types)
fig, axes = plt.subplots(3, 3, figsize=(15, 12))
for i in range(9):
    row = i // 3
    col = i % 3
    axes[row, col].scatter(val_targets[:, i], val_preds[:, i], s=0.5, alpha=0.5)
    axes[row, col].set_title(f'Edge {i}')
    axes[row, col].set_xlabel('Target')
    axes[row, col].set_ylabel('Prediction')
plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Flatten to 1D for overall distribution
val_targets_flat = val_targets.cpu().numpy().flatten()
val_preds_flat = val_preds.cpu().numpy().flatten()

plt.figure(figsize=(8, 4))
plt.hist(val_targets_flat, bins=50, alpha=0.5, label='Validation Targets')
plt.hist(val_preds_flat, bins=50, alpha=0.5, label='Validation Predictions')
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.title('Distribution of Validation Targets and Predictions')
plt.legend()
plt.show()

In [None]:
# Histogram for train_preds and train_targets
targets_flat = train_targets.cpu().numpy().flatten()
preds_flat = train_preds.cpu().numpy().flatten()

plt.figure(figsize=(8, 4))
plt.hist(targets_flat, bins=50, alpha=0.5, label='Train Targets')
plt.hist(preds_flat, bins=50, alpha=0.5, label='Train Predictions')
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.title('Distribution of Train Targets and Predictions')
plt.legend()
plt.show()