In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!add-apt-repository -y ppa:jonathonf/ffmpeg-4
!apt update
!apt install -y ffmpeg
!pip install datasets>=2.6.1
!pip install git+https://github.com/huggingface/transformers
!pip install librosa
!pip install evaluate>=0.30
!pip install jiwer
!pip install gradio
!pip install -q bitsandbytes datasets accelerate loralib
!pip install -q git+https://github.com/huggingface/transformers.git@main git+https://github.com/huggingface/peft.git@main

In [None]:
# Step 0: Setup in Kaggle Notebook librosa datasets transformers evaluate
!pip install torchaudio  soundfile gdown pandas chardet num2words jiwer tqdm 
import gdown
import os
import pandas as pd
import torchaudio
from datasets import Dataset, DatasetDict
import librosa
import soundfile as sf

# #modified for testing purpose 

In [None]:
def check_dataset(audio_dir, text_dir):
    audio_files = set([f.split('.')[0] for f in os.listdir(audio_dir) if f.endswith('.wav')])
    text_files = set([f.split('.')[0] for f in os.listdir(text_dir) if f.endswith('.txt')])
    
    missing_audio = text_files - audio_files
    missing_text = audio_files - text_files
    
    print(f"Total pairs: {len(audio_files & text_files)}")
    print(f"Missing audio for {len(missing_audio)} text files")
    print(f"Missing text for {len(missing_text)} audio files")
    
check_dataset('/kaggle/input/tamil-dataset-for-asr/Training/Audio', '/kaggle/input/tamil-dataset-for-asr/Training/Transcripts')

all_files = [f.split('.')[0] for f in os.listdir('/kaggle/input/tamil-dataset-for-asr/Training/Audio') if f.endswith('.wav')]
df = pd.DataFrame({'file_id': all_files})

# Kaggle: Use fixed random_state for reproducibility
# First split: take 5 samples for final test
final_test_df = df.sample(n=5, random_state=42)
remaining_df = df.drop(final_test_df.index)

# Then split the remaining data into train and validation
train_df = remaining_df.sample(frac=0.99, random_state=42)  # Using 99% of remaining for training
test_df = remaining_df.drop(train_df.index)  # The rest for validation

# Save splits
train_df.to_csv('/kaggle/working/train_files.csv', index=False)
test_df.to_csv('/kaggle/working/val_files.csv', index=False)  # Validation set
final_test_df.to_csv('/kaggle/working/final_test_files.csv', index=False)  # Final test set

print(f"Training samples: {len(train_df)}")
print(f"Validation samples: {len(test_df)}")
print(f"Final test samples: {len(final_test_df)}")

In [None]:
final_test_df

In [None]:
train_df

In [None]:
!pip install pyloudnorm

In [None]:
import pyloudnorm

from pyloudnorm import Meter

# def preprocess_audio(audio_path):
#     waveform, sr = librosa.load(audio_path, sr=16000, mono=True)
#     meter = pyloudnorm.Meter(sr) 
#     loudness = meter.integrated_loudness(waveform)
#     target_loudness = -16  # Optimal for speech
#     waveform = pyloudnorm.normalize.loudness(waveform, loudness, target_loudness)
#     waveform = waveform - np.mean(waveform)
#     waveform = np.clip(waveform, -1.0, 1.0)
#     output_path = f"/kaggle/working/processed_audio/{os.path.basename(audio_path)}"
#     sf.write(output_path, waveform, 16000, subtype='PCM_16')
#     return output_path.strip



