In [None]:
!nvidia-smi

/bin/bash: line 1: nvidia-smi: command not found


In [None]:
import os
import torch
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch.utils.checkpoint import checkpoint
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F

In [None]:
torch.cuda.is_available()

False

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
%cd /content/drive/MyDrive/DEPI_Project

/content/drive/MyDrive/DEPI_Project


In [None]:
!pwd

/content/drive/MyDrive/DEPI_Project


In [None]:
FEATURE_DIR = "/content/drive/MyDrive/DEPI_Project/featured_extracted/cv-other-dev"  # Update this
CSV_PATH = "/content/drive/MyDrive/DEPI_Project/mel_processed_dataset.csv"  # Update this

In [None]:
torch.cuda.empty_cache()

In [None]:
class Config:
    """Configuration for the audio transcription model"""
    # Data paths
    feature_dir = FEATURE_DIR
    csv_path = CSV_PATH

    # Text processing
    vocab = ['<sos>', '<eos>'] + list("abcdefghijklmnopqrstuvwxyz' ")
    char_to_idx = {char: idx for idx, char in enumerate(vocab)}
    idx_to_char = {idx: char for idx, char in enumerate(vocab)}
    vocab_size = len(vocab)
    sos_idx = char_to_idx['<sos>']
    eos_idx = char_to_idx['<eos>']

    # Feature dimensions
    feature_dim = 128

    # Model parameters - better settings for this dataset
    hidden_size = 256
    num_layers = 2
    bidirectional = True
    dropout = 0.3  # More reasonable dropout

    # Training parameters
    batch_size = 16
    learning_rate = 0.722081  # Much more reasonable learning rate
    weight_decay = 1e-5
    num_epochs = 100

    # Label smoothing helps with overconfidence
    label_smoothing = 0.1

    # Curriculum learning parameters
    teacher_forcing_ratio_start = 1.0
    teacher_forcing_ratio_end = 0.5

    # Device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Feature processing fix - Add this function to correctly reshape features
def preprocess_features(features, feature_dim):
    """Properly reshape feature matrices to [time_steps, feature_dim]"""
    # Check dimensions and reshape if needed
    if features.shape[0] == features.shape[1] == feature_dim:
        # If it's a square matrix, it's likely a misformatted feature
        # Let's treat each row as a time step
        return features

    if len(features.shape) == 2:
        if features.shape[1] == feature_dim:
            # Already in [time, feature_dim] format
            return features
        elif features.shape[0] == feature_dim:
            # Convert from [feature_dim, time] to [time, feature_dim]
            return features.transpose(0, 1)

    # If we get a 1D array, reshape to [1, feature_dim]
    if len(features.shape) == 1 and features.shape[0] == feature_dim:
        return features.reshape(1, -1)

    # Return original if we can't determine the correct shape
    return features


In [None]:
class AudioFeatureDataset(Dataset):
    def __init__(self, csv_path, config, train=True):
        self.config = config
        self.train = train

        # Load mappings from CSV file
        self.df = pd.read_csv(csv_path)

        # Ensure feature path column and transcript column exist
        if 'feature_path' not in self.df.columns or 'transcript' not in self.df.columns:
            raise ValueError("CSV file must contain 'feature_path' and 'transcript' columns")

        # Data split
        if train:
            self.df = self.df.sample(frac=0.8, random_state=42)
        else:
            all_df = pd.read_csv(csv_path)
            train_df = all_df.sample(frac=0.8, random_state=42)
            self.df = all_df[~all_df.index.isin(train_df.index)]

        # Validate that all feature files exist
        valid_rows = []
        for idx, row in self.df.iterrows():
            feature_path = row['feature_path']
            if os.path.exists(feature_path):
                valid_rows.append(row)
            else:
                print(f"Warning: Feature file not found: {feature_path}")

        if valid_rows:
            self.df = pd.DataFrame(valid_rows).reset_index(drop=True)
            print(f"Dataset loaded with {len(self.df)} valid samples")
        else:
            raise ValueError("No valid feature files found. Check your feature paths in the CSV file.")

    def augment_features(self, features):
        """Simple augmentation techniques for audio features"""
        if self.train and np.random.random() < 0.7:
            # Time masking
            if np.random.random() < 0.5:
                time_mask_size = max(1, int(features.shape[0] * 0.05))
                start = np.random.randint(0, max(1, features.shape[0] - time_mask_size))
                features[start:start+time_mask_size, :] = 0

            # Feature masking
            if np.random.random() < 0.5:
                freq_mask_size = max(1, int(features.shape[1] * 0.05))
                start = np.random.randint(0, max(1, features.shape[1] - freq_mask_size))
                features[:, start:start+freq_mask_size] = 0

            # Slight Gaussian noise
            if np.random.random() < 0.5:
                noise = torch.randn_like(features) * 0.01
                features = features + noise

        return features

    def __len__(self):
        return len(self.df)

    # Modify the AudioFeatureDataset.__getitem__ method to use the preprocessing function
    def __getitem__(self, idx):
        feature_path = self.df.iloc[idx]['feature_path']
        transcript = str(self.df.iloc[idx]['transcript']).lower()

        # Load features with proper checks
        features = np.load(feature_path)
        features = torch.tensor(features, dtype=torch.float32)

        # Use the new preprocessing function
        features = preprocess_features(features, self.config.feature_dim)

        # Apply augmentation
        features = self.augment_features(features)

        # Process transcript - Correct indentation here
        text_indices = [self.config.sos_idx] + [self.config.char_to_idx[c] for c in transcript if c in self.config.char_to_idx] + [self.config.eos_idx]
        text_indices = torch.tensor(text_indices, dtype=torch.long)

        return features, text_indices, features.shape[0], len(text_indices)

In [None]:
def collate_fn(batch):
    batch.sort(key=lambda x: x[2], reverse=True)
    features, text_indices, feature_lengths, text_lengths = zip(*batch)

    features_padded = pad_sequence(features, batch_first=True)
    text_padded = pad_sequence(text_indices, batch_first=True)

    feature_lengths = torch.tensor(feature_lengths)
    text_lengths = torch.tensor(text_lengths)

    return features_padded, text_padded, feature_lengths, text_lengths


In [None]:
class AudioEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, bidirectional, dropout):
        super(AudioEncoder, self).__init__()

        # Input normalization
        self.layer_norm = nn.LayerNorm(input_dim)
        self.batch_norm = nn.BatchNorm1d(input_dim)

        self.lstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            bidirectional=bidirectional,
            dropout=dropout if num_layers > 1 else 0,
            batch_first=True
        )

        self.output_dim = hidden_dim * 2 if bidirectional else hidden_dim
        self.reset_parameters()

    def reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, x, lengths):
        # Apply layer normalization
        x = self.layer_norm(x)

        # Apply batch normalization along feature dimension
        batch_size, time_steps, feature_dim = x.size()
        x_reshaped = x.reshape(-1, feature_dim)
        x_normalized = self.batch_norm(x_reshaped)
        x = x_normalized.reshape(batch_size, time_steps, feature_dim)

        packed = nn.utils.rnn.pack_padded_sequence(
            x, lengths, batch_first=True, enforce_sorted=True
        )

        # Use gradient checkpointing to save memory
        def create_custom_forward(module):
            def custom_forward(*inputs):
                outputs = module(*inputs)
                return outputs
            return custom_forward

        outputs, _ = checkpoint(create_custom_forward(self.lstm), packed, use_reentrant=False)
        outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)

        return outputs


