take a group of 13 nodes closest nodes, and mask out the 3 in the centre
1. select a random node
2. find the 12 closest nodes
3. figure out the centre
4. find 3 nodes closest to the centre

In [None]:
# weather stuff
import pandas as pd
cols = ['rain', 'temp', 'rhum']

w1 = None
w2 = None
w3 = None

def setup_weather_dfs():
    global w1, w2, w3
    w1 = pd.read_csv("./datasets/scats/2023/3m/w1.csv")
    w2 = pd.read_csv("./datasets/scats/2023/3m/w2.csv")
    w3 = pd.read_csv("./datasets/scats/2023/3m/w3.csv")
    # drop useless rows to speed up
    w1 = w1.drop(w1.index[:520000])
    w2 = w2.drop(w2.index[:160000])
    w3 = w3.drop(w3.index[:520000])

setup_weather_dfs()

import geopy.distance as distance

def get_df_closest(lat, long):
    coords = {
        "w1": (53.306, -6.439),
        "w2": (53.364, -6.350),
        "w3": (53.428, -6.241)
    }
    def get_dist_lat_long(lat1, long1, lat2, long2):
        return distance.distance((lat1, long1), (lat2, long2)).m
    min_dist = float("inf")
    closest = None
    for key, val in coords.items():
        dist = get_dist_lat_long(lat, long, val[0], val[1])
        if dist < min_dist:
            min_dist = dist
            closest = key
    if closest == "w1":
        return w1
    elif closest == "w2":
        return w2
    elif closest == "w3":
        return w3
    else:
        raise ValueError("Invalid closest")
    
def get_weather(month, day, hour, lat, long):
    df = get_df_closest(lat, long)
    day_pad = str(day).zfill(2)
    hour_pad = str(hour).zfill(2)
    month_str = None
    if month == 10:
        month_str = "oct"
    elif month == 11:
        month_str = "nov"
    elif month == 12:
        month_str = "dec"
    else:
        raise ValueError("Invalid month")
    date = f"{day_pad}-{month_str}-2023 {hour_pad}:00"
    row = df.loc[df['date'] == date]
    row = row.iloc[0]
    array = row[cols].values
    array = [float(x) for x in array]
    return array

In [None]:
import pandas as pd
import random
from geopy import distance
import torch
import numpy as np
from datetime import datetime

In [None]:
df_path = "./datasets/scats/2023/3m/processed.csv"
df = pd.read_csv(df_path)
df.head()

In [None]:
n_nodes = df["Site"].nunique()
n_nodes
# nodes from 0 to 608

In [None]:
# get lat long pair for each site_id
site_lat_long = df.groupby("Site")[["Lat", "Long"]].first()
site_lat_long.head()

### Preprocessing closest nodes

for each node find and store the `n_closest` nodes closest to it

In [None]:
def get_dist_lat_long(lat1, long1, lat2, long2):
    return distance.distance((lat1, long1), (lat2, long2)).m

def lat(i):
    return site_lat_long.loc[i, "Lat"]

def long(i):
    return site_lat_long.loc[i, "Long"]

In [None]:
n_closest = 13
closest = [[] for i in range(n_nodes)]
for i in range(n_nodes):
    cur_dists = []
    for j in range(n_nodes):
        dist = get_dist_lat_long(lat(i), long(i), lat(j), long(j))
        cur_dists.append((dist, j))
    cur_dists.sort()
    # taket the top n_closest
    for j in range(n_closest):
        closest[i].append(cur_dists[j][1])

closest[2]

In [None]:
vals = []
for i in range(n_nodes):
    vals.append(get_dist_lat_long(lat(i), long(i), lat(closest[i][-1]), long(closest[i][-1])))
mean = sum(vals) / len(vals)
mean

# mean dist between node and its farthest closest node

### Find the center

In [None]:
def centre_i_neighbours(i):
    lt = sum(lat(j) for j in closest[i]) / n_closest
    lg = sum(long(j) for j in closest[i]) / n_closest
    return lt, lg

