In [2]:
import sys
import os

sys.path.append('..')
from data_preprocessing.TrajectoryDatasetLoader import TrajectoryDatasetLoader
from torch_geometric_temporal.signal import temporal_signal_split
import torch
import numpy as np
from tqdm import tqdm
import pandas as pd
import json
from sklearn.metrics import precision_recall_curve

from folium.plugins import DualMap

import folium
import matplotlib.pyplot as plt

import torch.nn.functional as F
from torch_geometric_temporal.nn.recurrent import A3TGCN, DCRNN, TGCN, GConvGRU
from branca.colormap import linear

from IPython.display import display
from sklearn.metrics import auc
import h3
from geopy.distance import geodesic


%load_ext autoreload
%autoreload 2

In [3]:
regions = pd.read_pickle("../data/regions_enhanced.pkl")
#visualize_df = regions.copy()
#visualize_df.head()

In [4]:
# somewhat buggy
def visualize_ab(label, prediction, ids_to_hex_names):

    label_t = np.transpose(label, (1, 2, 0)).squeeze()
    prediction_t = np.transpose(prediction, (1, 2, 0)).squeeze()


    regions = pd.read_pickle("../data/regions_enhanced.pkl")
    visualize_df = regions.copy()

    for i in range(label_t.shape[0]):
        visualize_df[f"label{i}"] = {ids_to_hex_names[i]: int(v) for i,v in enumerate(label_t[i,:].int()) }

    for i in range(label_t.shape[0]):
        visualize_df[f"prediction{i}"] = {ids_to_hex_names[i]: int(v) for i,v in enumerate(prediction_t[i,:].int()) }


    cmap = linear.Blues_09.scale(0, 1)
    dual_m = DualMap(location=(48.12, 11.58), tiles=None, zoom_start=10)

    folium.TileLayer("openstreetmap").add_to(dual_m)

    visualize_df.explore(name="label0",     m=dual_m, cmap=cmap, column="label0",       style_kwds = {"fillOpacity" : 0.2})
    visualize_df.explore(name="label1",     m=dual_m, cmap=cmap, column="label1",       style_kwds = {"fillOpacity" : 0.2})
    visualize_df.explore(name="label2",     m=dual_m, cmap=cmap, column="label2",       style_kwds = {"fillOpacity" : 0.2})

    visualize_df.explore(name="prediction0", m=dual_m, cmap=cmap, column="prediction0", style_kwds = {"fillOpacity" : 0.2})
    visualize_df.explore(name="prediction1", m=dual_m, cmap=cmap, column="prediction1", style_kwds = {"fillOpacity" : 0.2})
    visualize_df.explore(name="prediction2", m=dual_m, cmap=cmap, column="prediction2", style_kwds = {"fillOpacity" : 0.2})

    folium.LayerControl().add_to(dual_m)
    return dual_m


def visualize_y(y, idx_to_node_ids):
    # input: nodes, feature, timesteps
    # output: feature, timestep, nodes -> timestep, nodes
    y_t = np.transpose(y, (1, 2, 0)).squeeze()
    
    occurance_list = [[ i for i,a in enumerate(ab) if a==1 ]for ab in y_t]
    occurance_list = [j for sub in occurance_list for j in sub]
    occurance_dict = {v:0 for v in idx_to_node_ids.values()}
    
    for o in occurance_list:
        occurance_dict[idx_to_node_ids[o]] += 1
    
    regions = pd.read_pickle("../data/regions_enhanced.pkl")
    visualize_df = regions.copy()
    visualize_df["visited"] = pd.Series(occurance_dict)    
    cmap = linear.Blues_09.scale(0, max(occurance_dict.values()))
    
    return visualize_df.explore("visited", cmap=cmap)

