## Create Engine

### Log in

In [1]:
from dotenv import load_dotenv
import mysql.connector as mydb
import os
load_dotenv()  # take environment variables from .env.

db_name = os.getenv('DB_NAME')
db_user = os.getenv('DB_USER')
db_password = os.getenv('DB_PASSWORD')
db_host = os.getenv('DB_HOST')
db_port = os.getenv('DB_PORT')



# create connection
conn = mydb.connect(
    host=db_host,
    port=db_port,
    user=db_user,
    password=db_password,
    database=db_name
)
conn.ping(reconnect=True)
print(conn.is_connected())

SQLALCHEMY_DATABASE_URL = f"mysql://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}?charset=utf8mb4"
print(SQLALCHEMY_DATABASE_URL)

True
mysql://client:new_password@localhost:3306/project2?charset=utf8mb4


In [2]:
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker

engine = create_engine(
    SQLALCHEMY_DATABASE_URL
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

Base = declarative_base()

  Base = declarative_base()


In [3]:
from sqlalchemy import Column, ForeignKey, Integer, String, Float, DateTime, CheckConstraint, Text
from sqlalchemy.orm import relationship

from backend.database import Base
from backend.models import Base, Station, Line, LineDetail
from sqlalchemy import MetaData

# initialize metadata
metadata = MetaData()

# bind metadata to engine
metadata.bind = engine
# reflect db schema to MetaData
metadata.reflect(bind=engine)
# drop all tables in the database
metadata.drop_all(bind=engine)

# recreate all tables
Base.metadata.create_all(bind=engine)

In [4]:
import backend.models as models
import backend.schemas as schemas
import backend.crud as crud
db = SessionLocal()


* 'orm_mode' has been renamed to 'from_attributes'


create a line:

In [5]:
crud.create_line(db, schemas.LineCreate(name='1号线', color='red', start_time=None, end_time=None, mileage=None, first_opening=None, url=None, intro=None))


create a station:

In [6]:
db_station1 = crud.create_station(db, schemas.StationCreate(name='Pingguoyuan', district='昌平区', chinese_name='苹果园',intro='苹果园站是北京地铁1号线的起点站，位于北京市昌平区。'))


get station by name:

In [7]:

crud.create_station(db, schemas.StationCreate(name='Gucheng', district='昌平区', chinese_name='古城',intro='古城站是北京地铁1号线的起点站，位于北京市昌平区。'))
db_station2 = crud.get_station_by_name(db, 'Gucheng')


get station by station id:

In [8]:
crud.create_station(db, schemas.StationCreate(name='Bajiao Amusement Park', district='昌平区', chinese_name='八角游乐园',intro='八角游乐园站是北京地铁1号线的起点站，位于北京市昌平区。'))
db_station3 = crud.get_station(db, 3)

get line by name:

In [9]:
db_line = crud.get_line_by_name(db, '1号线')

add stations to line by increment:

In [10]:
crud.add_station_to_line(db, db_line.id, db_station1.id)
crud.add_station_to_line(db, db_line.id, db_station2.id)

insert a station:

In [11]:
crud.insert_station_to_line(db, db_line.id, db_station3.id,1)


get line details:

In [None]:
line_details = db.query(models.LineDetail).filter(models.LineDetail.line_id == 1).order_by(models.LineDetail.order).all()
[[line_details.station_id, line_details.order] for line_details in line_details]

[[1, 0], [3, 1], [2, 2]]

get stations ahead and behind:

In [12]:
query_station = crud.get_nth_station_behind(db, db_line.id, db_station3.id, 1)
query_station.name
query_station.chinese_name

'苹果园'

In [13]:
query_station = crud.get_nth_station_ahead(db, db_line.id, db_station3.id, 1)
query_station.chinese_name

'古城'

In [14]:
from collections import defaultdict
import heapq

# Parse the input data
data = """1,1,1
1,2,2
1,3,3
1,4,4
1,5,5
1,6,6
1,7,7
1,8,8
1,9,9
1,10,10
1,11,11
1,12,12
1,13,13
1,14,14
1,15,15
1,16,16
1,17,17
1,18,18
1,19,19
1,20,20
1,21,21
1,22,22
1,23,23
1,24,24
1,25,25
1,26,26
1,27,27
1,28,28
1,29,29
1,30,30
2,49,1
2,50,2
2,51,3
2,52,4
2,53,5
2,54,6
2,55,7
2,56,8
2,44,9
2,57,10
2,58,11
2,15,12
2,59,13
2,60,14
2,61,15
2,62,16
2,63,17
2,64,18
2,65,19
2,66,20
2,46,21
2,67,22
2,47,23
2,68,24
2,69,25
2,4,26
2,70,27
2,71,28
2,72,29
3,73,1
3,74,2
3,9,3
3,46,4
3,75,5
3,76,6
3,77,7
3,78,8
3,79,9
3,3,10
3,80,11
3,81,12
3,82,13
3,83,14
3,84,15
3,85,16
3,86,17
3,87,18
3,88,19
3,89,20
3,90,21
3,91,22
3,92,23
3,93,24
3,94,25
3,95,26
3,96,27
3,97,28
3,98,29
3,99,30
4,100,1
4,101,2
4,8,3
4,67,4
4,75,5
4,102,6
4,103,7
4,104,8
4,105,9
4,106,10
4,107,11
4,108,12
4,109,13
4,110,14
4,111,15
4,112,16
4,113,17
4,114,18
4,115,19
4,116,20
4,117,21
4,118,22
4,119,23
5,22,1
5,120,2
5,121,3
5,24,4
5,122,5
5,123,6
5,124,7
5,125,8
5,126,9
5,127,10
5,128,11
5,129,12
5,130,13
5,106,14
5,131,15
5,132,16
5,133,17
5,134,18
5,135,19
5,136,20
5,137,21
5,85,22
5,138,23
5,139,24
5,140,25
5,141,26
5,71,27
7,142,1
7,127,2
7,143,3
7,144,4
7,145,5
7,146,6
7,147,7
7,61,8
7,148,9
7,11,10
7,149,11
7,150,12
7,74,13
7,151,14
7,101,15
7,152,16
7,153,17
7,154,18
7,155,19
7,68,20
7,77,21
7,156,22
7,157,23
7,158,24
7,159,25
7,160,26
7,82,27
7,140,28
7,161,29
9,162,1
9,163,2
9,164,3
9,165,4
9,166,5
9,167,6
9,168,7
9,169,8
9,170,9
9,171,10
9,45,11
9,172,12
9,173,13
9,174,14
9,11,15
9,175,16
9,65,17
9,176,18
9,177,19
9,178,20
9,103,21
9,179,22
9,180,23
9,181,24
9,158,25
9,182,26
9,79,27
9,183,28
9,184,29
9,185,30
9,186,31
9,187,32
11,31,1
11,32,2
11,33,3
11,34,4
11,35,5
11,36,6
11,37,7
11,38,8
11,39,9
11,40,10
11,41,11
11,42,12
11,22,13
11,43,14
11,44,15
11,45,16
11,11,17
11,46,18
11,47,19
11,48,20
11,4,21
""".split()


In [16]:
# data = db.query(models.LineDetail).order_by(models.LineDetail.order).all()
edges = defaultdict(list)
for d in data:
    line, station, position = map(int, d.split(','))
    edges[line].append((position, station))

# Create the graph
graph = defaultdict(list)
for line in edges:
    edges[line].sort()
    for i in range(len(edges[line])-1):
        graph[edges[line][i][1]].append((edges[line][i+1][1], 1))
        graph[edges[line][i+1][1]].append((edges[line][i][1], 1))

import pickle

# Save the graph
with open('graph.pkl', 'wb') as f:
    pickle.dump(graph, f)


In [17]:
import pickle
import heapq
# Load the graph
with open('graph.pkl', 'rb') as f:
    graph = pickle.load(f)

# Dijkstra's algorithm
def shortest_path(start, end):
    queue = [(0, start, [])]
    seen = set()
    while queue:
        (cost, node, path) = heapq.heappop(queue)
        if node not in seen:
            seen.add(node)
            path = path + [node]
            if node == end:
                return cost, path
            for next_node, next_cost in graph[node]:
                if next_node not in seen:
                    heapq.heappush(queue, (cost + next_cost, next_node, path))
    return float('inf'), []


# Test the function
print(shortest_path(17, 100))  # Replace with your station ids

(11, [17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 101, 100])
