In [1]:

from pathlib import Path
import torch, random, datasets, numpy as np, os
from transformers import (
    AutoModelForSpeechSeq2Seq,
    AutoTokenizer,
    WhisperConfig,
    WhisperFeatureExtractor,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
)

seed = 42
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
if hasattr(datasets, "set_seed"):      # ↓ старые версии datasets этого не имеют
    datasets.set_seed(seed)

teacher_id = "openai/whisper-large-v3"
work_dir   = Path("distil_largev3_student"); work_dir.mkdir(exist_ok=True)
device     = "cuda" if torch.cuda.is_available() else "cpu"


In [2]:
# %% [code]  (замена старой ячейки 2)
from transformers import AutoModelForSpeechSeq2Seq, WhisperConfig
import torch

teacher = AutoModelForSpeechSeq2Seq.from_pretrained(
    teacher_id, torch_dtype=torch.float32
).to(device).eval()

cfg = WhisperConfig.from_pretrained(teacher_id).to_dict()
cfg["decoder_layers"] = 8                       # ← 8 вместо 2
student_cfg = WhisperConfig(**cfg)
student = AutoModelForSpeechSeq2Seq.from_config(student_cfg)

# conv‑фронт
for name in ("conv1", "conv2"):
    getattr(student.model.encoder, name).load_state_dict(
        getattr(teacher.model.encoder, name).state_dict()
    )

# копируем 4 слоя декодера (0,1,6,7)
for tgt, src in zip([0,1,6,7], [0,1,-2,-1]):
    student.model.decoder.layers[tgt].load_state_dict(
        teacher.model.decoder.layers[src].state_dict()
    )

student = student.float(); teacher = teacher.float()

# замораживаем энкодер — разморозим позже колбэком
for p in student.model.encoder.parameters():
    p.requires_grad = False

student.to(device)
print("Student 8‑layer decoder — params:", sum(p.numel() for p in student.parameters())/1e6,"M")


Student 8‑layer decoder — params: 913.82272 M


In [3]:
# %% [code] — Ячейка 3 (целиком)

import os, random, pathlib
from datasets import Dataset, Audio

def parse_transcripts(txt_path):
    mapping = {}
    with open(txt_path, "r", encoding="utf-8") as f:
        for line in f:
            utt_id, text = line.strip().split(" ", 1)
            mapping[utt_id] = text
    return mapping

def collect_examples(subset_dir, max_items=None):
    subset = pathlib.Path(subset_dir)
    examples = []
    for flac_path in subset.rglob("*.flac"):
        speaker = flac_path.parent.parent.name
        chapter = flac_path.parent.name
        trans_path = flac_path.parent / f"{speaker}-{chapter}.trans.txt"
        if not trans_path.exists():
            continue
        transcripts = parse_transcripts(trans_path)
        utt_id = flac_path.stem
        if utt_id not in transcripts:
            continue
        examples.append({"audio": str(flac_path), "text": transcripts[utt_id]})
        if max_items and len(examples) >= max_items:
            break
    return examples

# ── каталоги с уже скачанными данными ─────────────────────────────
DATA_ROOT = os.path.expanduser("~/librispeech_local")
TRAIN_DIR = os.path.join(DATA_ROOT, "train-clean-100")
VAL_DIR   = os.path.join(DATA_ROOT, "dev-clean")

# ►► ЗДЕСЬ изменение: берём ВСЁ train‑clean‑100 ◄◄
train_examples = collect_examples(TRAIN_DIR, max_items=None)   # ≈ 28 000
val_examples   = collect_examples(VAL_DIR,   max_items=100)

print(f"собрано: train={len(train_examples)},  val={len(val_examples)}")

random.shuffle(train_examples)

train_ds = Dataset.from_list(train_examples).cast_column("audio", Audio(sampling_rate=16000))
val_ds   = Dataset.from_list(val_examples).cast_column("audio",   Audio(sampling_rate=16000))

print("пример:", train_ds[0]["text"][:60], "...")


собрано: train=28539,  val=100
пример: THE EYES OF DUKES OF THE BLOOD ROYAL HAVE BEEN PLUCKED OUT F ...


