# 01_make_dataset_colab

このノートブックは、7クラス（アナグマ・アライグマ・ハクビシン・タヌキ・ネコ・ノウサギ・テン）分類用データセットの作成を目的としています。

実施内容:
1. iNaturalist API から画像を収集
2. 手元の動画からフレームを抽出（任意）
3. `metadata/sources.csv` を作成
4. `train/val/test` のCSVを作成


## 使いやすい公開データソース

- iNaturalist API: https://api.inaturalist.org/v1/observations
- iNaturalist API docs: https://www.inaturalist.org/api
- GBIF API docs: https://techdocs.gbif.org/en/openapi/images
- LILA BC datasets: https://lila.science/datasets

注意:
- 画像ライセンスは必ず確認してください
- 学習データに使う前にクラス誤りを目視で除去してください


In [None]:
!pip -q install requests pandas tqdm scikit-learn opencv-python-headless

In [None]:
from pathlib import Path
import os
import re
import sys
import time
from typing import Dict, List

import pandas as pd
import requests
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm

IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    from google.colab import drive
    drive.mount("/content/drive")
    PROJECT_ROOT = Path("/content/drive/MyDrive/code260212")
else:
    PROJECT_ROOT = Path.cwd()

RAW_DIR = PROJECT_ROOT / "data" / "raw"
META_DIR = PROJECT_ROOT / "metadata"
for p in [RAW_DIR, META_DIR]:
    p.mkdir(parents=True, exist_ok=True)

CLASS_TAXA: Dict[str, str] = {
    "アナグマ": "Meles anakuma",
    "アライグマ": "Procyon lotor",
    "ハクビシン": "Paguma larvata",
    "タヌキ": "Nyctereutes viverrinus",
    "ネコ": "Felis catus",
    "ノウサギ": "Lepus brachyurus",
    "テン": "Martes melampus",
}

print(f"PROJECT_ROOT: {PROJECT_ROOT}")
print(f"RAW_DIR: {RAW_DIR}")
print(f"classes: {list(CLASS_TAXA.keys())}")

In [None]:
INAT_ENDPOINT = "https://api.inaturalist.org/v1/observations"
PHOTO_LICENSE = "cc0,cc-by,cc-by-sa,cc-by-nc"

def _safe_name(text: str) -> str:
    return re.sub(r"[^0-9A-Za-z._-]+", "_", str(text))

def _inat_large_url(photo: dict) -> str:
    url = photo.get("url", "")
    if "/square." in url:
        return url.replace("/square.", "/large.")
    return url.replace("square", "large")

def _download_image(url: str, out_path: Path, timeout: int = 30) -> bool:
    try:
        r = requests.get(url, timeout=timeout)
        r.raise_for_status()
        out_path.parent.mkdir(parents=True, exist_ok=True)
        with open(out_path, "wb") as f:
            f.write(r.content)
        return True
    except Exception:
        return False

def collect_inat_images(
    class_name: str,
    taxon_name: str,
    max_images: int = 300,
    max_pages: int = 15,
    per_page: int = 200,
    place_id: str = None,
) -> List[dict]:
    rows: List[dict] = []
    out_dir = RAW_DIR / class_name
    out_dir.mkdir(parents=True, exist_ok=True)
    session = requests.Session()
    downloaded = 0

    for page in range(1, max_pages + 1):
        params = {
            "taxon_name": taxon_name,
            "quality_grade": "research",
            "photos": "true",
            "photo_license": PHOTO_LICENSE,
            "per_page": min(per_page, 200),
            "page": page,
        }
        if place_id:
            params["place_id"] = place_id

        resp = session.get(INAT_ENDPOINT, params=params, timeout=30)
        resp.raise_for_status()
        results = resp.json().get("results", [])
        if not results:
            break

        for obs in results:
            obs_id = obs.get("id")
            obs_url = obs.get("uri")
            observed_on = obs.get("observed_on")
            for photo in obs.get("photos", []):
                if downloaded >= max_images:
                    break
                photo_id = photo.get("id")
                license_code = photo.get("license_code")
                image_url = _inat_large_url(photo)
                file_name = _safe_name(f"inat_{obs_id}_{photo_id}.jpg")
                local_path = out_dir / file_name
                ok = local_path.exists() or _download_image(image_url, local_path)
                if not ok:
                    continue
                downloaded += 1
                rows.append({
                    "class_name": class_name,
                    "taxon_name": taxon_name,
                    "source_dataset": "iNaturalist",
                    "observation_id": obs_id,
                    "photo_id": photo_id,
                    "observed_on": observed_on,
                    "observation_url": obs_url,
                    "source_url": image_url,
                    "license_code": license_code,
                    "file_path": str(local_path.relative_to(PROJECT_ROOT)),
                })
            if downloaded >= max_images:
                break

        if downloaded >= max_images:
            break

        time.sleep(0.3)

    return rows