def visualize_y_yhat(y_hat,y,idx_to_node_ids):
    regions = pd.read_pickle("../data/regions_enhanced.pkl")
    visualize_df = regions.copy()
    visualize_df.head()

    label_t = np.transpose(y.detach(), (1, 2, 0)).squeeze()
    prediction_t = (np.transpose(y_hat.detach(), (1, 2, 0)).squeeze())

    label_and_prediction_dict = {"label": label_t, "prediction": prediction_t}
    for k,v in label_and_prediction_dict.items():

        occurance_list = [[ i for i,a in enumerate(ab) if a==1 ]for ab in v]
        occurance_list = [j for sub in occurance_list for j in sub]
        occurance_dict = {v:0 for v in idx_to_node_ids.values()}

        for o in occurance_list:
            occurance_dict[idx_to_node_ids[o]] += 1
        
        visualize_df[f"visited_{k}"] = occurance_dict
        

    m = DualMap(location=(48.12, 11.58), tiles=None, zoom_start=10)
    folium.TileLayer("openstreetmap").add_to(m.m1)
    folium.TileLayer("openstreetmap").add_to(m.m2)

    cmap1 = linear.Blues_09.scale(0, max(visualize_df['visited_label']))
    cmap2 = linear.Blues_09.scale(0, max(visualize_df['visited_prediction']))

    t1 = folium.GeoJsonTooltip(fields=["visited_label"])
    t2 = folium.GeoJsonTooltip(fields=["visited_prediction"])

    gj1 = folium.GeoJson(visualize_df, tooltip=t1, style_function= lambda x: {"fillColor" : cmap1(x["properties"]["visited_label"]), "fillOpacity": 0.5,"color": "white", "weight": 1})
    gj2 = folium.GeoJson(visualize_df, tooltip=t2, style_function= lambda x: {"fillColor" : cmap2(x["properties"]["visited_prediction"]), "fillOpacity": 0.5, "color": "white", "weight": 1})

    gj1.add_to(m.m1)
    gj2.add_to(m.m2)

    return m

class MyTransform:
    def __init__(self, dataset, a,b) -> None:
        self.mi = np.min(dataset.features, axis=(0,1,-1))
        self.mx = np.max(dataset.features, axis=(0,1,-1))
        
        self.a = a
        self.b = b
        
    def min_max_scale(self, x):
        mit = np.tile(self.mi, (x.shape[-1],1)).T
        mxt = np.tile(self.mx, (x.shape[-1],1)).T
        x_prime = self.a + (x- mit)*(self.b -self.a) / (mxt -mit)
        return x_prime
    
    def inv_min_max_scale(self,x_prime):
        mit = np.tile(self.mi, (x_prime.shape[-1],1)).T
        mxt = np.tile(self.mx, (x_prime.shape[-1],1)).T

        x = ((x_prime - self.a) * (mxt - mit))/(self.b - self.a) + mit
        return x

In [5]:
def intersection_over_union(prediction, label):
    my_lab = label.squeeze(1)
    my_pred = prediction.squeeze(1)
    my_lab_sum = torch.sum(my_lab, dim=1)
    my_pred_sum = torch.sum(my_pred, dim=1)
    maximum = torch.max(my_lab_sum, my_pred_sum)
    minimum = torch.min(my_lab_sum, my_pred_sum)
    iou =  minimum.sum() / maximum.sum()
    return iou

def abs_visited_difference(prediciton, label):
    approx_travel_dist_label = torch.sum(label)
    approx_travel_dist_pred = torch.sum(prediciton)
    return torch.abs(approx_travel_dist_label-approx_travel_dist_pred)

def load_from_checkpoint(path, model, optimizer):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']

    return model, optimizer, epoch, loss

def get_weight_rebal(train_dataset):
    y = next(train_dataset).y
    num_elem = sum([snapshot.y.numel() for snapshot in train_dataset])
    num_ones = sum([snapshot.y.sum() for snapshot in train_dataset])
    factor = (num_elem-num_ones)/(num_ones)
    weight_rebal = torch.ones_like(y) / factor + (1.0 - 1.0/factor)*y
    weight_rebal.shape

