## Distillation Model

In [3]:
import os
import numpy as np
from tensorflow.keras import layers, models, Input
from tensorflow.keras.models import load_model
from tensorflow.keras.optimizers import Adam
from sklearn.preprocessing import LabelEncoder
import tensorflow as tf
from tqdm import tqdm

# === Parameter ===
FEATURE_DIR = "data/en/features_clean"
LABEL_DIR = "data/en/labels_clean"
segment_len = 100
input_dim = 39
BATCH_SIZE = 64
EPOCHS = 20
TEMPERATURE = 5.0
ALPHA = 0.5

# === Teacher Modell laden ===
teacher_model = load_model("teacher_phoneme_model.h5")
teacher_model.trainable = False

# === Feature-Normalisierung laden ===
mean = np.load("feature_mean.npy")
std = np.load("feature_std.npy")

# === Label-Encoder vorbereiten ===
all_labels = []
for file in os.listdir(LABEL_DIR):
    with open(os.path.join(LABEL_DIR, file)) as f:
        all_labels.extend([line.strip() for line in f])
distinct_phonemes = sorted(set(all_labels))
label_encoder = LabelEncoder()
label_encoder.fit(distinct_phonemes)

# === Sequenzaufteilung ===
def split_sequence(X, y, segment_len):
    segments_X, segments_y = [], []
    max_start = (len(X) // segment_len) * segment_len
    for start in range(0, max_start, segment_len):
        segments_X.append(X[start:start + segment_len])
        segments_y.append(y[start:start + segment_len])
    return segments_X, segments_y

# === Daten laden ===
X_all, y_all = [], []
for file in tqdm(os.listdir(FEATURE_DIR), desc="Lade Daten"):
    if not file.endswith(".npy"):
        continue
    base = os.path.splitext(file)[0]
    features = np.load(os.path.join(FEATURE_DIR, file))
    features = (features - mean) / std
    with open(os.path.join(LABEL_DIR, f"{base}.txt")) as f:
        labels = [line.strip() for line in f]
    if len(features) != len(labels):
        continue
    labels_encoded = label_encoder.transform(labels)
    X_seg, y_seg = split_sequence(features, labels_encoded, segment_len)
    X_all.extend(X_seg)
    y_all.extend(y_seg)

X_all = np.array(X_all)
y_all = np.array(y_all)
print(f"✅ Student-Datensatz: {X_all.shape[0]} Sequenzen á {segment_len} Frames")

# === Teacher-Logits vorberechnen (Distillation Targets) ===
TEACHER_LOGITS_PATH = "teacher_logits.npy"
if not os.path.exists(TEACHER_LOGITS_PATH):
    teacher_logits = []
    for x in tqdm(X_all, desc="Berechne Teacher-Logits"):
        pred = teacher_model.predict(np.expand_dims(x, axis=0), verbose=0)
        teacher_logits.append(pred.squeeze() / TEMPERATURE)
    teacher_logits = np.array(teacher_logits)
    np.save(TEACHER_LOGITS_PATH, teacher_logits)
else:
    teacher_logits = np.load(TEACHER_LOGITS_PATH)

# === tf.data.Dataset ===
dataset = tf.data.Dataset.from_tensor_slices((X_all, y_all, teacher_logits))
dataset = dataset.shuffle(2048).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

# === Kompaktes Student-Modell ===
inputs = Input(shape=(segment_len, input_dim))
x = layers.Conv1D(64, 3, padding='same', activation='relu')(inputs)
x = layers.Bidirectional(layers.GRU(64, return_sequences=True))(x)
outputs = layers.TimeDistributed(layers.Dense(len(distinct_phonemes), activation='softmax'))(x)
student_model = models.Model(inputs, outputs)

# === Distillation Loss ===
def distillation_loss(y_true, y_pred, teacher_soft):
    y_true = tf.one_hot(tf.cast(y_true, tf.int32), depth=len(distinct_phonemes))
    y_pred_soft = tf.nn.softmax(y_pred / TEMPERATURE)
    loss_true = tf.keras.losses.categorical_crossentropy(y_true, y_pred)
    loss_soft = tf.keras.losses.KLD(teacher_soft, y_pred_soft)
    return ALPHA * loss_soft + (1 - ALPHA) * loss_true

# === Training Setup ===
optimizer = Adam(1e-3)
train_acc = tf.keras.metrics.SparseCategoricalAccuracy()

@tf.function
def train_step(x, y, t_soft):
    with tf.GradientTape() as tape:
        pred = student_model(x, training=True)
        loss = tf.reduce_mean(distillation_loss(y, pred, t_soft))
    grads = tape.gradient(loss, student_model.trainable_variables)
    optimizer.apply_gradients(zip(grads, student_model.trainable_variables))
    train_acc.update_state(y, pred)
    return loss

# === Training Loop ===
for epoch in range(EPOCHS):
    print(f"\n🧪 Epoch {epoch + 1}/{EPOCHS}")
    pbar = tqdm(dataset, desc="Training", unit="batch")
    for batch_x, batch_y, batch_teacher_soft in pbar:
        loss = train_step(batch_x, batch_y, batch_teacher_soft)
        pbar.set_postfix({"loss": loss.numpy(), "acc": train_acc.result().numpy()})
    print(f"🔍 Epoch Accuracy: {train_acc.result().numpy():.4f}")
    train_acc.reset_state()

# === Speichern ===
student_model.save("student_model_distilled.h5")
print("✅ Student-Modell gespeichert als student_model_distilled.h5")


Lade Daten: 100%|██████████████████████████████████████████████████████████████| 26482/26482 [00:17<00:00, 1473.30it/s]


✅ Student-Datensatz: 126219 Sequenzen á 100 Frames


Berechne Teacher-Logits: 100%|███████████████████████████████████████████████| 126219/126219 [1:31:30<00:00, 22.99it/s]



🧪 Epoch 1/20


Training: 100%|███████████████████████████████████████████| 1973/1973 [01:34<00:00, 20.91batch/s, loss=0.88, acc=0.604]


🔍 Epoch Accuracy: 0.6042

🧪 Epoch 2/20


Training: 100%|██████████████████████████████████████████| 1973/1973 [01:32<00:00, 21.25batch/s, loss=0.908, acc=0.656]


🔍 Epoch Accuracy: 0.6557

🧪 Epoch 3/20


Training: 100%|███████████████████████████████████████████| 1973/1973 [01:33<00:00, 21.18batch/s, loss=0.707, acc=0.67]


🔍 Epoch Accuracy: 0.6696

🧪 Epoch 4/20


Training: 100%|██████████████████████████████████████████| 1973/1973 [01:33<00:00, 21.18batch/s, loss=0.633, acc=0.678]


🔍 Epoch Accuracy: 0.6778

🧪 Epoch 5/20


Training: 100%|██████████████████████████████████████████| 1973/1973 [01:33<00:00, 21.16batch/s, loss=0.706, acc=0.684]


🔍 Epoch Accuracy: 0.6836

🧪 Epoch 6/20


Training: 100%|██████████████████████████████████████████| 1973/1973 [01:33<00:00, 21.06batch/s, loss=0.601, acc=0.688]


🔍 Epoch Accuracy: 0.6877

🧪 Epoch 7/20


Training: 100%|██████████████████████████████████████████| 1973/1973 [01:33<00:00, 21.16batch/s, loss=0.626, acc=0.691]


🔍 Epoch Accuracy: 0.6908

🧪 Epoch 8/20


Training: 100%|██████████████████████████████████████████| 1973/1973 [01:33<00:00, 21.15batch/s, loss=0.581, acc=0.694]


🔍 Epoch Accuracy: 0.6938

🧪 Epoch 9/20


Training: 100%|██████████████████████████████████████████| 1973/1973 [01:32<00:00, 21.23batch/s, loss=0.735, acc=0.696]


🔍 Epoch Accuracy: 0.6958

🧪 Epoch 10/20


Training: 100%|███████████████████████████████████████████| 1973/1973 [01:33<00:00, 21.17batch/s, loss=0.77, acc=0.698]


🔍 Epoch Accuracy: 0.6978

🧪 Epoch 11/20


Training: 100%|██████████████████████████████████████████| 1973/1973 [01:33<00:00, 21.13batch/s, loss=0.488, acc=0.699]


🔍 Epoch Accuracy: 0.6994

🧪 Epoch 12/20


Training: 100%|██████████████████████████████████████████| 1973/1973 [01:33<00:00, 21.10batch/s, loss=0.502, acc=0.701]


🔍 Epoch Accuracy: 0.7010

🧪 Epoch 13/20


Training: 100%|██████████████████████████████████████████| 1973/1973 [01:33<00:00, 21.16batch/s, loss=0.576, acc=0.702]


🔍 Epoch Accuracy: 0.7023

🧪 Epoch 14/20


Training: 100%|██████████████████████████████████████████| 1973/1973 [01:33<00:00, 21.13batch/s, loss=0.454, acc=0.703]


🔍 Epoch Accuracy: 0.7035

🧪 Epoch 15/20


Training: 100%|██████████████████████████████████████████| 1973/1973 [01:33<00:00, 21.13batch/s, loss=0.507, acc=0.705]


🔍 Epoch Accuracy: 0.7048

🧪 Epoch 16/20


Training: 100%|██████████████████████████████████████████| 1973/1973 [01:33<00:00, 21.09batch/s, loss=0.547, acc=0.706]


🔍 Epoch Accuracy: 0.7057

🧪 Epoch 17/20


Training: 100%|███████████████████████████████████████████| 1973/1973 [01:33<00:00, 21.18batch/s, loss=0.61, acc=0.706]


🔍 Epoch Accuracy: 0.7063

🧪 Epoch 18/20


Training: 100%|██████████████████████████████████████████| 1973/1973 [01:33<00:00, 21.16batch/s, loss=0.447, acc=0.707]


🔍 Epoch Accuracy: 0.7073

🧪 Epoch 19/20


Training: 100%|██████████████████████████████████████████| 1973/1973 [01:33<00:00, 21.15batch/s, loss=0.556, acc=0.708]


🔍 Epoch Accuracy: 0.7081

🧪 Epoch 20/20


Training: 100%|██████████████████████████████████████████| 1973/1973 [01:33<00:00, 21.17batch/s, loss=0.754, acc=0.709]


🔍 Epoch Accuracy: 0.7089
✅ Student-Modell gespeichert als student_model_distilled.h5