In [None]:
class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim):
        super(Attention, self).__init__()
        self.encoder_attn = nn.Linear(encoder_dim, decoder_dim)
        self.decoder_attn = nn.Linear(decoder_dim, decoder_dim)
        self.full_attn = nn.Linear(decoder_dim, 1)
        self.reset_parameters()

    def reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, decoder_hidden, encoder_outputs):
        # Ensure decoder_hidden is properly shaped
        if len(decoder_hidden.shape) == 1:
            decoder_hidden = decoder_hidden.unsqueeze(0)

        # Transform encoder outputs [batch_size, seq_len, encoder_dim] -> [batch_size, seq_len, decoder_dim]
        encoder_transform = self.encoder_attn(encoder_outputs)

        # Transform decoder hidden [batch_size, decoder_dim] -> [batch_size, 1, decoder_dim]
        decoder_transform = self.decoder_attn(decoder_hidden).unsqueeze(1)

        # Calculate attention scores with better numerical stability
        attn_scores = self.full_attn(torch.tanh(encoder_transform + decoder_transform))

        # Get attention weights through softmax with temperature for sharper focus
        attn_weights = torch.softmax(attn_scores / 0.5, dim=1)

        # Use weights to get context vector [batch_size, 1, encoder_dim]
        context = torch.bmm(attn_weights.transpose(1, 2), encoder_outputs)

        # Return context with consistent shape [batch_size, encoder_dim]
        return context.squeeze(1), attn_weights


In [None]:
class DecoderRNN(nn.Module):
    def __init__(self, vocab_size, hidden_dim, encoder_dim, num_layers, dropout):
        super(DecoderRNN, self).__init__()

        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        self.lstm = nn.LSTM(
            input_size=hidden_dim + encoder_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            dropout=dropout if num_layers > 1 else 0,
            batch_first=True
        )
        self.attention = Attention(encoder_dim, hidden_dim)

        # Add a projection layer with dropout for better generalization
        self.dropout = nn.Dropout(dropout)
        self.output_projection = nn.Sequential(
            nn.Linear(hidden_dim + encoder_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, vocab_size)
        )

        self.hidden_dim = hidden_dim
        self.encoder_dim = encoder_dim
        self.reset_parameters()

    def reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward_step(self, decoder_input, decoder_hidden, encoder_outputs):
        # Ensure decoder_input has proper dimensions [batch_size, 1]
        if len(decoder_input.shape) == 1:
            decoder_input = decoder_input.unsqueeze(1)

        # Get embeddings [batch_size, 1, hidden_dim]
        embedded = self.embedding(decoder_input)

        # Get context vector and attention weights
        context, attention_weights = self.attention(decoder_hidden[0][-1], encoder_outputs)

        # Ensure context has correct dimensions to match embedded input
        # If context is [batch_size, encoder_dim], reshape to [batch_size, 1, encoder_dim]
        if len(context.shape) == 2:
            context = context.unsqueeze(1)

        # Make sure embedded and context have the same number of dimensions
        # embedded shape should be [batch_size, 1, hidden_dim]
        # context shape should be [batch_size, 1, encoder_dim]

        # Concatenate embedded input and context vector along the feature dimension
        lstm_input = torch.cat([embedded, context], dim=-1)

        # Forward through LSTM
        output, hidden = self.lstm(lstm_input, decoder_hidden)

        # Apply dropout to the output
        output = self.dropout(output)

        # Ensure context and output have compatible dimensions for concatenation
        if len(context.shape) != len(output.shape):
            if len(context.shape) > len(output.shape):
                context = context.squeeze(1)
            else:
                output = output.squeeze(1)

        # Concatenate output and context for prediction
        output_context = torch.cat([output, context], dim=-1)
        output = self.output_projection(output_context)

        return output, hidden, attention_weights

    def init_hidden(self, batch_size):
        device = next(self.parameters()).device
        h0 = torch.zeros(self.lstm.num_layers, batch_size, self.hidden_dim).to(device)
        c0 = torch.zeros(self.lstm.num_layers, batch_size, self.hidden_dim).to(device)
        return (h0, c0)


In [None]:
class Seq2SeqModel(nn.Module):
    def __init__(self, config):
        super(Seq2SeqModel, self).__init__()

        # Add batch normalization for input features
        self.batch_norm = nn.BatchNorm1d(config.feature_dim)

        self.encoder = AudioEncoder(
            input_dim=config.feature_dim,
            hidden_dim=config.hidden_size,
            num_layers=config.num_layers,
            bidirectional=config.bidirectional,
            dropout=config.dropout
        )

        encoder_dim = config.hidden_size * 2 if config.bidirectional else config.hidden_size

        # Add an extra projection layer before decoder
        self.encoder_projection = nn.Linear(encoder_dim, config.hidden_size)

        self.decoder = DecoderRNN(
            vocab_size=config.vocab_size,
            hidden_dim=config.hidden_size,
            encoder_dim=encoder_dim,
            num_layers=config.num_layers,
            dropout=config.dropout
        )

        self.config = config

    def forward(self, features, target_texts, feature_lengths, text_lengths, teacher_forcing_ratio=0.9):
        batch_size = features.size(0)
        max_text_length = target_texts.size(1)

        # Apply batch normalization along feature dimension
        batch_size, time_steps, feature_dim = features.size()
        features_reshaped = features.reshape(-1, feature_dim)
        features_normalized = self.batch_norm(features_reshaped)
        features = features_normalized.reshape(batch_size, time_steps, feature_dim)

        encoder_outputs = self.encoder(features, feature_lengths)
        decoder_hidden = self.decoder.init_hidden(batch_size)

        # Start with SOS token (0)
        decoder_input = torch.zeros(batch_size, 1, dtype=torch.long, device=self.config.device)

        outputs = torch.zeros(batch_size, max_text_length, self.config.vocab_size, device=self.config.device)

        for t in range(max_text_length):
            output, decoder_hidden, _ = self.decoder.forward_step(
                decoder_input, decoder_hidden, encoder_outputs
            )

            outputs[:, t:t+1] = output

            # Teaching forcing with high ratio for better learning
            use_teacher_forcing = torch.rand(1).item() < teacher_forcing_ratio

            if use_teacher_forcing and t < max_text_length - 1:
                decoder_input = target_texts[:, t:t+1]
            else:
                _, topi = output.topk(1)
                decoder_input = topi.squeeze(-1).detach()

        return outputs

    def predict(self, features, feature_lengths):
        batch_size = features.size(0)
        max_text_length = 200  # Increased max length for prediction

        # Apply batch normalization
        batch_size, time_steps, feature_dim = features.size()
        features_reshaped = features.reshape(-1, feature_dim)
        features_normalized = self.batch_norm(features_reshaped)
        features = features_normalized.reshape(batch_size, time_steps, feature_dim)

        encoder_outputs = self.encoder(features, feature_lengths)
        decoder_hidden = self.decoder.init_hidden(batch_size)

        # Start with SOS token (assuming token 0 is SOS)
        decoder_input = torch.full((batch_size, 1), self.config.sos_idx, dtype=torch.long, device=self.config.device) # Fixed: Removed extra indent
        predicted_indices = []
        attention_weights_list = []

        for t in range(max_text_length):
            output, decoder_hidden, attention_weights = self.decoder.forward_step(
                decoder_input, decoder_hidden, encoder_outputs
            )

            # Get the most likely next token
            topv, topi = output.topk(1)

            # Add prediction to list
            predicted_index = topi.squeeze().item()
            predicted_indices.append(predicted_index)
            attention_weights_list.append(attention_weights)

            # Stop at EOS token (assume 1 is EOS token)
            if predicted_index == self.config.eos_idx:
                break

            # Use predicted token as next input
            decoder_input = topi.detach()

        return predicted_indices, attention_weights_list