def preprocess_audio(audio_path):
    waveform, sr = librosa.load(audio_path, sr=16000, mono=True)
    
    # Silence removal with buffers and duration check
    intervals = librosa.effects.split(waveform, top_db=40, frame_length=1024, hop_length=256)
    
    # Add 200ms buffers and filter short segments
    buffer_samples = int(0.2 * sr)
    min_samples = int(0.3 * sr)
    
    processed_intervals = []
    for start, end in intervals:
        buffered_start = max(0, start - buffer_samples)
        buffered_end = min(len(waveform), end + buffer_samples)
        if (buffered_end - buffered_start) > min_samples:
            processed_intervals.append((buffered_start, buffered_end))
    
    if not processed_intervals:
        print(f"No speech found in {audio_path}")
        return None
        
    non_silent_audio = np.concatenate(
        [waveform[start:end] for start, end in processed_intervals]
    )
    
    # Post-processing
    non_silent_audio -= np.mean(non_silent_audio)  # DC offset
    non_silent_audio = np.clip(non_silent_audio, -1.0, 1.0)
    
    # Peak normalization (better than loudness norm for ASR)
    peak = np.max(np.abs(non_silent_audio))
    if peak > 0:
        non_silent_audio /= peak
        
    # Save output
    os.makedirs("/kaggle/working/processed_audio", exist_ok=True)
    output_path = f"/kaggle/working/processed_audio/{os.path.basename(audio_path)}"
    sf.write(output_path, non_silent_audio, sr, subtype="PCM_16")
    
    return output_path
    
# def process_all_audio(df):
#     os.makedirs('/kaggle/working/processed_audio', exist_ok=True)
#     df['audio_path'] = df['file_id'].apply(
#         lambda x: preprocess_audio(f"/kaggle/input/tamil-dataset-for-asr/Training/Audio/{x}.wav")
#     )
#     return df

def process_all_audio(df):
    # Skip creating a processed_audio directory (no preprocessing needed)
    # Just assign the original file paths directly
    df['audio_path'] = df['file_id'].apply(
        lambda x: f"/kaggle/input/tamil-dataset-for-asr/Training/Audio/{x}.wav"
    )
    return df

train_df = process_all_audio(train_df)
test_df = process_all_audio(test_df)
final_test_df=process_all_audio(final_test_df)

In [None]:
print("Train columns:", train_df.columns.tolist())
print("Test columns:", test_df.columns.tolist())

In [None]:
train_df['audio_path'][0]

In [None]:
import re
import unicodedata

def normalize_tamil_text(text):
    text = re.sub(r'<\|.*?\|>', '', text)
    text = text.replace('\u200c', '')
    text = unicodedata.normalize('NFKC', text)
    tamil_punctuation = '''!(),–‘“”…•॥ॐ।<>?@[\\]^_`{|}~'''
    translator = str.maketrans('', '', tamil_punctuation)
    text = text.translate(translator)
    return text.strip()

def load_transcript(file_id):
    txt_path = f"/kaggle/input/tamil-dataset-for-asr/Training/Transcripts/{file_id}.txt"
    encodings = ['utf-8', 'utf-16', 'utf-16-le', 'utf-16-be', 'latin1']
    
    for encoding in encodings:
        try:
            with open(txt_path, 'r', encoding=encoding) as f:
                text = f.read().strip()
                
            if any('\u0B80' <= char <= '\u0BFF' for char in text[:10]):
                text = normalize_tamil_text(text)
                return text
        except UnicodeError:
            continue
    print(f"Failed to decode {file_id}.txt with common encodings")
    return None

train_df['text'] = train_df['file_id'].apply(load_transcript)
test_df['text'] = test_df['file_id'].apply(load_transcript)
final_test_df['text'] = final_test_df['file_id'].apply(load_transcript)
train_df = train_df.dropna(subset=['text'])
test_df = test_df.dropna(subset=['text'])
final_test_df = final_test_df.dropna(subset=['text'])

In [None]:
test_df

# Data Analysis

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import librosa
import numpy as np
import pandas as pd
from collections import Counter

