# TP4 — Phase 4 : Filtrage DAS (Divergence-Aware Sampling)

> **⚠️ Ce notebook nécessite un GPU** — À exécuter sur Google Colab (T4) ou Kaggle (2×T4)

**Objectif** : Filtrer les exemples générés en Phase 3 selon leur valeur pédagogique.  
**Principe** : On conserve les exemples où le Teacher est confiant mais le Student hésite (*Teacher Sentences*).

| Type | Condition | Action |
|------|-----------|--------|
| Teacher Sentence | P_teacher >> P_student | **GARDER** |
| Shared Sentence  | P_teacher ≈ P_student  | garder |
| Student Sentence | P_student > P_teacher  | **REJETER** |

## 1. Installation

In [None]:
!pip install -q transformers accelerate bitsandbytes openai matplotlib
print("Installation terminée.")

## 2. Imports

In [None]:
import json
import numpy as np
import torch
import matplotlib.pyplot as plt
from pathlib import Path
from openai import OpenAI
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

print(f"PyTorch : {torch.__version__}")
print(f"GPU disponible : {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU : {torch.cuda.get_device_name(0)}")

## 3. Upload des données (depuis Phase 3)

Uploadez les fichiers `stage1_raw.json` et `stage2_raw.json` générés localement.

In [None]:
from google.colab import files
from pathlib import Path

# ── Option A : Upload depuis votre machine ──────────────────────────────
# Cliquez sur "Choisir des fichiers" et sélectionnez :
#   stage1_raw.json   ET   stage2_raw.json
uploaded = files.upload()

# Vérification : les fichiers atterrissent dans /content/
DATA_DIR   = Path("/content")
OUTPUT_DIR = Path("/content/data_filtered")
OUTPUT_DIR.mkdir(exist_ok=True)

for fname in ["stage1_raw.json", "stage2_raw.json"]:
    fpath = DATA_DIR / fname
    if fpath.exists():
        import json
        with open(fpath, encoding="utf-8") as f:
            n = len(json.load(f))
        print(f"  ✓ {fname} — {n} exemples")
    else:
        print(f"  ✗ {fname} — MANQUANT, re-uploadez ce fichier")

# ── Option B : Depuis Google Drive ──────────────────────────────────────
# Décommenter les lignes suivantes si vous préférez Drive :
# from google.colab import drive
# drive.mount('/content/drive')
# DATA_DIR = Path('/content/drive/MyDrive/tp4/data')

print(f"\nSortie filtrée : {OUTPUT_DIR}")

## 4. Classe DASPipelineQwen

In [None]:
import time

