In [1]:
import os
import math
import json
import random
from dataclasses import dataclass, field
from typing import Optional, Dict, Any, List, Tuple

In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
from datasets import Dataset, Audio
from transformers import (
    Wav2Vec2Model,
    Wav2Vec2Processor,
    Wav2Vec2FeatureExtractor,
    Wav2Vec2CTCTokenizer,
    Trainer,
    TrainingArguments,
)
# Remove load_metric as it's deprecated - use evaluate library instead
from evaluate import load

In [4]:
# Optional imports for augmentation
try:
    import torchaudio
    TORCHAUDIO_AVAILABLE = True
except Exception:
    TORCHAUDIO_AVAILABLE = False

try:
    import librosa
    LIBROSA_AVAILABLE = True
except Exception:
    LIBROSA_AVAILABLE = False

In [5]:
# ---------------------------
# Config (edit to your paths)
# ---------------------------
TSV_PATH = "/home/sahil_duwal/MajorProject/Dataset/ne_np_female/line_index.tsv"
AUDIO_BASE = "/home/sahil_duwal/MajorProject/Dataset/ne_np_female/wavs/"
TOKENIZER_DIR = "./tokenizer_nepali"
PRETRAINED_W2V = "/home/sahil_duwal/MajorProject/Shruti---AVSR-in-Nepali-Language-/nep-2/wav2vec2-nepali-finetuned-v2"
OUTPUT_DIR = "./wav2vec2_custom_head"
SAMPLE_RATE = 16000
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# Flags
FREEZE_W2V = True
USE_ADAPTERS = True
ADAPTER_BOTTLENECK = 128
CUSTOM_TRANSFORMER_LAYERS = 4
CUSTOM_TRANSFORMER_DIM = 768
CUSTOM_TRANSFORMER_HEADS = 8
DROPOUT = 0.1
USE_SPEC_AUG = True
USE_AUDIO_AUG = True

# Training hyperparams
EPOCHS = 20
LR = 1e-4  # Reduced from 3e-4
BATCH_SIZE = 1  # Reduced from 2

In [7]:
# ---------------------------
# Utilities: tokenizer + processor
# ---------------------------
def load_or_create_tokenizer(tsv_path: str, tokenizer_dir: str):
    os.makedirs(tokenizer_dir, exist_ok=True)
    vocab_file = os.path.join(tokenizer_dir, "vocab.json")
    
    if os.path.exists(vocab_file):
        try:
            with open(vocab_file, 'r', encoding='utf-8') as f:
                vocab_dict = json.load(f)
            
            tokenizer = Wav2Vec2CTCTokenizer(
                vocab_file,
                unk_token="<unk>",
                pad_token="<pad>",
                word_delimiter_token="|"
            )
            return tokenizer
        except Exception as e:
            print(f"Loading existing tokenizer failed: {e}")

    # Create character-level tokenizer for CTC
    import pandas as pd
    df = pd.read_csv(tsv_path, sep="\t", low_memory=False)
    
    print(f"Available columns in TSV: {df.columns.tolist()}")
    
    # Find text column
    text_column = None
    possible_text_cols = ['text', 'transcription', 'transcript', 'sentence', 'label']
    for col in possible_text_cols:
        if col in df.columns:
            text_column = col
            break
    
    if text_column is None and len(df.columns) >= 2:
        text_column = df.columns[1]
        print(f"Using second column '{text_column}' as text data")
    
    texts = df[text_column].astype(str).tolist()
    
    # Build character-level vocabulary
    chars = set()
    for text in texts:
        chars.update(text.lower())
    
    # FIXED: Create vocabulary with proper CTC blank token
    # CTC needs: [blank, other_tokens...]
    vocab_list = ['<blank>', '<unk>', '<pad>'] + sorted(list(chars))
    vocab_dict = {char: i for i, char in enumerate(vocab_list)}
    
    # Save vocabulary
    with open(vocab_file, 'w', encoding='utf-8') as f:
        json.dump(vocab_dict, f, ensure_ascii=False, indent=2)
    
    # Create tokenizer with proper blank token
    tokenizer = Wav2Vec2CTCTokenizer(
        vocab_file,
        unk_token="<unk>",
        pad_token="<pad>",
        word_delimiter_token="|"
    )
    
    return tokenizer

