In [1]:
pip install transformers datasets librosa torch torchaudio scipy

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [2]:
import os
import pandas as pd
from datasets import Dataset, load_metric
import librosa
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, TrainingArguments, Trainer
from dataclasses import dataclass, field
from typing import Any, Dict, List, Union

# Load the CSV file
csv_file_path = "/home/5489/Downloads/archive/medical speech transcription and intent/Medical Speech, Transcription, and Intent/overview-of-recordings.csv"
df = pd.read_csv(csv_file_path)

# Paths to the folders
train_folder = "/home/5489/Downloads/archive/medical speech transcription and intent/Medical Speech, Transcription, and Intent/recordings/train"
test_folder = "/home/5489/Downloads/archive/medical speech transcription and intent/Medical Speech, Transcription, and Intent/recordings/test"
validate_folder = "/home/5489/Downloads/archive/medical speech transcription and intent/Medical Speech, Transcription, and Intent/recordings/validate"

# Function to get file paths from folder
def get_file_paths(folder):
    return [os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".wav")]

# Get file paths for each split
train_files = get_file_paths(train_folder)
test_files = get_file_paths(test_folder)
validate_files = get_file_paths(validate_folder)

# Function to create dataset from file paths
def create_dataset(file_paths, df):
    data = {"audio": [], "transcription": []}
    for file_path in file_paths:
        filename = os.path.basename(file_path)
        transcription = df[df["file_name"] == filename]["phrase"].values[0]
        data["audio"].append(file_path)
        data["transcription"].append(transcription)
    return Dataset.from_dict(data)

# Create datasets
train_dataset = create_dataset(train_files, df)
test_dataset = create_dataset(test_files, df)
validate_dataset = create_dataset(validate_files, df)

# Preprocess the datasets
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")

def preprocess(batch):
    audio, _ = librosa.load(batch["audio"], sr=16000)
    batch["input_values"] = processor(audio, sampling_rate=16000).input_values[0]
    with processor.as_target_processor():
        batch["labels"] = processor(batch["transcription"]).input_ids
    return batch

train_dataset = train_dataset.map(preprocess, remove_columns=["audio", "transcription"])
test_dataset = test_dataset.map(preprocess, remove_columns=["audio", "transcription"])
validate_dataset = validate_dataset.map(preprocess, remove_columns=["audio", "transcription"])

# Data collator
@dataclass
class DataCollatorCTCWithPadding:
    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                return_tensors="pt",
            )

        # Replace padding with -100 to ignore in loss
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

data_collator = DataCollatorCTCWithPadding(processor=processor)

# Initialize the Wav2Vec2 model
model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-base-960h",
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
)

# Define metrics
wer_metric = load_metric("wer")

def compute_metrics(pred):
    pred_logits = torch.tensor(pred.predictions)
    pred_ids = torch.argmax(pred_logits, dim=-1)
    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
    pred_str = processor.batch_decode(pred_ids)
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    return {"wer": wer}

# Clear GPU cache
torch.cuda.empty_cache()

# Training arguments
training_args = TrainingArguments(
    output_dir="/home/5489/checkpoints",
    group_by_length=True,
    per_device_train_batch_size=2,  # Further reduced batch size
    per_device_eval_batch_size=2,   # Further reduced batch size
    gradient_accumulation_steps=8,  # Increased gradient accumulation steps
    evaluation_strategy="steps",
    num_train_epochs=10,
    save_steps=500,
    eval_steps=500,
    logging_steps=500,
    learning_rate=3e-4,
    warmup_steps=500,
    save_total_limit=2,
    fp16=True,  # Enable mixed precision training
    push_to_hub=False,
)

# Trainer
trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=validate_dataset,
    tokenizer=processor.feature_extractor,
)

# Train the model
trainer.train()

# Evaluate the model
results = trainer.evaluate(test_dataset)
print(results)

# Save the model and processor
model.save_pretrained("/home/5489/checkpoints/wav2vec2-medical")
processor.save_pretrained("/home/5489/checkpoints/wav2vec2-medical")


