In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import json
import random
import time
from collections import Counter

from tqdm.notebook import tqdm
import google.generativeai as genai

In [None]:
# ----------------------------
# USER CONFIGURATION
# ----------------------------
TRAIN_JSON    = "/content/drive/MyDrive/Gemini_Captions/kg_train.json"
OUTPUT_JSON   = "/content/drive/MyDrive/Gemini_Captions/kg_train_augmented.json"
API_KEY       = "GEMINI_API_KEY"
MODEL_NAME    = "gemini-2.0-flash"
BATCH_SIZE    = 100
SLEEP_SECONDS = 1.5

In [None]:
# ----------------------------
# LOAD EXISTING DATA
# ----------------------------
with open(TRAIN_JSON, "r") as f:
    train_data = json.load(f)

counts = Counter(item["hazard_rating"].lower() for item in train_data)

print("📊 Current label counts:")
for label in ["low", "medium", "high"]:
    print(f"- {label}: {counts[label]}")

📊 Current label counts:
- low: 409
- medium: 8729
- high: 1086


In [None]:
# ----------------------------
# USER INPUT TARGETS
# ----------------------------
print("\n🎯 Enter target max per label:")
target_counts = {}
for label in ["low", "medium", "high"]:
    user_input = input(f"Target total for '{label}' (current {counts[label]}): ").strip()
    target_counts[label] = int(user_input)


🎯 Enter target max per label:
Target total for 'low' (current 409): 8800
Target total for 'medium' (current 8729): 8800
Target total for 'high' (current 1086): 8800


In [None]:
# ----------------------------
# HAND-CRAFTED EXAMPLES
# ----------------------------
handcrafted = [
    ("Straight residential street, daytime, good visibility, zebra crossing ahead, pedestrian near on left.", "low"),
    ("Narrow road, nighttime, poor visibility, pedestrian directly ahead, stopped vehicle on right.", "high"),
]

# ----------------------------
# INIT GEMINI
# ----------------------------
genai.configure(api_key=API_KEY)
model = genai.GenerativeModel(MODEL_NAME)

max_idx = max(int(item["filename"].replace("file", "").replace(".jpg", "")) for item in train_data) + 1
all_new_entries = []


In [None]:
# ----------------------------
# MAIN GENERATION LOOP
# ----------------------------
for label in ["low", "medium", "high"]:
    current = counts[label]
    target = target_counts[label]
    to_generate = max(0, target - current)

    if to_generate <= 0:
        print(f"\n✅ '{label}' already has {current}, skipping (target was {target}).")
        continue

    # SAMPLE EXAMPLES
    examples = [item for item in train_data if item["hazard_rating"].lower() == label]
    examples = random.sample(examples, min(20, len(examples)))

    print(f"\n🧠 Using {len(examples)} real '{label}' examples as context...")

    # BUILD PROMPT
    prompt_lines = [
        f"You are an autonomous driving assistant. Your task is to generate exactly **{to_generate}** new one‑sentence driving‑scene captions, all labeled **{label}**, strictly following these guidelines:\n",
        "1. **Description** must include: road type, lighting/time, visibility, vehicles, obstacles, and hazards.\n"
        "2. **Hazard Rating** should be one of: low, medium, high — based on visibility, proximity, road condition, and obstacle.\n"
        "3. **Output format** (no extra text!):\n"
        "   Description: <caption sentence>\n"
        "   Hazard Rating: <low/medium/high>\n",
        "4. **No** lists, numbers, or extra explanation — just repeat that two-line block.\n",
        "\n---\n",
        "**Hand‑crafted examples:**\n"
    ]
    for desc, hr in handcrafted:
        prompt_lines.append(f"Description: {desc}\nHazard Rating: {hr}\n")

    prompt_lines.append("\n---\n")
    prompt_lines.append(f"**{len(examples)} real '{label}' examples from your dataset:**\n")
    for ex in examples:
        prompt_lines.append(f"Description: {ex['description']}\nHazard Rating: {ex['hazard_rating'].lower()}\n")

    prompt_lines.append("\n---\n")
    prompt_lines.append(f"Now generate exactly **{to_generate}** new examples labeled **{label}**:\n")
    MASTER_PROMPT = "\n".join(prompt_lines)

    # GENERATE IN BATCHES
    for start in tqdm(range(0, to_generate, BATCH_SIZE), desc=f"Generating '{label}'"):
        n = min(BATCH_SIZE, to_generate - start)
        prompt = MASTER_PROMPT.replace(f"{to_generate}", str(n))  # safe replace

        try:
            resp = model.generate_content(prompt)
            lines = [l.strip() for l in resp.text.splitlines() if l.strip()]
        except Exception as e:
            print(f"❌ Error generating: {e}")
            continue

        entry = {}
        for line in lines:
            if line.lower().startswith("description:"):
                entry["description"] = line.split(":", 1)[1].strip().rstrip(".") + "."
            elif line.lower().startswith("hazard rating:"):
                entry["hazard_rating"] = line.split(":", 1)[1].strip().lower()
                entry["filename"] = f"file{max_idx}.jpg"
                all_new_entries.append(entry)
                max_idx += 1
                entry = {}

        time.sleep(SLEEP_SECONDS)



🧠 Using 20 real 'low' examples as context...


Generating 'low':   0%|          | 0/84 [00:00<?, ?it/s]


🧠 Using 20 real 'medium' examples as context...


Generating 'medium':   0%|          | 0/1 [00:00<?, ?it/s]


🧠 Using 20 real 'high' examples as context...


Generating 'high':   0%|          | 0/78 [00:00<?, ?it/s]

ERROR:tornado.access:500 POST /v1beta/models/gemini-2.0-flash:generateContent?%24alt=json%3Benum-encoding%3Dint (127.0.0.1) 1265.53ms


❌ Error generating: 500 POST https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent?%24alt=json%3Benum-encoding%3Dint: TypeError: Failed to fetch


In [None]:
# ----------------------------
# SAVE FINAL DATASET
# ----------------------------
final_data = train_data + all_new_entries
with open(OUTPUT_JSON, "w") as f:
    json.dump(final_data, f, indent=4)

print(f"\n✅ Added {len(all_new_entries)} new entries.")
print(f"📂 Saved augmented dataset to: {OUTPUT_JSON}")

# ----------------------------
# FINAL COUNT CHECK
# ----------------------------
final_counts = Counter(item["hazard_rating"].lower() for item in final_data)
print("\n📊 Final label counts after augmentation:")
for label in ["low", "medium", "high"]:
    print(f"- {label}: {final_counts[label]} (target was {target_counts[label]})")


✅ Added 16664 new entries.
📂 Saved augmented dataset to: /content/drive/MyDrive/Gemini_Captions/kg_train_augmented.json

📊 Final label counts after augmentation:
- low: 9118 (target was 8800)
- medium: 8801 (target was 8800)
- high: 8969 (target was 8800)
