# Flow Theory

# IMPORTS

In [None]:
from typing import List, Tuple
import math
from collections import Counter, deque
import urllib.request
import time

# UTILS

In [123]:
def conv_seconds_milliseconds(seconds: float) -> float:
    return seconds * 1000

def duration(start: float, end: float) -> float:
    return end - start


# Max Flow

## [Download Speed](https://cses.fi/problemset/task/1694)

In [118]:
class MaxFlow:
    def __init__(self, n: int, edges: List[Tuple[int, int, int]]):
        self.size = n
        self.edges = edges

    def build(self, n: int, edges: List[Tuple[int, int, int]]) -> None:
        self.adj_list = {}
        for u, v, cap in edges:
            if u not in self.adj_list:
                self.adj_list[u] = Counter()
            self.adj_list[u][v] += cap
            if v not in self.adj_list:
                self.adj_list[v] = Counter()

    def main_dfs(self, source: int, sink: int) -> int:
        self.build(self.size, self.edges)
        maxflow = 0
        while True:
            self.reset()
            cur_flow = self.dfs(source, sink, math.inf)
            if cur_flow == 0:
                break
            maxflow += cur_flow
        return maxflow

    def reset(self) -> None:
        self.parents = [-1] * self.size

    def dfs(self, node: int, sink: int, flow: int) -> int:
        if node == sink:
            return flow
        self.parents[node] = 1
        cap = self.adj_list[node]
        for nei, cap in cap.items():
            if self.parents[nei] == -1 and cap > 0:
                cur_flow = self.dfs(nei, sink, min(flow, cap))
                if cur_flow > 0:
                    self.adj_list[node][nei] -= cur_flow
                    self.adj_list[nei][node] += cur_flow
                    return cur_flow
        return 0
    
    def main_edmonds_karp(self, source: int, sink: int) -> int:
        self.build(self.size, self.edges)
        maxflow = 0
        while True:
            self.reset()
            cur_flow = self.edmonds_karp(source, sink)
            if cur_flow == 0:
                break
            maxflow += cur_flow
        return maxflow

    def edmonds_karp(self, source: int, sink: int) -> int:
        queue = deque([(source, math.inf)])
        self.parents[source] = -2
        while queue:
            node, flow = queue.popleft()
            if node == sink:
                break
            capacity = self.adj_list[node]
            for nei, cap in capacity.items():
                if self.parents[nei] == -1 and cap > 0:
                    self.parents[nei] = node
                    queue.append((nei, min(flow, cap)))
        if node == sink:
            while node != source:
                parent = self.parents[node]
                self.adj_list[parent][node] -= flow
                self.adj_list[node][parent] += flow # residual edge
                node = parent
            return flow
        return 0


In [119]:
urls = ['https://cses.fi/file/acec992e42fe3462f07114ad2d5f7ce9ff27434922e9b52f39006310ca79d019/1/1/', \
    'https://cses.fi/file/558f035a5dce8931e19371bda522b5a81d28a9a1a1835a6205e566ca9de324c8/1/1/', \
    'https://cses.fi/file/9201642e4901d251a2c18f26429a67089a018a07cd3aa6025cb5fd12d4f88126/1/1/', \
    'https://cses.fi/file/654fbbbac2b61ff15187c1d399394ea7e27b05b3dddf32bdba1bb1c6708e3593/1/1/', \
    'https://cses.fi/file/8286fe339a5312417d20620138dec793deb78cd8960f33ddc4f521982e71f046/1/1/', \
    'https://cses.fi/file/297a2fce46a4102cbd86bea796751acd566fcae258aa00b62d34f5436e441b27/1/1/', \
    'https://cses.fi/file/f1cb0fbf03699e8e91a47846d49e084dae8ec899186d7766461383b5bf562452/1/1/', \
    'https://cses.fi/file/d31400a9196af8d78037127201e471353fcd1f5aaecc9939ea4740a054559c0f/1/1/', \
    'https://cses.fi/file/963201f693af2a27f8d43a78a6213b938576971e71fd5270ad62b538cae9cd47/1/1/', \
    'https://cses.fi/file/9b1a8c894a16cc3228c663a38b764156f7f47183b2f7b206866f935d693dbae7/1/1/', \
    'https://cses.fi/file/e27523c04940efd4cddc19cb7ad99a65635c2fb88cf1c86a7492e08089e8c942/1/1/', \
    'https://cses.fi/file/a09e3665a05e05a0f4e6d590b271ba889bed02c5aae69522d38d8bf1c62aa371/1/1/', \
    'https://cses.fi/file/ec19840ed099c8e55fd77bf40b1cf4f6fdbd43c0a63c74dfe736de4d38cb67cd/1/1/']

