In [1]:
import sys
import string
import itertools
from collections import Counter, defaultdict
import re

from pathlib import Path
import os

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import math
import networkx as nx
from copy import deepcopy

In [2]:
%load_ext line_profiler

In [3]:
data = Path('../data/day_12.txt').read_text()

In [4]:
inps = [k.split('-') for k in data.splitlines()]

In [5]:
graph = defaultdict(list)
for start, end in inps:
    graph[start].append(end)
    graph[end].append(start)

In [6]:
def part_a(graph):
    paths = []
    def solve(graph, node, path=None, path_set=None):
        if path is None:
            path = []
            path_set = set()
        else:
            path = path.copy()
            path_set = path_set.copy()
            
        path.append(node)
        if node.islower():
            path_set.add(node)
        if node == 'end':
            paths.append(path)
            return
        for k in graph[node]:
            if k.islower() and k in path_set:
                continue
            solve(graph, k, path, path_set)
        # print(paths)
    solve(graph, 'start')

    return len(paths)

print(part_a(graph))
%timeit part_a(graph)

4720
11.8 ms ± 209 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [7]:
def part_b(graph):
    paths = []
    def solve(graph, node, path=None, lower_counts=None):
        if path is None:
            path = []
            lower_counts = {}
        else:
            path = path.copy()
            lower_counts = lower_counts.copy()
        path.append(node)
        if node == 'end':
            paths.append(path)
            return
        if node.islower():
            lower_counts[node] = lower_counts.get(node, 0) + 1
        for k in graph[node]:
            if k == 'start':
                continue
            elif k.islower():
                counts = {k for k,v in lower_counts.items() if v > 1}
                if counts and k in path:
                    continue
            solve(graph, k, path, lower_counts)
        # print(paths)
    solve(graph, 'start')

    return len(paths)

print(part_b(graph))
%timeit part_b(graph)

147848
1.08 s ± 12.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
