In [36]:
import os
import re
import json
import pandas as pd
from datetime import datetime
from google import genai
from google.genai import types


client = genai.Client() 


## Instructions & Set Up

In [None]:
classification_policy = """
Classification Policy:

Restricted:
Data that could cause criminal charges, massive fines, or serious harm if exposed.
Examples include social security number (SSN), complete date of birth (year‑month‑day), driver's license number, passport number, banking information, passwords, health information shared for insurance, payroll.

Confidential:
Data that if disclosed could cause a high risk. Examples include date of birth (year & month only), budgets, sales information, corporate financial info, strategic plans, third‑party confidential information, security plans, designs, and intellectual property.

Internal‑use only:
Accessible strictly to internal company personnel or employees granted access. Examples: internal memos, business plans.

Public:
Freely accessible to all employees/company personnel. Examples: job descriptions, press releases.
"""

# --- Read CSV or Excel ---
def load_dataset(file_path: str) -> pd.DataFrame:
    if file_path.lower().endswith(('.xlsx', '.xls')):
        return pd.read_excel(file_path)
    return pd.read_csv(file_path)

# --- Classify columns using GenAI ---
def classify_columns(df: pd.DataFrame) -> pd.DataFrame:
    column_names = df.columns.tolist()
    sample_data = df.head(3).to_dict(orient="records")

    contents = [
        types.Content(
            role="user",
            parts=[
                types.Part.from_text(text=f"Classification policy:\n{classification_policy}"),
                types.Part.from_text(text=f"Dataset column names:\n{column_names}"),
                types.Part.from_text(text=f"Sample of first 3 rows:\n{json.dumps(sample_data, indent=2)}"),
                types.Part.from_text(text=(
                    "For each column name provided, assign exactly one classification level from "
                    "Restricted, Confidential, Internal‑use only, Public. "
                    "Provide a concise reason (1‑2 sentences). "
                    "Output strictly as a valid JSON list of objects with fields: "
                    '[{"header":"ColumnName","classification":"Confidential","reason":"Because ..."}]'
                )),
            ]
        )
    ]

    response = client.models.generate_content(
        model="gemini-2.5-flash",
        contents=contents,
        config=types.GenerateContentConfig(
            thinking_config=types.ThinkingConfig(thinking_budget=0),
            temperature=0.0
        )
    )

    raw_text = response.text.strip()


    return raw_text

# --- Parse classification string ---
def parse_classification_string(raw: str, column_names: list[str]) -> pd.DataFrame:
    """
    Given a raw text string (model output), this will extract header/classification/reason
    pairs by regex, fallback to blank if missing.
    """
    # 1. Clean code fences or markdown
    if raw.startswith("```"):
        # remove leading/trailing ``` blocks
        parts = raw.split("```")
        if len(parts) >= 3:
            raw = parts[1].strip()
    raw = raw.strip()
    
    # 2. Define regex to extract objects
    # For example: "header": "ssn", "classification": "Restricted", "reason": "Some reason."
    pattern = re.compile(
        r'"header"\s*:\s*"(?P<header>[^"]+)"\s*,\s*"classification"\s*:\s*"(?P<classification>[^"]+)"\s*,\s*"reason"\s*:\s*"(?P<reason>[^"]+)"',
        flags=re.IGNORECASE
    )
    
    matches = pattern.finditer(raw)
    items = []
    for m in matches:
        items.append({
            "header": m.group("header"),
            "classification": m.group("classification"),
            "reason": m.group("reason")
        })
    
    # 3. If regex found nothing, fallback: build rows for each column name blank classification
    if not items:
        items = [{"header": h, "classification": "", "reason": ""} for h in column_names]
    else:
        # Ensure all original column names are covered
        found_headers = {it["header"] for it in items}
        missing = set(column_names) - found_headers
        for mh in missing:
            items.append({"header": mh, "classification": "", "reason": ""})
    
    # 4. Build DataFrame and reorder by original column_names
    df = pd.DataFrame(items)
    df["sort_index"] = df["header"].apply(lambda h: column_names.index(h) if h in column_names else -1)
    df = df.sort_values("sort_index").drop(columns=["sort_index"]).reset_index(drop=True)
    
    return df


## Application

In [None]:
data = pd.read_csv('.../synthetic_user_records.csv')
classifications = classify_columns(data)
classifications = parse_classification_string(classifications, data.columns.tolist())


In [42]:
classifications

Unnamed: 0,header,classification,reason
0,ssn,Restricted,Social security numbers are explicitly listed ...
1,passport_number,Restricted,Passport numbers are explicitly listed as Rest...
2,dob_full,Restricted,Complete date of birth (year-month-day) is exp...
3,dob_year_month,Confidential,Date of birth (year & month only) is explicitl...
4,monthly_budget,Confidential,Budgets are explicitly listed as Confidential ...
5,internal_memo,Internal‑use only,Internal memos are explicitly listed as Intern...
6,project_plan,Confidential,Strategic plans and designs are listed as Conf...
7,job_title,Public,Job descriptions are explicitly listed as Publ...
8,press_release,Public,Press releases are explicitly listed as Public...


## Save Results to Excel

In [43]:
# --- Save results to Excel ---
def summary_to_excel(df):
    os.makedirs("./c_tool_output", exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d‑%H%M%S")
    output_path = os.path.join("./c_tool_output", f"classification_summary_{timestamp}.xlsx")
    df.to_excel(output_path, index=False)

    df_results = pd.DataFrame(df)
    summary = (
        df_results["classification"]
        .value_counts()
        .rename_axis("classification")
        .reset_index(name="count")
    )

    with pd.ExcelWriter(output_path, engine="openpyxl") as writer:
        df_results.to_excel(writer, sheet_name="Classifications", index=False)
        summary.to_excel(writer, sheet_name="Summary", index=False)

    return output_path

In [44]:
output = summary_to_excel(classifications)
output

'./c_tool_output/classification_summary_20251103‑150720.xlsx'