## STAR Algorithm Adaptation

In [None]:
import os
import gc
import torch
import torchaudio
import numpy as np
import random
import copy
from torch import nn
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence
from jiwer import wer
from whisper.normalizers import EnglishTextNormalizer
import pandas as pd
import librosa
import logging
from transformers import (
    WhisperForConditionalGeneration,
    WhisperProcessor,
    WhisperTokenizer,
    WhisperFeatureExtractor
)


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NORMALIZER = EnglishTextNormalizer()

In [None]:
# ================= CONFIGURATION =================
TRAIN_AUDIO = "data/train_data"
TRAIN_METADATA = "data/final_train.csv"

TEST_AUDIO = "data/test_data"
TEST_METADATA = "data/final_test.csv"

OUTPUT_TRAIN = "data/star_process/train_all.pt"
OUTPUT_FULL = "data/star_process/train_final_ready.pt"

accents = ["India and South Asia (India, Pakistan, Sri Lanka)",
           "Australian English",
           "Southern African (South Africa, Zimbabwe, Namibia)",
           "Singaporean English"]

TEST_IN = "data/test/test_in.pt"
TEST_AU = "data/test/test_au.pt"
TEST_SA = "data/test/test_sa.pt"
TEST_SG = "data/test/test_sg.pt"
TEST_PATH= "data/test"
COL_FILENAME = "path"
COL_TEXT = "sentence"
# =================================================

## Get torch dataset

In [None]:

def main():
    os.makedirs(os.path.dirname(OUTPUT_TRAIN), exist_ok=True)

    if not os.path.exists(TRAIN_METADATA):
        print(f"Error: Could not find CSV at {TRAIN_METADATA}")
        return
    if not os.path.exists(TEST_METADATA):
        print(f"Error: Could not find CSV at {TEST_METADATA}")
        return

    df = pd.read_csv(TRAIN_METADATA)
    df_test = pd.read_csv(TEST_METADATA)
    print(f"Found {len(df)} rows in train CSV.")
    print(f"Found {len(df_test)} rows in test CSV.")

    dataset_list = []
    test_list = []

    print(f"Processing audio from: {TRAIN_AUDIO}...")
    for idx, row in tqdm(df.iterrows(), total=len(df)):
        filename = str(row[COL_FILENAME])
        text = row[COL_TEXT]

        file_path = os.path.join(TRAIN_AUDIO, filename)

        # Handle missing .mp3 extension if CSV doesn't have it
        if not os.path.exists(file_path):
            if os.path.exists(file_path + ".mp3"):
                file_path += ".mp3"
            else:
                continue

        try:
            audio_array, _ = librosa.load(file_path, sr=16000)
            entry = {
                "audio": {
                    "array": audio_array,
                    "sampling_rate": 16000
                },
                "text": text,
                "file_path": file_path
            }
            dataset_list.append(entry)

        except Exception as e:
            print(f"Failed to load {filename}: {e}")

    # process each test accent separately
    for accent, output_file in zip(accents, [TEST_IN, TEST_AU, TEST_SA, TEST_SG]):
        accent_test_list = []
        print(f"Processing test audio for accent: {accent}...")
        for idx, row in tqdm(df_test[df_test['accents'] == accent].iterrows(), total=len(df_test[df_test['accents'] == accent])):
            filename = str(row[COL_FILENAME])
            text = row[COL_TEXT]

            file_path = os.path.join(TEST_AUDIO, filename)

            # Handle missing .mp3 extension if CSV doesn't have it
            if not os.path.exists(file_path):
                if os.path.exists(file_path + ".mp3"):
                    file_path += ".mp3"
                else:
                    continue

            try:
                audio_array, _ = librosa.load(file_path, sr=16000)
                entry = {
                    "audio": {
                        "array": audio_array,
                        "sampling_rate": 16000
                    },
                    "text": text,
                    "file_path": file_path
                }
                accent_test_list.append(entry)

            except Exception as e:
                print(f"Failed to load {filename}: {e}")

        # save accent-specific test set
        print(f"Saving test dataset for accent {accent} to {output_file}...")
        torch.save(accent_test_list, output_file)
        test_list.extend(accent_test_list)
    print(f"Total test samples processed: {len(test_list)}")

