Creation of the chain responsible for the generation of the star-schema structure

In [None]:
from typing_extensions import TypedDict
from typing import List, Literal

class AggregationState(TypedDict):
    selected_tables: List[str]
    aggregation_columns: List[str]
    selected_columns: List[str]
    selected_operations: List[str, Literal["COUNT", "SUM", "AVG", "MIN", "MAX"]]
    time_aggregation: List[str, Literal["HOUR", "WEEKDAY", "WEEK", "MONTH", "QUARTER"]]

In [3]:
# test data

test_schemas = [
    {
        "battles": """
        battle_id INT PRIMARY KEY,
        player_id INT NOT NULL,
        tank_id INT NOT NULL,
        damage_dealt INT NOT NULL,
        damage_blocked INT NOT NULL,
        damage_assisted INT NOT NULL,
        battle_played DATE NOT NULL,
        FOREIGN KEY (player_id) REFERENCES players(player_id),
        FOREIGN KEY (tank_id) REFERENCES tanks(tank_id)
        """,
        "players": """
        player_id int primary key,
        username varchar(30),
        battles int,
        winrate decimal(3,2)
        """,
        "tanks": """
        tank_id int primary_key,
        name varchar(100),
        tier int
        """
    },
    {
        "orders": """
        order_id int primary key,
        product_id int,
        order_date DATE,
        quantity decimal(4,2),
        price decimal(6,2),
        customer_id int, 
        FOREIGN KEY (product_id) REFERENCES products(product_id),
        FOREIGN KEY (customer_id) REFERENCES customers(customer_id)
        """,
        "products": """
        product_id int primary key,
        product_name varchar(100),
        product_price decimal(4,2)
        """,
        "customers": """
        customer_id int primary key, 
        customer_name varchar(100),
        country varchar(100)
        """
    },
    {
        "orders": """
        order_id int primary key,
        order_date DATE,
        quantity decimal(4,2),
        price decimal(6,2),
        customer_id int, 
        FOREIGN KEY (product_id) REFERENCES products(product_id),
        FOREIGN KEY (customer_id) REFERENCES customers(customer_id)
        """,
        "products": """
        product_id int primary key,
        product_name varchar(100),
        product_price decimal(4,2)
        """,
        "customers": """
        customer_id int primary key, 
        customer_name varchar(100),
        country varchar(100)
        """
    },
    {
        "calls":"""
        call_id int primary key,
        caller_id int,
        receiver_id int,
        date DATE,
        duration decimal(4,2),
        FOREIGN KEY (caller_id) REFERENCES customers(customer_id),
        FOREIGN KEY (receiver_id) REFERENCES customers(customer_id)
        """,
        "customers": """
        customer_id int primary key,
        name varchar(100),
        plan_id int,
        FOREIGN KEY (plan_id) REFERENCES plan(plan_id)
        """,
        "plan": """
        plan_id int primary key,
        description varchar(255)
        """
    },
    {
        "calls": """
        call_id int primary key,
        caller_id int,
        receiver_id int,
        date DATE,
        duration decimal(4,2),
        FOREIGN KEY (caller_id) REFERENCES customers(customer_id)
        FOREIGN KEY (receiver_id) REFERENCES customers(customer_id)
        """,
        "customers": """
        customer_id int primary key,
        name varchar(100),
        plan_id int,
        phone_id int,
        FOREIGN KEY (plan_id) REFERENCES plan(plan_id),
        FOREIGN KEY (phone_id) REFERENCES phone(phone_id)
        """,
        "phone": """
        phone_id int primary key,
        version varchar(50),
        company varchar(50)
        """,
        "plan": """
        plan_id int primary key,
        description varchar(255)
        """
    },
    {
       "drives": """
        drives_id primary key,
        driver_id int
        car_id int,
        fuel_burnt decimal(4,2),
        days_out int,
        mileage decimal(5,2),
        FOREIGN KEY (driver_id) REFERENCES drivers(driver_id),
        FOREIGN KEY (car_id) REFERENCES cars(car_id)
        """,
        "drivers": """
        driver_id primary key,
        driver_name varchar(100),
        driver_age int
        """,
        "cars": """
        car_id int primary key,
        car_model varchar(100),
        car_manufacturer int,
        FOREIGN KEY (car_manufacturer) REFERENCES car_manufacturer(manufacturer_id)
        """,
        "car_manufacturer": """
        manufacturer_id int primary key,
        name varchar(100), 
        country varchar(50)
        """ 
    },
    {
        "battles": """
        battle_id int primary key,
        player_id int,
        tank_id int,
        damage_dealt int,
        damage_blocked int,
        damage_assisted int,
        battle_played DATE,
        FOREIGN KEY (player_id) REFERENCES players(player_id),
        FOREIGN KEY (tank_id) REFERENCES tanks(tank_id)
        """ ,
        "players": """
        player_id int primary key,
        username varchar(30),
        battles int,
        winrate decimal(3,2)
        """ ,
        "tanks": """
        tank_id int primary_key,
        name varchar(100),
        tier int
        """ 
    },
    {
        "orders": """
        order_id int primary key,
        product_id int,
        order_date DATE,
        quantity decimal(4,2),
        price decimal(6,2),
        customer_id int, 
        FOREIGN KEY (product_id) REFERENCES products(product_id),
        FOREIGN KEY (customer_id) REFERENCES customers(customer_id)
        """,
        "customers": """
        customer_id int primary key, 
        customer_name varchar(100),
        country varchar(100)
        """ ,
        "products": """
        product_id int primary key,
        product_name varchar(100),
        product_price decimal(4,2)
        """     
    },
    {
        "orders": """
        order_id int primary key,
        product_id int,
        order_date DATE,
        quantity decimal(4,2),
        total_price decimal(6,2),
        customer_id int, 
        FOREIGN KEY (product_id) REFERENCES products(product_id),
        FOREIGN KEY (customer_id) REFERENCES customers(customer_id)
        """,
        "products": """
        product_id int primary key,
        product_name varchar(100),
        product_price decimal(4,2)
        """,
        "customers": """
        customer_id int primary key, 
        customer_name varchar(100),
        country varchar(100)
        """
    },
    {
        "calls": """
        call_id int primary key,
        caller_id int,
        receiver_id int,
        date DATE,
        duration decimal(4,2),
        FOREIGN KEY (caller_id) REFERENCES customers(customer_id),
        FOREIGN KEY (receiver_id) REFERENCES customers(customer_id)
        """,
        "customers": """
        customer_id int primary key,
        name varchar(100),
        plan_id int,
        phone_id int,
        FOREIGN KEY (plan_id) REFERENCES plan(plan_id),
        FOREIGN KEY (phone_id) REFERENCES phone(phone_id)
        """,
        "phone": """
        phone_id int primary key,
        version varchar(50),
        company varchar(50)
        """,
        "plan": """
        plan_id int primary key,
        description varchar(255)    
        """
    },
    {
        "calls": """
        call_id int primary key,
        caller_id int,
        receiver_id int,
        date DATE,
        duration decimal(4,2),
        FOREIGN KEY (caller_id) REFERENCES customers(customer_id),
        FOREIGN KEY (receiver_id) REFERENCES customers(customer_id)
        """,
        "customers": """
        customer_id int primary key,
        name varchar(100),
        plan_id int,
        phone_id int,
        FOREIGN KEY (plan_id) REFERENCES plan(plan_id),
        FOREIGN KEY (phone_id) REFERENCES phone(phone_id)
        """,
        "phone": """
        phone_id int primary key,
        version varchar(50),
        company varchar(50)
        """,
        "plan": """
        plan_id int primary key,
        description varchar(255)    
        """
    },
    {
        "battles": """
        battle_id int primary key,
        player_id int,
        tank_id int,
        damage_dealt int,
        damage_blocked int,
        damage_assisted int,
        battle_played DATE,
        FOREIGN KEY (player_id) REFERENCES players(player_id),
        FOREIGN KEY (tank_id) REFERENCES tanks(tank_id)  
        """,
        "players": """
        player_id int primary key,
        username varchar(30),
        battles int,
        winrate decimal(3,2) 
        """,
        "tanks": """
        tank_id int primary_key,
        name varchar(100),
        tier int    
        """
    },
    {
        "battles": """
        battle_id int primary key,
        player_id int,
        tank_id int,
        damage_dealt int,
        damage_blocked int,
        damage_assisted int,
        battle_played DATE,
        FOREIGN KEY (player_id) REFERENCES players(player_id),
        FOREIGN KEY (tank_id) REFERENCES tanks(tank_id)  
        """,
        "players": """
        player_id int primary key,
        username varchar(30),
        battles int,
        winrate decimal(3,2) 
        """,
        "tanks": """
        tank_id int primary_key,
        name varchar(100),
        tier int    
        """,
        
    },
    {
        "races": """
        race_id int primary key,
        location_id int,
        season_id int,
        race_winner_number int,
        fasted_lap decimal(3,4),
        FOREIGN KEY (location_id) REFERENCES locations(location_id),
        FOREIGN KEY (season_id) REFERENCES seasons(season_id),
        FOREIGN KEY (race_winner_number) REFERENCES racers(race_number)
        """,
        "locations": """
        location_id int primary key,
        country varchar(30),
        city varchar(25)
        """,
        "seasons": """
        season_id int primary key,
        season varchar(20),
        """,
        "racers": """
        race_number int primary key,
        racer_name varchar(100)
        """
    },
    {
        "flights": """
        flight_number int primary key,
        plane_number int,
        departure_place int,
        destination_place int,
        departure_time DATE,
        arrival_time DATE,
        FOREIGN KEY (plane_number) REFERENCES planes(plane_number),
        FOREIGN KEY (departure_place) REFERENCES place(place_id),
        FOREIGN KEY (destination_place) REFERENCES place(place_id)
        """,
        "planes": """
        plane_number int primary key,
        plane_company int,
        plane_model varchar(50),
        flights_number int,
        flight_hours decimal(5, 2),
        FOREIGN KEY (plane_company) REFERENCES companies(company_id)
        """,
        "companies": """
        company_id int primary key,
        name varchar(50),
        number_planes int
        """,
        "place": """
        place_id int primary key,
        country varchar(30)
        """
    },
    {
        "flights": """
        flight_number int primary key,
        plane_number int,
        departure_place int,
        destination_place int,
        departure_time DATE,
        arrival_time DATE,
        FOREIGN KEY (plane_number) REFERENCES planes(plane_number),
        FOREIGN KEY (departure_place) REFERENCES place(place_id),
        FOREIGN KEY (destination_place) REFERENCES place(place_id)
        """,
        "planes": """
        plane_number int primary key,
        plane_company int,
        plane_model varchar(50),
        flights_number int,
        flight_hours decimal(5, 2),
        FOREIGN KEY (plane_company) REFERENCES companies(company_id)
        """,
        "companies": """
        company_id int primary key,
        name varchar(50),
        number_planes int
        """,
        "place": """
        place_id int primary key,
        country varchar(30)
        """
    },
    {
        "flights": """
        flight_number int primary key,
        plane_number int,
        departure_place int,
        destination_place int,
        flight_duration decimal(5,2),
        FOREIGN KEY (plane_number) REFERENCES planes(plane_number),
        FOREIGN KEY (departure_place) REFERENCES place(place_id),
        FOREIGN KEY (destination_place) REFERENCES place(place_id)
        """,
        "planes": """
        plane_number int primary key,
        plane_company int,
        plane_model varchar(50),
        flights_number int,
        flight_hours decimal(5, 2),
        FOREIGN KEY (plane_company) REFERENCES companies(company_id)
        """,
        "companies": """
        company_id int primary key,
        name varchar(50),
        number_planes int
        """,
        "place": """
        place_id int primary key,
        country varchar(30)
        """,
    },
    {
        "exams": """
        exam_id int,
        student_id int,
        primary key (exam_id, student_id),
        score decimal(3,2),
        FOREIGN KEY (student_id) REFERENCES students(student_id),
        FOREIGN KEY (exam_id) REFERENCES exam_dim(exam_id)
        """,
        "students": """
        student_id int primary key,
        student_name varchar(100)
        class_id int,
        FOREIGN KEY (class_id) REFERENCES classes(class_id)
        """,
        "classes": """
        class_id int primary key,
        grade int,
        letter varchar(2),
        number_students int
        """,
        "exam_dim": """
        exam_id int,
        class varchar(20),
        date DATE
        """,
    },
    {
        "transaction": """
        transaction_id int primary key,
        date DATE,
        sender_id int,
        receiver_id int,
        amount decimal(6,2),
        FOREIGN KEY (sender_id) REFERENCES customers(customer_id),
        FOREIGN KEY (receiver_id) REFERENCES customers(customer_id)
        """,
        "customers": """
        customer_id int primary key,
        name varchar(100),
        balance decimal(7,2)
        """
    },
    {
        "exams": """
        exam_id in,
        student_id int,
        score decimal(3,2),
        primary key (exam_id, student_id),
        FOREIGN KEY (student_id) REFERENCES students(student_id),
        FOREIGN KEY (exam_id) REFERENCES exam_dim(exam_id)
        """,
        "exam_dim": """
        exam_id int primary key,
        subject varchar(20),
        date DATE
        """,
        "classes": """
        class_id int primary key,
        grade int,
        letter varchar(2),
        number_students int
        """,
        "students": """
        student_id int primary key,
        student_name varchar(100)
        class_id int,
        FOREIGN KEY (class_id) REFERENCES classes(class_id)
        """
    },
    {
        "flights": """
        flight_id int primary key,
        plane_id int,
        company_id int,
        tickets_sold int,
        ticket_price decimal(4,2),
        operation_day DATE,
        FOREIGN KEY (plane_id) REFERENCES planes(plane_id),
        FOREIGN KEY (company_id) REFERENCES companies(company_id)
        """,
        "planes": """
        plane_id int primary key,
        plane_model varchar(100),
        plane_manufacturer varchar(100)
        """,
        "companies": """
        company_id int primary key,
        name varchar(100)
        """
    },
    {
        "flights": """
        flight_id int primary key,
        plane_id int,
        company_id int,
        tickets_sold int,
        operation_day DATE,
        FOREIGN KEY (plane_id) REFERENCES planes(plane_id),
        FOREIGN KEY (company_id) REFERENCES companies(company_id)
        """,
        "planes": """
        plane_id int primary key,
        plane_model varchar(100),
        plane_manufacturer varchar(100)
        """,
        "companies": """
        company_id int primary key,
        name varchar(100)
        """
    },
    {
        "battles": """
        battle_id int primary key,
        player_id int,
        tank_id int,
        damage_dealt int,
        damage_blocked int,
        damage_assisted int,
        battle_played DATE,
        FOREIGN KEY (player_id) REFERENCES players(player_id),
        FOREIGN KEY (tank_id) REFERENCES tanks(tank_id)
        """,
        "players": """
        player_id int primary key,
        username varchar(30),
        battles int,
        winrate decimal(3,2)
        """,
        "tanks": """
        tank_id int primary_key,
        name varchar(100),
        tier int
        """
    },
    {
        "battles": """
        battle_id int primary key,
        player_id int,
        tank_id int,
        damage_dealt int,
        damage_blocked int,
        damage_assisted int,
        battle_played DATE,
        FOREIGN KEY (player_id) REFERENCES players(player_id),
        FOREIGN KEY (tank_id) REFERENCES tanks(tank_id)
        """,
        "players": """
        player_id int primary key,
        username varchar(30),
        battles int,
        winrate decimal(3,2)
        """,
        "tanks": """
        tank_id int primary_key,
        name varchar(100),
        tier int
        """
    },
    {
        "calls": """
        call_id int primary key,
        caller_id int,
        receiver_id int,
        date DATE,
        duration decimal(4,2),
        FOREIGN KEY (caller_id) REFERENCES customers(customer_id),
        FOREIGN KEY (receiver_id) REFERENCES customers(customer_id)
        """,
        "customers": """
        customer_id int primary key,
        name varchar(100),
        plan_id int,
        phone_id int,
        FOREIGN KEY (plan_id) REFERENCES plan(plan_id),
        FOREIGN KEY (phone_id) REFERENCES phone(phone_id)
        """,
        "phone": """
        phone_id int primary key,
        version varchar(50),
        company varchar(50)
        """,
        "plan": """
        plan_id int primary key,
        description varchar(255)
        """
    }
    
]