# 1. Basic Dataset Statistics
def print_basic_stats(df, name):
    total_duration = df['duration'].sum()
    avg_duration = df['duration'].mean()
    avg_text_len = df['text_length'].mean()
    
    print(f"\n{name} Set:")
    print(f"Samples: {len(df):,}")
    print(f"Total Duration: {total_duration/3600:.2f} hours")
    print(f"Avg Duration: {avg_duration:.2f}s")
    print(f"Avg Text Length: {avg_text_len:.0f} characters")

# Compute durations and text lengths
def compute_metrics(df):
    df['duration'] = df['audio_path'].apply(lambda x: librosa.get_duration(path=x))
    df['text_length'] = df['text'].str.len()
    return df

train_df = compute_metrics(train_df)
test_df = compute_metrics(test_df)
final_test_df = compute_metrics(final_test_df)

# Print basic statistics
print("=== Dataset Statistics ===")
print_basic_stats(train_df, "Training")
print_basic_stats(test_df, "Validation")
print_basic_stats(final_test_df, "Final Test")

# 2. Sample Distribution Visualization
plt.figure(figsize=(10, 5))
sns.barplot(x=['Training', 'Validation', 'Final Test'],
            y=[len(train_df), len(test_df), len(final_test_df)])
plt.title('Sample Distribution Across Splits')
plt.ylabel('Number of Samples')
plt.show()

# 3. Audio Duration Distribution
plt.figure(figsize=(12, 6))
sns.histplot(train_df['duration'], bins=50, color='blue', alpha=0.5, label='Train')
sns.histplot(test_df['duration'], bins=50, color='orange', alpha=0.5, label='Validation')
sns.histplot(final_test_df['duration'], bins=50, color='green', alpha=0.5, label='Final Test')
plt.title('Audio Duration Distribution')
plt.xlabel('Duration (seconds)')
plt.ylabel('Count')
plt.legend()
plt.show()

# 4. Text Length Distribution
plt.figure(figsize=(12, 6))
sns.histplot(train_df['text_length'], bins=50, color='blue', alpha=0.5, label='Train')
sns.histplot(test_df['text_length'], bins=50, color='orange', alpha=0.5, label='Validation')
plt.title('Text Length Distribution (Characters)')
plt.xlabel('Text Length (characters)')
plt.ylabel('Count')
plt.legend()
plt.show()

# 5. Character Frequency Analysis
all_text = pd.concat([train_df['text'], test_df['text'], final_test_df['text']]).str.cat()
char_counts = Counter(all_text)
top_chars = dict(sorted(char_counts.items(), key=lambda x: x[1], reverse=True)[:20])

plt.figure(figsize=(15, 5))
sns.barplot(x=list(top_chars.keys()), y=list(top_chars.values()))
plt.title('Top 20 Most Frequent Characters')
plt.xlabel('Character')
plt.ylabel('Count')
plt.show()

# 6. Data Leakage Check
train_texts = set(train_df['text'])
test_texts = set(test_df['text'])
final_test_texts = set(final_test_df['text'])

overlap_train_test = train_texts & test_texts
overlap_train_final = train_texts & final_test_texts
overlap_test_final = test_texts & final_test_texts

print("\n=== Data Leakage Check ===")
print(f"Common texts between Train-Validation: {len(overlap_train_test)}")
print(f"Common texts between Train-Final Test: {len(overlap_train_final)}")
print(f"Common texts between Validation-Final Test: {len(overlap_test_final)}")

# 7. Audio Waveform Examples (for 3 random samples)
def plot_waveforms(df, title):
    plt.figure(figsize=(15, 3))
    for i in range(3):
        idx = np.random.choice(df.index)
        audio_path = df.loc[idx, 'audio_path']
        waveform, sr = librosa.load(audio_path, sr=16000)
        plt.subplot(1, 3, i+1)
        librosa.display.waveshow(waveform, sr=sr)
        plt.title(f"Sample {idx}\nDuration: {df.loc[idx, 'duration']:.2f}s")
    plt.suptitle(title)
    plt.tight_layout()
    plt.show()