In [None]:
def get_closest_to_centre(i):
    centre = centre_i_neighbours(i)
    cur_dists = []
    for j in closest[i]:
        dist = get_dist_lat_long(centre[0], centre[1], lat(j), long(j))
        cur_dists.append((dist, j))
    cur_dists.sort()
    return list(map(lambda x: x[1], cur_dists))

In [None]:
def sample_neighbours(seed_node):
    return get_closest_to_centre(seed_node)

### Random generation

In [None]:
# select random node from 0 to n_nodes
node = random.randint(0, n_nodes)
display(node, sample_neighbours(node))

### sample a time and node

In [None]:
n_mask = 3

In [None]:
def get_volume(node, time):
    val = df[(df["Site"] == node) & (df["Time"] == time)]["Volume"].values[0]
    # round off to the nearest 25 multiple
    val = round(val / 25)
    return val

In [None]:
def sample():
    try:
        time = df["Time"].sample().values[0]
        node = random.randint(0, n_nodes)
        
        neighbours = sample_neighbours(node)

        # miss the first n_mask neighbours
        vols = [get_volume(neighbour, time) for neighbour in neighbours]
        missing_sum = sum(vols[:n_mask])
        og_vols = vols[:n_mask]
        rem_vols = vols[n_mask:]

        return time, neighbours, missing_sum, rem_vols, og_vols
    except:
        return None

sample()

In [None]:
def get_time_vector(time):
    time = datetime.strptime(time, '%Y-%m-%d %H:%M:%S')
    month = time.month
    day = time.day
    hour = time.hour
    return month, day, hour

In [None]:
def sample_n(n):
    samples = []
    while len(samples) < n:
        s = sample()
        if s:
            samples.append(s)
    return samples

### creating a dataset

In [None]:
n_train = 1000
n_test = 100

In [None]:
embeddings_path = "./datasets/scats/2023/3m/graph_emb_3m.pt"
embeddings = torch.load(embeddings_path, map_location=torch.device('cpu'))

In [None]:
def encode_time_to_feature(time):
    # Parse the input time string into a datetime object
    time_obj = datetime.strptime(time, "%Y-%m-%d %H:%M:%S")
    
    # Extract hour
    hour = time_obj.hour
    
    # Determine if it's a weekday or weekend (Monday is 0 and Sunday is 6)
    weekday_or_weekend = 0 if time_obj.weekday() < 5 else 1
    
    # Convert features into a PyTorch tensor
    features_tensor = torch.tensor([hour, weekday_or_weekend], dtype=torch.float32)
    
    # Return the tensor
    return features_tensor

In [None]:
def generate_dataset(count):
    samples = sample_n(count)
    X = []
    Y = []
    for samp in samples:
        time, nodes, sum_miss, rem_vols, og_vols = samp

        x = [ encode_time_to_feature(time) ]
        x.extend([ embeddings[i]  for i in nodes ])
        x.append( torch.tensor(get_weather(*get_time_vector(time), lat(nodes[0]), long(nodes[0])) ) )
        x.append(torch.tensor([sum_miss]))
        x.append(torch.tensor(rem_vols))
        # flatten to 1 dim
        x = torch.cat(x).view(-1)

        X.append(x)
        Y.append(torch.tensor(og_vols))
    
    X = torch.stack(X)
    Y = torch.stack(Y)
    return X, Y

In [None]:
trainX, trainY = generate_dataset(100)
display(trainX.shape, trainY.shape)

In [None]:
# save tensors
torch.save(trainX, "./datasets/scats/2023/3m/train3X100.pt")
torch.save(trainY, "./datasets/scats/2023/3m/train3Y100.pt")

In [None]:
testX, testY = generate_dataset(100)
torch.save(testX, "./datasets/scats/2023/3m/test3X100.pt")
torch.save(testY, "./datasets/scats/2023/3m/test3Y100.pt")