In [4]:
# expected results

test_outcomes = [
    """
    user_id int,
    tank_id int,
    total_damage int,
    total_blocked int,
    total_assisted int,
    primary key(user_id, tank_id),
    FOREIGN KEY (player_id) REFERENCES players(player_id),
    FOREIGN KEY (tank_id) REFERENCES tanks(tank_id)
    """,
    """
    create table product_customer_fact(
    product_id int,
    customer_id int,
    total_price decimal(7,2),
    primary key (product_id, customer_id),
    FOREIGN KEY (product_id) REFERENCES products(product_id),
    FOREIGN KEY (customer_id) REFERENCES customers(customer_id),
    );
    """,
    """
    create table order_fact(
    product_id int primary key,
    total_price decimal(7,2),
    FOREIGN KEY (product_id) REFERENCES products(product_id),
    );
    """,
    """
    create table call_length(
    first_caller_id int,
    second_caller_id int,
    total_duration decimal(5,2),
    primary key(first_caller_id, second_caller_id),
    FOREIGN KEY (first_caller_id) REFERENCES customers(customer_id)
    FOREIGN KEY (second_caller_id) REFERENCES customers(customer_id)
    );
    """,
    """
    create table call_length(
    plan_id int,
    phone_id int,
    total_duration decimal(5,2),
    primary key(plan_id, phone_id),
    FOREIGN KEY (plan_id) REFERENCES plan(plan_id),
    FOREIGN KEY (phone_id) REFERENCES phone(phone_id)
    );
    """,
    """
    create table table_fact(
    manufacturer_id int primary key,
    total_mileage decimal(8,2),
    avg_days_out decimal(3,2),
    FOREIGN KEY (manufacturer_id) REFERENCES car_manufacturer(manufacturer_id)
    );
    """,
    """
    create table total_tank_fact(
    tank_id int primary key,
    avg_damage decimal(5,2),
    avg_blocked decimal(5,2),
    avg_assisted decimal(5,2),
    counter int,
    FOREIGN KEY (tank_id) REFERENCES tanks(tank_id)
    );
    """,
    """
    create table product_hourly_fact(
    product_id int,
    order_date_hour DATETIME,
    total_product_price decimal(7,2),
    primary key (product_id, order_date_hour),
    FOREIGN KEY (product_id) REFERENCES product_dim(product_id),
    );
    """,
    """
    create table product_fact(
    product_id int primary key,
    products_over_month DATETIME,
    count int,
    primary key(product_id, products_over_month),
    FOREIGN KEY (product_id) REFERENCES products(product_id)
    );
    """,
    """
    create table count_plan_phone_fact(
    phone_id int,
    plan_id int,
    count_calls int,
    primary key(phone_id, plan_id),
    FOREIGN KEY (plan_id) REFERENCES plan(plan_id),
    FOREIGN KEY (phone_id) REFERENCES phone(phone_id)
    );
    """,
    """
    create table count_plan_phone_weekly_fact(
    phone_id int,
    plan_id int,
    count_calls int,
    Call_Date_Weekly DATETIME,
    primary key(phone_id, plan_id, Call_Date_Weekly),
    FOREIGN KEY (plan_id) REFERENCES plan(plan_id),
    FOREIGN KEY (phone_id) REFERENCES phone(phone_id)
    );
    """,
    """
    create table count_player_tank_fact(
    player_id int,
    tank_id int,
    count_battles int,
    total_damage int,
    avg_damage decimal(10,2),
    primery key(player_id, tank_id),
    FOREIGN KEY (player_id) REFERENCES players(player_id),
    FOREIGN KEY (tank_id) REFERENCES tanks(tank_id)
    );
    """,
    """
    create table count_player_tank_daily_fact(
    player_id int,
    tank_id int,
    count_battles int,
    total_damage int,
    battle_date_daily DATETIME,
    avg_damage decimal(10,2),
    primery key(player_id, tank_id, battle_date_daily),
    FOREIGN KEY (player_id) REFERENCES players(player_id),
    FOREIGN KEY (tank_id) REFERENCES tanks(tank_id)
    );
    """,
    """
    create table count_races_season_fact(
    season_id int,
    racer_id int,
    count_winner int,
    primary key(season_id, racer_id),
    FOREIGN KEY (season_id) REFERENCES seasons(season_id),
    FOREIGN KEY (racer_id) REFERENCES racers(race_number)
    );
    """,
    """
    create table count_flights_company_daily(
    company_id int,
    flights_over_daily DATETIME,
    flights_count int,
    primary key(company_id, flights_over_daily),
    FOREIGN KEY (company_id) REFERENCES companies(company_id)
    );
    """,
    """
    create table count_flights_company_daily(
    departure_place int,
    destination_place int,
    flights_count int,
    flights_over_daily DATETIME,
    primary key(company_id, departure_place, destination_place, flights_over_daily),
    FOREIGN KEY (company_id) REFERENCES companies(company_id),
    FOREIGN KEY (departure_place) REFERENCES place(place_id),
    FOREIGN KEY (destination_place) REFERENCES palce(place_id)
    );
    """,
    """
    create table flights_fact(
    flight_number int,
    departure_place int,
    destination_place int,
    min_duration decimal(5,2),
    primary key (departure_place, destination_place),
    FOREIGN KEY (departure_place) REFERENCES place(place_id),
    FOREIGN KEY (destination_place) REFERENCES place(place_id)
    );
    """,
    """
    create table min_max_score_exam_class_fact(
    class_id int,
    exam_id int,
    min_student_id int,
    max_student_id int,
    min_score decimal(3,2),
    max_score decimal(3,2),
    primary key(class_id, exam_id),
    FOREIGN KEY (class_id) REFERENCES classes(class_id),
    FOREIGN KEY (exam_id) REFERENCES exam_dim(exam_id),
    FOREIGN KEY (min_student_id) REFERENCES students(student_id),
    FOREIGN KEY (max_student_id) REFERENCES students(student_id)
    );
    """,
    """
    create table montly_customer_spendings(
    customer_id int,
    spendings_over_month DATETIME,
    total_spendings decimal(7,2),
    primary key (customer_id, spendings_over_month),
    FOREIGN KEY (receiver_id) REFERENCES customers(customer_id)
    );
    """,
    """
    create table avg_score_exam_class_fact(
    class_id int,
    exam_id int,
    counter int,
    avg_score decimal(3,2),
    primary key(class_id, exam_id),
    FOREIGN KEY (class_id) REFERENCES classes(class_id),
    FOREIGN KEY (exam_id) REFERENCES exam_dim(exam_id)
    );
    """,
    """
    create table company_fact(
    company_id int primary key,
    flights_count int,
    total_revenue decimal(5,2),
    avg_tickets_sold decimal(4,2),
    FOREIGN KEY (company_id) REFERENCES companies(company_id)
    );
    """,
    """
    create table avg_tickets_sold_company_monthly_fact(
    company_id int,
    tickets_sold_over_month DATETIME,
    flights_count int,
    total_tickets int,
    primary key(company_id, tickets_sold_over_month),
    avg_tickets_sold decimal(4,2),
    FOREIGN KEY (company_id) REFERENCES companies(company_id)
    );
    """,
    """
    create table avg_player_tank_fact(
    player_id int,
    tank_id int,
    count_battles int,
    avg_damage decimal(10,2),
    primery key(player_id, tank_id),
    FOREIGN KEY (player_id) REFERENCES players(player_id),
    FOREIGN KEY (tank_id) REFERENCES tanks(tank_id)
    );
    """,
    """
    create table avg_damage_player_month_fact(
    player_id int,
    count_battles int,
    avg_damage decimal(10,2),
    damage_over_month DATETIME,
    primery key(player_id, damage_over_month),
    FOREIGN KEY (player_id) REFERENCES players(player_id),
    );
    """,
    """
    CREATE TABLE fact_calls_weekday (
    caller_id INT,
    weekday VARCHAR(10),
    avg_calls_per_weekday DECIMAL(5,2),
    total_duration_per_weekday int,
    PRIMARY KEY (caller_id, weekday),
    FOREIGN KEY (caller_id) REFERENCES customers(customer_id)
    );
    """
]