In [None]:
def beam_search_decode(model, features, feature_lengths, beam_width=5, max_length=200, config=None ):  # Remove config=None
    """Decode using beam search for better results"""
    model.eval()
    batch_size = features.shape[0]

    # Encode input
    encoder_outputs = model.encoder(features, feature_lengths)

    # Initialize beam search
    beams = [{'sequence': [0], 'score': 0.0, 'hidden': model.decoder.init_hidden(batch_size)}]
    finished_beams = []

    # Beam search loop
    for step in range(max_length):
        new_beams = []

        for beam in beams:
            # Get last token from sequence
            last_token = torch.tensor([beam['sequence'][-1]], dtype=torch.long, device=features.device)

            # Run through decoder
            output, hidden, _ = model.decoder.forward_step(
                last_token, beam['hidden'], encoder_outputs
            )

            # Get top k predictions
            logits = output[0]
            probs = F.log_softmax(logits, dim=-1)
            topk_probs, topk_indices = probs.topk(beam_width)

            # Create new beams
            for i in range(beam_width):
                token = topk_indices[i].item()
                log_prob = topk_probs[i].item()

                new_beam = {
                    'sequence': beam['sequence'] + [token],
                    'score': beam['score'] + log_prob,
                    'hidden': (hidden[0].clone(), hidden[1].clone()),
                }

                # If token is EOS, add to finished beams
                if token == 1:  # Assuming 1 is EOS token
                    finished_beams.append(new_beam)
                else:
                    new_beams.append(new_beam)

        # Sort by score and keep top beam_width
        new_beams = sorted(new_beams, key=lambda x: x['score'] / len(x['sequence']), reverse=True)
        beams = new_beams[:beam_width]

        # Early stopping
        if not beams or len(finished_beams) >= beam_width:
            break

    # Return best finished beam or best ongoing beam
    if finished_beams:
        best_beam = max(finished_beams, key=lambda x: x['score'] / len(x['sequence']))
    else:
        best_beam = max(beams, key=lambda x: x['score'] / len(x['sequence']))

    return best_beam['sequence'][1:]  # Remove SOS token



In [None]:
def train(model, train_loader, criterion, optimizer, device, epoch, config, clip=1.0):
    model.train()
    total_loss = 0

    # Implement curriculum learning - gradually reduce teacher forcing
    # Gradually decrease teacher forcing ratio
    teacher_forcing_ratio = max(
    config.teacher_forcing_ratio_end,
    config.teacher_forcing_ratio_start * (0.95 ** epoch)
    )

    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}")

    for batch_idx, (features, target_texts, feature_lengths, text_lengths) in enumerate(progress_bar):
        # Move data to device
        features = features.to(device)
        target_texts = target_texts.to(device)

        # Zero gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(features, target_texts, feature_lengths, text_lengths,
                       teacher_forcing_ratio=teacher_forcing_ratio)

        # Calculate loss - use masks for variable length sequences
        mask = torch.zeros_like(target_texts, dtype=torch.bool)
        for i, length in enumerate(text_lengths):
            mask[i, :length] = 1

        # Reshape outputs and targets for loss calculation
        outputs = outputs.view(-1, outputs.size(-1))
        target_texts = target_texts.view(-1)
        mask = mask.view(-1)

        # Apply mask
        outputs = outputs[mask]
        target_texts = target_texts[mask]

        # Calculate loss
        loss = criterion(outputs, target_texts)

        # Backward pass
        loss.backward()

        # Clip gradients - use lower value to stabilize training
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        # Update parameters
        optimizer.step()

        # Update total loss
        total_loss += loss.item()

        # Update progress bar
        progress_bar.set_postfix({"loss": loss.item()})

    return total_loss / len(train_loader)

In [None]:
def evaluate(model, eval_loader, criterion, device, config):
    model.eval()
    total_loss = 0
    all_true_texts = []
    all_pred_texts = []

    with torch.no_grad():
        for features, target_texts, feature_lengths, text_lengths in tqdm(eval_loader, desc="Evaluating"):
            # Move data to device
            features = features.to(device)
            target_texts = target_texts.to(device)

            # Forward pass
            outputs = model(features, target_texts, feature_lengths, text_lengths, teacher_forcing_ratio=0)

            # Calculate loss with masking
            mask = torch.zeros_like(target_texts, dtype=torch.bool)
            for i, length in enumerate(text_lengths):
                mask[i, :length] = 1

            # Reshape for loss calculation
            outputs_flat = outputs.view(-1, outputs.size(-1))
            target_texts_flat = target_texts.view(-1)
            mask_flat = mask.view(-1)

            outputs_masked = outputs_flat[mask_flat]
            target_texts_masked = target_texts_flat[mask_flat]

            loss = criterion(outputs_masked, target_texts_masked)
            total_loss += loss.item()

            # Store for later analysis
            for i in range(min(5, len(features))):  # Limit to avoid excessive memory usage
                # Get actual predictions
                _, pred_indices = torch.max(outputs[i], dim=1)
                pred_indices = pred_indices.cpu().numpy()

                # Convert target tensors to integers before using as dictionary keys
                true_indices = target_texts[i][:text_lengths[i]].cpu().numpy()

                # Convert to text
                pred_text = ''.join([config.idx_to_char[idx] for idx in pred_indices[:text_lengths[i]]])
                true_text = ''.join([config.idx_to_char[idx] for idx in true_indices])

                all_true_texts.append(true_text)
                all_pred_texts.append(pred_text)

    # Analyze predictions
    if len(all_true_texts) > 0:
        analyze_predictions(all_true_texts, all_pred_texts)

    return total_loss / len(eval_loader)


In [None]:
def analyze_predictions(true_texts, pred_texts):
    """Analyze model predictions for common errors"""
    print("\nPrediction Analysis:")

    # Character accuracy
    total_chars = 0
    correct_chars = 0

    # Word accuracy (treat space as separator)
    total_words = 0
    correct_words = 0

    # Common errors
    error_patterns = {}

    for true, pred in zip(true_texts, pred_texts):
        # Character-level analysis
        min_len = min(len(true), len(pred))

        for i in range(min_len):
            total_chars += 1
            if true[i] == pred[i]:
                correct_chars += 1
            else:
                error = (true[i], pred[i])
                error_patterns[error] = error_patterns.get(error, 0) + 1

        # Word-level analysis
        true_words = true.strip().split()
        pred_words = pred.strip().split()

        for i in range(min(len(true_words), len(pred_words))):
            total_words += 1
            if true_words[i] == pred_words[i]:
                correct_words += 1

    # Print detailed stats
    if total_chars > 0:
        char_acc = correct_chars / total_chars * 100
        print(f"Character Accuracy: {char_acc:.2f}%")

    if total_words > 0:
        word_acc = correct_words / total_words * 100
        print(f"Word Accuracy: {word_acc:.2f}%")

    # Top errors
    top_errors = sorted(error_patterns.items(), key=lambda x: x[1], reverse=True)[:10]
    if top_errors:
        print("Top substitution errors (true → pred):")
        for (true_char, pred_char), count in top_errors:
            print(f"  '{true_char}' → '{pred_char}': {count} times")

    # Print example predictions
    print("\nExample predictions:")
    for i in range(min(5, len(true_texts))):
        print(f"True: \"{true_texts[i]}\"")
        print(f"Pred: \"{pred_texts[i]}\"")
        print("-" * 50)

In [None]:
def transcribe(model, feature_path, config):
    """Transcribe using pre-extracted features from a file"""
    model.eval()

    # Load pre-extracted features
    features = np.load(feature_path)
    features = torch.tensor(features, dtype=torch.float32)

    # Handle different feature shapes
    # Check dimensions and reshape if needed
    if features.shape[0] == features.shape[1] == config.feature_dim:
        # Most likely the features are incorrectly shaped
        # Reshape to something more sensible for audio (time, features)
        features = features.reshape(-1, config.feature_dim)

    # Ensure features are [time, feature_dim]
    if len(features.shape) == 2:
        if features.shape[1] == config.feature_dim:
            # Already in [time, feature_dim] format
            pass
        elif features.shape[0] == config.feature_dim:
            # Convert from [feature_dim, time] to [time, feature_dim]
            features = features.transpose(0, 1)
        else:
            raise ValueError(f"Unexpected feature shape: {features.shape}")
    else:
        raise ValueError(f"Unexpected feature shape: {features.shape}")

    # Prepare for model
    feature_length = features.shape[0]
    features = features.unsqueeze(0).to(config.device)  # Add batch dimension

    with torch.no_grad():
        # Use beam search for better results
        predicted_indices = beam_search_decode(model, features, torch.tensor([feature_length]), beam_width=5, config=config) # Fixed: Indented this line

    # Convert indices to text
    transcription = "".join([config.idx_to_char.get(idx, "") for idx in predicted_indices])

    return transcription


