<a href="https://colab.research.google.com/github/Moses05/CardiovascularDiseaseWebApp/blob/main/lingala_english_asr.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Table of Contents

In [161]:
%%capture
!pip install datasets

In [162]:
import os

if not (os.path.exists("lingala-english-asr")):
  !git clone https://github.com/Moses05/lingala-english-asr.git

In [163]:
base_path = "lingala-english-asr/LRSC/lingala"

train_audio_path = f"{base_path}/train/audio"
train_transcript_path = f"{base_path}/train/transcript.txt"

valid_audio_path = f"{base_path}/valid/audio"
valid_transcript_path = f"{base_path}/valid/transcript.txt"

In [164]:
manifest_path = f"{base_path}/manifest"

dict_txt = f"{manifest_path}/dict.ltr.txt"

train_letter = f"{manifest_path}/train.ltr"
train_tsv = f"{manifest_path}/train.tsv"
train_word = f"{manifest_path}/train.wrd"

valid_letter = f"{manifest_path}/valid.ltr"
valid_tsv = f"{manifest_path}/valid.tsv"
valid_word = f"{manifest_path}/valid.wrd"

In [165]:
import wave

def check_sample_rate(dir):

  wrongFramerate = []

  for wav in os.listdir(dir):
    path = f"{dir}/{wav}"

    if os.path.isfile(path):
      with wave.open(path, "rb") as wav_file:
        if wav_file.getframerate() != 16000:
          wrongFramerate.append(wav_file)

  return wrongFramerate


train_wrongSample = check_sample_rate(train_audio_path)
valid_wrongSample = check_sample_rate(valid_audio_path)

print(f"list of train audio files not 16000hz: {train_wrongSample}")
print(f"list of valid audio files not 16000hz: {valid_wrongSample}")

list of train audio files not 16000hz: []
list of valid audio files not 16000hz: []


In [166]:
from datasets import load_dataset, Dataset

def load_manifest_data(tsv_file, ltr_file, audio_path):
    data = {
        "path": [],
        "duration": [],
        "text": [],
    }

    # Load the .tsv file
    with open(tsv_file, 'r') as tsv_f:
        lines = tsv_f.readlines()[1:]  # Skip header
        # lines = tsv_f.readlines()[1:]  # Skip header
        for line in lines:
            parts = line.strip().split("\t")

            # Ensure two columns path and duration
            if len(parts) != 2:
              print(f"skipping malformed line: {line}")
              continue

            path, duration = parts
            full_path = os.path.join(audio_path, path) # prepend base path

            data["path"].append(full_path)
            # data["duration"].append(int(duration) / 1000)
            data["duration"].append(duration)

    # Load the .ltr file for transcriptions
    with open(ltr_file, 'r') as ltr_f:
        transcriptions = ltr_f.readlines()
        data["text"] = [trans.strip() for trans in transcriptions]

    return Dataset.from_dict(data)

# Load the training and validation datasets
train_dataset = load_manifest_data(train_tsv, train_letter, train_audio_path)
valid_dataset = load_manifest_data(valid_tsv, valid_letter, valid_audio_path)

In [167]:
print(f"train dataset \n{train_dataset} \n\nvalid dataset \n{valid_dataset}")

train dataset 
Dataset({
    features: ['path', 'duration', 'text'],
    num_rows: 2557
}) 

valid dataset 
Dataset({
    features: ['path', 'duration', 'text'],
    num_rows: 383
})


In [168]:
import random
import pandas as pd
from IPython.display import display, HTML
import numpy as np

def show_random_elements(dataset, num_examples=10):
  assert num_examples <= len(dataset), "More specified examples than dataset elements"

  picks = np.random.randint(len(dataset), size = num_examples)

  random_samples = []

  for pick in picks:
    random_samples.append(dataset[int(pick)])


  df = pd.DataFrame(random_samples)

  display(HTML(df.to_html()))

In [169]:
show_random_elements(train_dataset)

