In [1]:
import torch
import torch.nn as nn

import wget as wget
from tqdm.notebook import tqdm

!pip install ipywidgets
!jupyter nbextension enable --py widgetsnbextension

print("PyTorch has version {}".format(torch.__version__))

usage: jupyter [-h] [--version] [--config-dir] [--data-dir] [--runtime-dir]
               [--paths] [--json] [--debug]
               [subcommand]

Jupyter: Interactive Computing

positional arguments:
  subcommand     the subcommand to launch

optional arguments:
  -h, --help     show this help message and exit
  --version      show the versions of core jupyter packages and exit
  --config-dir   show Jupyter config dir
  --data-dir     show Jupyter data dir
  --runtime-dir  show Jupyter runtime dir
  --paths        show all Jupyter paths. Add --json for machine-readable
                 format.
  --json         output paths as machine-readable json
  --debug        output debug information about paths

Available subcommands: bundlerextension-script.py bundlerextension.exe
console-script.py console.exe kernel kernel.exe kernelspec kernelspec.exe lab-
script.py lab.exe labextension-script.py labextension.exe labhub-script.py
labhub.exe migrate migrate.exe nbclassic-script.py nbclassic.

In [2]:
#!pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.13.1+cu116.html
#!pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.13.1+cu116.html
#!pip install torch-geometric
#!pip install ogb
#!pip install wget

In [3]:
from bs4 import BeautifulSoup as bs
from urllib.request import urlopen
import os


def get_graphs(url, folder):
  """
  Inspired by: https://python.plainenglish.io/notesdownloader-use-web-scraping-to-download-all-pdfs-with-python-511ea9f55e48
  """
  if not os.path.exists(folder):
    os.makedirs(folder)

  links = []
  html = urlopen(url).read()
  html_page = bs(html, features="lxml") 

  for link in html_page.find_all('a'):
    current_link = link.get('href')

    if current_link.endswith('gz'):
      links.append(url + current_link)

  for link in tqdm(links):
    try:
      wget.download(link, out = folder)
    except:
      print("Unable to Download A File")
  print("File download done!")

In [4]:
import gzip
import shutil
from pathlib import Path


def uncompress_files(path, outpath):
  """
  Inspired by https://stackoverflow.com/questions/3548673/how-can-i-replace-or-strip-an-extension-from-a-filename-in-python
  and https://stackoverflow.com/questions/31028815/how-to-unzip-gz-file-using-python
  """
  if not os.path.exists(outpath):
    os.makedirs(outpath)
  for file in tqdm(os.listdir(path)):
    if os.path.isdir(os.path.join(path, file)):
      continue
    new_filename = Path(file)
    extensions = "".join(new_filename.suffixes)
    new_ext = ".txt"
    new_filename = str(new_filename).replace(extensions, new_ext)
    new_filename = os.path.join(outpath, new_filename)

    with gzip.open(os.path.join(path, file), 'rb') as f_in:
        with open(new_filename, 'wb') as f_out:
            shutil.copyfileobj(f_in, f_out)

In [5]:
#get_graphs("http://lime.cs.elte.hu/~kpeter/data/mcf/netgen/", './content/data/netgen')

In [6]:
get_graphs("http://lime.cs.elte.hu/~kpeter/data/mcf/road/", './data/road/compressed')

  0%|          | 0/70 [00:00<?, ?it/s]

Unable to Download A File


In [5]:
uncompress_files('./data/road/compressed', './resources')

  0%|          | 0/40 [00:00<?, ?it/s]

In [10]:
from min_cost_flow import Network, successive_shortest_paths, primal_value
from data_parsing import parse, build_networkx, build_pyg, build_network
import networkx as nx

In [11]:
def min_cost_flow(nodes, edges, flow_alg):
    if flow_alg == 'nx':
        G = build_networkx(nodes, edges)
        opt = nx.min_cost_flow_cost(G)
    if flow_alg == 'cbn':
        N = build_network(nodes, edges)
        _, _, _, opt = successive_shortest_paths(N, iter_limit=150)
    return opt

def process(filename, flow_alg):
    nodes, edges = parse(filename)
    if len(edges.keys()) <= 1e6:
        opt = min_cost_flow(nodes, edges, flow_alg)
        return build_pyg(nodes, edges, opt)
    else:
        return {"converged": False}

In [4]:
import os.path as osp

import torch
from torch_geometric.data import Data, Dataset, download_url


class MinCostDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)

    @property
    def raw_file_names(self):
        """If these files are found in the raw directory, download is skipped"""
        return []

    @property
    def processed_file_names(self):
        """If these files are found in the processed directory, processing is skipped"""
        processed_files = []
        path = self.processed_dir
        for file in tqdm(os.listdir(path)):
            file_path = os.path.join(path, file)
            if not os.path.isdir(file_path) and not file == "pre_filter.pt" and not file == "pre_transform.pt":
                processed_files.append(file)

        return processed_files

    def download(self):
        pass

    def process(self):
        idx = 0
        path = self.raw_dir
        for file in tqdm(os.listdir(path)):
            print(file)
            file_path = os.path.join(path, file)
            if os.path.isdir(file_path):
                continue
            # Read data from `raw_path`.
            output = process(file_path)
            if output["converged"]:
                x = output["x"]
                edge_index = output["edge_index"]
                edge_attr = output["edge_attr"]
                y = output["y"]
                data = Data(x = x, edge_index = edge_index, edge_attr = edge_attr, y = y, filename = file_path)

                torch.save(data, osp.join(self.processed_dir, f'data_{idx}.pt'))
                idx += 1

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, f'data_{idx}.pt'))
        return data

In [5]:
dataset = MinCostDataset(root = "./data/")

  0%|          | 0/34 [00:00<?, ?it/s]

In [6]:
print(dataset)

  0%|          | 0/34 [00:00<?, ?it/s]

  0%|          | 0/34 [00:00<?, ?it/s]

MinCostDataset(32)


In [7]:
def i_messed_up(dataset, idx):
    data = dataset.get(idx)
    #data.edge_index = torch.reshape(data.edge_index, (2, -1))
    #data.edge_attr = torch.reshape(data.edge_attr, (-1, 2))
    #data.x = torch.reshape(data.x, (-1,1))
    data.x = data.x.type(torch.float32)
    torch.save(data, osp.join(dataset.processed_dir, f'data_{idx}.pt'))
for i in range(32):
    i_messed_up(dataset, i)