In [None]:
def early_stopping(val_losses, patience=10):
    """
    Check if training should stop early based on validation loss plateau
    Returns True if training should stop, False otherwise
    """
    # Need at least patience+1 epochs of validation data
    if len(val_losses) <= patience:
        return False

    # Check if validation loss hasn't improved for the last 'patience' epochs
    best_loss_idx = val_losses.index(min(val_losses))
    return len(val_losses) - best_loss_idx > patience


In [None]:
def load_and_use_model(model_path, feature_path, config=None):
    """Load a trained model and use it to transcribe audio"""
    # Load the saved model
    checkpoint = torch.load(model_path)

    if config is None:
        # Create a config from the saved parameters
        config = Config()
        config.feature_dim = checkpoint['config']['feature_dim']
        config.hidden_size = checkpoint['config']['hidden_size']
        config.num_layers = checkpoint['config']['num_layers']
        config.bidirectional = checkpoint['config']['bidirectional']
        config.dropout = checkpoint['config']['dropout']
        config.vocab_size = checkpoint['config']['vocab_size']
        config.char_to_idx = checkpoint['config']['char_to_idx']
        config.idx_to_char = checkpoint['config']['idx_to_char']

    # Initialize model
    model = Seq2SeqModel(config).to(config.device)

    # Load weights
    model.load_state_dict(checkpoint['model_state_dict'])

    # Set to evaluation mode
    model.eval()

    # Transcribe
    transcription, _ = transcribe(model, feature_path, config)

    return transcription

In [None]:
def find_learning_rate(model, train_loader, criterion, device, start_lr=1e-7, end_lr=1, num_steps=100):
    """Find optimal learning rate by exponentially increasing it"""
    model.train()

    # Create log-spaced learning rates
    lrs = np.logspace(np.log10(start_lr), np.log10(end_lr), num_steps)
    losses = []

    # Initialize optimizer with small learning rate
    optimizer = optim.AdamW(model.parameters(), lr=start_lr)

    # Store original model state
    original_state = {k: v.clone() for k, v in model.state_dict().items()}

    # Get a batch of data
    data_iter = iter(train_loader)

    # Iterate through learning rates
    for i, lr in enumerate(tqdm(lrs, desc="Finding optimal learning rate")):
        # Update learning rate
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        try:
            # Try to get a batch, if we run out, reset the iterator
            try:
                features, target_texts, feature_lengths, text_lengths = next(data_iter)
            except StopIteration:
                data_iter = iter(train_loader)
                features, target_texts, feature_lengths, text_lengths = next(data_iter)

            features = features.to(device)
            target_texts = target_texts.to(device)

            # Zero gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(features, target_texts, feature_lengths, text_lengths)

            # Calculate loss with masking
            mask = torch.zeros_like(target_texts, dtype=torch.bool)
            for j, length in enumerate(text_lengths):
                mask[j, :length] = 1

            # Reshape for loss calculation
            outputs_flat = outputs.view(-1, outputs.size(-1))
            target_texts_flat = target_texts.view(-1)
            mask_flat = mask.view(-1)

            outputs_masked = outputs_flat[mask_flat]
            target_texts_masked = target_texts_flat[mask_flat]

            loss = criterion(outputs_masked, target_texts_masked)

            # Backward pass
            loss.backward()

            # Store loss
            current_loss = loss.item()
            losses.append(current_loss)

            # If loss explodes, stop
            if len(losses) > 1 and current_loss > 4 * losses[-2]:
                print(f"Loss exploded from {losses[-2]} to {current_loss} at learning rate {lr}. Stopping.")
                break

        except Exception as e:
            print(f"Error occurred at learning rate {lr}: {e}")
            break

    # Restore original model state
    model.load_state_dict(original_state)

    # If we didn't collect any losses, return a default value
    if not losses:
        print("Could not determine optimal learning rate. Using default value.")
        return 1e-3

    # Plot results if we have enough data points
    if len(losses) > 1:
        plt.figure(figsize=(10, 6))
        plt.plot(lrs[:len(losses)], losses)
        plt.xscale('log')
        plt.xlabel('Learning Rate')
        plt.ylabel('Loss')
        plt.title('Learning Rate Finder')
        plt.savefig('learning_rate_finder.png')

        # Find optimal learning rate (where loss drops the fastest)
        try:
            derivative = np.gradient(losses)
            optimal_idx = np.argmin(derivative)
            optimal_lr = lrs[optimal_idx]
            print(f"Suggested learning rate: {optimal_lr:.6f}")
            return optimal_lr
        except:
            # Fall back to a more robust method if gradient calculation fails
            min_loss_idx = np.argmin(losses)
            if min_loss_idx > 0:
                optimal_lr = lrs[min_loss_idx] / 10  # Conservative choice
                print(f"Falling back to conservative learning rate: {optimal_lr:.6f}")
                return optimal_lr

    # If all else fails, return a reasonable default
    print("Using default learning rate")
    return 1e-3

In [None]:
# Modify the main function to address the model loading issue
def main():
    # Initialize config
    config = Config()

    # Create datasets and dataloaders
    train_dataset = AudioFeatureDataset(
        csv_path=config.csv_path,
        config=config,
        train=True
    )

    val_dataset = AudioFeatureDataset(
        csv_path=config.csv_path,
        config=config,
        train=False
    )

    # Before creating DataLoader, print a few samples' feature shape
    print(f"Expected feature dimension: {config.feature_dim}")

    # Sample a few entries to verify dimensions
    for i in range(5):
        features, _, _, _ = train_dataset[i]
        print(f"Sample {i} feature shape: {features.shape}")

    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=2,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=2,
        pin_memory=True
    )

    # Print dataset statistics
    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")

    # Skip the learning rate finder - use a reasonable default instead
    config.learning_rate = 0.001

    # Initialize model
    model = Seq2SeqModel(config).to(config.device)
    print(f"Model initialized on {config.device}")

    # Loss and optimizer with label smoothing
    criterion = nn.CrossEntropyLoss(label_smoothing=config.label_smoothing)
    optimizer = optim.AdamW(
        model.parameters(),
        lr=config.learning_rate,
        weight_decay=config.weight_decay
    )

    # Learning rate scheduler with warmup and plateau reduction
    from torch.optim.lr_scheduler import ReduceLROnPlateau
    scheduler = ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.5,
        patience=5,
        verbose=True
    )

    # Training loop
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []
    no_improvement_count = 0
    early_stop_patience = 15  # More patience for early stopping

    for epoch in range(1, config.num_epochs + 1):
        # Train
        train_loss = train(model, train_loader, criterion, optimizer, config.device, epoch, config)
        train_losses.append(train_loss)

        # Evaluate
        val_loss = evaluate(model, val_loader, criterion, config.device, config)
        val_losses.append(val_loss)

        # Update learning rate scheduler
        scheduler.step(val_loss)

        print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")

        # Check if this is the best model so far
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            no_improvement_count = 0
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'config': {
                    'feature_dim': config.feature_dim,
                    'hidden_size': config.hidden_size,
                    'num_layers': config.num_layers,
                    'bidirectional': config.bidirectional,
                    'dropout': config.dropout,
                    'vocab_size': config.vocab_size,
                    'char_to_idx': config.char_to_idx,
                    'idx_to_char': config.idx_to_char
                }
            }, 'best_audio_transcription_model.pth')
            print("Saved new best model!")
        else:
            no_improvement_count += 1
            print(f"No improvement for {no_improvement_count} epochs")

        # Early stopping check
        if no_improvement_count >= early_stop_patience:
            print(f"Early stopping triggered after {epoch} epochs (no improvement for {early_stop_patience} epochs)")
            break

    # Plot training history
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training and Validation Loss')
    plt.savefig('training_history.png')
    plt.show()

    # Load the best model for testing with weights_only=False
    print("\nLoading best model for testing...")
    try:
        checkpoint = torch.load('best_audio_transcription_model.pth', weights_only=False)
        model.load_state_dict(checkpoint['model_state_dict'])

        # Test the model on a few samples
        print("\nTesting model on a few examples:")
        for i in range(min(5, len(val_dataset))):
            feature_path = val_dataset.df.iloc[i]['feature_path']
            true_transcription = val_dataset.df.iloc[i]['transcript']

            pred_transcription = transcribe(model, feature_path, config)

            print(f"\nSample {i+1}: {os.path.basename(feature_path)}")
            print(f"True: {true_transcription}")
            print(f"Pred: {pred_transcription}")
    except Exception as e:
        print(f"Error loading model: {e}")
        print("Continuing with the current model state without loading checkpoint...")

        # Test with the current model state
        print("\nTesting current model on a few examples:")
        for i in range(min(5, len(val_dataset))):
            feature_path = val_dataset.df.iloc[i]['feature_path']
            true_transcription = val_dataset.df.iloc[i]['transcript']

            pred_transcription = transcribe(model, feature_path, config)

            print(f"\nSample {i+1}: {os.path.basename(feature_path)}")
            print(f"True: {true_transcription}")
            print(f"Pred: {pred_transcription}")