class DASPipelineQwen:
    def __init__(self, openai_api_key, student_model_id="unsloth/Qwen3-4B-Instruct-2507-unsloth-bnb-4bit"):
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16,
        )
        self.student_model_id = student_model_id
        print(f"Chargement du modèle étudiant : {self.student_model_id}...")

        self.tokenizer = AutoTokenizer.from_pretrained(self.student_model_id, trust_remote_code=True)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.padding_side = "left"   # left-padding pour batch causal LM

        self.model = AutoModelForCausalLM.from_pretrained(
            self.student_model_id,
            quantization_config=bnb_config,
            device_map="auto",
            trust_remote_code=True,
            torch_dtype=torch.float16,
        )
        self.model.eval()
        print("Modèle étudiant chargé.")

        self.client = OpenAI(
            api_key=openai_api_key,
            base_url="https://api.infomaniak.com/2/ai/48/openai/v1",
        )
        self.teacher_model_name = "openai/gpt-oss-120b"

    # ----------------------------------------------------------------
    # Forward pass en BATCH  (le cœur de l'optimisation)
    # ----------------------------------------------------------------
    def get_student_logprobs_batch(
        self,
        prompts: list,
        responses: list,
        max_length: int = 768,
    ) -> list:
        """
        Calcule les logprobs student pour N exemples en un seul forward pass.
        ~N× plus rapide que N appels individuels.
        max_length : longueur max en tokens (tronque les séquences longues).
        """
        full_texts, prompt_lengths = [], []

        for prompt, response in zip(prompts, responses):
            # Texte complet prompt + réponse
            full = self.tokenizer.apply_chat_template(
                [{"role": "user", "content": prompt}, {"role": "assistant", "content": response}],
                tokenize=False,
            )
            # Longueur du prompt seul (pour masquer les labels)
            prompt_only = self.tokenizer.apply_chat_template(
                [{"role": "user", "content": prompt}],
                tokenize=False, add_generation_prompt=True,
            )
            n_prompt = self.tokenizer(
                prompt_only, return_tensors="pt", add_special_tokens=False
            ).input_ids.shape[1]

            full_texts.append(full)
            prompt_lengths.append(n_prompt)

        # Tokenisation groupée avec padding + troncature
        enc = self.tokenizer(
            full_texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length,
        )
        input_ids      = enc.input_ids.to(self.model.device)
        attention_mask = enc.attention_mask.to(self.model.device)

        # Labels : masquer le prompt et le padding
        labels = input_ids.clone()
        for i, p_len in enumerate(prompt_lengths):
            labels[i, :p_len] = -100          # masque prompt
        labels[attention_mask == 0] = -100    # masque padding

        with torch.no_grad():
            outputs      = self.model(input_ids, attention_mask=attention_mask)
            shift_logits = outputs.logits[..., :-1, :].contiguous().float()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct     = torch.nn.CrossEntropyLoss(reduction="none", ignore_index=-100)
            token_losses = loss_fct(shift_logits.transpose(1, 2), shift_labels)

        results = []
        for i in range(len(prompts)):
            valid_mask    = shift_labels[i] != -100
            valid_logprobs = (-token_losses[i][valid_mask]).cpu().numpy()
            mean_lp = float(np.exp(np.mean(valid_logprobs))) if len(valid_logprobs) > 0 else 0.0
            results.append({"mean_logprob": mean_lp, "num_tokens": len(valid_logprobs)})

        return results

    # ----------------------------------------------------------------
    # Décision DAS
    # ----------------------------------------------------------------
    def decide_keep_prompt(self, teacher_mean_prob, student_mean_prob, threshold=0.1):
        divergence = teacher_mean_prob - student_mean_prob
        if divergence > threshold:
            label, keep = "TEACHER_SENTENCE", True
        elif divergence < -threshold:
            label, keep = "STUDENT_SENTENCE", False
        else:
            label, keep = "SHARED_SENTENCE", True
        return {"keep": keep, "label": label, "divergence": divergence,
                "teacher_prob": teacher_mean_prob, "student_prob": student_mean_prob}

    # ----------------------------------------------------------------
    # Filtrage du dataset — version batched
    # ----------------------------------------------------------------
    def filter_dataset(self, input_path, output_dir, stage, threshold=0.1,
                       batch_size=8, max_length=768):
        """
        batch_size : exemples traités en parallèle (GPU).
                     Augmenter si pas d'OOM (8→16→32).
        max_length : tronque les séquences longues pour accélérer.
        """
        output_dir = Path(output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)

        with open(input_path, "r", encoding="utf-8") as f:
            examples = json.load(f)

        n_total = len(examples)
        print(f"\n{'='*60}")
        print(f"FILTRAGE DAS — Stage {stage} | {n_total} ex | batch={batch_size} | max_len={max_length}")
        print(f"{'='*60}")

        filtered_raw, llamafactory_data, all_scores = [], [], []
        t_start = time.time()

        # Traitement par batches
        for b_start in range(0, n_total, batch_size):
            batch = examples[b_start : b_start + batch_size]

            prompts   = [ex["instruction"] for ex in batch]
            responses = [ex["response"]    for ex in batch]

            # Logprobs teacher (déjà dans le fichier Phase 3)
            teacher_mean_probs = []
            for ex in batch:
                lps = [lp["logprob"] for lp in ex.get("logprobs", [])]
                teacher_mean_probs.append(float(np.exp(np.mean(lps))) if lps else 0.0)

            try:
                # Un seul forward pass pour tout le batch
                student_results = self.get_student_logprobs_batch(prompts, responses, max_length)
            except Exception as e:
                print(f"\n  ERREUR batch [{b_start}:{b_start+len(batch)}] : {e}")
                continue

            for j, (ex, t_prob, s_res) in enumerate(zip(batch, teacher_mean_probs, student_results)):
                decision = self.decide_keep_prompt(t_prob, s_res["mean_logprob"], threshold)
                all_scores.append(decision["divergence"])

                idx    = b_start + j + 1
                artist = ex.get("artist_name", "Unknown")[:25]
                status = "KEEP" if decision["keep"] else "SKIP"
                print(f"[{idx:>4}/{n_total}] {artist:<25} {status} | div={decision['divergence']:+.4f} | {decision['label']}")

                if decision["keep"]:
                    filtered_raw.append({**ex, **{f"das_{k}": v for k, v in decision.items()}})
                    llamafactory_data.append({
                        "conversations": [
                            {"from": "human", "value": ex["instruction"]},
                            {"from": "gpt",   "value": ex["response"]},
                        ]
                    })

            # ETA après chaque batch
            elapsed  = time.time() - t_start
            done     = b_start + len(batch)
            per_ex   = elapsed / done
            remaining = (n_total - done) * per_ex
            print(f"  ↳ batch terminé — {elapsed:.0f}s écoulées — ETA {remaining:.0f}s ({remaining/60:.1f} min)")

        # Sauvegarde
        raw_out  = output_dir / f"stage{stage}_filtered_raw.json"
        lmf_out  = output_dir / f"stage{stage}_filtered_llamafactory.json"
        plot_out = output_dir / f"stage{stage}_das_scores.png"

        with open(raw_out,  "w", encoding="utf-8") as f: json.dump(filtered_raw,      f, ensure_ascii=False, indent=2)
        with open(lmf_out,  "w", encoding="utf-8") as f: json.dump(llamafactory_data, f, ensure_ascii=False, indent=2)

        if all_scores:
            fig, ax = plt.subplots(figsize=(9, 5))
            ax.hist(all_scores, bins=30, color="steelblue", edgecolor="white", alpha=0.85)
            ax.axvline(x= threshold, color="green",  linestyle="--", linewidth=1.5, label=f"Seuil KEEP (+{threshold})")
            ax.axvline(x=-threshold, color="red",    linestyle="--", linewidth=1.5, label=f"Seuil REJECT (-{threshold})")
            ax.axvline(x=0,          color="orange", linestyle=":",  linewidth=1.2, label="Divergence = 0")
            ax.set_xlabel("Divergence (P_teacher − P_student)", fontsize=12)
            ax.set_ylabel("Nombre d'exemples", fontsize=12)
            ax.set_title(f"Distribution DAS — Stage {stage}", fontsize=14)
            ax.legend(); fig.tight_layout(); fig.savefig(plot_out, dpi=150); plt.show()

        n_kept = len(filtered_raw)
        total_time = time.time() - t_start
        print(f"\n── Résultats Stage {stage} ──────────────────────────────────")
        print(f"  Conservés      : {n_kept}/{n_total} ({100*n_kept/n_total:.1f}%)")
        print(f"  Rejetés        : {n_total - n_kept}/{n_total}")
        print(f"  Div. moyenne   : {float(np.mean(all_scores)):+.4f}")
        print(f"  Temps total    : {total_time:.0f}s ({total_time/60:.1f} min)")
        print(f"  Temps/exemple  : {total_time/n_total:.1f}s")
        return filtered_raw