In [4]:
# %% [code]  (полная замена prepare_batch)
from transformers import AutoTokenizer, WhisperFeatureExtractor
import torch, torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence

tokenizer         = AutoTokenizer.from_pretrained(teacher_id)
feature_extractor = WhisperFeatureExtractor.from_pretrained(teacher_id)
teacher.eval()

def prepare_batch(batch):
    # ----- mel‑фичи -------------------------------------------------
    audio_list = batch["audio"]            # list[dict{array, sampling_rate}]
    feats = [
        feature_extractor(a["array"], sampling_rate=16000,
                          return_tensors="pt").input_features[0]
        for a in audio_list
    ]
    max_T = max(f.shape[1] for f in feats)
    feats_padded = torch.stack(
        [F.pad(f, (0, max_T - f.shape[1])) for f in feats]
    )                                       # (B, 128, T_max)

    # ----- ground‑truth labels -------------------------------------
    text_list = batch["text"]               # list[str]
    labels = tokenizer(
        text_list,
        return_tensors="pt",
        padding=True,
        truncation=True
    ).input_ids                              # (B, L)

    return {"input_features": feats_padded, "labels": labels}


train_ds = train_ds.map(
    prepare_batch, batched=True, batch_size=32,
    remove_columns=train_ds.column_names, num_proc=1,
    desc="prep train"
)
val_ds = val_ds.map(
    prepare_batch, batched=True, batch_size=32,
    remove_columns=val_ds.column_names, num_proc=1,
    desc="prep val"
)
train_ds.set_format(type="torch"); val_ds.set_format(type="torch")


prep train:   0%|          | 0/28539 [00:00<?, ? examples/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


prep val:   0%|          | 0/100 [00:00<?, ? examples/s]

In [7]:
# %% [code]
class WhisperDataCollator:
    def __init__(self, feat_extractor, tokenizer, fp16=True):
        self.fe = feat_extractor
        self.tok = tokenizer
        self.pad = tokenizer.pad_token_id
        self.fp16 = fp16
    def __call__(self, batch):
        feats = [b["input_features"] for b in batch]
        audio = self.fe.pad({"input_features": feats}, padding="longest", return_tensors="pt")["input_features"]
        if self.fp16:
            audio = audio.to(torch.float16)
        labels = self.tok.pad({"input_ids": [b["labels"] for b in batch]},
                              padding="longest", return_tensors="pt")["input_ids"]
        labels[labels == self.pad] = -100
        return {"input_features": audio, "labels": labels}

data_collator = WhisperDataCollator(
    feature_extractor,
    tokenizer,
    fp16=False        # ← не трогаем dtype
)


In [15]:
def shift_tokens_right(labels, bos_id, pad_id):
    """
    labels:  (B, L)  с ‑100 на masked позициях
    returns: (B, L)  decoder_input_ids
    """
    # replace -100 -> pad, затем сдвинуть и вставить BOS
    shifted = labels.clone()
    shifted[shifted == -100] = pad_id
    shifted = torch.cat(
        [torch.full((shifted.size(0), 1),
                    bos_id, dtype=shifted.dtype, device=shifted.device),
         shifted[:, :-1]],
        dim=1,
    )
    return shifted


In [16]:
# %% [code]  (замена DistillSeq2SeqTrainer)
from transformers import Seq2SeqTrainer, TrainerCallback
import torch.nn.functional as F

class DistillSeq2SeqTrainer(Seq2SeqTrainer):
    def __init__(self, *args, teacher_model, alpha_kd=0.3, **kw):
        super().__init__(*args, **kw)
        self.teacher   = teacher_model.eval()
        self.alpha_kd  = alpha_kd

    def compute_loss(
        self, model, inputs, return_outputs=False, num_items_in_batch=None
    ):
        labels = inputs.pop("labels")
        feats  = inputs["input_features"]

        # ---- decoder_input_ids ---------------------------------------
        dec_in = shift_tokens_right(
            labels,
            bos_id = self.tokenizer.bos_token_id,   # 50257
            pad_id = self.tokenizer.pad_token_id    # 50257
        )

        # ---- forward --------------------------------------------------
        st_out = model(input_features=feats, decoder_input_ids=dec_in)

        # ---- KD logits (teacher) -------------------------------------
        with torch.no_grad():
            tch_out = self.teacher(input_features=feats, decoder_input_ids=dec_in)

        kd = F.kl_div(
            F.log_softmax(st_out.logits / 2.0, -1),
            F.softmax(  tch_out.logits / 2.0, -1),
            reduction="batchmean"
        ) * 4.0

        ce = F.cross_entropy(
            st_out.logits.view(-1, st_out.logits.size(-1)),
            labels.view(-1), ignore_index=-100
        )

        loss = (1 - self.alpha_kd)*ce + self.alpha_kd*kd
        return (loss, st_out) if return_outputs else loss


class UnfreezeEncoder(TrainerCallback):
    "Разморозить энкодер после первой эпохи"
    def on_epoch_end(self, args, state, control, **kw):
        if state.epoch == 1:
            for p in kw["model"].model.encoder.parameters():
                p.requires_grad = True
            print("\nEncoder unfrozen → fine‑tune everything")



In [9]:
# %% [code]
from transformers import Seq2SeqTrainingArguments   # можно и TrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir=work_dir,
    per_device_train_batch_size=2,      # VRAM ~18GB
    gradient_accumulation_steps=16,      # effective 32
    num_train_epochs=6,
    learning_rate=1e-4,
    warmup_steps=1000,
    lr_scheduler_type="cosine",
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_steps=200,
    fp16=False,
    gradient_checkpointing=True,
    report_to="none",
)




