In [None]:
from os import getenv
from connect_db import connect_db
con = connect_db(
    username=getenv("ORACLE_USERNAME"),
    password=getenv("ORACLE_PASSWORD"),
    host=getenv("ORACLE_HOST"),
    port=int(getenv("ORACLE_PORT"))
)
cur = con.cursor()

In [None]:
SEATS_PER_BOARDING_GROUP = 40

In [None]:
import networkx as nx

graph = nx.Graph()
stmt = "SELECT ROUTE_ID, FROM_AIRPORT_ID, TO_AIRPORT_ID, DISTANCE_IN_KM FROM ROUTE"

for row in cur.execute(stmt):
    graph.add_edge(row[1], row[2], weight=row[3], route_id=row[0])

In [None]:
import numpy as np
stmt = "SELECT FROM_AIRPORT_ID FROM ROUTE"
AIRPORTS = list(set(row[0] for row in cur.execute(stmt)))
AIRPORT_WEIGHTS = [1, 2, 1, 1, 6, 2, 2, 1, 3, 4, 1, 1, 2, 2, 1, 3, 1]
AIRPORT_WEIGHTS = np.array(AIRPORT_WEIGHTS) / sum(AIRPORT_WEIGHTS)

In [None]:
from typing import Iterable

def get_random_airports() -> Iterable[str]:
    n = 1
    for i in range(len(AIRPORTS)):
        if np.random.randint(1, np.power(2, i) + 1) != 1:
            break
        n += 1

    yield from np.random.choice(AIRPORTS, size=n)

In [None]:
def get_random_number_of_tickets(max_tickets: int = 10) -> int:
    for i in range(max_tickets):
        if np.random.randint(1, np.power(2, i) + 1) != 1:
            return i
    return max_tickets

In [None]:
from collections import defaultdict
ACCOUNT_AIRPORT_PREFERENCES: defaultdict[str, list[str]] = defaultdict(list)
stmt = "SELECT ACCOUNT_ID FROM ACCOUNT"

for row in cur.execute(stmt):
    for airport_id in get_random_airports():
        ACCOUNT_AIRPORT_PREFERENCES[row[0]].append(airport_id)

In [None]:
def get_random_travel_rate():
    return np.random.choice(
        np.arange(1, 51),
        p=np.concatenate((
            np.ones(10) * 0.8 / 10,
            np.ones(10) * 0.1 / 10,
            np.ones(10) * 0.06 / 10,
            np.ones(10) * 0.03 / 10,
            np.ones(10) * 0.01 / 10,
        ))
    ).item()
ACCOUNT_TRAVEL_RATE = {account_id: get_random_travel_rate() for account_id in ACCOUNT_AIRPORT_PREFERENCES}

In [None]:
from functools import cache
import itertools

@cache
def get_airports_stops(from_airport: str, to_airport: str) -> list[str]:
    return nx.dijkstra_path(graph, from_airport, to_airport)

@cache
def get_routes(from_airport: str, to_airport: str) -> list[str]:
    out = []
    for i, j in itertools.pairwise(get_airports_stops(from_airport, to_airport)):
        out.append(graph.get_edge_data(i, j)["route_id"])
    return out

In [None]:
PRICE_PER_KM = 0.238
def calc_flight_price(distance: int, number_of_seats: int):
    return distance * PRICE_PER_KM

stmt = """SELECT FLIGHT_ID, DISTANCE_IN_KM, NUMBER_OF_SEAT FROM FLIGHT
    JOIN ROUTE ON FLIGHT.ROUTE_ID = ROUTE.ROUTE_ID
    JOIN AIRCRAFT ON FLIGHT.AIRCRAFT_ID = AIRCRAFT.AIRCRAFT_ID
    JOIN AIRCRAFT_MODEL ON AIRCRAFT.AIRCRAFT_MODEL_ID = AIRCRAFT_MODEL.AIRCRAFT_MODEL_ID"""
FLIGHT_PRICE = {row[0]: calc_flight_price(row[1], row[2]) for row in cur.execute(stmt)}