In [5]:
# below is the code for testing 

test_agg_columns = [
    ["user_id", "tank_id"],
    ["product_id", "customer_id"],
    ["product_id"],
    ["caller_id", "receiver_id"],
    ["plan_id", "phone_id"],
    ["manufacturer_id"],
    ["tank_id"],
    ["product_id", "order_date"],
    ["product_id", "order_date"],
    ["phone_id", "plan_id"],
    ["phone_id", "plan_id", "date"],
    ["player_id", "tank_id"],
    ["player_id", "tank_id" , "battle_played"],
    ["season_id", "racer_id"],
    ["company_id", "departure_time"],
    ["company_id", "departure_place", "destination_place", "departure_time"],
    ["departure_place", "destination_place"],
    ["class_id", "exam_id"],
    ["sender_id", "date"],
    ["class_id", "exam_id"],
    ["company_id"],
    ["company_id", "operation_day"],
    ["player_id", "tank_id"],
    ["player_id", "battle_played"],
    ["customer_id", "date"]
]

test_columns = [
    ["damage_dealt", "damage_blocked", "damage_assisted"],
    ["price", "discount"],
    ["price"],
    ["duration"],
    ["duration"],
    ["mileage", "days_out"],
    ["damage_dealt", "damage_blocked", "damage_assisted", "battle_id"],
    ["price"],
    ["product_id"],
    ["call_id"],
    ["call_id"],
    ["battle_id", "damage_dealt", "damage_dealt"],
    ["battle_id", "damage_dealt", "damage_dealt"],
    ["race_winner_number"],
    ["flight_id"],
    ["flight_id"],
    ["flight_duration"],
    ["score", "score"],
    ["amount"],
    ["score"],
    ["tickets_sold", "ticket_price"],
    ["tickets_sold"],
    ["damage_dealt"],
    ["damage_dealt"],
    ["call_id", "duration"]
]