if __name__ == "__main__":
    main()

In [None]:
# preview train_all.pt
train_all = torch.load("data/star_ready/train_all.pt", weights_only=False)
train_all[0]

{'audio': {'array': array([-1.0339758e-23, -2.6883370e-23, -8.2718061e-24, ...,
         -2.5170762e-04, -2.0353909e-04, -6.1307161e-05],
        shape=(65280,), dtype=float32),
  'sampling_rate': 16000},
 'text': 'She told me she feels hopeful and optimistic.',
 'file_path': 'data/train_data/common_voice_en_134463.mp3'}

## Get Pseudo-Labels and Train Model

**Training Configuration:**
- Optimizer: Adam
- Learning Rate: Initial rate of $1 \times 10^{-5}$
- Duration: 2 epochs
- Batch Strategy: Batch size of 1 with 16 gradient accumulation steps

**Hyperparameters & Filtering:**
- Threshold ($\lambda$): 2
- Temperature ($\tau$): 10
- Utterance-level Filtering ($\alpha$): 20th percentile (chosen for consistent effectiveness across datasets)

In [None]:
# ================= CONFIGURATION =================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME = "openai/whisper-tiny"
DATA_PATH = "data/"
STAR_PATH = "data/star_process"
logging.getLogger("transformers").setLevel(logging.ERROR)

# ================= CONFIGURATION =================
INPUT_PT_FILE = "data/train_all.pt"
OUTPUT_DIR = "data/star_process"
MODEL_NAME = "openai/whisper-tiny"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# STAR Hyperparameters
THRESHOLD = 2.0
TAU = 10
TOP_PERCENT = 0.8
NUM_CHUNKS = 4
# =================================================

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def calculate_star_weights(probs, weights, threshold=THRESHOLD, tau=TAU):
    """Combines Confidence (probs) and Attentive Score (weights) [cite: 186-193]."""
    final_scores = []
    for ci, ai in zip(probs, weights):
        ci = max(ci, 1e-6)
        ai = max(ai, 1e-6)
        c_over_a = (ci**2) / ai
        a_over_c = (ai**2) / ci

        conflict = (sigmoid((c_over_a - threshold) * tau) + sigmoid((a_over_c - threshold) * tau)) * ai
        no_conflict = (sigmoid((threshold - c_over_a) * tau) * sigmoid((threshold - a_over_c) * tau)) * \
                      ai * np.exp((ci - ai) / tau)
        final_scores.append(conflict + no_conflict)
    return final_scores