In [8]:
# ---------------------------
# Data augmentation helpers
# ---------------------------
def speed_perturb(wave: np.ndarray, sr: int, factors=(0.9, 1.0, 1.1)) -> np.ndarray:
    if not LIBROSA_AVAILABLE:
        return wave
    f = random.choice(factors)
    if f == 1.0:
        return wave
    return librosa.effects.time_stretch(wave, rate=f)

def pitch_shift(wave: np.ndarray, sr: int, n_steps=(-2, 0, 2)) -> np.ndarray:
    if not LIBROSA_AVAILABLE:
        return wave
    step = random.choice(n_steps)
    return librosa.effects.pitch_shift(wave, sr, n_steps=step)

def add_background_noise(wave: np.ndarray, snr_db_min=5, snr_db_max=20) -> np.ndarray:
    rms = np.sqrt(np.mean(wave**2))
    if rms == 0:
        return wave
    snr_db = random.uniform(snr_db_min, snr_db_max)
    snr = 10 ** (snr_db / 20.0)
    noise_rms = rms / snr
    noise = np.random.normal(0, noise_rms, wave.shape)
    return wave + noise

def spec_augment(features: torch.Tensor, time_mask_param=30, freq_mask_param=13, num_time_masks=2, num_freq_masks=2):
    B, T, D = features.shape
    for _ in range(num_time_masks):
        t = random.randint(0, time_mask_param)
        t0 = random.randint(0, max(0, T - t)) if T - t > 0 else 0
        features[:, t0:t0+t, :] = 0
    for _ in range(num_freq_masks):
        f = random.randint(0, freq_mask_param)
        f0 = random.randint(0, max(0, D - f)) if D - f > 0 else 0
        features[:, :, f0:f0+f] = 0
    return features

In [9]:
# ---------------------------
# Adapter module
# ---------------------------
class Adapter(nn.Module):
    def __init__(self, dim, bottleneck=128):
        super().__init__()
        self.down = nn.Linear(dim, bottleneck)
        self.relu = nn.ReLU()
        self.up = nn.Linear(bottleneck, dim)

    def forward(self, x):
        residual = x
        x = self.down(x)
        x = self.relu(x)
        x = self.up(x)
        return x + residual

In [10]:
# ---------------------------
# Custom head: Transformer encoder + linear -> vocab (CTC)
# ---------------------------
class CustomTransformerCTCHead(nn.Module):
    def __init__(self, input_dim: int, model_dim: int, num_layers: int, num_heads: int, dropout: float, vocab_size: int):
        super().__init__()
        self.input_dim = input_dim
        self.model_dim = model_dim
        self.vocab_size = vocab_size
        
        # Project input to model dimension if needed
        if input_dim != model_dim:
            self.input_proj = nn.Linear(input_dim, model_dim)
        else:
            self.input_proj = nn.Identity()
        
        # Transformer encoder layers
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=model_dim,
            nhead=num_heads,
            dim_feedforward=model_dim * 4,
            dropout=dropout,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Output projection to vocabulary
        self.output_proj = nn.Linear(model_dim, vocab_size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, features: torch.Tensor) -> torch.Tensor:
        x = self.input_proj(features)
        x = self.dropout(x)
        x = self.transformer(x)
        logits = self.output_proj(x)
        return logits

