In [None]:
%cd ../BackEnd

In [None]:
import json
import pandas as pd
import sqlalchemy as sql
from sqlalchemy.orm import Session
from database import SessionLocal
from models.tv_show import TVShow, tv_genres
from models.genre import Genre

%cd ../DB

def load_tv_bulk(csv_path: str, batch_size: int = 1000):


    # Local caches to avoid duplicates
    local_genres_cache = {}
    new_genres_for_bulk = []

    # For bridging table tv_genres: we accumulate (tmdb_id, genre_id)
    bridging_rows = []

    chunk_iter = pd.read_csv(
        csv_path,
        chunksize=batch_size,
        sep=",",
        encoding="utf-8",
        quotechar='"'
    )

    db = SessionLocal()
    try:
        for chunk_index, chunk_df in enumerate(chunk_iter, start=1):
            print(f"Processing chunk #{chunk_index} with {len(chunk_df)} rows...")

            # accumulator for bulk insert
            tvshows_to_insert = []

            for _, row in chunk_df.iterrows():
                # Parse or skip the date
                if "first_air_date" in row and pd.notna(row["first_air_date"]):
                    first_air = row["first_air_date"]
                else:
                    first_air = None

                

                tmdb_id = row["id"]

                tvshows_to_insert.append({
                    "tmdb_id":         tmdb_id,
                    "name":            row.get("name"),
                    "original_name":   row.get("original_name"),
                    "overview":        row.get("overview"),
                    "tagline":         row.get("tagline"),
                    "first_air_date":  first_air,
                    "popularity":      row.get("popularity"),
                    "vote_average":    row.get("vote_average"),
                    "vote_count":      row.get("vote_count"),
                    "poster_path":     row.get("poster_path"),
                    "backdrop_path":   row.get("backdrop_path"),
                    "type":            row.get("type"),
                })

                # Parse the genres from JSON in the "genres" column
                if "genres" in row and pd.notna(row["genres"]):
                    try:
                        genre_data = json.loads(row["genres"])
                        for g in genre_data:
                            gid = g["id"]
                            # if not within cache use 
                            if gid not in local_genres_cache:
                                local_genres_cache[gid] = True
                                new_genres_for_bulk.append({
                                    "genre_id": gid,
                                    "name":     g["name"]
                                })

                            # build bridging
                            bridging_rows.append({
                                "tmdb_id":  tmdb_id,
                                "genre_id": gid
                            })
                    except Exception:
                        pass


            # instert chunks of genres
            if new_genres_for_bulk:
                print(f"Inserting {len(new_genres_for_bulk)} new genres discovered...")
                db.bulk_insert_mappings(Genre, new_genres_for_bulk)
                db.commit()
                new_genres_for_bulk.clear()
                print("New genres inserted and committed.")

            # Insert chunk of TV shows
            if tvshows_to_insert:
                print(f"Bulk inserting {len(tvshows_to_insert)} tv shows for chunk #{chunk_index}...")
                db.bulk_insert_mappings(TVShow, tvshows_to_insert)
                db.commit()
                print("TV shows inserted & committed.")

            # Insert bridging rows in tv_genres table, deduplicating
            if bridging_rows:
                print(f"Bulk inserting {len(bridging_rows)} bridging rows for chunk #{chunk_index}...")
                # eliminating duplicates
                unique_pairs = set()
                for br in bridging_rows:
                    unique_pairs.add((br["tmdb_id"], br["genre_id"]))

                deduped_rows = [
                    {"tmdb_id": pair[0], "genre_id": pair[1]}
                    for pair in unique_pairs
                ]

                print(f"Actually inserting {len(deduped_rows)} unique bridging rows after dedup...")

                insert_stmt = tv_genres.insert()
                db.execute(insert_stmt, deduped_rows)
                db.commit()
                bridging_rows.clear()
                print("Bridging rows inserted & committed.")

        print("All chunks processed successfully!")

    except Exception as e:
        db.rollback()
        print(f"Error: {e}")
    finally:
        db.close()

if __name__ == "__main__":
    load_tv_bulk("tv.csv", batch_size=100000)
    print("TV import complete!")