def get_threshold_and_fscore(labels, predictions):
    labels_arr_flatt = torch.stack(labels).flatten().detach().numpy()
    predictions_arr_flatt = torch.stack(predictions).flatten().detach().numpy()
    
    precision, recall, thresholds =precision_recall_curve(labels_arr_flatt, predictions_arr_flatt)
    # convert to f score
    epsilon = 1e-10*np.ones_like(precision)
    fscore = (2 * precision * recall) / (precision + recall +epsilon)
    # locate the index of the largest f score
    ix = np.argmax(fscore)
    best_threshold =  thresholds[ix]
    best_fscore = fscore[ix]
    auc_score = auc(recall, precision)
    return best_threshold, best_fscore, auc_score

def get_iou(labelss, predictionss):
    predictionss_visits_per_node =torch.sum(predictionss, dim=(0,2,3))
    labelss_visits_per_node =torch.sum(labelss, dim=(0,2,3))

    max_per_node = torch.max(predictionss_visits_per_node, labelss_visits_per_node)
    min_per_node = torch.min(predictionss_visits_per_node, labelss_visits_per_node)
    iou = torch.sum(max_per_node) / torch.sum(min_per_node)
    return iou

def create_json_file(filename):
    # Create an empty JSON file
    with open(filename, 'w') as f:
        json.dump({}, f)

def add_data_to_json(filename, new_data):
    # Load existing JSON data
    with open(filename, 'r') as f:
        data = json.load(f)
    
    # Add new data
    data.update(new_data)
    
    # Write updated data back to the file
    with open(filename, 'w') as f:
        json.dump(data, f)

def get_normalized_visited_difference(labelss, predictionss):
    return (torch.sum(predictionss) -torch.sum(labelss))  / torch.sum(labelss)

In [6]:
x_features = ["hex_work","hex_errand", "hex_leisure","activity_work","activity_errand","activity_leisure", "visited"]
y_features = ["visited"]

dataset_name = "dataset_6trajspp"
loader = TrajectoryDatasetLoader(f"../data/{dataset_name}.json", x_features, y_features, 1000)
dataset = loader.get_dataset()

# 33799
print("Dataset type:  ", dataset)
print("Number of samples / sequences: ",  dataset.snapshot_count)

Dataset type:   <torch_geometric_temporal.signal.static_graph_temporal_signal.StaticGraphTemporalSignal object at 0x108b38e80>
Number of samples / sequences:  1000


In [7]:
train_set, test_set = temporal_signal_split(dataset, train_ratio=0.8)

train_transform = MyTransform(train_set, 0,1)
test_transform = MyTransform(test_set, 0,1)

print("Number of train examples: ", train_set.snapshot_count)
print("Number of test examples: ", test_set.snapshot_count)

Number of train examples:  800
Number of test examples:  200


In [9]:
example = next(iter(train_set))
x_prime = train_transform.min_max_scale(example.x)
display(visualize_y(example.y, loader.ids_to_hex_names))

In [39]:
def ellipsoidal_distance(p1, p2) -> float:
    """ Calculate distance (in meters) between p1 and p2, where 
    each point is represented as a tuple (lat, lon) """
    return geodesic(p1, p2).meters

def compute_distance_matrix(df_sites, dist_metric=ellipsoidal_distance):
    """ Creates an N x N distance matrix from a dataframe of N locations 
    with a latitute column and a longitude column """
    df_dist_matrix = pd.DataFrame(index=df_sites.index, columns=df_sites.index)

    for orig, orig_loc in df_sites.iterrows():  # for each origin
        for dest, dest_loc in df_sites.iterrows():  # for each destination
            df_dist_matrix.at[orig, dest] = dist_metric(orig_loc, dest_loc)
    return df_dist_matrix

regions["lat"] = regions.index.to_series().apply(lambda x: h3.cell_to_latlng(x)[0])
regions["lon"] = regions.index.to_series().apply(lambda x: h3.cell_to_latlng(x)[1])
distances_df = compute_distance_matrix(regions[["lat", "lon"]])

In [55]:
from pyproj import Transformer
from shapely import MultiPoint

In [62]:
# radius of gyration
# sum up visited points
# get 