In [None]:
if __name__ == "__main__":
    main()

Dataset loaded with 2404 valid samples




Dataset loaded with 601 valid samples
Expected feature dimension: 128
Sample 0 feature shape: torch.Size([128, 128])
Sample 1 feature shape: torch.Size([128, 128])
Sample 2 feature shape: torch.Size([128, 128])
Sample 3 feature shape: torch.Size([128, 128])
Sample 4 feature shape: torch.Size([128, 128])
Training samples: 2404
Validation samples: 601
Model initialized on cuda


Epoch 1: 100%|██████████| 151/151 [16:34<00:00,  6.59s/it, loss=2.58]
Evaluating: 100%|██████████| 38/38 [04:41<00:00,  7.41s/it]



Prediction Analysis:
Character Accuracy: 20.73%
Word Accuracy: 6.15%
Top substitution errors (true → pred):
  ' ' → 't': 339 times
  ' ' → 'h': 291 times
  ' ' → 'e': 284 times
  'e' → 'h': 206 times
  'e' → 't': 185 times
  'e' → ' ': 178 times
  't' → ' ': 154 times
  't' → 'h': 141 times
  'a' → 't': 140 times
  'o' → ' ': 130 times

Example predictions:
True: "<sos>i could die happily and that made me feel good<eos>"
Pred: "<sos>the the the the the the the the the the the the"
--------------------------------------------------
True: "<sos>are you going to live with your mother<eos>"
Pred: "<sos>the cout the the the the the the the th"
--------------------------------------------------
True: "<sos>the waitress was carrying an impressive amount of dinnerware but then an earthquake occurred and she dropped it all<eos>"
Pred: "<sos>the sout the the the the the the the the the the the the the the the the the the the the the the the the the the the"
-------------------------------------

Epoch 2: 100%|██████████| 151/151 [00:54<00:00,  2.78it/s, loss=2.47]
Evaluating: 100%|██████████| 38/38 [00:04<00:00,  8.58it/s]



Prediction Analysis:
Character Accuracy: 18.90%
Word Accuracy: 3.56%
Top substitution errors (true → pred):
  ' ' → 'e': 246 times
  'e' → ' ': 169 times
  ' ' → 't': 164 times
  ' ' → 'h': 142 times
  't' → ' ': 142 times
  ' ' → 's': 134 times
  'h' → ' ': 131 times
  'o' → ' ': 131 times
  'e' → 't': 114 times
  't' → 'e': 110 times

Example predictions:
True: "<sos>i could die happily and that made me feel good<eos>"
Pred: "<sos>i don't the sare the sare the boy the boy the b"
--------------------------------------------------
True: "<sos>are you going to live with your mother<eos>"
Pred: "<sos>i was the sare the boy the boy the boy "
--------------------------------------------------
True: "<sos>the waitress was carrying an impressive amount of dinnerware but then an earthquake occurred and she dropped it all<eos>"
Pred: "<sos>i don't the sare the sare the sare the boy the boy the boy the boy the boy the boy the boy the boy the boy the boy t"
-------------------------------------

Epoch 3: 100%|██████████| 151/151 [00:52<00:00,  2.88it/s, loss=2.24]
Evaluating: 100%|██████████| 38/38 [00:04<00:00,  7.60it/s]



Prediction Analysis:
Character Accuracy: 18.71%
Word Accuracy: 3.69%
Top substitution errors (true → pred):
  ' ' → 'e': 222 times
  'e' → ' ': 211 times
  ' ' → 't': 179 times
  ' ' → 'h': 173 times
  't' → ' ': 147 times
  'o' → ' ': 144 times
  'a' → ' ': 136 times
  'h' → ' ': 123 times
  ' ' → 'b': 117 times
  'n' → ' ': 101 times

Example predictions:
True: "<sos>i could die happily and that made me feel good<eos>"
Pred: "<sos>i was the see the boy the boy the boy the boy t"
--------------------------------------------------
True: "<sos>are you going to live with your mother<eos>"
Pred: "<sos>i was the see the boy the boy the boy t"
--------------------------------------------------
True: "<sos>the waitress was carrying an impressive amount of dinnerware but then an earthquake occurred and she dropped it all<eos>"
Pred: "<sos>i was the see the boy the boy the boy the boy the boy the boy the boy the boy the boy the boy the boy the boy the bo"
-------------------------------------

Epoch 4: 100%|██████████| 151/151 [00:53<00:00,  2.83it/s, loss=2.3]
Evaluating: 100%|██████████| 38/38 [00:04<00:00,  7.82it/s]



Prediction Analysis:
Character Accuracy: 19.15%
Word Accuracy: 3.71%
Top substitution errors (true → pred):
  'e' → ' ': 195 times
  ' ' → 't': 151 times
  ' ' → 'e': 148 times
  ' ' → 'i': 135 times
  ' ' → 'h': 133 times
  ' ' → 'd': 128 times
  'o' → ' ': 123 times
  ' ' → 'a': 117 times
  ' ' → 's': 108 times
  'a' → ' ': 107 times

Example predictions:
True: "<sos>i could die happily and that made me feel good<eos>"
Pred: "<sos>the said the said the said the said the said th"
--------------------------------------------------
True: "<sos>are you going to live with your mother<eos>"
Pred: "<sos>i don't to the said the boy to the boy<eos>"
--------------------------------------------------
True: "<sos>the waitress was carrying an impressive amount of dinnerware but then an earthquake occurred and she dropped it all<eos>"
Pred: "<sos>the said the said the said the said the said the said the said the said the said the said the said the said the said"
---------------------------------

Epoch 5: 100%|██████████| 151/151 [00:54<00:00,  2.79it/s, loss=2.15]
Evaluating: 100%|██████████| 38/38 [00:04<00:00,  8.51it/s]



Prediction Analysis:
Character Accuracy: 18.49%
Word Accuracy: 3.89%
Top substitution errors (true → pred):
  'e' → ' ': 219 times
  'h' → ' ': 151 times
  't' → ' ': 150 times
  ' ' → 't': 143 times
  ' ' → 'e': 139 times
  ' ' → 'h': 138 times
  'o' → ' ': 135 times
  ' ' → 'o': 131 times
  'a' → ' ': 130 times
  ' ' → 'b': 109 times

Example predictions:
True: "<sos>i could die happily and that made me feel good<eos>"
Pred: "<sos>i was the boy a sear the boy the boy the boy th"
--------------------------------------------------
True: "<sos>are you going to live with your mother<eos>"
Pred: "<sos>i was the boy to the boy to the boy<eos><eos>an"
--------------------------------------------------
True: "<sos>the waitress was carrying an impressive amount of dinnerware but then an earthquake occurred and she dropped it all<eos>"
Pred: "<sos>i was a could the boy a sear the boy the boy the boy the boy the boy the boy the boy the boy the boy the boy the boy"
-----------------------------

Epoch 6: 100%|██████████| 151/151 [00:53<00:00,  2.81it/s, loss=2.49]
Evaluating: 100%|██████████| 38/38 [00:05<00:00,  7.35it/s]