def prepare_star_dataset_chunked(model, input_file, processor, feature_extractor, tokenizer):
    model.eval()

    # Force configuration on the model object
    # This ensures the generation loop cannot ignore these flags
    model.config.output_attentions = True
    model.config.output_scores = True
    model.config.return_dict_in_generate = True

    # Also update the generation_config if it exists
    if model.generation_config:
        model.generation_config.output_attentions = True
        model.generation_config.output_scores = True
        model.generation_config.return_dict_in_generate = True

    # Get Prompt Info
    forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")

    print(f"Loading raw data from {input_file}...")
    raw_data_list = torch.load(input_file, weights_only=False)

    total_size = len(raw_data_list)
    print(f"Loaded {total_size} samples. Starting Fixed Processing...")

    chunk_size = (total_size + NUM_CHUNKS - 1) // NUM_CHUNKS
    temp_files = []

    # --- CHUNK LOOP ---
    for chunk_idx in range(NUM_CHUNKS):
        start = chunk_idx * chunk_size
        end = min((chunk_idx + 1) * chunk_size, total_size)
        if start >= total_size: break

        print(f"\nProcessing Chunk {chunk_idx + 1}/{NUM_CHUNKS} (Samples {start} to {end})...")
        current_batch = raw_data_list[start:end]
        chunk_data = []

        # Identify params for perturbation once
        params_to_noise = [p for p in model.parameters() if p.requires_grad and p.dtype == torch.float32]

        for i, item in enumerate(tqdm(current_batch)):
            try:
                # --- 1. Audio Loading ---
                if isinstance(item['audio'], dict):
                    audio_array = item['audio']['array']
                    sr = item['audio'].get('sampling_rate', 16000)
                else:
                    audio_array = item['audio']
                    sr = 16000

                if isinstance(audio_array, np.ndarray):
                    audio_tensor = torch.from_numpy(audio_array).float()
                else:
                    audio_tensor = torch.tensor(audio_array).float()

                if sr != 16000:
                    audio_tensor = audio_tensor.cpu()
                    resampler = torchaudio.transforms.Resample(sr, 16000)
                    audio_tensor = resampler(audio_tensor.unsqueeze(0)).squeeze(0)

                audio_np = audio_tensor.cpu().numpy()
                inputs = feature_extractor(audio_np, sampling_rate=16000, return_tensors="pt")
                mel = inputs.input_features.to(DEVICE)

                # --- FIX: Create Attention Mask ---
                # This prevents the "attention mask not set" warning
                # Since we have batch size 1, the mask is just all ones
                attention_mask = torch.ones(mel.shape[0], mel.shape[2], dtype=torch.long).to(DEVICE)

                # --- 2. FAST GENERATION ---
                with torch.no_grad():
                    outputs = model.generate(
                        mel,
                        # We pass mask explicitly
                        # attention_mask=attention_mask, # Optional for Whisper, usually ignored for Encoder-Decoder but good practice
                        max_new_tokens=128,
                        # Arguments are now also backed by model.config
                        output_attentions=True,
                        output_scores=True,
                        return_dict_in_generate=True,
                        forced_decoder_ids=forced_decoder_ids
                    )

                # --- 3. Extract Confidence ---
                if outputs.scores is None:
                    # If this hits, the config fix didn't work, skip sample safely
                    print(f"WARNING: outputs.scores is None, skipping sample {i}.")
                    continue

                scores_stack = torch.stack(outputs.scores, dim=0).squeeze(1)
                probs_tensor = torch.softmax(scores_stack / 1.0, dim=-1)
                top_probs, _ = probs_tensor.max(dim=-1)
                probs = top_probs.tolist()

                # --- 4. Reconstruct Attention Matrix ---
                if outputs.decoder_attentions is None:
                    print(f"WARNING: outputs.decoder_attentions is None, skipping sample {i}.")
                    continue

                generated_len = len(probs)

                # Infer prompt length from first step
                first_step_attn = outputs.decoder_attentions[0][-1]
                total_len_start = first_step_attn.shape[-1]
                n_prompt_toks = total_len_start - 1

                full_matrix_len = n_prompt_toks + generated_len
                attn_matrix = torch.zeros((generated_len, full_matrix_len))

                for t in range(generated_len):
                    step_attn = outputs.decoder_attentions[t][-1]
                    # Average heads
                    step_attn_avg = step_attn[0, :, 0, :].mean(dim=0).cpu()
                    cur_len = step_attn_avg.shape[0]
                    attn_matrix[t, :cur_len] = step_attn_avg

                # --- 5. Calculate STAR Weights ---
                weights = []
                for k in range(generated_len):
                    global_idx = n_prompt_toks + k
                    #print(global_idx >= full_matrix_len)
                    if global_idx >= full_matrix_len: break

                    row_slice = attn_matrix[k, n_prompt_toks:]
                    row_sum = row_slice.sum().item()

                    col_slice = attn_matrix[:, global_idx]
                    col_sum = col_slice.sum().item()

                    diag = attn_matrix[k, global_idx].item()

                    w = row_sum + col_sum - diag
                    weights.append(float(w))

                # Normalize & Combine
                if len(probs) > 0 and len(weights) > 0:
                    min_l = min(len(probs), len(weights))
                    probs = probs[:min_l]
                    weights = weights[:min_l]

                    p_mean = np.mean(probs) if np.mean(probs) > 0 else 1e-6
                    w_mean = np.mean(weights) if np.mean(weights) > 0 else 1e-6

                    norm_probs = [p / p_mean for p in probs]
                    norm_weights = [w / w_mean for w in weights]

                    star_weights = calculate_star_weights(norm_probs, norm_weights)

                    pred_ids = outputs.sequences[0]
                    pseudo_text = tokenizer.decode(pred_ids, skip_special_tokens=True)

                    # --- 6. Uncertainty (Fast) ---
                    avg_wer = 0
                    generated_texts = []

                    for _ in range(3):
                        noise_cache = []
                        for p in params_to_noise:
                            noise = torch.randn_like(p) * p.std() * 0.1
                            p.data.add_(noise)
                            noise_cache.append(noise)

                        with torch.no_grad():
                            g_out = model.generate(mel, max_new_tokens=80)

                            if isinstance(g_out, torch.Tensor):
                                g_seq = g_out[0]
                            else: # If ModelOutput dict
                                g_seq = g_out.sequences[0]

                            g_text = tokenizer.decode(g_seq, skip_special_tokens=True)
                            generated_texts.append(g_text)
                            try: avg_wer += wer(pseudo_text, g_text)
                            except: avg_wer += 1.0

                        for p, n in zip(params_to_noise, noise_cache):
                            p.data.sub_(n)

                    avg_wer = avg_wer / 3
                    diversity = len(set(generated_texts))
                    filter_metric = avg_wer * diversity

                    sample_data = {
                        'mel': mel.squeeze(0).cpu(),
                        'labels': pred_ids.cpu(),
                        'star_weights': torch.tensor(star_weights).cpu(),
                        'filter_metric': filter_metric,
                        'text': pseudo_text
                    }
                    chunk_data.append(sample_data)

            except Exception as e:
                print(f"ERROR: {e}")
                continue

        # Save Chunk
        if len(chunk_data) > 0:
            temp_file = os.path.join(OUTPUT_DIR, f"chunk_{chunk_idx}.pt")
            print(temp_file)
            torch.save(chunk_data, temp_file)
            temp_files.append(temp_file)

        del chunk_data
        gc.collect()

    # --- MERGE & FILTER ---
    print("\nMerging and Filtering...")
    full_data = []
    for f in temp_files:
        if os.path.exists(f):
            full_data.extend(torch.load(f, weights_only=False))
            #os.remove(f)

    if len(full_data) == 0:
        print("CRITICAL: Final dataset is empty.")
        return

    full_data.sort(key=lambda x: x['filter_metric'])
    cutoff = int(len(full_data) * TOP_PERCENT)
    final_data = full_data[:cutoff]

    save_path = os.path.join(OUTPUT_DIR, "train_final_ready.pt")
    torch.save(final_data, save_path)
    print(f"Success! Saved {len(final_data)} samples to {save_path}")