2024-06-11 21:14:02.037592: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-11 21:14:02.037629: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-11 21:14:02.037659: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-06-11 21:14:02.044379: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Map:   0%|          | 0/5895 [00:00<?, ? examples/s]



Map:   0%|          | 0/381 [00:00<?, ? examples/s]

Map:   0%|          | 0/385 [00:00<?, ? examples/s]

Some weights of the model checkpoint at facebook/wav2vec2-base-960h were not used when initializing Wav2Vec2ForCTC: ['wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.masked_spec_embed', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You sho

  0%|          | 0/3680 [00:00<?, ?it/s]



{'loss': 2.7596, 'learning_rate': 0.00029039999999999996, 'epoch': 1.36}


  0%|          | 0/193 [00:00<?, ?it/s]

{'eval_loss': 0.9946462512016296, 'eval_wer': 1.0, 'eval_runtime': 14.4768, 'eval_samples_per_second': 26.594, 'eval_steps_per_second': 13.332, 'epoch': 1.36}




{'loss': 1.011, 'learning_rate': 0.00025547169811320755, 'epoch': 2.71}


  0%|          | 0/193 [00:00<?, ?it/s]

{'eval_loss': 1.246246099472046, 'eval_wer': 1.0, 'eval_runtime': 14.6166, 'eval_samples_per_second': 26.34, 'eval_steps_per_second': 13.204, 'epoch': 2.71}




{'loss': 1.1005, 'learning_rate': 0.0002093396226415094, 'epoch': 4.07}


  0%|          | 0/193 [00:00<?, ?it/s]

{'eval_loss': 1.7169311046600342, 'eval_wer': 1.0, 'eval_runtime': 14.2145, 'eval_samples_per_second': 27.085, 'eval_steps_per_second': 13.578, 'epoch': 4.07}




{'loss': 1.157, 'learning_rate': 0.00016339622641509433, 'epoch': 5.43}


  0%|          | 0/193 [00:00<?, ?it/s]

{'eval_loss': 1.7169311046600342, 'eval_wer': 1.0, 'eval_runtime': 14.7387, 'eval_samples_per_second': 26.122, 'eval_steps_per_second': 13.095, 'epoch': 5.43}




{'loss': 1.1389, 'learning_rate': 0.00011698113207547169, 'epoch': 6.78}


  0%|          | 0/193 [00:00<?, ?it/s]

{'eval_loss': 1.7169311046600342, 'eval_wer': 1.0, 'eval_runtime': 14.2329, 'eval_samples_per_second': 27.05, 'eval_steps_per_second': 13.56, 'epoch': 6.78}




{'loss': 1.1519, 'learning_rate': 7.113207547169811e-05, 'epoch': 8.14}


  0%|          | 0/193 [00:00<?, ?it/s]

{'eval_loss': 1.7169311046600342, 'eval_wer': 1.0, 'eval_runtime': 14.4107, 'eval_samples_per_second': 26.716, 'eval_steps_per_second': 13.393, 'epoch': 8.14}




{'loss': 1.1458, 'learning_rate': 2.4905660377358486e-05, 'epoch': 9.5}


  0%|          | 0/193 [00:00<?, ?it/s]

{'eval_loss': 1.7169311046600342, 'eval_wer': 1.0, 'eval_runtime': 14.7081, 'eval_samples_per_second': 26.176, 'eval_steps_per_second': 13.122, 'epoch': 9.5}




{'train_runtime': 3152.1999, 'train_samples_per_second': 18.701, 'train_steps_per_second': 1.167, 'train_loss': 1.341972243267557, 'epoch': 9.99}


  0%|          | 0/191 [00:00<?, ?it/s]

{'eval_loss': 1.8148024082183838, 'eval_wer': 1.0, 'eval_runtime': 13.6137, 'eval_samples_per_second': 27.987, 'eval_steps_per_second': 14.03, 'epoch': 9.99}