In [11]:
# ---------------------------
# Combined Model wrapper
# ---------------------------
class W2V2WithCustomHead(nn.Module):
    def __init__(self, pretrained_name: str, processor: Wav2Vec2Processor, tokenizer: Wav2Vec2CTCTokenizer,
                 freeze_w2v: bool=True, use_adapters: bool=True, adapter_dim: int=128,
                 trans_layers: int=4, trans_dim: int=768, trans_heads: int=8, dropout: float=0.1):
        super().__init__()
        self.processor = processor
        self.tokenizer = tokenizer
        self.vocab_size = len(tokenizer.get_vocab())

        # Base wav2vec model
        self.wav2vec = Wav2Vec2Model.from_pretrained(pretrained_name)
        self.wav2vec_feature_dim = self.wav2vec.config.hidden_size

        if freeze_w2v:
            for p in self.wav2vec.parameters():
                p.requires_grad = False

        self.use_adapters = use_adapters
        if use_adapters:
            self.adapter = Adapter(self.wav2vec_feature_dim, bottleneck=adapter_dim)
        else:
            self.adapter = None

        self.custom_head = CustomTransformerCTCHead(
            input_dim=self.wav2vec_feature_dim, 
            model_dim=trans_dim, 
            num_layers=trans_layers, 
            num_heads=trans_heads, 
            dropout=dropout, 
            vocab_size=self.vocab_size
        )

    def forward(self, input_values: torch.Tensor, attention_mask: Optional[torch.Tensor]=None, 
            labels: Optional[torch.Tensor]=None):
        
        # Debug input shapes
        print(f"Forward pass - input_values shape: {input_values.shape}")
        if labels is not None:
            print(f"Forward pass - labels shape: {labels.shape}")
        
        outputs = self.wav2vec(input_values, attention_mask=attention_mask)
        features = outputs.last_hidden_state
        
        if self.adapter is not None:
            features = self.adapter(features)

        logits = self.custom_head(features)
        print(f"Forward pass - logits shape: {logits.shape}")

        loss = None
        if labels is not None:
            # CTC loss computation
            log_probs = F.log_softmax(logits, dim=-1)
            
            # Get input lengths
            input_lengths = torch.full(
                size=(log_probs.shape[0],), 
                fill_value=log_probs.shape[1], 
                dtype=torch.long, 
                device=log_probs.device
            )
            
            # Get target lengths (exclude pad tokens which are ID=2)
            target_lengths = (labels != 2).sum(dim=-1)
            print(f"Target lengths: {target_lengths}")
            
            # Flatten labels for CTC loss
            targets_list = []
            for i, length in enumerate(target_lengths):
                valid_labels = [label.item() for label in labels[i, :length] if label.item() != 2]
                targets_list.extend(valid_labels)
            
            print(f"CTC targets: {len(targets_list)} tokens")
            
            if len(targets_list) > 0 and target_lengths.sum() > 0:
                targets = torch.tensor(targets_list, dtype=torch.long, device=logits.device)
                
                log_probs_t = log_probs.transpose(0, 1)
                
                ctc_loss_fn = nn.CTCLoss(
                    blank=0,
                    zero_infinity=True, 
                    reduction='mean'
                )
                
                target_lengths = torch.clamp(target_lengths, min=1)
                
                try:
                    loss = ctc_loss_fn(log_probs_t, targets, input_lengths, target_lengths)
                    print(f"CTC loss computed: {loss.item()}")
                    
                    if torch.isnan(loss):
                        print("Warning: NaN loss detected")
                        loss = torch.tensor(0.1, device=logits.device, requires_grad=True)
                        
                except Exception as e:
                    print(f"CTC loss computation failed: {e}")
                    loss = torch.tensor(0.1, device=logits.device, requires_grad=True)
            else:
                print("No valid targets for CTC loss")
                loss = torch.tensor(0.1, device=logits.device, requires_grad=True)

        return {
            'loss': loss,
            'logits': logits,
        }

In [12]:
# ---------------------------
# Data Collator
# ---------------------------
@dataclass
class DataCollatorCTCWithAugment:
    processor: Wav2Vec2Processor
    tokenizer: Wav2Vec2CTCTokenizer
    sample_rate: int = SAMPLE_RATE
    padding: bool = True
    apply_spec_augment: bool = True
    apply_audio_aug: bool = True
    min_audio_length: int = 1600

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        input_values = []
        labels = []

        for f in features:
            try:
                # Get audio file path
                audio_path = f.get("audio_path", "")
                
                if not audio_path or not os.path.exists(audio_path):
                    print(f"Audio file not found: {audio_path}")
                    continue

                # Load audio using librosa
                if LIBROSA_AVAILABLE:
                    arr, sr = librosa.load(audio_path, sr=self.sample_rate)
                else:
                    print("librosa not available - cannot load audio")
                    continue

                # Ensure 1D array
                if arr.ndim > 1:
                    arr = arr.flatten()

                # More lenient minimum length check
                if len(arr) < 800:  # Reduced from 1600 to 800
                    print(f"Skipping very short audio: {len(arr)} samples")
                    continue

                # Process with wav2vec2
                inputs = self.processor(
                    arr,
                    sampling_rate=self.sample_rate,
                    return_tensors="pt",
                )
                
                input_tensor = inputs.input_values.squeeze()
                if input_tensor.dim() == 0:
                    input_tensor = input_tensor.unsqueeze(0)
                input_values.append(input_tensor)

                # Process text
                text = str(f.get("text", "")).lower().strip()
                if not text:
                    print("Skipping empty text")
                    continue

                # Character-level tokenization
                token_ids = []
                vocab = self.tokenizer.get_vocab()
                
                for char in text:
                    if char in vocab:
                        token_ids.append(vocab[char])
                    else:
                        token_ids.append(vocab.get("<unk>", 1))  # unk is ID=1

                if token_ids:
                    labels.append(torch.tensor(token_ids, dtype=torch.long))
                else:
                    print("No valid tokens found")

            except Exception as e:
                print(f"Error processing sample {f.get('audio_path', 'unknown')}: {e}")
                import traceback
                traceback.print_exc()
                continue

        # Handle empty batch - IMPROVED ERROR HANDLING
        if len(input_values) == 0 or len(labels) == 0:
            print(f"WARNING: Empty batch - input_values: {len(input_values)}, labels: {len(labels)}")
            print("This indicates a data loading problem!")
            
            # Create a minimal valid batch instead of dummy
            vocab = self.tokenizer.get_vocab()
            dummy_audio = torch.zeros(self.sample_rate // 2)  # 0.5 second of silence
            dummy_labels = torch.tensor([vocab.get("a", 3)], dtype=torch.long)  # Use a common character
            
            return {
                "input_values": dummy_audio.unsqueeze(0),
                "labels": dummy_labels.unsqueeze(0),
            }

        # Pad sequences
        input_values = nn.utils.rnn.pad_sequence(input_values, batch_first=True, padding_value=0.0)
        labels = nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=2)  # pad token is ID=2

        print(f"Batch created: input_values shape: {input_values.shape}, labels shape: {labels.shape}")
        
        return {
            "input_values": input_values,
            "labels": labels,
        }