Unnamed: 0,path,duration,text
0,lingala-english-asr/LRSC/lingala/train/audio/david_221011-141252_lin_359_elicit_12.wav,50457,y ɔ k a | y e | o m o n i |
1,lingala-english-asr/LRSC/lingala/train/audio/edimon_221010-123027_lin_359_elicit_96.wav,70785,n a | o y ɔ | n y ɔ n s o | n a w u t i | k o t a n g e l a | p e | k o y e b i s a | b i n o |
2,lingala-english-asr/LRSC/lingala/train/audio/kerene_221011-105639_lin_359_elicit_136.wav,59895,b a m o n a | m ɔ t ɔ | a z a l i | l i s u s u | k o n i n g a n a | t e |
3,lingala-english-asr/LRSC/lingala/train/audio/david_221011-142032_lin_359_elicit_25.wav,80949,b a k o n z i | y a | m b o k a | n d e | m a w a | m i n g i |
4,lingala-english-asr/LRSC/lingala/train/audio/emma_221010-142655_lin_359_elicit_59.wav,60984,b a y e b i s a k i | n g a i | q u e | o b o t e r | p r o f | n a | e t e y e l o |
5,lingala-english-asr/LRSC/lingala/train/audio/kev_221010-150830_lin_359_elicit_9.wav,79860,s o k i | o z a l i | n a | y ɔ | m a y ε l ε | t e | t i k a | k o l u k a | b a r a i s o n s |
6,lingala-english-asr/LRSC/lingala/train/audio/rebecca_221011-120830_lin_359_elicit_67.wav,41019,a k o t i | n a | d e p o t | y a | m i b a l e |
7,lingala-english-asr/LRSC/lingala/train/audio/kangamotema_221011-145514_lin_359_elicit_1.wav,97647,a s i l a k i | k o z a l a | e n | c o n t a c t | n a | g a r ç o n | m o k o | y a | k i n g a s a n i |
8,lingala-english-asr/LRSC/lingala/train/audio/exauce1_221010-164503_lin_359_elicit_9.wav,121242,e z a l i | t e | m p o | n a | s o u c i s | y a | k o b o t a | b a b o s a n a | m w a n a | o y ɔ | b a z a l a k i | n a | y e |
9,lingala-english-asr/LRSC/lingala/train/audio/exauce2_221011-153959_lin_359_elicit_22.wav,92928,m ɔ t ɔ | n y ɔ n s o | o y ɔ | a z a l a k i | k o b i m a | l i b a n d a | a k o k a k i | k o b e l a | m p e n z a |


In [170]:
vocab_dict = dict()

with open(dict_txt, "r") as f:
  for line in f:
    char, index = line.strip().split()
    index = int(index)

    vocab_dict[char] = index

vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)

print(vocab_dict)
print(f"length of vocab dict: {len(vocab_dict)}")

{'’': 0, 'x': 1, 'a': 2, 'w': 3, 'à': 4, 'f': 5, '-': 6, 'q': 7, '5': 8, 'ç': 9, 'd': 10, 'î': 11, 'j': 12, 'e': 13, '0': 14, 'g': 15, 's': 16, 'o': 17, 'c': 18, "'": 19, 'h': 20, '3': 21, 't': 22, 'l': 23, 'ǎ': 24, 'r': 25, 'ε': 26, 'ê': 27, '8': 28, 'y': 29, 'n': 30, '|': 31, 'u': 32, 'ɔ': 33, 'z': 34, 'ɛ': 35, 'k': 36, 'm': 37, 'å': 38, 'v': 39, 'i': 40, 'p': 41, 'b': 42, '[UNK]': 43, '[PAD]': 44}
length of vocab dict: 45


In [171]:
import json

with open("vocab.json", "w") as vocab_file:
  json.dump(vocab_dict, vocab_file)

print(os.path.exists("vocab.json"))

True


In [172]:
from transformers import Wav2Vec2CTCTokenizer

tokenizer = Wav2Vec2CTCTokenizer("vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")

In [173]:
from transformers import Wav2Vec2FeatureExtractor

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalise=True, return_attention_mask=True)

In [174]:
from transformers import Wav2Vec2Processor

processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

In [175]:
from datasets import Audio

def raw_audio(dataset):
  if "path" in dataset.column_names:
    dataset = dataset.rename_column("path", "audio")

    dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))

  return dataset

train_dataset = raw_audio(train_dataset)
valid_dataset = raw_audio(valid_dataset)

In [176]:
print(train_dataset[0])

{'audio': {'path': 'lingala-english-asr/LRSC/lingala/train/audio/exauce1_221010-164503_lin_359_elicit_73.wav', 'array': array([ 0.        ,  0.        ,  0.        , ...,  0.00247192,
       -0.0140686 , -0.01870728]), 'sampling_rate': 16000}, 'duration': '71874', 'text': 'm ɔ t ɔ | a l o b i s a | y e | t e | t i i | n t a n g o | p o n d u | e k o b e l a |'}