plot_waveforms(train_df, "Training Set Waveform Examples")
plot_waveforms(test_df, "Validation Set Waveform Examples")

In [None]:
def plot_spectrograms(df, title, sr=16000):
    plt.figure(figsize=(15, 10))
    for i in range(3):
        idx = np.random.choice(df.index)
        audio_path = df.loc[idx, 'audio_path']
        y, sr = librosa.load(audio_path, sr=sr)
        
        plt.subplot(3, 2, 2*i+1)
        librosa.display.waveshow(y, sr=sr)
        plt.title(f"Waveform - Sample {idx}")
        
        plt.subplot(3, 2, 2*i+2)
        D = librosa.amplitude_to_db(np.abs(librosa.stft(y)), ref=np.max)
        librosa.display.specshow(D, sr=sr, x_axis='time', y_axis='log')
        plt.colorbar(format='%+2.0f dB')
        plt.title(f"Spectrogram - Sample {idx}")
    
    plt.suptitle(title)
    plt.tight_layout()
    plt.show()

plot_spectrograms(train_df, "Training Set Audio Examples")

In [None]:
def plot_mfccs(df, title, sr=16000):
    plt.figure(figsize=(15, 5))
    for i in range(3):
        idx = np.random.choice(df.index)
        audio_path = df.loc[idx, 'audio_path']
        y, sr = librosa.load(audio_path, sr=sr)
        
        mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13)
        
        plt.subplot(1, 3, i+1)
        librosa.display.specshow(mfccs, x_axis='time')
        plt.colorbar()
        plt.title(f"MFCCs - Sample {idx}")
    
    plt.suptitle(title)
    plt.tight_layout()
    plt.show()

plot_mfccs(train_df, "Validation Set MFCC Examples")

In [None]:
def plot_pitch_distribution(df, title):
    pitches = []
    for _, row in df.sample(100).iterrows():  # Subsample for speed
        y, sr = librosa.load(row['audio_path'], sr=None)  # Keep original SR
        pitch = librosa.yin(y, fmin=50, fmax=2000, sr=sr)
        valid_pitches = pitch[(pitch > 50) & (pitch < 2000)]  # Explicit valid range
        pitches.extend(valid_pitches)
    
    plt.figure(figsize=(12, 6))
    
    # Use KDE for smoother distribution
    sns.kdeplot(pitches, fill=True, log_scale=True)
    
    # Add musical reference lines
    for note in [55, 110, 220, 440, 880]:  # A1, A2, A3, A4, A5
        plt.axvline(note, color='red', linestyle='--', alpha=0.3)
    
    plt.title(f'{title} - Pitch Distribution')
    plt.xlabel('Fundamental Frequency (Hz) [Log Scale]')
    plt.ylabel('Density')
    plt.xscale('log')
    plt.xticks([50, 100, 200, 400, 800, 1600], 
               ['50', '100', '200', '400', '800', '1600'])
    plt.grid(True, which='both', linestyle='--', alpha=0.5)
    plt.show()

plot_pitch_distribution(train_df, 'Training Set')

