In [None]:
import matplotlib.pyplot as plt
from matplotlib.colors import hsv_to_rgb
import numpy as np
from numbers import Rational
import random
import json
from matplotlib import cm
from matplotlib import colormaps
import matplotlib.colors as mcolors
from tqdm import tqdm
from PIL import Image
import itertools
import copy
from collections import defaultdict
import uuid

In [None]:
# assuming your canvas is 18x18

valid_moves = {'A':[(0, -1), (-1, 0), (1, 0)],
               'B':[(0, -1), (0, 1), (-1, 0)],
               'C':[(0, 1), (-1, 0), (1, 0)],
               'D':[(0, -1), (0, 1), (1, 0)]}

starting_moves = {'A':(0, -1),
                  'B':(-1, 0),
                  'C':(0, 1),
                  'D':(1, 0)}

stations_points = {'A': [(8, 16), (9, 16), (10, 16)],
                'B': [(16, 8), (16, 9), (16, 10)],
                'C': [(8, 2), (9, 2), (10, 2)],
                'D': [(2, 8), (2, 9), (2, 10)]}



node_to_label = dict()

for k, v in stations_points.items():
    for i in range(len(v)):
        node_to_label[str(v[i])] = k


images = []

stations = ['A', 'B', 'C', 'D']
random.shuffle(stations)

path_nums = [1, 2, 3]

while True:
    
    cnt = 0
    while cnt < len(path_nums):
        p = path_nums[cnt]
        path_counter = defaultdict(int)
        visited = []
        all_routes = []
        cross_dest = []

        for station in stations:
            path_counter[station] += 0

        for station in stations:

            start_list = copy.deepcopy(stations_points[station])
            random.shuffle(start_list)

            possible_destinations = []

            for k, v in stations_points.items():
                if k == station:
                    continue

                if path_counter[k] == p:
                    continue
                
                for i in range(len(v)):
                    if v[i] not in cross_dest:
                        possible_destinations.append(v[i])

            for i in list(np.arange(1, p+1)):

                if path_counter[station] == p:
                    break
                
                start = start_list.pop()
                root = start
                end = (start[0] + starting_moves[station][0], start[1] + starting_moves[station][1])

                routes = dict()
                routes['path'] = []
                
                while True:
                    if end in possible_destinations:
                        break

                    if [start, end] not in visited and [end, start] not in visited and end[0] > 2 and end[0] < 16 and end[1] > 2 and end[1] < 16:
                        routes['path'].append([start, end])
                        visited.append([start, end])
                        temp_moves = copy.deepcopy(valid_moves[station])

                    else:
                        if len(routes['path']) < 1:
                            break
                        old_data = routes['path'][-1]
                        start, end = old_data[0], old_data[1]

                    if len(temp_moves) < 1:
                        if len(routes['path']) > 1:
                            routes['path'].pop(-1)
                        else:
                            break
                        old_data = routes['path'][-1]
                        start, end = old_data[0], old_data[1]
                        temp_moves = copy.deepcopy(valid_moves[station])
                    
                    route = random.choice(temp_moves)
                    temp_moves.pop(temp_moves.index(route))
                    start = end
                    end = (start[0] + route[0], start[1] + route[1])

                if end in possible_destinations and [start, end] not in visited and [end, start] not in visited:
                    routes['path'].append([start, end])
                    visited.append([start, end])
                    all_routes.append(routes)
                    path_counter[station] += 1
                    cross_dest.append(end)
                    cross_dest.append(root)

                    for k, v in stations_points.items():
                        if end in v:
                            path_counter[k] += 1

        sw = 1
        for station in stations:
            if path_counter[station] == p: 
                sw *= 1
            else:
                sw *= 0

        if sw == 1:              
            images.append(all_routes)
            cnt += 1

    if len(images) == 45:
        break

len(images)

In [None]:
import matplotlib.pyplot as plt

def get_colors_from_colormap(colormap_name, num_colors):
    colormap = colormaps[colormap_name]
    indicies = np.arange(num_colors)
    colors = [colormap(i) for i in indicies]
    return colors

def rgba_to_color_name(rgba):
    colors = mcolors.CSS4_COLORS
    
    input_rgb = rgba[:3]

    closest_color = min(colors, key=lambda name: np.linalg.norm(np.array(mcolors.to_rgba(colors[name])[:3]) - np.array(input_rgb)))
    
    return closest_color

def draw_lines(all_routes, path, node_to_label, color):
    fig, ax = plt.subplots(figsize=(18, 18), dpi=500)

    ax.set_aspect('equal', adjustable='box')
    ax.set_xlim(0, 18)
    ax.set_ylim(0, 18)
    ax.axis('off')

    ax.text(8.45, 16.5, 'A', fontsize=100, color='k', fontweight='bold')
    ax.text(16.5, 8.6, 'B', fontsize=100, color='k', fontweight='bold')
    ax.text(8.45, 0.6, 'C', fontsize=100, color='k', fontweight='bold')
    ax.text(0.4, 8.6, 'D', fontsize=100, color='k', fontweight='bold')

    # ax.set_xticks(np.arange(0, 19, 1))
    # ax.set_yticks(np.arange(0, 19, 1))
    # ax.minorticks_on()
    # ax.grid(color='gray', linestyle='--', linewidth=1)
    # ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)

    connections = []
    for i, paths in enumerate(all_routes):
        connection = dict()
        for path in paths['path']:
            x1, y1 = path[0][0], path[0][1]
            x2, y2 = path[1][0], path[1][1]
            ax.plot([x1, x2], [y1, y2], color=color[i], linestyle='solid', linewidth=thickness)

        connection['start'] = node_to_label[str(paths['path'][0][0])]
        connection['end'] = node_to_label[str(paths['path'][-1][-1])]
        connection['color'] = rgba_to_color_name(color[i])
        connections.append(connection)

    return fig, connections
    
def convert_fig_to_pil(fig):
    fig.canvas.draw()
    width, height = fig.canvas.get_width_height()

    buf = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(height, width, 3)
    image = Image.fromarray(buf)

    return image


metadata = []
counter = 1
for all_routes in images:
    color = get_colors_from_colormap('tab10', len(all_routes))
    for image_size in [512, 1024]:
        for thickness in [10, 20]:
            fig, connections = draw_lines(all_routes, thickness, node_to_label, color)
            uid  = str(uuid.uuid4())

            name =  "pixels_" + str(image_size) + "_linewidth_" + str(thickness) + "_path_" + str(counter) + "_" + uid
            plt.tight_layout(pad=0.0)
            image = convert_fig_to_pil(fig)
            plt.close(fig)
            image = image.resize((image_size, image_size))
            image.save('./SubwayConnection/' + name + '.png')
            image.save('./SubwayConnection/' + name + '.pdf')
           
            info = dict()
            info['image_id'] = uid
            info['linewidth'] = thickness
            info['name'] = name
            info['path_outs'] = counter
            info['size'] = image_size
            info['connections'] = connections
        
            metadata.append(info)
    counter += 1
    if counter == 4:
        counter = 1
    
with open("./SubwayConnection/metadata.json", "w") as fp:
    json.dump(metadata, fp)