test_operations = [
    ["SUM", "SUM", "SUM"],
    ["SUM", "AVG"],
    ["SUM"],
    ["SUM"],
    ["SUM"],
    ["SUM", "AVG"],
    ["AVG","AVG","AVG","COUNT"],
    ["SUM"],
    ["COUNT"],
    ["COUNT"],
    ["COUNT"],
    ["COUNT", "SUM", "AVG"],
    ["COUNT", "SUM", "AVG"],
    ["COUNT"],
    ["COUNT"],
    ["COUNT"],
    ["MIN"],
    ["MIN", "MAX"],
    ["SUM"],
    ["AVG"],
    ["AVG", "SUM"],
    ["AVG"],
    ["AVG"],
    ["AVG"],
    ["AVG", "SUM"]
]

time_aggregation = [
    "NONE",
    "NONE",
    "NONE",
    "NONE",
    "NONE",
    "NONE",
    "NONE",
    "HOUR",
    "MONTH",
    "NONE",
    "WEEK",
    "NONE",
    "WEEKDAY",
    "NONE",
    "WEEKDAY",
    "WEEKDAY",
    "NONE",
    "NONE",
    "MONTH",
    "NONE",
    "NONE",
    "MONTH",
    "NONE",
    "WEEKDAY",
    "WEEKDAY"
]