def train_with_star(model, train_dataset, device, num_epochs=10, batch_size=1, accum_steps=16):
    """
    STAR Informed Fine-Tuning with Loss Re-weighting.
    """

    # The original paper states that full fine-tuning performs better than freezing encoder
    # therefore we'll be using full fine-tuning
    # for param in model.model.encoder.parameters():
    #    param.requires_grad = False

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
    loss_fn = nn.CrossEntropyLoss(ignore_index=-100, reduction='none')

    print(f"Phase 2: Training on {len(train_dataset)} samples")
    print(f"Config: BS={batch_size}, Accum={accum_steps}, Epochs={num_epochs}")

    # We explicitly define this to match the STAR loss splitting logic
    n_prompt_toks = 4

    for epoch in range(num_epochs):
        model.train()
        random.shuffle(train_dataset)
        epoch_loss = 0
        step_count = 0

        optimizer.zero_grad()

        for i in tqdm(range(0, len(train_dataset), batch_size)):
            batch = train_dataset[i:i + batch_size]

            # --- 1. Prepare Inputs ---
            input_features = torch.cat([item['mel'].unsqueeze(0) for item in batch]).to(device)

            # Labels
            padded_labels = pad_sequence([item['labels'] for item in batch],
                                         batch_first=True, padding_value=-100).to(device)

            # STAR Weights
            star_weights = [item['star_weights'].to(device) for item in batch]


            # --- 2. Manual Shift for Causal Loss ---
            # We slice manually to ensure alignment with our weights
            decoder_input_ids = padded_labels[:, :-1]
            targets = padded_labels[:, 1:]


            # --- 3. Forward Pass ---
            with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
                outputs = model(
                    input_features=input_features,
                    decoder_input_ids=decoder_input_ids
                )
                logits = outputs.logits

                # Calculate Raw Loss (Unreduced)
                # Shape: (Batch, Seq_Len)
                loss_tensor = loss_fn(logits.permute(0, 2, 1), targets)

                # --- 4. Apply STAR Weights (Informed Fine-Tuning) ---
                # loss = (sum(prompt_loss) + sum(generated_loss * weights)) / total_len

                final_loss = 0.0
                valid_samples = 0

                for b in range(len(batch)):
                    sample_loss = loss_tensor[b]
                    weights = star_weights[b]

                    # Normalize weights
                    if weights.mean() > 0:
                        weights = weights / weights.mean()

                    # Calculate indices
                    # The prompt loss is the first (n_prompt_toks - 1) tokens
                    prompt_loss_sum = torch.sum(sample_loss[:n_prompt_toks-1])

                    # The generated loss is the rest
                    gen_loss = sample_loss[n_prompt_toks-1:]

                    # Truncate to matching length (handle padding/EOS mismatches safely)
                    valid_len = min(len(gen_loss), len(weights))

                    if valid_len > 0:
                        # Apply Reweighting: Loss * STAR_Score
                        weighted_gen_loss = gen_loss[:valid_len] * weights[:valid_len]

                        # Combine: Prompt Loss (Unweighted) + Generated Loss (Weighted)
                        total_sample_loss = prompt_loss_sum + torch.sum(weighted_gen_loss)

                        # Normalize by total active tokens
                        normalization_factor = (n_prompt_toks - 1) + valid_len
                        final_loss += total_sample_loss / normalization_factor
                        valid_samples += 1


            # --- 5. Backward & Step ---
            if valid_samples > 0:
                # Average over batch
                loss = final_loss / valid_samples

                # Scale for accumulation
                loss = loss / accum_steps
                loss.backward()

                epoch_loss += loss.item() * accum_steps

            step_count += 1

            if step_count % accum_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                optimizer.zero_grad()

        print(f"Epoch {epoch+1}/{num_epochs} - Avg Loss: {epoch_loss / len(train_dataset):.4f}")

    return model