In [13]:
# # ---------------------------
# # Simple greedy + shallow fusion beam search (stub)
# # ---------------------------


# def greedy_decode_logits(logits: np.ndarray, tokenizer: PreTrainedTokenizerFast):
#     # logits: (B, T, V)
#     pred_ids = np.argmax(logits, axis=-1)
#     texts = []
#     for b in range(pred_ids.shape[0]):
#         ids = pred_ids[b].tolist()
#         # collapse repeats and remove pad (assume pad id is 0)
#         prev = None
#         out = []
#         for i in ids:
#             if i != prev and i != tokenizer.pad_token_id:
#                 out.append(i)
#             prev = i
#         texts.append(tokenizer.decode(out, clean_up_tokenization_spaces=False))
#     return texts


# # Shallow fusion requires an LM that can provide log-prob for token sequences. We'll provide a simple
# # integration hook using a pretrained causal LM (like GPT-2); in practice you'd prefer a kenlm or an ngram.


# class SimpleShallowFusionLM:
#     def __init__(self, model_name=LM_MODEL_NAME, device='cpu'):
#         try:
#             from transformers import AutoModelForCausalLM, AutoTokenizer
#             self.tokenizer = AutoTokenizer.from_pretrained(model_name)
#             self.model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
#             self.model.eval()
#             self.device = device
#         except Exception as e:
#             print("Warning: Could not load LM for shallow fusion:", e)
#             self.model = None
#             self.tokenizer = None


#     def score_sequences(self, sequences: List[str]) -> List[float]:
#         # return average log-prob per token (negative loss)
#         if self.model is None:
#             return [0.0] * len(sequences)
#         enc = self.tokenizer(sequences, return_tensors='pt', padding=True).to(self.device)
#         with torch.no_grad():
#             out = self.model(**enc, labels=enc['input_ids'])
#             # negative loss gives log-prob
#             loss = out.loss.cpu().numpy()
#             # convert to per-sequence score; huggingface returns mean loss over tokens & batch
#             return [-float(loss)] * len(sequences)

In [14]:
# ---------------------------
# Dataset preparation
# ---------------------------
def prepare_dataset(tsv_path: str, audio_base: str) -> Dataset:
    import pandas as pd
    df = pd.read_csv(tsv_path, sep='\t', low_memory=False)
    
    # Handle non-standard column names
    columns = df.columns.tolist()
    if len(columns) >= 2:
        df = df.rename(columns={columns[0]: 'path', columns[1]: 'text'})
    
    if 'path' not in df.columns:
        raise ValueError('TSV must have at least 2 columns: audio path and text')

    # Create full audio paths
    def get_full_path(p):
        if os.path.isabs(p):
            return p
        else:
            base_path = os.path.join(audio_base, p)
            if os.path.exists(base_path):
                return base_path
            for ext in ['.wav', '.mp3', '.flac', '.ogg']:
                if os.path.exists(base_path + ext):
                    return base_path + ext
            return base_path
    
    df['audio_path'] = df['path'].apply(get_full_path)
    
    # Filter valid files
    initial_count = len(df)
    df = df[df['audio_path'].apply(os.path.exists)]
    final_count = len(df)
    
    print(f"Filtered dataset: {initial_count} -> {final_count} samples")
    
    if final_count == 0:
        raise ValueError("No valid audio files found!")
    
    # Keep as simple Dataset without Audio casting
    dataset = Dataset.from_pandas(df[['audio_path', 'text']])
    
    return dataset