start_prompt_tempalate = "You are an SQL expert. Your task is to create the table schema for so-called fact table that follows star-schema design. You are provided with the schemas of available tables: {schemas}"
generated_column_part = ". Include fields that compute "
output_instructions = "Output should contain only the sql query to create a table. Don't repeat the question, and don't provide an explanation"

# generates a test prompt based on the table name, aggregation fields, and the desired operation
def generate_test_prompt(counter):
    agg_columns = test_agg_columns[counter]
    t_columns = test_columns[counter]
    operations = test_operations[counter]

    start_prompt = genegare_table_schema(counter)

    generated_agg_part = f" Generated table should aggregate data by: {', '.join(agg_columns)}"
    if time_aggregation[counter] != "NONE":
        a = ". Please note that from " + agg_columns[-1] + " should be extracted " + time_aggregation[counter] + " and new field should be of type DATETIME"
        generated_agg_part += a
    operations_prompt = ", ".join(f"{op} of {col}" for op, col in zip(operations, t_columns))
    output = ". The output should be the defined fact table schema in SQL (CREATE TABLE)."
    if len(agg_columns) > 1:
        output + f" Ensure that the primary key is a composite key ({', '.join(agg_columns)})"
    return start_prompt + generated_agg_part + generated_column_part + operations_prompt + output + output_instructions

