In [None]:
import pandas as pd
import random
from geopy import distance
import torch
import numpy as np
from datetime import datetime
import utils

In [None]:
# weather stuff
import pandas as pd
cols = ['rain', 'temp', 'rhum']

w1 = None
w2 = None
w3 = None

def setup_weather_dfs():
    global w1, w2, w3
    w1 = pd.read_csv("./datasets/scats/2023/3m/w1.csv")
    w2 = pd.read_csv("./datasets/scats/2023/3m/w2.csv")
    w3 = pd.read_csv("./datasets/scats/2023/3m/w3.csv")
    # drop useless rows to speed up
    w1 = w1.drop(w1.index[:520000])
    w2 = w2.drop(w2.index[:160000])
    w3 = w3.drop(w3.index[:520000])

setup_weather_dfs()

import geopy.distance as distance

def get_df_closest(lat, long):
    coords = {
        "w1": (53.306, -6.439),
        "w2": (53.364, -6.350),
        "w3": (53.428, -6.241)
    }
    def get_dist_lat_long(lat1, long1, lat2, long2):
        return distance.distance((lat1, long1), (lat2, long2)).m
    min_dist = float("inf")
    closest = None
    for key, val in coords.items():
        dist = get_dist_lat_long(lat, long, val[0], val[1])
        if dist < min_dist:
            min_dist = dist
            closest = key
    if closest == "w1":
        return w1
    elif closest == "w2":
        return w2
    elif closest == "w3":
        return w3
    else:
        raise ValueError("Invalid closest")
    
def get_weather(month, day, hour, lat, long):
    df = get_df_closest(lat, long)
    day_pad = str(day).zfill(2)
    hour_pad = str(hour).zfill(2)
    month_str = None
    if month == 10:
        month_str = "oct"
    elif month == 11:
        month_str = "nov"
    elif month == 12:
        month_str = "dec"
    else:
        raise ValueError("Invalid month")
    date = f"{day_pad}-{month_str}-2023 {hour_pad}:00"
    row = df.loc[df['date'] == date]
    row = row.iloc[0]
    array = row[cols].values
    array = [float(x) for x in array]
    return array

In [None]:
import pandas as pd
import random
from geopy import distance
import torch
import numpy as np
from datetime import datetime