Prediction Analysis:
Character Accuracy: 19.71%
Word Accuracy: 4.02%
Top substitution errors (true → pred):
  ' ' → 't': 177 times
  'e' → ' ': 169 times
  ' ' → 'e': 152 times
  ' ' → 'h': 145 times
  't' → ' ': 145 times
  ' ' → 'o': 117 times
  'e' → 'o': 107 times
  'o' → ' ': 106 times
  'a' → ' ': 101 times
  's' → ' ': 101 times

Example predictions:
True: "<sos>i could die happily and that made me feel good<eos>"
Pred: "<sos>the boy to the boy the boy the boy the said<eos>the"
--------------------------------------------------
True: "<sos>are you going to live with your mother<eos>"
Pred: "<sos>the boy to the boy to the boy<eos><eos><eos>at<eos><eos><eos>th"
--------------------------------------------------
True: "<sos>the waitress was carrying an impressive amount of dinnerware but then an earthquake occurred and she dropped it all<eos>"
Pred: "<sos>the boy the said the boy the boy the said the boy the said<eos>the boy the said<eos>the boy the said<eos>the boy the said<eos>t

Epoch 7: 100%|██████████| 151/151 [00:53<00:00,  2.81it/s, loss=2.09]
Evaluating: 100%|██████████| 38/38 [00:04<00:00,  8.43it/s]



Prediction Analysis:
Character Accuracy: 18.99%
Word Accuracy: 4.41%
Top substitution errors (true → pred):
  'e' → ' ': 183 times
  ' ' → 'e': 176 times
  'o' → ' ': 140 times
  ' ' → 'h': 135 times
  ' ' → 'a': 128 times
  ' ' → 't': 126 times
  't' → ' ': 118 times
  'a' → ' ': 110 times
  'h' → ' ': 109 times
  'e' → 't': 99 times

Example predictions:
True: "<sos>i could die happily and that made me feel good<eos>"
Pred: "<sos>the e the said the boy and the said the boy tan"
--------------------------------------------------
True: "<sos>are you going to live with your mother<eos>"
Pred: "<sos>i was the boy to the boy the boy<eos><eos>one<eos>t"
--------------------------------------------------
True: "<sos>the waitress was carrying an impressive amount of dinnerware but then an earthquake occurred and she dropped it all<eos>"
Pred: "<sos>the e the said the boy and the boy and the said<eos>the sare the said<eos>the sare the said<eos>the sare the said<eos>the sare the s"
----------

Epoch 8: 100%|██████████| 151/151 [00:53<00:00,  2.80it/s, loss=2.21]
Evaluating: 100%|██████████| 38/38 [00:05<00:00,  7.01it/s]



Prediction Analysis:
Character Accuracy: 20.36%
Word Accuracy: 3.97%
Top substitution errors (true → pred):
  'e' → ' ': 258 times
  ' ' → 't': 244 times
  ' ' → 'e': 169 times
  't' → ' ': 168 times
  'a' → ' ': 160 times
  'o' → ' ': 155 times
  ' ' → 'o': 150 times
  'e' → 't': 149 times
  'i' → ' ': 129 times
  'h' → ' ': 128 times

Example predictions:
True: "<sos>i could die happily and that made me feel good<eos>"
Pred: "<sos>the  out the  out the  out the  out the  out th"
--------------------------------------------------
True: "<sos>are you going to live with your mother<eos>"
Pred: "<sos>the  out the  out the  out the boy took"
--------------------------------------------------
True: "<sos>the waitress was carrying an impressive amount of dinnerware but then an earthquake occurred and she dropped it all<eos>"
Pred: "<sos>the  out the  out the  out the  out the  out the  out the  out the boy and to the tor the sere to the tor the sere t"
-------------------------------------

Epoch 9: 100%|██████████| 151/151 [00:54<00:00,  2.77it/s, loss=2.38]
Evaluating: 100%|██████████| 38/38 [00:04<00:00,  8.60it/s]



Prediction Analysis:
Character Accuracy: 20.54%
Word Accuracy: 3.77%
Top substitution errors (true → pred):
  'e' → ' ': 220 times
  ' ' → 'e': 176 times
  't' → ' ': 172 times
  ' ' → 'o': 164 times
  ' ' → 't': 159 times
  ' ' → 'h': 136 times
  'a' → ' ': 129 times
  'e' → 'o': 128 times
  'o' → ' ': 124 times
  'i' → ' ': 114 times

Example predictions:
True: "<sos>i could die happily and that made me feel good<eos>"
Pred: "<sos>the boy a see the said the boy the said the boy"
--------------------------------------------------
True: "<sos>are you going to live with your mother<eos>"
Pred: "<sos>the boy the  out the boy to the boy<eos><eos>on"
--------------------------------------------------
True: "<sos>the waitress was carrying an impressive amount of dinnerware but then an earthquake occurred and she dropped it all<eos>"
Pred: "<sos>the boy a see the  on the boy the boy the boy the boy the boy the boy the boy the said<eos>the boy the said<eos>the boy the "
---------------------

Epoch 10: 100%|██████████| 151/151 [00:53<00:00,  2.83it/s, loss=2.41]
Evaluating: 100%|██████████| 38/38 [00:05<00:00,  6.99it/s]



Prediction Analysis:
Character Accuracy: 20.81%
Word Accuracy: 4.22%
Top substitution errors (true → pred):
  'e' → ' ': 202 times
  't' → ' ': 176 times
  ' ' → 'e': 163 times
  'o' → ' ': 131 times
  ' ' → 'o': 121 times
  'a' → ' ': 120 times
  ' ' → 'a': 113 times
  ' ' → 't': 110 times
  'r' → ' ': 106 times
  'h' → ' ': 102 times

Example predictions:
True: "<sos>i could die happily and that made me feel good<eos>"
Pred: "<sos>the boy was a see the boy and the boy and the b"
--------------------------------------------------
True: "<sos>are you going to live with your mother<eos>"
Pred: "<sos>ihe boy a see  on the boy took<eos>the boy<eos>"
--------------------------------------------------
True: "<sos>the waitress was carrying an impressive amount of dinnerware but then an earthquake occurred and she dropped it all<eos>"
Pred: "<sos>the boy was a see  the boy and the boy and the boy and the boy and the boy and the boy and the boy and the boy and t"
-----------------------------

Epoch 11: 100%|██████████| 151/151 [00:53<00:00,  2.81it/s, loss=2.56]
Evaluating: 100%|██████████| 38/38 [00:04<00:00,  8.68it/s]



Prediction Analysis:
Character Accuracy: 21.62%
Word Accuracy: 4.16%
Top substitution errors (true → pred):
  'e' → ' ': 309 times
  't' → ' ': 241 times
  'a' → ' ': 208 times
  'o' → ' ': 197 times
  'h' → ' ': 164 times
  ' ' → 'e': 158 times
  ' ' → 'o': 155 times
  'i' → ' ': 153 times
  'r' → ' ': 149 times
  'n' → ' ': 147 times

Example predictions:
True: "<sos>i could die happily and that made me feel good<eos>"
Pred: "<sos>the  oy a s the  ore   the  ore   the  ore  the"
--------------------------------------------------
True: "<sos>are you going to live with your mother<eos>"
Pred: "<sos>i was  the  ore  on the boy to the boy<eos>"
--------------------------------------------------
True: "<sos>the waitress was carrying an impressive amount of dinnerware but then an earthquake occurred and she dropped it all<eos>"
Pred: "<sos>the  oo   a   the  ore   the  ore  to the boy and the boy and the boy and the boy and the boy and the boy and the bo"
---------------------------------

Epoch 12: 100%|██████████| 151/151 [00:53<00:00,  2.81it/s, loss=2.37]
Evaluating: 100%|██████████| 38/38 [00:05<00:00,  7.60it/s]



Prediction Analysis:
Character Accuracy: 23.91%
Word Accuracy: 4.27%
Top substitution errors (true → pred):
  'e' → ' ': 511 times
  't' → ' ': 388 times
  'a' → ' ': 325 times
  'o' → ' ': 293 times
  's' → ' ': 252 times
  'i' → ' ': 249 times
  'n' → ' ': 241 times
  'h' → ' ': 233 times
  'r' → ' ': 226 times
  'l' → ' ': 147 times