def radius_of_gyration(y,ids_to_hex_names, distances_df):
    y = example.y.squeeze(1)
    ids_to_hex_names = loader.ids_to_hex_names
    # trip, nodes
    
    trips = [[] for _ in range(y.shape[1])]
    for i in range(y.shape[0]):
        for j in range(y.shape[1]):
            if y[i,j]:
                trips[j].append(ids_to_hex_names[i])

    trip = trips[0]
    source_dest_hexes = []

    # get distances
    # select 2 nodex with max distance
    for trip in trips:
        max_dist_pair = None, None
        max_dist = -1
        for hex_a in trip:
            for hex_b in trip:
                dist = distances_df.at[hex_a,hex_b]
                if max_dist < dist:
                    max_dist = dist
                    max_dist_pair = hex_a, hex_b
        source_dest_hexes.append(max_dist_pair[0])
        source_dest_hexes.append(max_dist_pair[1])

    transformer = Transformer.from_crs("EPSG:4326", "EPSG:3857")
    source_dest_points_t = ([transformer.transform(*h3.cell_to_latlng(hex)) for hex in source_dest_hexes])
    points_t = MultiPoint(source_dest_points_t)
    c = points_t.centroid

    k = len(source_dest_points_t)
    rg = np.sqrt(1/k * np.sum([c.distance(p)**2 for p in points_t.geoms]))

    return rg

5973.6226594929

In [63]:
# jump size
def jump_sizes(y,ids_to_hex_names, distances_df):

    # trip, nodes
    
    trips = [[] for _ in range(y.shape[1])]
    for i in range(y.shape[0]):
        for j in range(y.shape[1]):
            if y[i,j]:
                trips[j].append(ids_to_hex_names[i])

    # get distances
    # select 2 nodex with max distance
    jump_sizes = []
    for trip in trips:
        max_dist = -1
        for hex_a in trip:
            for hex_b in trip:
                dist = distances_df.at[hex_a,hex_b]
                if max_dist < dist:
                    max_dist = dist
        
        jump_sizes.append(max_dist)
    return jump_sizes

y = example.y.squeeze(1)
ids_to_hex_names = loader.ids_to_hex_names

jump_sizes(y,ids_to_hex_names, distances_df)

[2469.4034620044667, 935.8012463085796, 9251.840579436606]

In [67]:
from collections import Counter

In [71]:
def top_k_frequent_elements(lst, k):
    counts = Counter(lst)
    top_k = counts.most_common(k)
    return [element for element, _ in top_k]


def k_radius_of_gyration(y,ids_to_hex_names, distances_df,k=2):
    y = example.y.squeeze(1)
    ids_to_hex_names = loader.ids_to_hex_names
    # trip, nodes
    
    trips = [[] for _ in range(y.shape[1])]
    for i in range(y.shape[0]):
        for j in range(y.shape[1]):
            if y[i,j]:
                trips[j].append(ids_to_hex_names[i])

    trip = trips[0]
    source_dest_hexes = []

    # get distances
    # select 2 nodex with max distance
    for trip in trips:
        max_dist_pair = None, None
        max_dist = -1
        for hex_a in trip:
            for hex_b in trip:
                dist = distances_df.at[hex_a,hex_b]
                if max_dist < dist:
                    max_dist = dist
                    max_dist_pair = hex_a, hex_b
        source_dest_hexes.append(max_dist_pair[0])
        source_dest_hexes.append(max_dist_pair[1])

    
    top_k_source_dest_hexes = top_k_frequent_elements(source_dest_hexes, k)
    transformer = Transformer.from_crs("EPSG:4326", "EPSG:3857")
    source_dest_points_t = ([transformer.transform(*h3.cell_to_latlng(hex)) for hex in top_k_source_dest_hexes])
    points_t = MultiPoint(source_dest_points_t)
    c = points_t.centroid

    k = len(top_k_source_dest_hexes)
    rg = np.sqrt(1/k * np.sum([c.distance(p)**2 for p in points_t.geoms]))
    return rg   

y = example.y.squeeze(1)
ids_to_hex_names = loader.ids_to_hex_names
k_radius_of_gyration(y,ids_to_hex_names, distances_df, 2)