print("Classe DASPipelineQwen (batched) définie.")

## 5. Chargement du pipeline (Student + Teacher)

In [None]:
API_KEY    = "nKuJabWS1epvq3x-m8by6NOU4xP4_znNL9OhmgXBPz9OeWOHlyGJIENnG8oXLT-4oOXNmESqExEMZv6o"
STUDENT_ID = "unsloth/Qwen3-4B-Instruct-2507-unsloth-bnb-4bit"

pipeline = DASPipelineQwen(openai_api_key=API_KEY, student_model_id=STUDENT_ID)

## 6. Filtrage Stage 1 (τ = 0.3)

In [None]:
stage1_filtered = pipeline.filter_dataset(
    input_path = DATA_DIR / "stage1_raw.json",
    output_dir = OUTPUT_DIR,
    stage      = 1,
    threshold  = 0.1,
)

## 7. Filtrage Stage 2 (τ = 0.9)

In [None]:
stage2_filtered = pipeline.filter_dataset(
    input_path = DATA_DIR / "stage2_raw.json",
    output_dir = OUTPUT_DIR,
    stage      = 2,
    threshold  = 0.1,
)

## 8. Récapitulatif & Téléchargement des fichiers

In [None]:
print("=" * 50)
print("RÉCAPITULATIF PHASE 4 — DAS")
print("=" * 50)
print(f"Stage 1 conservés : {len(stage1_filtered)}")
print(f"Stage 2 conservés : {len(stage2_filtered)}")
print(f"Total pour entraînement : {len(stage1_filtered) + len(stage2_filtered)}")
print()
print("Fichiers prêts pour LLaMA-Factory :")
for f in sorted(OUTPUT_DIR.iterdir()):
    print(f"  {f.name}")

In [None]:
# Télécharger les fichiers filtrés sur votre machine
from google.colab import files
for f in OUTPUT_DIR.glob("*.json"):
    files.download(str(f))
for f in OUTPUT_DIR.glob("*.png"):
    files.download(str(f))