In [185]:
import IPython.display as ipd
import random

rand_int = random.randint(0, len(train_dataset)-1)


print(train_dataset[rand_int]['audio']['path'])
print(train_dataset[rand_int]["text"])
print(f"Shape: {train_dataset[rand_int]['audio']['array'].shape}")
ipd.Audio(data=train_dataset[rand_int]["audio"]["array"], autoplay=True, rate=16000)

lingala-english-asr/LRSC/lingala/train/audio/rebecca_221011-120830_lin_359_elicit_51.wav
a k e y i | e s i k a | b a | t a x i | o y ɔ | b a b e n g a k a | c e n t c e n t | e t e l e m a k i |
Shape: (139392,)


In [None]:
# def load_vocab(vocab_path):
#     char_to_index = {}
#     index_to_char = {}

#     with open(vocab_path, "r") as f:
#         for line in f:
#             char, index = line.strip().split()
#             index = int(index)
#             char_to_index[char] = index
#             index_to_char[index] = char

#     # Add the <blank> token if not already present
#     if "<blank>" not in char_to_index:
#         char_to_index["<blank>"] = len(char_to_index)
#         index_to_char[len(index_to_char)] = "<blank>"

#     return char_to_index, index_to_char

# char_to_index, index_to_char = load_vocab(dict_txt)
# vocab_size = len(char_to_index)

# char_to_index, index_to_char = load_vocab(dict_txt)
# vocab_size = len(char_to_index)

# print(f"Loaded vocabulary with {vocab_size} characters")

Loaded vocabulary with 44 characters


In [None]:
# def text_to_ids(text, char_to_index):
#     return [char_to_index[char] for char in text if char in char_to_index]

# def add_labels(batch):
#     # Convert text into numerical labels
#     batch["labels"] = text_to_ids(batch["text"], char_to_index)
#     return batch

In [None]:
import torchaudio

def preprocess_audio(batch):
  waveform, sample_rate = torchaudio.load(batch["path"])

  if waveform.shape[0] > 1:
    waveform = torch.mean(waveform, dim=0, keepdim=True)

  if sample_rate != 16000:
    resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
    waveform = resampler(waveform)

  waveform = waveform / torch.max(torch.abs(waveform))

  inputs = feature_extractor(
      waveform.squeeze().numpy(),
      sampling_rate=16000,
      return_tensors="pt",
      padding=True,
  )

  batch["input_values"] = inputs.input_values[0].float()
  return batch

In [None]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
  input_values = [
      torch.tensor(sample["input_values"], dtype=torch.float32)
      if not isinstance(sample["input_values"], torch.Tensor)
      else sample["input_values"].float()
      for sample in batch
  ]

  input_values_padded = pad_sequence(input_values, batch_first=True, padding_value=0.0)

  labels = [torch.tensor(sample["labels"], dtype=torch.long) for sample in batch]
  labels_padded = pad_sequence(labels, batch_first=True, padding_value=-100)

  return {
      "input_values": input_values_padded,
      "labels": labels_padded,
      "input_lengths": torch.tensor([len(x) for x in input_values]),
      "label_lengths": torch.tensor([len(x) for x in labels])
  }

In [None]:
def truncate_audio(sample, max_length=80000):
  if len(sample["input_values"]) > max_length:
    sample["input_values"] = sample["input_values"][:max_length]
  return sample

In [None]:
def prepare_dataset(dataset):
  dataset = dataset.map(preprocess_audio, remove_columns=["path", "duration"])

  dataset = dataset.map(add_labels, remove_columns=["text"])

  dataset.set_format(type="torch", columns=["input_values", "labels"])

  dataset = truncate_audio(dataset)

  return dataset

In [None]:
train_dataset = prepare_dataset(train_dataset)
valid_dataset = prepare_dataset(valid_dataset)

Map:   0%|          | 0/2557 [00:00<?, ? examples/s]

Map:   0%|          | 0/2557 [00:00<?, ? examples/s]

Map:   0%|          | 0/383 [00:00<?, ? examples/s]

Map:   0%|          | 0/383 [00:00<?, ? examples/s]

In [None]:
from torch.utils.data import DataLoader

train_loader = DataLoader(
    train_dataset,
    batch_size=1,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=2
)

valid_loader = DataLoader(
    valid_dataset,
    batch_size=1,
    collate_fn=collate_fn,
    num_workers=2
)