Example predictions:
True: "<sos>i could die happily and that made me feel good<eos>"
Pred: "<sos>the  ee      a   the  e     the  e    the  e   "
--------------------------------------------------
True: "<sos>are you going to live with your mother<eos>"
Pred: "<sos>i  an       t to   the  e t to   a   th"
--------------------------------------------------
True: "<sos>the waitress was carrying an impressive amount of dinnerware but then an earthquake occurred and she dropped it all<eos>"
Pred: "<sos>the  oo   a    a    the  e    a   the  e    a   the  ee    a   the  ee    a   the  ee    a   the  ee    a   the  ee "
-------------------------------------

Epoch 13: 100%|██████████| 151/151 [00:54<00:00,  2.78it/s, loss=2.77]
Evaluating: 100%|██████████| 38/38 [00:04<00:00,  8.37it/s]



Prediction Analysis:
Character Accuracy: 23.09%
Word Accuracy: 3.86%
Top substitution errors (true → pred):
  'e' → ' ': 443 times
  't' → ' ': 318 times
  'a' → ' ': 285 times
  'o' → ' ': 252 times
  'r' → ' ': 217 times
  'i' → ' ': 216 times
  'h' → ' ': 216 times
  's' → ' ': 216 times
  'n' → ' ': 199 times
  'd' → ' ': 131 times

Example predictions:
True: "<sos>i could die happily and that made me feel good<eos>"
Pred: "<sos>the  e          the  an    the  o the  o the  o"
--------------------------------------------------
True: "<sos>are you going to live with your mother<eos>"
Pred: "<sos>i           t the  oon <eos>o the  oon<eos><eos><eos><eos><eos>"
--------------------------------------------------
True: "<sos>the waitress was carrying an impressive amount of dinnerware but then an earthquake occurred and she dropped it all<eos>"
Pred: "<sos>the  o             a   the  an   a   the  an   a   the  an  an  e the  an  an  e the  an  an  e the  an  an  e the  "
-------------

Epoch 14: 100%|██████████| 151/151 [00:53<00:00,  2.83it/s, loss=2.46]
Evaluating: 100%|██████████| 38/38 [00:04<00:00,  8.50it/s]



Prediction Analysis:
Character Accuracy: 25.01%
Word Accuracy: 3.25%
Top substitution errors (true → pred):
  'e' → ' ': 554 times
  't' → ' ': 431 times
  'a' → ' ': 347 times
  'o' → ' ': 334 times
  'n' → ' ': 273 times
  'i' → ' ': 272 times
  's' → ' ': 272 times
  'h' → ' ': 258 times
  'r' → ' ': 258 times
  'l' → ' ': 160 times

Example predictions:
True: "<sos>i could die happily and that made me feel good<eos>"
Pred: "<sos>the  oo                e  o  the  or        the"
--------------------------------------------------
True: "<sos>are you going to live with your mother<eos>"
Pred: "<sos>ihe  oo               e  oo  oo  oo the"
--------------------------------------------------
True: "<sos>the waitress was carrying an impressive amount of dinnerware but then an earthquake occurred and she dropped it all<eos>"
Pred: "<sos>the  oo                                     the  or         the  or       the  or       the  or      the  or    the "
-------------------------------------

Epoch 15: 100%|██████████| 151/151 [00:53<00:00,  2.82it/s, loss=2.62]
Evaluating: 100%|██████████| 38/38 [00:05<00:00,  7.03it/s]



Prediction Analysis:
Character Accuracy: 23.28%
Word Accuracy: 4.61%
Top substitution errors (true → pred):
  'e' → ' ': 498 times
  't' → ' ': 356 times
  'a' → ' ': 292 times
  'o' → ' ': 262 times
  'h' → ' ': 227 times
  's' → ' ': 225 times
  'i' → ' ': 222 times
  'n' → ' ': 222 times
  'r' → ' ': 218 times
  'd' → ' ': 131 times

Example predictions:
True: "<sos>i could die happily and that made me feel good<eos>"
Pred: "<sos>the  is a   t the  at  a   the  at the  ore  an"
--------------------------------------------------
True: "<sos>are you going to live with your mother<eos>"
Pred: "<sos>i   a            t to   the  oy toe <eos>oo"
--------------------------------------------------
True: "<sos>the waitress was carrying an impressive amount of dinnerware but then an earthquake occurred and she dropped it all<eos>"
Pred: "<sos>the  are  an   a   t the  are  an   an   an  an  and the  are the  are the  are the  are the  are the  are the  are "
---------------------------------

Epoch 16: 100%|██████████| 151/151 [00:53<00:00,  2.81it/s, loss=2.54]
Evaluating: 100%|██████████| 38/38 [00:04<00:00,  8.57it/s]



Prediction Analysis:
Character Accuracy: 24.84%
Word Accuracy: 4.37%
Top substitution errors (true → pred):
  'e' → ' ': 619 times
  't' → ' ': 452 times
  'a' → ' ': 380 times
  'o' → ' ': 340 times
  's' → ' ': 300 times
  'h' → ' ': 294 times
  'i' → ' ': 286 times
  'n' → ' ': 282 times
  'r' → ' ': 275 times
  'l' → ' ': 169 times

Example predictions:
True: "<sos>i could die happily and that made me feel good<eos>"
Pred: "<sos>the e the  an                   the  o the  oe "
--------------------------------------------------
True: "<sos>are you going to live with your mother<eos>"
Pred: "<sos>i                        t to the e<eos>ter"
--------------------------------------------------
True: "<sos>the waitress was carrying an impressive amount of dinnerware but then an earthquake occurred and she dropped it all<eos>"
Pred: "<sos>the e the  an                                   the  an  e  on  the  on  the  on  the  on  the  on  the  on  the  on"
---------------------------------

Epoch 17: 100%|██████████| 151/151 [00:54<00:00,  2.80it/s, loss=2.57]
Evaluating: 100%|██████████| 38/38 [00:05<00:00,  6.92it/s]



Prediction Analysis:
Character Accuracy: 22.99%
Word Accuracy: 3.19%
Top substitution errors (true → pred):
  'e' → ' ': 431 times
  't' → ' ': 329 times
  'a' → ' ': 277 times
  'o' → ' ': 264 times
  'h' → ' ': 229 times
  'n' → ' ': 212 times
  'i' → ' ': 209 times
  's' → ' ': 208 times
  'r' → ' ': 198 times
  ' ' → 'o': 144 times

Example predictions:
True: "<sos>i could die happily and that made me feel good<eos>"
Pred: "<sos>the boy s a   the  oo      the  oo  and the boy"
--------------------------------------------------
True: "<sos>are you going to live with your mother<eos>"
Pred: "<sos>ihe  oo            the  ooe<eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos>"
--------------------------------------------------
True: "<sos>the waitress was carrying an impressive amount of dinnerware but then an earthquake occurred and she dropped it all<eos>"
Pred: "<sos>the  oan             a   the  oan      the  oan      the  oan  o  e  o  e  o  e  on  o  e  on  and the  oan

Epoch 18: 100%|██████████| 151/151 [00:53<00:00,  2.82it/s, loss=2.65]
Evaluating: 100%|██████████| 38/38 [00:04<00:00,  8.46it/s]



Prediction Analysis:
Character Accuracy: 24.03%
Word Accuracy: 4.10%
Top substitution errors (true → pred):
  'e' → ' ': 473 times
  't' → ' ': 377 times
  'a' → ' ': 307 times
  'o' → ' ': 259 times
  'i' → ' ': 250 times
  'r' → ' ': 248 times
  's' → ' ': 229 times
  'h' → ' ': 225 times
  'n' → ' ': 213 times
  'd' → ' ': 146 times

Example predictions:
True: "<sos>i could die happily and that made me feel good<eos>"
Pred: "<sos>the boy a s a                   the  ore  the  "
--------------------------------------------------
True: "<sos>are you going to live with your mother<eos>"
Pred: "<sos>i e  a                the  ore the  ore"
--------------------------------------------------
True: "<sos>the waitress was carrying an impressive amount of dinnerware but then an earthquake occurred and she dropped it all<eos>"
Pred: "<sos>the  are  a              the  are the  an  an  the  are the  an  an  the  are the  an  an  the  are the  an  an  the"
-------------------------------------