In [None]:
import random
@cache
def find_seats(aircraft_id: str, seat_class_weight: tuple[int] = (1, 5, 20)) -> list[tuple[str, str]]:
    seat_class = random.sample(("ECON", "BUSI", "FIRS"), k=1, counts=seat_class_weight)[0]
    stmt = f"""SELECT SEAT_ID, SEAT_CLASS_ID FROM SEAT WHERE AIRCRAFT_ID = '{aircraft_id}' AND SEAT_CLASS_ID = '{seat_class}'"""
    return [(row[0], row[1]) for row in cur.execute(stmt)]

In [None]:
from datetime import datetime, timedelta
USED_SEAT: defaultdict[str, set[str]] = defaultdict(set)

def find_closest_flight(current_dt: datetime, route_id: str, n_tickets: int) -> tuple[str, datetime, list[str], list[str]] | None:
    stmt = f"""SELECT * FROM (
        SELECT FLIGHT.FLIGHT_ID, DEPARTURE_DATETIME, AIRCRAFT.AIRCRAFT_ID, EST_DURATION_IN_HOUR
        FROM FLIGHT
            JOIN AIRCRAFT ON FLIGHT.AIRCRAFT_ID = AIRCRAFT.AIRCRAFT_ID
            JOIN AIRCRAFT_MODEL ON AIRCRAFT.AIRCRAFT_MODEL_ID = AIRCRAFT_MODEL.AIRCRAFT_MODEL_ID
        WHERE ROUTE_ID = '{route_id}' AND DEPARTURE_DATETIME >= TO_TIMESTAMP('{current_dt:%Y-%m-%d %H:%M:%S}', 'YYYY-MM-DD HH24:MI:SS')
                  )
    WHERE ROWNUM <= 50"""
    for row in cur.execute(stmt):
        seats = set()
        seat_classes = {}
        for seat, seat_class in find_seats(row[2]):
            seats.add(seat)
            seat_classes[seat] = seat_class
        seats -= USED_SEAT[row[0]]

        if len(seats) < n_tickets:
            continue

        seats = list(seats)
        seats = seats[:n_tickets]
        for seat in seats:
            USED_SEAT[row[0]].add(seat)
        return row[0], row[1] + timedelta(hours=row[3] + 12), seats, list(seat_classes[seat] for seat in seats)

In [None]:
ACCOUNT_WAIT_UNTIL: dict[str, datetime] = {}
stmt = "SELECT ACCOUNT_ID, JOINED_DATETIME FROM ACCOUNT"
for row in cur.execute(stmt):
    ACCOUNT_WAIT_UNTIL[row[0]] = row[1]

In [None]:
def get_payment_method() -> str:
    return random.choices(("CRCD", "DBCD", "PYPL", "BANK", "CASH"), k=1, weights=[0.3, 0.2, 0.1, 0.2, 0.2])[0]

In [None]:
from collections import Counter

payment_i = 0
def PAYMENT_ID_GEN():
    global payment_i
    out = f"P{payment_i:08}"
    payment_i += 1
    return out

flight_ticket_i: Counter[str, int] = Counter()
def TICKET_ID_GEN(dt: datetime):
    global flight_ticket_i
    formatted_dt = f"{dt:%y%m}"
    out = f"SFS{formatted_dt}{flight_ticket_i[formatted_dt]:06}"
    flight_ticket_i[formatted_dt] += 1
    return out

In [None]:
flight_ticket_file = open("flight_ticket.sql", "w")
flight_ticket_fmt = "    INTO FLIGHT_TICKET (ACCOUNT_ID, PAYMENT_ID, CREATED_AT, BOARDING_GROUP, TICKET_PRICE, TICKET_STATUS) VALUES ({}, '{}', TO_TIMESTAMP('{:%Y-%m-%d %H:%M:%S}', 'YYYY-MM-DD HH24:MI:SS'), {}, {:.2f}, 'Completed')\n"
flight_ticket_k = 0
flight_sequence_file = open("flight_seq.sql" , "w")
flight_sequence_fmt = "    INTO FLIGHT_SEQUENCE (FLIGHT_TICKET_ID, FLIGHT_ID, SEAT_ID, FLIGHT_SEQUENCE) VALUES ('{}', {}, '{}', {})\n"
flight_sequence_k = 0
payment_file = open("payment.sql", "w")
payment_fmt = "    INTO PAYMENT (PAYMENT_METHOD_ID, ACCOUNT_ID, CREATED_AT, AMOUNT) VALUES ('{}', {}, TO_TIMESTAMP('{:%Y-%m-%d %H:%M:%S}', 'YYYY-MM-DD HH24:MI:SS'), {:.2f})\n"
payment_k = 0
for fp in (flight_ticket_file, flight_sequence_file, payment_file):
    fp.write("INSERT ALL\n")