In [None]:
for batch in train_loader:
    input_values = batch["input_values"]
    print(f"Batch Shape:", input_values.shape)
    break

  labels = [torch.tensor(sample["labels"], dtype=torch.long) for sample in batch]
  labels = [torch.tensor(sample["labels"], dtype=torch.long) for sample in batch]


Batch Shape: torch.Size([1, 71148])


In [None]:
torch.cuda.empty_cache()

In [None]:
!nvidia-smi

Mon Nov 18 02:56:46 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   68C    P8              11W /  70W |      0MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [None]:
import torch
import torch.nn as nn
from transformers import Wav2Vec2Model

class Wav2Vec2CTC(nn.Module):
  def __init__(self, model, vocab_size):
    super(Wav2Vec2CTC, self).__init__()
    self.feature_extractor = model
    self.ctc_head = nn.Linear(self.feature_extractor.config.hidden_size, vocab_size)

  def forward(self, input_values):
    features = self.feature_extractor(input_values).last_hidden_state

    logits = self.ctc_head(features)
    return logits

wav2vec2_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")


# wav2vec2_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-xlsr-53")

model = Wav2Vec2CTC(wav2vec2_model, vocab_size)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)



Wav2Vec2CTC(
  (feature_extractor): Wav2Vec2Model(
    (feature_extractor): Wav2Vec2FeatureEncoder(
      (conv_layers): ModuleList(
        (0): Wav2Vec2GroupNormConvLayer(
          (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
          (activation): GELUActivation()
          (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
        )
        (1-4): 4 x Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
        (5-6): 2 x Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
      )
    )
    (feature_projection): Wav2Vec2FeatureProjection(
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (projection): Linear(in_features=512, out_features=768, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (en

In [None]:
import torch.optim as optim

blank_index = char_to_index["<blank>"]
print(f"Blank index: {blank_index}")

criterion = nn.CTCLoss(blank=blank_index, zero_infinity=True)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

Blank index: 43


In [None]:
torch.cuda.empty_cache()

In [None]:
def train_model(model, train_loader, criterion, optimizer, num_epochs=5):
    model.train()

    for epoch in range(num_epochs):
        total_loss = 0

        for batch in train_loader:
            inputs = batch["input_values"].to(device)
            labels = batch["labels"].to(device)
            # inputs = batch["input_values"]
            # labels = batch["labels"]

            # Forward pass through the model
            logits = model(inputs)
            logits = logits.log_softmax(2).permute(1, 0, 2)  # Shape: (seq_len, batch, vocab_size)

            # Calculate input lengths based on model output
            input_lengths = torch.full((logits.size(1),), logits.size(0), dtype=torch.long)
            label_lengths = torch.sum(labels != -100, dim=1)

            # Zero gradients
            optimizer.zero_grad()

            # Compute the loss
            loss = criterion(logits, labels, input_lengths, label_lengths)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(train_loader)}")

# Run the training
train_model(model, train_loader, criterion, optimizer)

  labels = [torch.tensor(sample["labels"], dtype=torch.long) for sample in batch]
  labels = [torch.tensor(sample["labels"], dtype=torch.long) for sample in batch]


Epoch 1/5, Loss: 2.936726912139308


  labels = [torch.tensor(sample["labels"], dtype=torch.long) for sample in batch]
  labels = [torch.tensor(sample["labels"], dtype=torch.long) for sample in batch]


Epoch 2/5, Loss: 2.8926950453965983


  labels = [torch.tensor(sample["labels"], dtype=torch.long) for sample in batch]
  labels = [torch.tensor(sample["labels"], dtype=torch.long) for sample in batch]


Epoch 3/5, Loss: 2.8898955950477787


  labels = [torch.tensor(sample["labels"], dtype=torch.long) for sample in batch]
  labels = [torch.tensor(sample["labels"], dtype=torch.long) for sample in batch]


Epoch 4/5, Loss: 2.8900353307130984


  labels = [torch.tensor(sample["labels"], dtype=torch.long) for sample in batch]
  labels = [torch.tensor(sample["labels"], dtype=torch.long) for sample in batch]


Epoch 5/5, Loss: 2.8899185226725375


In [None]:
!pip install evaluate jiwer

Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting jiwer
  Downloading jiwer-3.0.5-py3-none-any.whl.metadata (2.7 kB)
Collecting rapidfuzz<4,>=3 (from jiwer)
  Downloading rapidfuzz-3.10.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Downloading evaluate-0.4.3-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jiwer-3.0.5-py3-none-any.whl (21 kB)
Downloading rapidfuzz-3.10.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m40.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rapidfuzz, jiwer, evaluate
Successfully installed evaluate-0.4.3 jiwer-3.0.5 rapidfuzz-3.10.1


In [None]:
from evaluate import load
import torch

wer_metric = load("wer")

def evaluate_model(model, valid_loader, index_to_char):
    model.eval()
    total_wer = 0
    num_batches = 0

    with torch.no_grad():
        for batch in valid_loader:
            inputs = batch["input_values"].to(device)
            labels = batch["labels"].to(device)

            # Forward pass through model to get logits
            logits = model(inputs).log_softmax(2)

            # Get predicted IDs from the logits
            predicted_ids = torch.argmax(logits, dim=-1)

            # Decode predictions and labels
            pred_texts = ["".join([index_to_char[i] for i in pred if i in index_to_char]) for pred in predicted_ids]
            label_texts = ["".join([index_to_char[i] for i in label if i in index_to_char]) for label in labels]

            # Debugging: Print out a few predictions and labels
            print("\nPredicted Texts:", pred_texts[:3])
            print("Label Texts:", label_texts[:3])

            # Check for empty labels
            non_empty_indices = [i for i, label in enumerate(label_texts) if label]
            if not non_empty_indices:
                print("Skipping batch due to empty labels.")
                continue

            # Filter out empty references
            pred_texts = [pred_texts[i] for i in non_empty_indices]
            label_texts = [label_texts[i] for i in non_empty_indices]

            # Compute WER for the current batch
            wer = wer_metric.compute(predictions=pred_texts, references=label_texts)
            print(f"Batch WER: {wer:.4f}")

            total_wer += wer
            num_batches += 1

    avg_wer = total_wer / num_batches if num_batches > 0 else float("inf")
    print(f"\nValidation WER: {avg_wer:.4f}")

# Run the evaluation
evaluate_model(model, valid_loader, index_to_char)

Downloading builder script:   0%|          | 0.00/4.49k [00:00<?, ?B/s]

  labels = [torch.tensor(sample["labels"], dtype=torch.long) for sample in batch]
  labels = [torch.tensor(sample["labels"], dtype=torch.long) for sample in batch]



Predicted Texts: ['']
Label Texts: ['']
Skipping batch due to empty labels.

Predicted Texts: ['']
Label Texts: ['']
Skipping batch due to empty labels.

Predicted Texts: ['']
Label Texts: ['']
Skipping batch due to empty labels.

Predicted Texts: ['']
Label Texts: ['']
Skipping batch due to empty labels.

Predicted Texts: ['']
Label Texts: ['']
Skipping batch due to empty labels.

Predicted Texts: ['']
Label Texts: ['']
Skipping batch due to empty labels.

Predicted Texts: ['']
Label Texts: ['']
Skipping batch due to empty labels.

Predicted Texts: ['']
Label Texts: ['']
Skipping batch due to empty labels.

Predicted Texts: ['']
Label Texts: ['']
Skipping batch due to empty labels.

Predicted Texts: ['']
Label Texts: ['']
Skipping batch due to empty labels.

Predicted Texts: ['']
Label Texts: ['']
Skipping batch due to empty labels.

Predicted Texts: ['']
Label Texts: ['']
Skipping batch due to empty labels.

Predicted Texts: ['']
Label Texts: ['']
Skipping batch due to empty labels.

In [None]:
print("Inspecting validation dataset after preprocessing:")
for i in range(5):
    input_values = valid_dataset[i]['input_values']
    labels = valid_dataset[i]['labels']
    print(f"Sample {i+1}:")
    print(f"Input Shape: {input_values.shape if isinstance(input_values, torch.Tensor) else 'Not a tensor'}, Labels Length: {len(labels)}")
    print(f"Decoded Labels: {''.join([index_to_char[c] for c in labels if c in index_to_char])}")
    print("\n")

Inspecting validation dataset after preprocessing:
Sample 1:
Input Shape: torch.Size([57354]), Labels Length: 25
Decoded Labels: 


Sample 2:
Input Shape: torch.Size([115071]), Labels Length: 90
Decoded Labels: 


Sample 3:
Input Shape: torch.Size([52635]), Labels Length: 12
Decoded Labels: 


Sample 4:
Input Shape: torch.Size([149193]), Labels Length: 71
Decoded Labels: 


Sample 5:
Input Shape: torch.Size([83127]), Labels Length: 46
Decoded Labels: 