1847.5235559695147

In [None]:
print(example.edge_attr)

In [None]:
# https://www.kdnuggets.com/2021/09/imbalanced-classification-without-re-balancing-data.html
# https://discuss.pytorch.org/t/how-to-apply-a-weighted-bce-loss-to-an-imbalanced-dataset-what-will-the-weight-tensor-contain/56823
# https://datascience.stackexchange.com/questions/58735/weighted-binary-cross-entropy-loss-keras-implementation

In [None]:
#model = TemporalGNN(node_features=len(x_features), periods=periods).to(device)
#optimizer = torch.optim.Adam(model.parameters(), lr=lr)
#loss_fn = torch.nn.BCELoss(weight_rebal)

In [None]:
class TemporalGNN(torch.nn.Module):
    def __init__(self, node_features, periods, config):
        super(TemporalGNN, self).__init__()
        self.config = config

        # Attention Temporal Graph Convolutional Cell
        # checked, weight
        if config["rgcl"] ==  "A3TGCN":
            self.tgnn = A3TGCN(node_features,config["out_channels"],periods)
        # checked, weight
        elif config["rgcl"] == "DCRNN":
            self.tgnn =  DCRNN(node_features,config["out_channels"],1)
        # checked, no weight
        elif config["rgcl"] == "TGCN":
            self.tgnn = TGCN(node_features,config["out_channels"],True)
        # checked, weight
        elif config["rgcl"] == "GConvGRU":
            self.tgnn = GConvGRU(node_features,config["out_channels"],1)
        else:
            print("error")
        
        # Equals single-shot prediction
        self.linear1 = torch.nn.Linear(config["out_channels"], config["linear_dim"])
        self.linear2 = torch.nn.Linear(config["linear_dim"],periods)

    def forward(self, x, edge_index):
        """
        x = Node features for T time steps
        edge_index = Graph edge indices
        """
        h = None
        
        if self.config["rgcl"] in ["DCRNN", "TGCN", "GConvGRU"]:            
            h = self.tgnn(x[:,:,0], edge_index)
            for i in range(1,x.shape[-1]):
                h = self.tgnn(x[:,:,i], edge_index, H=h)
        else:
            h = F.relu(self.tgnn(x, edge_index))
        
        h = F.relu(self.linear1(h))
        h = F.sigmoid(self.linear2(h))
        return h



In [None]:
def custom_bce_loss(predictions, targets, distances):
    """
    Custom implementation of Binary Cross Entropy (BCE) loss function.

    Args:
    - predictions: Tensor of predicted values (0 to 1).
    - targets: Tensor of target binary values (0 or 1).

    Returns:
    - loss: BCE loss value.
    """
    predictions = predictions.squeeze(1).double()
    targets = targets.squeeze(1).double()

    epsilon = 1e-7  # to avoid division by zero
    loss = -torch.mean(torch.matmul(distances ,(targets * torch.log(predictions + epsilon) + (1 - targets) * torch.log(1 - predictions + epsilon))))
    return loss

def ellipsoidal_distance(p1, p2) -> float:
    """ Calculate distance (in meters) between p1 and p2, where 
    each point is represented as a tuple (lat, lon) """
    return geodesic(p1, p2).meters

def compute_distance_matrix(df_sites, dist_metric=ellipsoidal_distance):
    """ Creates an N x N distance matrix from a dataframe of N locations 
    with a latitute column and a longitude column """
    df_dist_matrix = pd.DataFrame(index=df_sites.index, columns=df_sites.index)

    for orig, orig_loc in df_sites.iterrows():  # for each origin
        for dest, dest_loc in df_sites.iterrows():  # for each destination
            df_dist_matrix.at[orig, dest] = dist_metric(orig_loc, dest_loc)
    return df_dist_matrix

regions["lat"] = regions.index.to_series().apply(lambda x: h3.cell_to_latlng(x)[0])
regions["lon"] = regions.index.to_series().apply(lambda x: h3.cell_to_latlng(x)[1])
distances_df = compute_distance_matrix(regions[["lat", "lon"]])