In [21]:
import torch, gc; gc.collect(); torch.cuda.empty_cache()


In [24]:
from transformers.utils import logging as hf_logging
hf_logging.set_verbosity_error()      # только ошибки


In [25]:
trainer = DistillSeq2SeqTrainer(
    model           = student,
    args            = training_args,
    train_dataset   = train_ds,
    eval_dataset    = val_ds,
    processing_class= tokenizer,      # <-- новый аргумент
    data_collator   = data_collator,
    teacher_model   = teacher,
    callbacks       = [UnfreezeEncoder()],
)
trainer.train()
trainer.save_model(work_dir / "checkpoint-final")


Epoch,Training Loss,Validation Loss
0,43.7924,36.114197
1,37.583,32.073261
2,34.1062,29.661638
3,32.2484,28.316227
4,30.564,28.063118
5,29.802,28.044121




In [26]:
# %% [code]
import torch, time, psutil, os, gc
from datasets import Dataset, Audio

# 30 секунд тестового аудио (берём любые 20 примеров val_ds)
bench_ds = val_ds.select(range(20))
BATCH = feature_extractor.pad(
    {"input_features": [ex["input_features"] for ex in bench_ds]},
    padding="longest", return_tensors="pt"
)["input_features"].to(student.device)

print("bench shape:", BATCH.shape)    # (B, 128, T_max)


bench shape: torch.Size([20, 128, 3000])


In [27]:
# %% [code]
def model_size_mb(model):
    return sum(p.numel()*4 for p in model.parameters()) / 1e6  # fp32=4 байта

print(f"{'Model':35}  Size (MB)")
print("-"*48)
print(f"{'Teacher large‑v3':35}  {model_size_mb(teacher):7.0f}")
print(f"{'6‑layer student':35}  {model_size_mb(student):7.0f}")


Model                                Size (MB)
------------------------------------------------
Teacher large‑v3                        6174
6‑layer student                         3655


In [28]:
# %% [code]
def bench(model, batch, n_rep=5):
    torch.cuda.empty_cache(); gc.collect()
    # прогрев
    with torch.no_grad():
        _ = model.generate(batch, max_new_tokens=128, num_beams=1)
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    with torch.no_grad():
        for _ in range(n_rep):
            _ = model.generate(batch, max_new_tokens=128, num_beams=1)
    torch.cuda.synchronize()
    return (time.perf_counter()-t0)/n_rep*1000   # ms

lat_gpu_teacher  = bench(teacher,  BATCH, n_rep=3)
lat_gpu_student  = bench(student,  BATCH, n_rep=3)