class Handler:
    def __init__(self):
        self.n_nodes = 609
        self.df_path = "./datasets/scats/2023/3m/processed.csv"
        self.df = pd.read_csv(self.df_path)
        self.prox = 20
        self.site_lat_long = site_lat_long = self.df.groupby("Site")[["Lat", "Long"]].first()

        n_closest = self.prox
        closest = [[] for i in range(self.n_nodes)]
        for i in range(self.n_nodes):
            cur_dists = []
            for j in range(self.n_nodes):
                lati = self.get_lat_node_id(i)
                longi = self.get_long_node_id(i)
                latj = self.get_lat_node_id(j)
                longj = self.get_long_node_id(j)
                dist = self.get_dist_lat_long(lati, longi, latj, longj)
                cur_dists.append((dist, j))
            cur_dists.sort()
            # taket the top n_closest
            for j in range(n_closest):
                closest[i].append(cur_dists[j][1])

        self.closest = closest

        def centre_i_neighbours(i):
            lt = sum(self.get_lat_node_id(j) for j in closest[i]) / n_closest
            lg = sum(self.get_long_node_id(j) for j in closest[i]) / n_closest
            return lt, lg
        
        self.centre_of_i_neighbours = [centre_i_neighbours(i) for i in range(self.n_nodes)]

        def get_closest_to_centre(i):
            centre = centre_i_neighbours(i)
            cur_dists = []
            for j in closest[i]:
                dist = self.get_dist_lat_long(centre[0], centre[1], self.get_lat_node_id(j), self.get_long_node_id(j))
                cur_dists.append((dist, j))
            cur_dists.sort()
            return list(map(lambda x: x[1], cur_dists))
        
        self.closest_ct_close = [get_closest_to_centre(i) for i in range(self.n_nodes)]

        self.df_time_type = self.df.copy()
        self.df_time_type['Time'] = pd.to_datetime(self.df['Time'])

        self.load_graph_emb()

    def sample(self):
        i = random.randint(0, self.n_nodes)
        return self.closest_ct_close[i]

    def regen(self):
        self.df = pd.read_csv(self.df_path)

    def set_prox(self, prox):
        self.prox = prox

    def get_dist_lat_long(self, lat1, long1, lat2, long2):
        return distance.distance((lat1, long1), (lat2, long2)).m

    def get_lat_node_id(self, i):
        return self.site_lat_long.loc[i, "Lat"]
    
    def get_long_node_id(self, i):
        return self.site_lat_long.loc[i, "Long"]
    
    def sample_node(self):
        return random.randint(0, self.n_nodes-1)
    
    def sample_time(self):
        return random.choice(self.df_time_type['Time'])
    
    def get_series(self, node, time, len):
        site = self.df_time_type[self.df_time_type['Site'] == node]
        vols = []
        for i in range(len):
            vol = site[site['Time'] == time]['Volume']
            if vol.empty:
                vols.append(0)
            else:
                vols.append(vol.iloc[0])
            time += pd.Timedelta(hours=1)
        return vols
    
    def load_graph_emb(self):
        emb_path = "./datasets/scats/2023/3m/graph_emb_3m.pt"
        self.embeddings = torch.load(emb_path, map_location=torch.device('cpu'))

    def get_gemb(self, i):
        return self.embeddings[i]
    
    def text_time_to_datetime(self, text_time):
        time_obj = datetime.strptime(text_time, "%Y-%m-%d %H:%M:%S")
        return time_obj
    
    def datetime_to_text_time(self, time_obj):
        return time_obj.strftime("%Y-%m-%d %H:%M:%S")
    
    def encode_time(self, time_obj):
        hour = time_obj.hour
        month = time_obj.month
        day = time_obj.day
        weekday_or_weekend = 0 if time_obj.weekday() < 5 else 1
        features_tensor = torch.tensor([hour, day, month, weekday_or_weekend], dtype=torch.float32)
        return features_tensor


In [None]:
h = Handler()
h.sample_node()

In [None]:
miss_rate = .1

In [None]:
def feat_one(node, time):
    emb = h.get_gemb(node)
    time_ft = h.encode_time(time)
    delta = 1
    len = 3
    qlen = 4
    series = h.get_series(node, time, qlen)
    xval = series[:len]
    yval = series[len:]

    xt = torch.cat([emb, torch.tensor(xval, dtype=torch.float32), time_ft])
    yt = torch.tensor(yval, dtype=torch.float32)

    return xt,yt

In [None]:
def sample_one():
    node, time = h.sample_node(), h.sample_time()
    len = 3
    qlen = 4
    nodes = h.closest_ct_close[node]

    X = []
    Y = []

    for i in nodes:
        x,y = feat_one(i, time)
        X.append(x)
        Y.append(y)
    
    X = torch.stack(X)
    Y = torch.cat(Y)
    return X,Y

In [None]:
a, b = sample_one()
display(a.shape, b.shape)

In [None]:
def sample_n(n):
    X = []
    Y = []
    for i in range(n):
        x,y = sample_one()
        X.append(x)
        Y.append(y)
    X = torch.stack(X)
    Y = torch.stack(Y)
    return X,Y

In [None]:
n = 100
X,Y = sample_n(n)

In [None]:
display(X.shape, Y.shape)

In [None]:
ntrain = 1000
ntest = 100

In [None]:
trainX, trainY = sample_n(ntrain)
testX, testY = sample_n(ntest)

In [None]:
base = "./datasets/scats/2023/3m/t2"
torch.save(trainX, f"{base}_tr1kx.pt")
torch.save(trainY, f"{base}_tr1ky.pt")
torch.save(testX, f"{base}_tr1tx.pt")
torch.save(testY, f"{base}_tr1ty.pt")

In [None]:
# save as is,apply missing while train