In [None]:
def plot_silence_distribution(df, title):
    silence_percentages = []
    
    for _, row in df.iterrows():
        # Load audio with consistent sample rate (match your preprocessing SR)
        y, sr = librosa.load(row['audio_path'], sr=16000)  # Match your preprocessing SR
        
        # Get NON-silent intervals (common confusion!)
        non_silent_intervals = librosa.effects.split(
            y, 
            top_db=30,  # Same threshold as preprocessing
            frame_length=1024,
            hop_length=256
        )
        # Calculate SILENT duration (total - non-silent)
        total_duration = len(y) / sr
        non_silent_duration = sum(end - start for (start, end) in non_silent_intervals) / sr
        silent_duration = total_duration - non_silent_duration
        
        # Handle edge case: avoid division by zero
        if total_duration > 0:
            silence_percent = (silent_duration / total_duration) * 100
            silence_percentages.append(silence_percent)
        else:
            silence_percentages.append(0)  # Handle 0-duration files
    
    # Create visualization
    plt.figure(figsize=(12, 6))
    
    # Use KDE plot for better distribution understanding
    ax = sns.histplot(
        silence_percentages, 
        bins=30,
        kde=True,
        stat='percent',
        color='skyblue',
        edgecolor='black'
    )
    
    # Add critical annotations
    mean_val = np.mean(silence_percentages)
    median_val = np.median(silence_percentages)
    plt.axvline(mean_val, color='red', linestyle='--', label=f'Mean: {mean_val:.1f}%')
    plt.axvline(median_val, color='green', linestyle='--', label=f'Median: {median_val:.1f}%')
    
    plt.title(f'{title} - Silence Percentage Distribution\n(top_db=30, SR=16kHz)')
    plt.xlabel('Percentage of Silence in Audio')
    plt.ylabel('Percentage of Samples')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

plot_silence_distribution(train_df, 'Training Set')

In [None]:
def plot_temporal_features(df, title):
    idx = np.random.choice(df.index)
    audio_path = df.loc[idx, 'audio_path']
    y, sr = librosa.load(audio_path)
    
    plt.figure(figsize=(15, 8))
    
    # Waveform
    plt.subplot(3, 1, 1)
    librosa.display.waveshow(y, sr=sr)
    plt.title(f'Waveform - {title}')
    
    # RMS Energy
    plt.subplot(3, 1, 2)
    S, phase = librosa.magphase(librosa.stft(y))
    rms = librosa.feature.rms(S=S)
    times = librosa.times_like(rms)
    plt.plot(times, rms[0])
    plt.title('RMS Energy')
    
    # Spectral Rolloff
    plt.subplot(3, 1, 3)
    rolloff = librosa.feature.spectral_rolloff(y=y, sr=sr)
    plt.plot(times, rolloff[0])
    plt.title('Spectral Rolloff')
    
    plt.tight_layout()
    plt.show()

plot_temporal_features(train_df, 'Training Sample Temporal Features')