In [15]:
def main():
    print("🚀 Starting Wav2Vec2 training with custom head...")
    
    # Load tokenizer and processor
    tokenizer = load_or_create_tokenizer(TSV_PATH, TOKENIZER_DIR)
    feature_extractor = Wav2Vec2FeatureExtractor(
        feature_size=1, 
        sampling_rate=SAMPLE_RATE, 
        padding_value=0.0, 
        do_normalize=True, 
        return_attention_mask=True
    )
    processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
    
    print(f"✅ Tokenizer vocabulary size: {len(tokenizer.get_vocab())}")
    
    # Prepare dataset
    dataset = prepare_dataset(TSV_PATH, AUDIO_BASE)
    print(f"✅ Dataset loaded: {len(dataset)} samples")
    
    # DEBUG: Check what's in the first few samples
    print("DEBUG: First sample keys:", dataset[0].keys())
    print("DEBUG: First sample audio_path:", dataset[0]['audio_path'])
    print("DEBUG: First sample text:", dataset[0]['text'])
    
    # Check if audio files actually exist
    sample_paths = [dataset[i]['audio_path'] for i in range(min(5, len(dataset)))]
    existing_files = [path for path in sample_paths if os.path.exists(path)]
    print(f"DEBUG: Sample audio files exist: {len(existing_files)}/{len(sample_paths)}")
    
    # Split dataset
    dataset = dataset.train_test_split(test_size=0.1, seed=42)
    train_dataset = dataset['train']
    eval_dataset = dataset['test']
    
    print(f"Train dataset size: {len(train_dataset)}")
    print(f"Eval dataset size: {len(eval_dataset)}")
    
    # CREATE MODEL INSTANCE HERE - BEFORE USING IT
    print("🏗️ Creating model...")
    model = W2V2WithCustomHead(
        pretrained_name=PRETRAINED_W2V,
        processor=processor,
        tokenizer=tokenizer,
        freeze_w2v=FREEZE_W2V,
        use_adapters=USE_ADAPTERS,
        adapter_dim=ADAPTER_BOTTLENECK,
        trans_layers=CUSTOM_TRANSFORMER_LAYERS,
        trans_dim=CUSTOM_TRANSFORMER_DIM,
        trans_heads=CUSTOM_TRANSFORMER_HEADS,
        dropout=DROPOUT
    )
    model.to(DEVICE)
    print(f"✅ Model loaded on {DEVICE}")
    
    # Test data collator with a small batch
    data_collator = DataCollatorCTCWithAugment(
        processor=processor,
        tokenizer=tokenizer,
        sample_rate=SAMPLE_RATE,
        apply_spec_augment=False,  # Disable for debugging
        apply_audio_aug=False      # Disable for debugging
    )
    
    # Test with first few samples
    test_batch = [train_dataset[i] for i in range(min(3, len(train_dataset)))]
    print("Testing data collator...")
    try:
        batch = data_collator(test_batch)
        print(f"Test batch successful: {batch['input_values'].shape}")
    except Exception as e:
        print(f"Data collator test failed: {e}")
        import traceback
        traceback.print_exc()
        return
    
    # Metrics
    wer_metric = load('wer')
    
    def compute_metrics(pred):
        logits = pred.predictions
        if isinstance(logits, tuple):
            logits = logits[0]
            
        pred_ids = np.argmax(logits, axis=-1)
        label_ids = pred.label_ids
        
        pred_texts = []
        ref_texts = []
        
        for i in range(pred_ids.shape[0]):
            # Decode predictions (remove repeats and padding)
            pred_seq = []
            prev_id = None
            for token_id in pred_ids[i]:
                if token_id != prev_id and token_id != 0:  # 0 is blank token
                    pred_seq.append(token_id)
                prev_id = token_id
            pred_text = tokenizer.decode(pred_seq, skip_special_tokens=True)
            pred_texts.append(pred_text)
            
            # Decode references
            ref_seq = [t for t in label_ids[i] if t != 2]  # 2 is pad token
            ref_text = tokenizer.decode(ref_seq, skip_special_tokens=True)
            ref_texts.append(ref_text)
        
        wer_score = wer_metric.compute(predictions=pred_texts, references=ref_texts)
        return {"wer": wer_score}
    
    # Training arguments
    training_args = TrainingArguments(
        output_dir=OUTPUT_DIR,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE,
        num_train_epochs=EPOCHS,
        learning_rate=LR,
        warmup_steps=500,
        logging_steps=50,
        eval_steps=500,
        save_steps=500,
        eval_strategy="steps",
        save_total_limit=3,
        remove_unused_columns=False,
        dataloader_pin_memory=False,
        group_by_length=False,
        fp16=False,
        push_to_hub=False,
        report_to=None,
        max_grad_norm=1.0,
        dataloader_num_workers=0,
    )
    
    # Wrapper for HuggingFace Trainer
    class HFWrapperModel(nn.Module):
        def __init__(self, inner_model):
            super().__init__()
            self.inner = inner_model
        
        def forward(self, input_values=None, labels=None, **kwargs):
            outputs = self.inner(input_values=input_values, labels=labels)
            return outputs['loss'], outputs['logits']
    
    # NOW model is defined, so this will work
    wrapped_model = HFWrapperModel(model)
    
    # Trainer
    trainer = Trainer(
        model=wrapped_model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
    )
    
    # Train
    print("🏋️ Starting training...")
    trainer.train()
    
    # Final evaluation
    print("📊 Final evaluation...")
    metrics = trainer.evaluate()
    print("Final metrics:", metrics)
    
    # Save model
    trainer.save_model()
    processor.save_pretrained(OUTPUT_DIR)
    print(f"✅ Model saved to {OUTPUT_DIR}")