stmt = ""
dt = datetime(2023, 1, 1)

try:
    while dt < datetime(2024, 4, 19, 15):
        accounts_to_select = filter(lambda t: t[1] > dt, ACCOUNT_WAIT_UNTIL.items())
    
        accounts_selected = []
        for account, _ in accounts_to_select:
            if np.random.randint(1, 100 - ACCOUNT_TRAVEL_RATE[account]) == 1:
                accounts_selected.append(account)
    
        for account_selected in accounts_selected:
            payment_id = PAYMENT_ID_GEN()
            payment_amt = 0
    
            start_place, end_place = random.sample(ACCOUNT_AIRPORT_PREFERENCES[account_selected], k=2)
            routes = get_routes(start_place, end_place)
            # 1 1 S1
            # 2 1 S2
            # 1 2 S3
            # 2 2 S4
            n_tickets = get_random_number_of_tickets()
            ticket_ids = []
            ticket_prices = []
            for n in range(n_tickets):
                ticket_ids.append(TICKET_ID_GEN(dt))
                ticket_prices.append(0)
    
            cancel = False
            temps = []
            for route in routes:
                temp = find_closest_flight(ACCOUNT_WAIT_UNTIL[account_selected], route, n_tickets)
                temps.append(temp)
                if temp is None:
                    cancel = True
                    break
    
            if cancel or not temps:
                continue
    
            for route_seq, (flight_id, wait_until, seat_ids, seat_classes) in enumerate(temps, start=1):
                ACCOUNT_WAIT_UNTIL[account_selected] = wait_until
    
                flight_price = FLIGHT_PRICE[flight_id]
    
                for i, (ticket_id, seat_id, seat_class) in enumerate(zip(ticket_ids, seat_ids, seat_classes)):
                    flight_sequence_file.write(flight_sequence_fmt.format(ticket_id, flight_id, seat_id, route_seq))
    
                    flight_sequence_k += 1
                    if flight_sequence_k >= 2000:
                        flight_sequence_file.write("SELECT 1 FROM DUAL;\n")
                        flight_sequence_file.write("INSERT ALL\n")
                        flight_sequence_k = 0
    
                    seat_rate = 1
                    if seat_class == "BUSI":
                        seat_rate = 2
                    elif seat_class == "FIRS":
                        seat_rate = 3
    
                    ticket_prices[i] += flight_price * seat_rate
                    payment_amt += flight_price * seat_rate
            for ticket_price in ticket_prices:
                flight_ticket_file.write(flight_ticket_fmt.format(account_selected, payment_id, dt, random.randint(1, 10), ticket_price)) # TODO: fix boarding group
    
                flight_ticket_k += 1
                if flight_ticket_k >= 2000:
                    flight_ticket_file.write("SELECT 1 FROM DUAL;\n")
                    flight_ticket_file.write("INSERT ALL\n")
                    flight_ticket_k = 0
            
            payment_file.write(payment_fmt.format(get_payment_method(), account_selected, dt, payment_amt))
            payment_k += 1
            if payment_k >= 2000:
                payment_file.write("SELECT 1 FROM DUAL;\n")
                payment_file.write("INSERT ALL\n")
                payment_k = 0
            dt += timedelta(minutes=random.randint(1, 360))
except KeyboardInterrupt:
    print("stopping")

for fp in (flight_ticket_file, flight_sequence_file, payment_file):
    fp.write("SELECT 1 FROM DUAL;\n")

flight_ticket_file.close()
flight_sequence_file.close()
payment_file.close()

In [None]:
con.close()