In [None]:
MAX_IMAGES_PER_CLASS = 300
MAX_PAGES = 15
PLACE_ID = None  # 日本限定にしたい場合は iNat の place_id を指定

all_rows = []
for class_name, taxon_name in tqdm(CLASS_TAXA.items(), desc="collect classes"):
    rows = collect_inat_images(
        class_name=class_name,
        taxon_name=taxon_name,
        max_images=MAX_IMAGES_PER_CLASS,
        max_pages=MAX_PAGES,
        place_id=PLACE_ID,
    )
    all_rows.extend(rows)
    print(f"{class_name}: {len(rows)} images")

inat_df = pd.DataFrame(all_rows)
inat_df.head()

In [None]:
import cv2

def extract_frames_from_video(video_path: Path, out_dir: Path, every_n_frames: int = 10, max_frames: int = 200):
    out_dir.mkdir(parents=True, exist_ok=True)
    cap = cv2.VideoCapture(str(video_path))
    if not cap.isOpened():
        print(f"failed to open: {video_path}")
        return 0

    frame_idx = 0
    saved = 0
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        if frame_idx % every_n_frames == 0:
            out_name = f"{video_path.stem}_f{frame_idx:06d}.jpg"
            out_path = out_dir / out_name
            cv2.imwrite(str(out_path), frame)
            saved += 1
            if saved >= max_frames:
                break
        frame_idx += 1

    cap.release()
    return saved

# 使い方例（必要なら有効化）:
# saved = extract_frames_from_video(
#     video_path=PROJECT_ROOT / "data" / "videos" / "sample.mp4",
#     out_dir=RAW_DIR / "タヌキ",
#     every_n_frames=8,
#     max_frames=300,
# )
# print("saved:", saved)

In [None]:
sources_csv = META_DIR / "sources.csv"

if sources_csv.exists() and sources_csv.stat().st_size > 0:
    old_df = pd.read_csv(sources_csv)
else:
    old_df = pd.DataFrame()

all_df = pd.concat([old_df, inat_df], ignore_index=True, sort=False)
if not all_df.empty:
    all_df = all_df.drop_duplicates(subset=["file_path"], keep="last")
all_df.to_csv(sources_csv, index=False)
print(f"saved: {sources_csv}")
print(all_df["class_name"].value_counts(dropna=False))

def make_split(df: pd.DataFrame, test_size: float = 0.15, val_size: float = 0.15, seed: int = 42):
    x = df.copy().dropna(subset=["class_name", "file_path"]).reset_index(drop=True)
    if x.empty:
        return pd.DataFrame(), pd.DataFrame(), pd.DataFrame()

    try:
        train_val, test = train_test_split(
            x, test_size=test_size, random_state=seed, stratify=x["class_name"]
        )
        val_ratio_in_trainval = val_size / (1.0 - test_size)
        train, val = train_test_split(
            train_val,
            test_size=val_ratio_in_trainval,
            random_state=seed,
            stratify=train_val["class_name"],
        )
    except ValueError:
        # クラス画像が少なすぎる場合はランダム分割にフォールバック
        x = x.sample(frac=1.0, random_state=seed).reset_index(drop=True)
        n = len(x)
        n_test = int(n * test_size)
        n_val = int(n * val_size)
        test = x.iloc[:n_test]
        val = x.iloc[n_test:n_test + n_val]
        train = x.iloc[n_test + n_val:]

    return train, val, test

train_df, val_df, test_df = make_split(all_df)
train_df.to_csv(META_DIR / "train.csv", index=False)
val_df.to_csv(META_DIR / "val.csv", index=False)
test_df.to_csv(META_DIR / "test.csv", index=False)

print("train/val/test:", len(train_df), len(val_df), len(test_df))
print("saved:", META_DIR / "train.csv")
print("saved:", META_DIR / "val.csv")
print("saved:", META_DIR / "test.csv")