In [None]:
def plot_sample_fft(df, title):
    idx = np.random.choice(df.index)
    y, sr = librosa.load(df.loc[idx, 'audio_path'])
    
    # Compute FFT
    fft = np.fft.fft(y)
    magnitude = np.abs(fft)
    frequency = np.linspace(0, sr, len(magnitude))
    
    plt.figure(figsize=(15, 5))
    plt.plot(frequency[:len(frequency)//2], magnitude[:len(magnitude)//2])  # Only plot first half (Nyquist)
    plt.title(f'{title} - FFT Analysis\nSample {idx} (Duration: {df.loc[idx, "duration"]:.2f}s)')
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Magnitude')
    plt.grid()
    plt.show()

plot_sample_fft(train_df, 'Training Set')

In [None]:
def plot_average_fft(df, title, n_samples=10):
    plt.figure(figsize=(15, 5))
    all_magnitudes = []
    
    for path in df['audio_path'].sample(n_samples):
        y, sr = librosa.load(path)
        fft = np.fft.fft(y)
        magnitude = np.abs(fft)[:len(y)//2]  # Take first half
        all_magnitudes.append(magnitude)
    
    # Compute mean spectrum
    avg_magnitude = np.mean(all_magnitudes, axis=0)
    frequency = np.linspace(0, sr//2, len(avg_magnitude))
    
    plt.plot(frequency, avg_magnitude)
    plt.title(f'{title} - Average FFT Spectrum ({n_samples} samples)')
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Average Magnitude')
    plt.grid()
    plt.show()

plot_average_fft(train_df, 'Training Set')

In [None]:
def plot_average_fft(df, title, n_samples=10):
    plt.figure(figsize=(15, 5))
    all_magnitudes = []
    
    for path in df['audio_path'].sample(n_samples):
        y, sr = librosa.load(path)
        fft = np.fft.fft(y)
        magnitude = np.abs(fft)[:len(y)//2]  # Take first half
        all_magnitudes.append(magnitude)
    
    # Compute mean spectrum
    avg_magnitude = np.mean(all_magnitudes, axis=0)
    frequency = np.linspace(0, sr//2, len(avg_magnitude))
    
    plt.plot(frequency, avg_magnitude)
    plt.title(f'{title} - Average FFT Spectrum ({n_samples} samples)')
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Average Magnitude')
    plt.grid()
    plt.show()

plot_average_fft(train_df, 'Training Set')

In [None]:
def plot_frequency_bands(df, title):
    band_energies = {'low': [], 'mid': [], 'high': []}
    
    for path in df['audio_path'].sample(100):  # Subsample
        y, sr = librosa.load(path)
        fft = np.fft.fft(y)
        magnitude = np.abs(fft)[:len(y)//2]
        frequency = np.linspace(0, sr//2, len(magnitude))  # Match length

        bands = {
            'low': (0, 200),
            'mid': (200, 2000),
            'high': (2000, sr//2)
        }

        for band, (low, high) in bands.items():
            idx = np.where((frequency >= low) & (frequency <= high))
            band_energies[band].append(np.sum(magnitude[idx]))

    plt.figure(figsize=(10, 5))
    sns.boxplot(data=pd.DataFrame(band_energies))
    plt.title(f'{title} - Frequency Band Energy Distribution')
    plt.ylabel('Energy Sum')
    plt.show()


sample_path = train_df['audio_path'].iloc[0]
y, sr = librosa.load(sample_path)
frequency = np.linspace(0, sr//2, len(y)//2)
plot_frequency_bands(train_df, 'Training Set')

# Model Training 

In [None]:
from datasets import Dataset

train_dataset = Dataset.from_pandas(train_df[['audio_path', 'text']])
test_dataset = Dataset.from_pandas(test_df[['audio_path', 'text']])

In [None]:
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from transformers import WhisperFeatureExtractor

feature_extractor = WhisperFeatureExtractor.from_pretrained("vasista22/whisper-tamil-large-v2")
from transformers import WhisperTokenizer

tokenizer = WhisperTokenizer.from_pretrained("vasista22/whisper-tamil-large-v2", language="ta", task="transcribe")
from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained("vasista22/whisper-tamil-large-v2", language="ta", task="transcribe")

In [None]:
import librosa

def prepare_dataset(batch):
    # Load audio and compute features
    audio, sr = librosa.load(batch["audio_path"], sr=16000)
    batch["input_features"] = processor(audio, sampling_rate=sr, return_tensors="pt").input_features[0]
    
    # Tokenize text
    batch["labels"] = processor.tokenizer(batch["text"]).input_ids
    return batch

train_dataset = train_dataset.map(prepare_dataset)
test_dataset = test_dataset.map(prepare_dataset)

In [None]:
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import torch

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Split inputs and labels
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        # Pad input features
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # Pad labels
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # Replace padding with -100 (to ignore in loss)
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels
        return batch

# Initialize collator
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

In [None]:
train_dataset

In [None]:
import evaluate

metric = evaluate.load("wer")
cer_metric = evaluate.load("cer")


In [None]:

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    if isinstance(pred_ids, tuple):
        pred_ids = pred_ids[0]
    
    if len(pred_ids.shape) > 2:
        pred_ids = pred_ids.argmax(axis=-1)
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)
    cer = 100 * cer_metric.compute(predictions=pred_str, references=label_str)
    
    return {"wer": wer, "cer": cer}

In [None]:
from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained("vasista22/whisper-tamil-large-v2", 
                                                        load_in_8bit=True, 
                                                        device_map="auto",
                                                        use_safetensors=True)


In [None]:
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

In [None]:
from peft import prepare_model_for_kbit_training

model = prepare_model_for_kbit_training(model)


In [None]:
from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model

config = LoraConfig(r=64, 
                    lora_alpha=128,
                    target_modules=["k_proj","q_proj", "v_proj","dense"], 
                    lora_dropout=0.05, 
                    bias="none",
                    use_dora=True)

model = get_peft_model(model, config)
model.print_trainable_parameters()

In [None]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="temp",  
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,  
    learning_rate=1e-5,
    warmup_steps=50,
    num_train_epochs=3,
    eval_strategy="epoch",
    fp16=True,
    per_device_eval_batch_size=8,
    generation_max_length=128,
    logging_steps=25,
    remove_unused_columns=False,
    label_names=["labels"],  
    report_to="none", 
    logging_dir="./logs" 
)


In [None]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)
model.config.use_cache = False

In [None]:
import os
os.environ["WANDB_DISABLED"] = "true"


In [None]:
trainer.train() 

# Prediction Generation

In [None]:
# First save your trained model and processor
trainer.model.save_pretrained("/kaggle/working/whisper-tamil-lora")
processor.save_pretrained("/kaggle/working/whisper-tamil-lora")

# Load the test samples
test_df = pd.read_csv('/kaggle/working/final_test_files.csv')
audio_dir = '/kaggle/input/tamil-dataset-for-asr/Training/Audio'
text_dir = '/kaggle/input/tamil-dataset-for-asr/Training/Transcripts'



In [None]:
from peft import PeftModel, PeftConfig

base_model = WhisperForConditionalGeneration.from_pretrained("vasista22/whisper-tamil-large-v2",
                                                            device_map="auto",
                                                            use_safetensors=True)
model = PeftModel.from_pretrained(base_model, "/kaggle/working/whisper-tamil-lora")
processor = WhisperProcessor.from_pretrained("/kaggle/working/whisper-tamil-lora")

# Move model to device and set to eval mode
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model.eval()




In [None]:
# Function to get actual transcript
def get_actual_transcript(file_id):
    text_path = os.path.join(text_dir, f"{file_id}.txt")
    with open(text_path, "r", encoding="utf-8") as f:
        return f.read().strip()

# Function to calculate WER and CER
def calculate_metrics(pred_text, actual_text):
    return {
        "wer": metric.compute(predictions=[pred_text], references=[actual_text]),
        "cer": cer_metric.compute(predictions=[pred_text], references=[actual_text])
    }


In [None]:
!pip install tabulate

In [None]:
# Prediction and evaluation
results = []

for file_id in test_df['file_id']:
    # Load audio
    audio_path = os.path.join(audio_dir, f"{file_id}.wav")
    audio, sr = librosa.load(audio_path, sr=16000)
    
    # Process audio
    inputs = processor(
        audio, 
        sampling_rate=16000, 
        return_tensors="pt"
    ).input_features.to(device)
    
    # Generate prediction
    with torch.no_grad():
        generated_ids = model.generate(inputs=inputs, max_length=448)
    
    # Decode prediction
    pred_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    
    # Get actual text
    actual_text = get_actual_transcript(file_id)
    
    # Calculate metrics
    metrics = calculate_metrics(pred_text, actual_text)
    
    print("Current file ID : ",file_id)
    
    results.append({
        "file_id": file_id,
        "prediction": pred_text,
        "actual": actual_text,
        "WER": f"{metrics['wer']*100:.2f}%",
        "CER": f"{metrics['cer']*100:.2f}%"
    })

# Create and display results table
results_df = pd.DataFrame(results)
print("\nEvaluation Results:")
print(tabulate(results_df[["prediction", "actual", "WER", "CER"]], 
             headers=["Prediction", "Actual", "WER", "CER"], 
             tablefmt="pretty"))

In [None]:
results_df

In [None]:
results_df['actual'][1]

In [None]:
results_df['prediction'][1]