def genegare_table_schema(counter):
    cur_dict = test_schemas[counter]
    formatted_str = ""
    edited = []
    for key in cur_dict.keys():
        string = cur_dict[key]
        values = string.split(",")
        edited = [s.strip("\n").strip() for s in values]
        formatted_str += key + ": " + ",".join(f"{value}" for value in edited)
    return start_prompt_tempalate.format(schemas = formatted_str)




def init_table_fields():
    for sample in test_schemas:
        cur_dict = sample[0]
        all_keys = []
        all_columns = []
        for key, value in cur_dict.items():
            print(key)
            cur_columns = []
            for column in value.split(","):
                column = column.strip("\n").strip()
                if not column.lower().startswith("foreign"):
                    cur_columns.append(column.split()[0])
            all_keys.append(key)
            all_columns.append(cur_columns)
        a = dict(zip(all_keys, all_columns))
        test_columns.append(a)


def get_agg_columns_from_selected_tables_test(counter):
    return test_agg_columns[counter]



def get_columns_from_selected_tables_test(counter):
    return test_columns[counter]



In [6]:
generate_test_prompt(0)

"You are an SQL expert. Your task is to create the table schema for so-called fact table that follows star-schema design. You are provided with the schemas of available tables: battles: battle_id INT PRIMARY KEY,player_id INT NOT NULL,tank_id INT NOT NULL,damage_dealt INT NOT NULL,damage_blocked INT NOT NULL,damage_assisted INT NOT NULL,battle_played DATE NOT NULL,FOREIGN KEY (player_id) REFERENCES players(player_id),FOREIGN KEY (tank_id) REFERENCES tanks(tank_id)players: player_id int primary key,username varchar(30),battles int,winrate decimal(3,2)tanks: tank_id int primary_key,name varchar(100),tier int Generated table should aggregate data by: user_id, tank_id. Include fields that compute SUM of damage_dealt, SUM of damage_blocked, SUM of damage_assisted. The output should be the defined fact table schema in SQL (CREATE TABLE).Output should contain only the sql query to create a table. Don't repeat the question, and don't provide an explanation"