distances_df_sorted = distances_df.reindex(index=list(loader.hex_names_to_ids.keys()), columns=list(loader.hex_names_to_ids.keys()))
distances_df_sorted_np = distances_df_sorted.to_numpy()
distances_df_sorted_np.shape

In [None]:
config = {"linear_dim" : 64,
          "out_channels": 64,
          "lr" : 0.1,
          "epochs" : 2,
          "rgcl" : "DCRNN"
          }

periods = loader.countTrajectoriesPerEntry //2
node_features = len(x_features)
model = TemporalGNN(node_features, periods, config).to("cpu")
y_hat = model(x_prime.float().to("cpu"),example.edge_index).unsqueeze(1)
y = example.y

#distances = torch.tensor(distances_df_sorted_np.astype(float))
#distances_normalized = torch.nn.functional.normalize(distances)
#custom_bce_loss(y_hat, y,distances_normalized)

y_hat.shape

In [None]:
def test_loop(test_dataset, model, criterion):
    loss = 0
    device = torch.device('cpu')
    '''
    if torch.cuda.is_available():
        device = "cuda:0"
        if torch.cuda.device_count() > 1:
            net = nn.DataParallel(net)
    '''
    model.to(device)

    # Store for analysis
    predictions = []
    labels = []
    input_trajs = []

    model.eval()
    snapshot = next(iter(test_dataset))
    edge_index = snapshot.edge_index.to(device)

    for snapshot in test_dataset:
        with torch.no_grad():
            x = snapshot.x
            y = snapshot.y.float().to(device)
            x = train_transform.min_max_scale(x).float().to(device)
            y_hat = model(x, edge_index).unsqueeze(1)

            loss = loss + criterion(y_hat, y)
            
            # Store for analysis below
            labels.append(y)
            predictions.append(y_hat)
            input_trajs.append(x[:,6,:])


    loss = loss / test_dataset.snapshot_count
    loss = loss.item()

    print("Test BCE: {:.4f}".format(loss))
    return predictions, labels, input_trajs

def train_loop(config, model_name, train_dataset, criterion, node_features, periods):

    model = TemporalGNN(node_features, periods,config)   
    device = torch.device('cpu')
    
    # running on gpu gives errors
    '''
    if torch.cuda.is_available():
        device = "cuda:0"
        if torch.cuda.device_count() > 1:
            net = nn.DataParallel(net)
    '''

    model.to(device)    
    optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])
    model_paths = []
    # TODO clear checkpoint dir
    model.train()
    snapshot = next(iter(train_dataset))
    edge_index = snapshot.edge_index.to(device)
    for epoch in tqdm(range(config["epochs"])): 
        loss = 0

        for i,snapshot in enumerate(train_dataset):
            x = snapshot.x
            y = snapshot.y.float().to(device)
            x = train_transform.min_max_scale(x).float().to(device)
            y_hat = model(x, edge_index).unsqueeze(1)
            loss = loss + criterion(y_hat, y)
            #loss = loss + custom_bce_loss(y_hat, y,distances_normalized) 

        loss = loss / train_dataset.snapshot_count
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        # tensor board logging

        # here we can add more scalars
        checkpoint_data = {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "loss": loss
        }

        # checkpoint saving
        
        model_dir = os.path.join(os.curdir, "checkpoints",model_name)
        model_path = os.path.join(model_dir,"epoch"+str(epoch)+".pth")
        model_paths.append(model_path)
        os.makedirs(model_dir, exist_ok=True)
        torch.save(checkpoint_data, model_path)
        print("Epoch {} train BCE: {:.4f}".format(epoch, loss.item()))

        # get best model
    
    best_loss = 10000
    best_params = None
    best_model_path = None
    for model_path in model_paths:
        checkpoint = torch.load(model_path)
        if checkpoint["loss"] < best_loss:
            best_loss = checkpoint["loss"]
            best_params = checkpoint["model_state_dict"]
            if best_model_path is not None:
                os.remove(best_model_path)
            best_model_path = model_path
        else:
            os.remove(model_path)
    print("best_model_path: "), best_model_path
    model.load_state_dict(best_params)

    return model



