In [17]:
import requests
from dotenv import load_dotenv
from datetime import datetime

import random
import torch
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.utils import from_networkx

In [3]:
# API URLs
BASE_URL_STAC = "https://geoservice.dlr.de/eoc/ogc/stac/v1"
#BASE_URL_STAC = "https://planetarycomputer.microsoft.com/api/stac/v1"

In [18]:
def datetime_to_numeric(datetime_str):
    try:
        dt = datetime.strptime(datetime_str, "%Y-%m-%dT%H:%M:%S.%fZ")
        return dt.timestamp()
    except (ValueError, TypeError):
        return None

In [None]:
def get_all_items(collection_id):
    items = []
    next_link = f"{BASE_URL_STAC}/collections/{collection_id}/items"
    
    while next_link:
        response = requests.get(next_link)
        response.raise_for_status()
        data = response.json()
        items.extend(data.get("features", []))
        # Check if there is a "next" link for pagination
        next_link = None
        for link in data.get("links", []):
            if link.get("rel") == "next":
                next_link = link.get("href")
                break
    
    return items

def get_stac_collections():
    response_col = requests.get(f"{BASE_URL_STAC}/collections")
    response_col.raise_for_status()
    collections = response_col.json()["collections"][:2]

    collections_with_items = []
    for col in collections:
        col_id = col.get("id")
        items = get_all_items(col_id)
        items_filtered = []
        print(len(items))
        for itm in items:
            items_filtered.append({
                "id": itm.get("id"),
                "datetime": datetime_to_numeric(itm.get("properties", {}).get("datetime"))
            })
        # Sort items by datetime
        items_filtered.sort(key=lambda x: x["datetime"] if x["datetime"] is not None else 0)
        collections_with_items.append({
            "id": col_id,
            "items": items_filtered
        })
    return collections_with_items

collections = get_stac_collections()
collections

In [None]:

def generate_one_graph(item_):
    #num_nodes = random.randint(2, 5)
    num_nodes = 4

    G = nx.path_graph(num_nodes) # see https://networkx.org/documentation/stable/reference/generators.html

    # add prices to each node
    for node in G.nodes():
        G.nodes[node]["price"] = random.randint(1, 50)

    ## convert the graph information to data ready to use by Pytorch Geometric
    data = from_networkx(G)

    # convert price attribute to node feature tensor [num_nodes, 1]
    data.x = torch.tensor([[G.nodes[i]["price"]] for i in G.nodes()],
        dtype=torch.float)

    prices = [G.nodes[n]['price'] for n in G.nodes()]

    label = 1 if sum(prices) >= 80 else 0
    #label = 1 if statistics.mean(prices) >= 30 else 0
    #label = 1 if G.number_of_nodes() >= 4  else 0 # works more or less nicely

    y = torch.tensor([[label]], dtype=torch.float)

    data.y = y

    return G, data

def visualize_graph(G, graph_label=None):
    pos = nx.spring_layout(G, seed=42)

    prices = [G.nodes[n]['price'] for n in G.nodes()]
    labels = {i: f"{prices[i]:.1f}" for i in range(len(prices))}

    node_colors = [G.nodes[n]['price'] for n in G.nodes()]
    fig, ax = plt.subplots(figsize=(5, 4))

    nx.draw(
        G,
        pos,
        node_color=node_colors,
        cmap=plt.cm.coolwarm,
        node_size=800,
        edge_color="gray",
        with_labels=False
    )
    nx.draw_networkx_labels(G, pos, labels=labels)

    cbar = plt.colorbar(
        plt.cm.ScalarMappable(cmap=plt.cm.coolwarm),
        ax=ax,
        label="House Price (â‚¬)"
    )

    if graph_label is not None:
        plt.title(f"Graph Label: {graph_label}")
    else:
        plt.title("Fully Connected Graph\n(Node label = price)")
    plt.axis("off")
    plt.show()


def generate_many_graphs(ngraph):

  n_pos = 0
  n_neg = 0
  limit = int(ngraph/2)
  graphs = []
  dataset = []


  while n_pos < limit or n_neg < limit:
    G, data = generate_one_graph()

    if data.y == 1 and n_pos <limit:
      n_pos += 1
      graphs.append(G)
      dataset.append(data)

    elif data.y == 0 and n_neg < limit:
      n_neg += 1
      graphs.append(G)
      dataset.append(data)

  return graphs, dataset


G, data = generate_one_graph()
print(G.nodes(data=True))
print(data)
visualize_graph(G,  graph_label=data.y.item())