In [7]:
import getpass
import os


if not os.environ.get("OPENAI_API_KEY"):
  os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter API key for OpenAI: ")

from langchain.chat_models import init_chat_model

llm = init_chat_model("gpt-4o-mini", model_provider="openai")

In [None]:
# testing loop
response = llm.invoke(generate_test_prompt(0))



content='```sql\nCREATE TABLE battle_facts (\n    user_id INT NOT NULL,\n    tank_id INT NOT NULL,\n    total_damage_dealt INT NOT NULL,\n    total_damage_blocked INT NOT NULL,\n    total_damage_assisted INT NOT NULL,\n    PRIMARY KEY (user_id, tank_id),\n    FOREIGN KEY (user_id) REFERENCES players(player_id),\n    FOREIGN KEY (tank_id) REFERENCES tanks(tank_id)\n);\n```' additional_kwargs={'refusal': None} response_metadata={'token_usage': {'completion_tokens': 88, 'prompt_tokens': 227, 'total_tokens': 315, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-mini-2024-07-18', 'system_fingerprint': 'fp_86d0290411', 'finish_reason': 'stop', 'logprobs': None} id='run-091a1bf1-67e1-472d-8cee-fd7c8a36c151-0' usage_metadata={'input_tokens': 227, 'output_tokens': 88, 'total_tokens': 315, 'input_token_details': {'audio

In [None]:
def process_testing():
    count_correct = 0
    count_wrong = 0
    for i in range(0, len(test_schemas)):
        response = llm.invoke(generate_test_prompt(i))
        print("Got the nexr result from the model")
        print(response.content)
        print("----------- Expected result -----------")
        print(test_outcomes[i])
        print("----------- Desired operations ----------------")
        print(test_agg_columns[i])
        print(test_columns[i])
        print(test_operations[i])
        print()
        user_input = input("Is it correct")
        if user_input.lower() == "yes":
            count_correct += 1
        else:
            count_wrong += 1
    total = count_wrong + count_correct
    return (count_correct / total, count_wrong / total)



In [12]:
(correct, wrong) = process_testing()

APIConnectionError: Connection error.

In [41]:
print(response.content)

```sql
CREATE TABLE battle_facts (
    user_id INT NOT NULL,
    tank_id INT NOT NULL,
    total_damage_dealt INT NOT NULL,
    total_damage_blocked INT NOT NULL,
    total_damage_assisted INT NOT NULL,
    PRIMARY KEY (user_id, tank_id),
    FOREIGN KEY (user_id) REFERENCES players(player_id),
    FOREIGN KEY (tank_id) REFERENCES tanks(tank_id)
);
```


In [None]:
# prompt user to answer whether he wants to generate a star-schema
def ask_star_schema(state):
    user_choice = input("Do you want to generate a star schema? (yes/no) ")
    if user_choice.lower() == "yes":
        state.current_step = "generate_star_schema"
    else:
        state.current_step = "fetch_tables"
    return state


def show_table_names(state):
    tables_list = get_tables_list_test(0)
    print("Available tables: ", tables_list)
    selected_tables = input("Select tables (comma-separated): ").split(",")
    fields = get_fields(selected_tables)

    