def calculate_wer(reference, hypothesis):
    return wer(reference, hypothesis)

def evaluate(model, dataset, processor, feature_extractor, device):
    model.eval()
    preds, targets = [], []
    print(" Phase 3: Evaluating ")

    for item in tqdm(dataset):
        # --- NEW: Process Raw Audio on the fly ---
        audio_array = item['audio']['array']
        sr = item['audio']['sampling_rate']

        inputs = feature_extractor(audio_array, sampling_rate=sr, return_tensors="pt")
        mel = inputs.input_features.to(device)
        # -----------------------------------------

        with torch.no_grad():
            gen_ids = model.generate(mel, max_new_tokens=225)
            pred_text = processor.batch_decode(gen_ids, skip_special_tokens=True)[0]

            preds.append(pred_text.lower())
            targets.append(item['text'].lower())

    return calculate_wer(targets, preds)


In [None]:
# ================= MAIN EXECUTION =================

# 1. Setup Model
processor = WhisperProcessor.from_pretrained(MODEL_NAME)
tokenizer = WhisperTokenizer.from_pretrained(MODEL_NAME)
feature_extractor = WhisperFeatureExtractor.from_pretrained(MODEL_NAME)
model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME, attn_implementation="eager").to(DEVICE)

# 2. Load Training Data
train_file = os.path(INPUT_PT_FILE)