# Test BCE: 0.0796
#model = train_loop(config, model_name, train_set, criterion, node_features, periods)
#predictions, labels, input_trajs= test_loop(test_set, model, criterion)

#save_path = os.path.join(os.curdir, "models", model_name+ ".pth")
#torch.save(model, save_path)
#model = torch.load(save_path)

In [None]:
# experiment
filename = "experiment_results3.json"
criterion = torch.nn.BCELoss()

linear_dims = [256] #[2**i for i in range(6,6+3)]
out_channels = [256] #[2**i for i in range(6,6+3)]
lrs = [0.01] #[10**(-i) for i in range(1, 1+4)]
epochs = [40]
rgcls = ["DCRNN", "TGCN", "GConvGRU"]

create_json_file(filename)
cnt = 0

for rgcl in rgcls:
    for ld in linear_dims:
        for oc in out_channels:
            for lr in lrs:
                for ep in epochs:                
                    name = f"model_rgcl_{rgcl}_ld{ld}_oc{oc}_lr{lr}_ep{ep}_ds{dataset_name}"

                    config = {
                        "linear_dim" : ld,
                        "out_channels": oc,
                        "lr" : lr,
                        "epochs" : ep,
                        "rgcl": rgcl,
                        "ds": dataset_name
                        }
                    
                    model = train_loop(config, name, train_set, criterion, node_features, periods)
                    predictions, labels, input_trajs = test_loop(test_set, model, criterion)
                    # calculate metrics
                    labelss = torch.stack(labels)
                    predictionss = torch.stack(predictions)

                    # classification metrics
                    best_threshold, best_fscore, auc_score = get_threshold_and_fscore(labels, predictions)
                    bce = criterion(predictionss, labelss)

                    # mobility metrics
                    predictions_01 = (predictionss >=best_threshold).int()
                    nvd = get_normalized_visited_difference(labelss, predictions_01)
                    iou = get_iou(labelss,predictionss )

                    result = {"bce": bce.item(), 
                            "fscore": best_fscore.item(), 
                            "auc": auc_score.item(), 
                            "nvd": nvd.item(), 
                            "iou":iou.item()}
                    
                    entry = dict()
                    dataset = {}
                    entry[name] = config | result

                    add_data_to_json(filename, entry)                
                    save_path = os.path.join(os.curdir, "models", name+ ".pth")
                    torch.save(model, save_path)
                



In [None]:
best_threshold, best_fscore, auc_score = get_threshold_and_fscore(labels, predictions)
print('Best Threshold=%f, F-Score=%.3f, AUC = %f' % (best_threshold, best_fscore, auc_score))

In [None]:
# plot the roc curve for the model
no_skill = len(labels_arr_flatt[labels_arr_flatt==1]) / len(labels_arr_flatt)

plt.plot([0,1], [no_skill,no_skill], linestyle='--', label='No Skill')
plt.plot(recall, precision, marker='.', label='Logistic')
plt.scatter(recall[ix], precision[ix], marker='o', color='black', label='Best')

# axis labels
plt.xlabel('Recall / PPV')
plt.ylabel('Precision / TPR')
plt.legend()
# show the plot
plt.show()

In [None]:
exampe_id = 3
label = labels[exampe_id]
prediction = predictions[exampe_id]
input_traj = input_trajs[exampe_id]

In [None]:
labelst = torch.stack(labels)
labelst.shape

In [None]:
regions = pd.read_pickle("../data/regions_enhanced.pkl")
visualize_df = regions.copy()

visualize_df["my_iou"] = {loader.ids_to_hex_names[i]: v.item() for i,v in enumerate(my_iou)}

visualize_df.explore("my_iou")

In [None]:
input_traj = input_traj.unsqueeze(1)
visualize_y(input_traj,loader.ids_to_hex_names)

In [None]:
visualize_ab((prediction.detach()>= best_threshold).int(),label.detach(),loader.ids_to_hex_names)