In [None]:
"""
Using the dfs implementation as the base case, it was tested to work in the online judge.
"""
results = [0]*len(urls)
for i, url in enumerate(urls):
    data = urllib.request.urlopen(url)
    for j, line in enumerate(map(lambda line: line.decode('utf-8').strip('\n'), data)):
        if j == 0:
            n, m = map(int, line.split())
            edges = []
        else:
            u, v, cap = map(int, line.split())
            edges.append((u - 1, v - 1, cap))
    start_time = time.perf_counter()
    mf = MaxFlow(n, edges).main_dfs(0, n - 1)
    end_time = time.perf_counter()
    results[i] = mf
    print(f'Finished testcase: {i} in {end_time - start_time} seconds')

In [None]:
%%time
for i, url in enumerate(urls):
    data = urllib.request.urlopen(url)
    for j, line in enumerate(map(lambda line: line.decode('utf-8').strip('\n'), data)):
        if j == 0:
            n, m = map(int, line.split())
            edges = []
        else:
            u, v, cap = map(int, line.split())
            edges.append((u - 1, v - 1, cap))
    start_time = time.perf_counter()
    mf = MaxFlow(n, edges).main_dfs(0, n - 1)
    end_time = time.perf_counter()
    assert mf == results[i], f'Failed on testcase: {i}'
    print(f'Finished testcase: {i} in {end_time - start_time} seconds')


In [124]:
%%time
for i, url in enumerate(urls):
    data = urllib.request.urlopen(url)
    for j, line in enumerate(map(lambda line: line.decode('utf-8').strip('\n'), data)):
        if j == 0:
            n, m = map(int, line.split())
            edges = []
        else:
            u, v, cap = map(int, line.split())
            edges.append((u - 1, v - 1, cap))
    start_time = time.perf_counter()
    mf = MaxFlow(n, edges).main_dfs(0, n - 1)
    end_time = time.perf_counter()
    duration_dfs = duration(start_time, end_time)
    start_time = time.perf_counter()
    mf = MaxFlow(n, edges).main_edmonds_karp(0, n - 1)
    end_time = time.perf_counter()
    duration_edmonds_karp = duration(start_time, end_time)
    assert mf == results[i], f'Failed on testcase: {i}, output: {mf}, expected: {results[i]}'
    print(f'Test case {i} passed')
    print(f'dfs: {conv_seconds_milliseconds(duration_dfs)} milliseconds')
    print(f'edmonds-karp: {conv_seconds_milliseconds(duration_edmonds_karp)} milliseconds')


Test case 0 passed
dfs: 0.15285199879144784 milliseconds
edmonds-karp: 0.027812000553240068 milliseconds
Test case 1 passed
dfs: 0.13514599959307816 milliseconds
edmonds-karp: 0.038557998777832836 milliseconds
Test case 2 passed
dfs: 0.042644000131986104 milliseconds
edmonds-karp: 0.02200800008722581 milliseconds
Test case 3 passed
dfs: 0.05922600030316971 milliseconds
edmonds-karp: 0.022959000489208847 milliseconds
Test case 4 passed
dfs: 1.222647000759025 milliseconds
edmonds-karp: 1.0794099998747697 milliseconds
Test case 5 passed
dfs: 2.9189720007707365 milliseconds
edmonds-karp: 2.7333910002198536 milliseconds
Test case 6 passed
dfs: 15.551608999885502 milliseconds
edmonds-karp: 5.754584000897012 milliseconds
Test case 7 passed
dfs: 0.07178200030466542 milliseconds
edmonds-karp: 0.04002900095656514 milliseconds
Test case 8 passed
dfs: 0.11093200009781867 milliseconds
edmonds-karp: 0.08192099994630553 milliseconds
Test case 9 passed
dfs: 0.12475500079744961 milliseconds
edmonds-kar