# 3. Phase 1: Prepare STAR Data (Pseudo-labeling)
star_data_file = os.path.join(STAR_PATH, "train_final_ready.pt")
if os.path.exists(star_data_file):
    print(f"Loading STAR training data from {star_data_file}...")
    star_data = torch.load(star_data_file, weights_only=False)
else:
    if os.path.exists(train_file):
        print(f"Loading data from {train_file}...")
        star_data = prepare_star_dataset_chunked(model,INPUT_PT_FILE, processor, feature_extractor, tokenizer)

# Clear cache to free VRAM
torch.cuda.empty_cache()
gc.collect()

# 4. Phase 2: Train
trained_model = train_with_star(model, star_data, DEVICE)

# Save Model
torch.save(trained_model.state_dict(), "star_model_final.pt")
print("Model saved.")

# 5. Phase 3: Evaluate on Accents
# We look for all 'test_*.pt' files in the folder
test_files = [f for f in os.listdir(TEST_PATH) if f.startswith("test_") and f.endswith(".pt")]

print(f"\nFound {len(test_files)} test sets.")

for t_file in test_files:
    test_path = os.path.join(TEST_PATH, t_file)
    test_dataset = torch.load(test_path, weights_only=False)

    accent_name = t_file.replace("test_", "").replace(".pt", "")
    print(f"Evaluating {accent_name}...")

    score = evaluate(trained_model, test_dataset, processor, feature_extractor, DEVICE)
    print(f"WER for {accent_name}: {score * 100:.2f}%")


Loading STAR training data from data/star_process/train_final_ready.pt...
Phase 2: Training on 6400 samples
Config: BS=1, Accum=16, Epochs=10


100%|██████████| 6400/6400 [05:38<00:00, 18.89it/s]


Epoch 1/10 - Avg Loss: 0.4467


100%|██████████| 6400/6400 [05:38<00:00, 18.92it/s]


Epoch 2/10 - Avg Loss: 0.3012


100%|██████████| 6400/6400 [05:35<00:00, 19.07it/s]


Epoch 3/10 - Avg Loss: 0.1839


100%|██████████| 6400/6400 [05:38<00:00, 18.92it/s]


Epoch 4/10 - Avg Loss: 0.1183


100%|██████████| 6400/6400 [05:39<00:00, 18.87it/s]


Epoch 5/10 - Avg Loss: 0.0840


100%|██████████| 6400/6400 [05:37<00:00, 18.96it/s]


Epoch 6/10 - Avg Loss: 0.0628


100%|██████████| 6400/6400 [05:40<00:00, 18.82it/s]


Epoch 7/10 - Avg Loss: 0.0500


100%|██████████| 6400/6400 [05:43<00:00, 18.63it/s]


Epoch 8/10 - Avg Loss: 0.0414


100%|██████████| 6400/6400 [05:54<00:00, 18.08it/s]


Epoch 9/10 - Avg Loss: 0.0364


100%|██████████| 6400/6400 [05:59<00:00, 17.81it/s]


Epoch 10/10 - Avg Loss: 0.0328
Model saved.

Found 4 test sets.
Evaluating in...
 Phase 3: Evaluating 


100%|██████████| 100/100 [00:12<00:00,  7.93it/s]


WER for in: 42.29%
Evaluating au...
 Phase 3: Evaluating 


100%|██████████| 100/100 [00:11<00:00,  8.43it/s]


WER for au: 29.72%
Evaluating sa...
 Phase 3: Evaluating 


100%|██████████| 100/100 [00:13<00:00,  7.60it/s]


WER for sa: 45.31%
Evaluating sg...
 Phase 3: Evaluating 


100%|██████████| 100/100 [00:12<00:00,  8.11it/s]

WER for sg: 46.30%