print(f"GPU latency per batch (ms):  teacher {lat_gpu_teacher:6.1f}   student {lat_gpu_student:6.1f}")
print("Current VRAM (MB):", torch.cuda.memory_allocated()/1e6)


GPU latency per batch (ms):  teacher 3894.9   student 1547.0
Current VRAM (MB): 12546.884096


In [29]:
# %% [code]  — restore models, tokenizer, feature_extractor

from pathlib import Path
import torch, os
from transformers import (
    AutoModelForSpeechSeq2Seq,
    AutoTokenizer,
    WhisperFeatureExtractor,
)

device      = "cuda" if torch.cuda.is_available() else "cpu"
teacher_id  = "openai/whisper-large-v3"          # при желании укажите local path

# директория, куда мы ранее сохраняли финальный чек‑пойнт студента
work_dir    = Path("distil_largev3_student")
student_ckpt= work_dir / "checkpoint-final"

# ---------- загрузка ----------
print("Loading teacher...")
teacher  = AutoModelForSpeechSeq2Seq.from_pretrained(
    teacher_id, torch_dtype=torch.float32
).to(device).eval()

print("Loading student...")
student  = AutoModelForSpeechSeq2Seq.from_pretrained(
    student_ckpt, torch_dtype=torch.float32
).to(device).eval()

tokenizer         = AutoTokenizer.from_pretrained(teacher_id)
feature_extractor = WhisperFeatureExtractor.from_pretrained(teacher_id)

print("Models ready  →  teacher layers:",
      teacher.config.decoder_layers, "  student layers:",
      student.config.decoder_layers)


Loading teacher...
Loading student...
Models ready  →  teacher layers: 32   student layers: 8


In [30]:
# %% [code]  — WER по real‑GT (dev‑clean) с корректным generate

from datasets import Dataset, Audio
from jiwer import wer
from tqdm.notebook import tqdm
import torch, os, pathlib

# ---------- 1. собираем 100 примеров dev‑clean локально -------------
DATA_ROOT = os.path.expanduser("~/librispeech_local")
DEV_DIR   = os.path.join(DATA_ROOT, "dev-clean")

def parse_transcripts(p):
    with open(p, "r", encoding="utf-8") as f:
        return dict(line.strip().split(" ", 1) for line in f)

ex = []
for flac in pathlib.Path(DEV_DIR).rglob("*.flac"):
    spk = flac.parent.parent.name
    chap = flac.parent.name
    txt  = parse_transcripts(flac.parent / f"{spk}-{chap}.trans.txt")
    if flac.stem in txt:
        ex.append({"audio": str(flac), "text": txt[flac.stem]})
        if len(ex) == 100:
            break

gt_val = Dataset.from_list(ex).cast_column("audio", Audio(sampling_rate=16000))
print("Loaded GT examples:", len(gt_val))

# ---------- 2. функция WER -----------------------------------------
def wer_on_gt(ds, model):
    refs, hyps = [], []
    for ex in tqdm(ds, total=len(ds), desc=f"{model.config.decoder_layers}-layer"):
        feats = feature_extractor(
            ex["audio"]["array"], sampling_rate=16000, return_tensors="pt"
        ).input_features.to(model.device)

        attn = torch.ones(feats.shape[:-1], dtype=torch.long, device=feats.device)

        with torch.no_grad():
            ids = model.generate(
                feats,
                attention_mask=attn,
                max_new_tokens=128,
                num_beams=1,              # greedy
            )[0]

        hyps.append(tokenizer.decode(ids, skip_special_tokens=True).lower())
        refs.append(ex["text"].lower())
    return wer(refs, hyps)

# ---------- 3. вывод -----------------------------------------------
print(f"WER teacher : {wer_on_gt(gt_val, teacher)*100:.2f}%")
print(f"WER student : {wer_on_gt(gt_val, student)*100:.2f}%")


Loaded GT examples: 100


32-layer:   0%|          | 0/100 [00:00<?, ?it/s]

WER teacher : 8.74%


8-layer:   0%|          | 0/100 [00:00<?, ?it/s]

WER student : 96.57%