In [16]:
if __name__ == "__main__":
    main()

🚀 Starting Wav2Vec2 training with custom head...
✅ Tokenizer vocabulary size: 72
Filtered dataset: 2063 -> 2063 samples
✅ Dataset loaded: 2063 samples
DEBUG: First sample keys: dict_keys(['audio_path', 'text'])
DEBUG: First sample audio_path: /home/sahil_duwal/MajorProject/Dataset/ne_np_female/wavs/nep_0258_0461984530.wav
DEBUG: First sample text: डिग्रा देवीको जन्म सुदूरपश्चिम नेपालको बझाङ जिल्लामा भएको हो
DEBUG: Sample audio files exist: 5/5
Train dataset size: 1856
Eval dataset size: 207
🏗️ Creating model...
✅ Model loaded on cuda
Testing data collator...
Batch created: input_values shape: torch.Size([3, 55366]), labels shape: torch.Size([3, 44])
Test batch successful: torch.Size([3, 55366])
🏋️ Starting training...
Batch created: input_values shape: torch.Size([2, 50336]), labels shape: torch.Size([2, 45])
Batch created: input_values shape: torch.Size([2, 111638]), labels shape: torch.Size([2, 93])
Forward pass - input_values shape: torch.Size([2, 50336])
Forward pass - labels shape

Step,Training Loss,Validation Loss,Wer
500,5.1623,4.912009,1.0
1000,6.6034,4.427245,0.998865
1500,5.5869,4.361509,1.0
2000,5.8549,4.379185,0.999432
2500,5.7663,4.204469,1.002838
3000,4.9051,4.113732,1.001703
3500,6.3492,4.186207,1.0


Batch created: input_values shape: torch.Size([2, 64070]), labels shape: torch.Size([2, 44])
Forward pass - input_values shape: torch.Size([2, 111638])
Forward pass - labels shape: torch.Size([2, 93])
Forward pass - logits shape: torch.Size([2, 348, 72])
Target lengths: tensor([93, 49], device='cuda:0')
CTC targets: 142 tokens
CTC loss computed: 18.33310890197754
Batch created: input_values shape: torch.Size([2, 96726]), labels shape: torch.Size([2, 73])
Forward pass - input_values shape: torch.Size([2, 64070])
Forward pass - labels shape: torch.Size([2, 44])
Forward pass - logits shape: torch.Size([2, 199, 72])
Target lengths: tensor([44, 35], device='cuda:0')
CTC targets: 79 tokens
CTC loss computed: 17.356586456298828
Batch created: input_values shape: torch.Size([2, 115190]), labels shape: torch.Size([2, 79])
Forward pass - input_values shape: torch.Size([2, 96726])
Forward pass - labels shape: torch.Size([2, 73])
Forward pass - logits shape: torch.Size([2, 302, 72])
Target lengths

KeyboardInterrupt: 