Epoch 19: 100%|██████████| 151/151 [00:52<00:00,  2.88it/s, loss=2.83]
Evaluating: 100%|██████████| 38/38 [00:04<00:00,  8.10it/s]



Prediction Analysis:
Character Accuracy: 24.10%
Word Accuracy: 4.58%
Top substitution errors (true → pred):
  'e' → ' ': 508 times
  't' → ' ': 387 times
  'a' → ' ': 315 times
  'o' → ' ': 296 times
  'i' → ' ': 267 times
  'h' → ' ': 260 times
  's' → ' ': 248 times
  'n' → ' ': 236 times
  'r' → ' ': 232 times
  'd' → ' ': 147 times

Example predictions:
True: "<sos>i could die happily and that made me feel good<eos>"
Pred: "<sos>i  as                                          "
--------------------------------------------------
True: "<sos>are you going to live with your mother<eos>"
Pred: "<sos>i e  oo               e  oo  e  one<eos><eos><eos><eos>"
--------------------------------------------------
True: "<sos>the waitress was carrying an impressive amount of dinnerware but then an earthquake occurred and she dropped it all<eos>"
Pred: "<sos>the  are  a   a   a   a   a   a   the  an  an  an  an  an  an  an  an  an  an  an  an  an  an  an  an  an  an  an  a"
---------------------

Epoch 20: 100%|██████████| 151/151 [00:54<00:00,  2.78it/s, loss=2.49]
Evaluating: 100%|██████████| 38/38 [00:04<00:00,  7.71it/s]



Prediction Analysis:
Character Accuracy: 23.53%
Word Accuracy: 2.86%
Top substitution errors (true → pred):
  'e' → ' ': 492 times
  't' → ' ': 370 times
  'a' → ' ': 302 times
  'o' → ' ': 267 times
  'n' → ' ': 241 times
  'i' → ' ': 227 times
  'r' → ' ': 224 times
  's' → ' ': 222 times
  'h' → ' ': 210 times
  ' ' → 'e': 183 times

Example predictions:
True: "<sos>i could die happily and that made me feel good<eos>"
Pred: "<sos>the  oe                    the  ear    the  ore"
--------------------------------------------------
True: "<sos>are you going to live with your mother<eos>"
Pred: "<sos>ihe  oo             e  o  e  o  e e<eos><eos>er"
--------------------------------------------------
True: "<sos>the waitress was carrying an impressive amount of dinnerware but then an earthquake occurred and she dropped it all<eos>"
Pred: "<sos>the e was a e the  are  an  the  are  an  the  are the  are the  are the  are the  are the  are the  are the  are th"
-----------------------------

Epoch 21: 100%|██████████| 151/151 [00:53<00:00,  2.81it/s, loss=2.64]
Evaluating: 100%|██████████| 38/38 [00:04<00:00,  8.51it/s]



Prediction Analysis:
Character Accuracy: 24.94%
Word Accuracy: 3.08%
Top substitution errors (true → pred):
  'e' → ' ': 627 times
  't' → ' ': 476 times
  'a' → ' ': 392 times
  'o' → ' ': 332 times
  'n' → ' ': 312 times
  'h' → ' ': 301 times
  'i' → ' ': 295 times
  's' → ' ': 294 times
  'r' → ' ': 283 times
  'd' → ' ': 176 times

Example predictions:
True: "<sos>i could die happily and that made me feel good<eos>"
Pred: "<sos>the  oo                                        "
--------------------------------------------------
True: "<sos>are you going to live with your mother<eos>"
Pred: "<sos>i    ou t    e  e            e  ee  e  "
--------------------------------------------------
True: "<sos>the waitress was carrying an impressive amount of dinnerware but then an earthquake occurred and she dropped it all<eos>"
Pred: "<sos>the  are  a                                        e  an  an  ear  an  an  ear  an  an  ear  an  an  ear  an  an  ea"
-------------------------------------

Epoch 22: 100%|██████████| 151/151 [00:53<00:00,  2.85it/s, loss=2.5]
Evaluating: 100%|██████████| 38/38 [00:05<00:00,  6.89it/s]



Prediction Analysis:
Character Accuracy: 22.27%
Word Accuracy: 4.24%
Top substitution errors (true → pred):
  'e' → ' ': 418 times
  't' → ' ': 334 times
  'o' → ' ': 266 times
  'a' → ' ': 261 times
  'h' → ' ': 226 times
  's' → ' ': 216 times
  'i' → ' ': 211 times
  'n' → ' ': 209 times
  'r' → ' ': 192 times
  ' ' → 'e': 128 times

Example predictions:
True: "<sos>i could die happily and that made me feel good<eos>"
Pred: "<sos>i  an  t the  or                             th"
--------------------------------------------------
True: "<sos>are you going to live with your mother<eos>"
Pred: "<sos>i  as the  oo   a   the boy<eos><eos><eos>an<eos><eos><eos><eos><eos><eos><eos>"
--------------------------------------------------
True: "<sos>the waitress was carrying an impressive amount of dinnerware but then an earthquake occurred and she dropped it all<eos>"
Pred: "<sos>the  are  a   t the  are  an  the  are  an  the  ore the  ore the  ore the  ore the  ore the  ore the  ore the  ore

Epoch 23: 100%|██████████| 151/151 [00:53<00:00,  2.81it/s, loss=2.7]
Evaluating: 100%|██████████| 38/38 [00:04<00:00,  8.55it/s]



Prediction Analysis:
Character Accuracy: 23.87%
Word Accuracy: 3.12%
Top substitution errors (true → pred):
  'e' → ' ': 494 times
  't' → ' ': 384 times
  'a' → ' ': 309 times
  'o' → ' ': 280 times
  'i' → ' ': 238 times
  's' → ' ': 226 times
  'n' → ' ': 225 times
  'r' → ' ': 223 times
  'h' → ' ': 219 times
  'l' → ' ': 151 times

Example predictions:
True: "<sos>i could die happily and that made me feel good<eos>"
Pred: "<sos>the  oe   e   the  ore                         "
--------------------------------------------------
True: "<sos>are you going to live with your mother<eos>"
Pred: "<sos>i  ould  a   to   a   to  e to the  oot"
--------------------------------------------------
True: "<sos>the waitress was carrying an impressive amount of dinnerware but then an earthquake occurred and she dropped it all<eos>"
Pred: "<sos>the  are a   a   the  are and the  are and the  are the  are the  are the  are the  are the  are the  are the  are t"
-------------------------------------

Epoch 24: 100%|██████████| 151/151 [00:54<00:00,  2.79it/s, loss=2.42]
Evaluating: 100%|██████████| 38/38 [00:05<00:00,  6.88it/s]



Prediction Analysis:
Character Accuracy: 24.43%
Word Accuracy: 4.28%
Top substitution errors (true → pred):
  'e' → ' ': 574 times
  't' → ' ': 426 times
  'a' → ' ': 342 times
  'o' → ' ': 319 times
  'i' → ' ': 275 times
  'r' → ' ': 270 times
  's' → ' ': 270 times
  'h' → ' ': 262 times
  'n' → ' ': 257 times
  'd' → ' ': 155 times

Example predictions:
True: "<sos>i could die happily and that made me feel good<eos>"
Pred: "<sos>the  oe                                        "
--------------------------------------------------
True: "<sos>are you going to live with your mother<eos>"
Pred: "<sos>i                                  to  "
--------------------------------------------------
True: "<sos>the waitress was carrying an impressive amount of dinnerware but then an earthquake occurred and she dropped it all<eos>"
Pred: "<sos>the e the  e                                          e  e e the  eee the  eeer<eos><eos>n the sand<eos><eos><eos> an  e e the  eeer<eos><eos>n "
---------

Epoch 25:  94%|█████████▍| 142/151 [00:50<00:02,  3.04it/s, loss=2.86]