# AI4D Baamtu Datamation - Automatic Speech Recognition in WOLOF: Third Place Solution
By Geoffrey Frost, Matthew Baas and Kevin Eloff

**Disclaimer:** Much of this notebook was derived from this Hugging Face [blog post](https://www.google.com/search?q=hugging+faceblog+wav2vec2&client=ubuntu&hs=AJZ&channel=fs&sxsrf=ALeKk03DKQIae5Klbf2ezEi7kP0cv4hXow%3A1621943696618&ei=kOWsYKuoJaWEhbIP4t2GsAI&oq=hugging+faceblog+wav2vec2&gs_lcp=Cgdnd3Mtd2l6EAM6BwgAEEcQsAM6BwgjELACECc6BggAEA0QCjoECAAQDToGCAAQDRAeOgcIIRAKEKABUPQgWOAwYPMxaAFwAngBgAG9AogBlxuSAQYyLTEyLjGYAQCgAQGqAQdnd3Mtd2l6yAEIwAEB&sclient=gws-wiz&ved=0ahUKEwirsIeP4-TwAhUlQkEAHeKuASYQ4dUDCA0&uact=5) by Patrick von Platen.

## Imports
TODO: Talk about version info here

In [1]:
from pathlib import Path
import os
import warnings
import random

import pandas as pd
from datasets import ClassLabel, Dataset, DatasetDict
import librosa

import re
import json

import torch
from transformers import Wav2Vec2CTCTokenizer
from transformers import Wav2Vec2Processor

## Data preperation
- Load in Train.csv as pandas dataframes
- Remove bad transcriptions from training data
- Remove unneccsery cloumns
- Add audio file paths
- Convert both dataframes to a DatasetDict object (makes working with Hugging Face easier)
- Preproccess training transcriptions (remove some special tokens that are hard for the AM model to learn)
- Create char vocab dict and correspinding JSON file for the Wav2Vec2Processor 

In [3]:
wolof_train = pd.read_csv('Train.csv')

In [5]:
# remove bad transcriptions from df
wolof_train = wolof_train[wolof_train.down_votes == 0]

In [6]:
wolof_train = wolof_train.drop(["up_votes", "down_votes", "age", "gender"], axis = 1)

In [7]:
# Create dict of ID's and file paths
paths = {}
path = 'clips/'
for root, dirs, files in os.walk(path):
    for file in files:
        if file.endswith(".mp3"):
            audio_path = os.path.join(root, file)
            p = Path(audio_path)
            id = p.parts[-1].split('.')[0]
            paths[id] = audio_path

In [8]:
wolof_train['file'] = [paths[id] for id in list(wolof_train['ID'])]
wolof_test['file'] = [paths[id] for id in list(wolof_test['ID'])]
wolof_test = wolof_test.drop(["ID"], axis = 1)
wolof_train = wolof_train.drop(["ID"], axis = 1)

In [9]:
def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    display(HTML(df.to_html()))

In [11]:
wolof = DatasetDict({'train': Dataset.from_pandas(wolof_train[0:6000]), 'test': Dataset.from_pandas(wolof_train[6000:])})

In [12]:
wolof_test = DatasetDict({'test': Dataset.from_pandas(wolof_test)})

In [24]:
chars_to_ignore_regex = '[\"\?\.\!\-\;\:\(\)\,]'

def remove_special_characters(batch):
    batch["transcription"] = re.sub(chars_to_ignore_regex, '', batch["transcription"]).lower() + " "
    return batch

In [25]:
wolof = wolof.map(remove_special_characters)

HBox(children=(FloatProgress(value=0.0, max=6000.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=251.0), HTML(value='')))




In [26]:
show_random_elements(wolof['train'], num_examples=20)

Unnamed: 0,__index_level_0__,file,transcription
0,2034,clips/4ccea33037df8cbbe5b2bf5e3fd220c9b67be2256a42eda02feda16fb0a7b1ab096480d0ced0283493754d3fa98cba79c2e86013fabd79b737bc8347ea431adf.mp3,orca
1,229,clips/097c3cfeb70b3a4d89422b8748e3649fefa4f06b680dcf38d9f6d719b4879a39cc2bc67c562dce96e5a3481ca83c611378583991d3837041444901578a596a0a.mp3,avenue el mansour sy
2,3263,clips/7de2e7fca8b788bdb9f920acc363757d55be040109ad725445f999ccaea8ddaa4dab33390c2e4c1c17877afbac6856f5a291a52950a4b24abdc72f2f42db5f48.mp3,rond point jet d'eau
3,2539,clips/60ec84e2a019a78a21553bf483d647b9284dd0a1f009d79ae411e1c15eae6ef325b216e8ad1ed2c9ea43f383a9f2ac8ea4d77d3d676905f489f993de39c7b439.mp3,scat urbam
4,1580,clips/3c21d8c0a9f6d8985ce43936fb3a86779a87c009c380260565cfc75dced69ee586f9b7786841208ace6dba985ae039c32f93ca4fbbbc80d8aea316b1ec618da1.mp3,bargny extension
5,1912,clips/483f16fa03936f2826047ad0c5c6fa4ec657a90ab96517d068a15f28d703301fe17dd10560a206ff56cddd33b3450c763daab3f949b4fcb89b02dd5301b6d5f9.mp3,gare feroviére de dakar
6,3048,clips/7512bf7a889851f86c0e6d57220754c29f1263f581662f3ba5ea0be1f92d30061262455b9a5535c8fcc656582bd29051d6c2cffa5c838be896dfb3adf8d5501d.mp3,ndeureuhlou
7,5302,clips/ca0d2826467873315e2f90d1ade0831cf298d560bc94e61215ac6875f1ea659086a43796f97b8577f746c6b06ef83b8c4c0f6560efee375e2e0c1dc02bb6ad5f.mp3,marché yeumbeul
8,5219,clips/c74752dd88feaf5c20bd8f69c9d0ece4c37c256d6de43788d7ea7cc8f14f10627d52759e9d12982e465d17e076d80897dd1229cdeed221e5f8192bdcd2f3519b.mp3,sococim
9,4731,clips/b50f16524d06796d00934ba606fd819b82146440eb6d143e20d1dcb9c15262194a63cecede7de4b50aa474f29ab0ed9b036759c449322a8ff8b7c3d94530a395.mp3,cité gendarmerie


In [27]:
def extract_all_chars(batch):
    all_text = " ".join(batch["transcription"])
    vocab = list(set(all_text))
    return {"vocab": [vocab], "all_text": [all_text]}

In [29]:
vocabs = wolof.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=wolof.column_names["train"])

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




In [30]:
vocab_list = list(set(vocabs["train"]["vocab"][0]))

In [32]:
vocab_dict = {v: k for k, v in enumerate(vocab_list)}
vocab_dict

{'j': 0,
 'œ': 1,
 'u': 2,
 'd': 3,
 'ô': 4,
 'a': 5,
 'â': 6,
 'y': 7,
 'î': 8,
 'ë': 9,
 's': 10,
 'w': 11,
 'ç': 12,
 'p': 13,
 'k': 14,
 'z': 15,
 'b': 16,
 't': 17,
 "'": 18,
 'é': 19,
 'm': 20,
 'q': 21,
 'x': 22,
 'h': 23,
 'è': 24,
 ' ': 25,
 'c': 26,
 'n': 27,
 'v': 28,
 'i': 29,
 'g': 30,
 '’': 31,
 'o': 32,
 'r': 33,
 'e': 34,
 'f': 35,
 'l': 36}

In [33]:
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]

In [34]:
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
len(vocab_dict)

39

In [94]:
with open('vocab.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

## Initializing for wav2vec2
- Create Wav2Vec2CTCTokenizer using the recently created vocab JSON file
- Create Wav2Vec2FeatureExtractor
- Read in audio files and add them to the DatasetDict
- Process inputs and targets 
- Define Data Collator
- Load and define WER metric
- Load the base wav2vec2-large-xlsr-53 model

In [95]:
tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|", additional_special_tokens = extra_vocab_list)

In [97]:
from transformers import Wav2Vec2FeatureExtractor
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)

In [98]:
from transformers import Wav2Vec2Processor
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

In [100]:
#import soundfile as sf
import librosa
import warnings

warnings.filterwarnings("ignore")

def speech_file_to_array_fn(batch):
    speech_array, sampling_rate = librosa.load(batch["file"], sr = 16000)
    batch["speech"] = speech_array.astype('float16')
    batch["sampling_rate"] = sampling_rate
    batch["target_text"] = batch["transcription"]
    return batch

In [None]:
wolof = wolof.map(speech_file_to_array_fn, remove_columns=wolof.column_names["train"], num_proc=1)

HBox(children=(FloatProgress(value=0.0, max=6000.0), HTML(value='')))

In [45]:
def prepare_dataset(batch):
    # check that all files have the correct sampling rate
    assert (
        len(set(batch["sampling_rate"])) == 1
    ), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."

    batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values
    
    with processor.as_target_processor():
        batch["labels"] = processor(batch["target_text"]).input_ids
    return batch

wolof_prepared = wolof.map(prepare_dataset, remove_columns=wolof.column_names["train"], batch_size=32, batched=True)

In [47]:
show_random_elements(wolof_prepared['train'], num_examples=20)

Unnamed: 0,input_values,labels
0,"[-0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, -0.001213734048021106, ...]","[17, 31, 16, 30, 13, 11, 31, 18, 16, 8, 13, 14, 19, 20, 31, 33, 13]"
1,"[-0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, -0.00021853389636763854, ...]","[8, 14, 17, 34, 18, 16, 10, 19, 13, 12, 33, 20, 13, 17, 31, 19, 14, 13]"
2,"[-2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, -2.7819283846619815e-05, ...]","[33, 17, 17, 14, 8, 13, 14, 8, 31, 18, 12, 14, 13]"
3,"[0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, 0.00023669257760074424, ...]","[19, 33, 34, 10, 13, 34, 10, 16, 18, 20, 18, 11, 33, 12, 14, 13, 21, 17, 33, 16, 30, 7, 31, 1, 1, 13]"
4,"[-0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, -0.0009510048703075828, ...]","[20, 18, 8, 23, 13, 14, 12, 13, 28, 33, 30, 4, 18, 13, 18, 15, 17, 33, 28, 18, 34, 33, 13, 16, 18, 33, 19, 19, 14, 13]"
5,"[-0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, -0.0002308140028933689, ...]","[20, 18, 8, 23, 13, 11, 31, 12, 18, 20, 14, 13]"
6,"[0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, 0.0005393354025109129, ...]","[15, 33, 16, 13, 15, 10, 19, 13, 34, 31, 31, 7, 13, 4, 33, 33, 17, 13, 14, 20, 31, 12, 14, 13, 20, 33, 19, 8, 31, 17, 13, 17, 10, 1, 18, 19, 5, 10, 14, 13]"
7,"[-0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, -0.00028954084964580274, ...]","[20, 33, 1, 23, 8, 23, 17, 18, 33, 13]"
8,"[0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, 0.00017815115478002586, ...]","[19, 18, 20, 33, 11, 13, 15, 33, 31, 15, 33, 15, 13]"
9,"[-0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, -0.0013596487697993516, ...]","[14, 20, 31, 12, 14, 13, 12, 14, 19, 13, 11, 23, 30, 33, 21, 31, 21, 10, 14, 19, 13]"


In [48]:
import torch

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
        max_length (:obj:`int`, `optional`):
            Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
        max_length_labels (:obj:`int`, `optional`):
            Maximum length of the ``labels`` returned list and optionally padding length (see above).
        pad_to_multiple_of (:obj:`int`, `optional`):
            If set will pad the sequence to a multiple of the provided value.
            This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
            7.5 (Volta).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True
    max_length: Optional[int] = None
    max_length_labels: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    pad_to_multiple_of_labels: Optional[int] = None

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                max_length=self.max_length_labels,
                pad_to_multiple_of=self.pad_to_multiple_of_labels,
                return_tensors="pt",
            )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

In [49]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

In [50]:
from datasets import load_metric
wer_metric = load_metric("wer")

In [51]:
import numpy as np
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [52]:
# Base model
from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-large-xlsr-53", 
    gradient_checkpointing=True, 
    ctc_loss_reduction="mean", 
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer)
)

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-xlsr-53 and are newly initialized: ['lm_head.bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [54]:
model.freeze_feature_extractor()

In [55]:
torch.cuda.is_available()

True

## Train model
What else to add?

In [56]:
# KEVIN NEEDS TO ADD THE CORRECT HYPER PARAMS FOR "MORE POWER"
from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir="./wav2vec2-wolof",
  group_by_length=True,
  per_device_train_batch_size=4,
  evaluation_strategy="steps",
  num_train_epochs=25,
  fp16=True,
  save_steps=500,
  eval_steps=500,
  logging_steps=500,
  learning_rate=1e-4,
  weight_decay=0.005,
  warmup_steps=1000,
  save_total_limit=2,
)

In [57]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=wolof_prepared["train"],
    eval_dataset=wolof_prepared["test"],
    tokenizer=processor.feature_extractor,
)

In [59]:
trainer.train()

RuntimeError: blank must be in label range

In [99]:
model.save_pretrained('wav2vec-wolof-model')
processor.save_pretrained('wav2vec-wolof-processor')

## Make validation set predictions

In [1]:
from transformers import Wav2Vec2Processor
from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained('./sub1/wav2vec-wolof-trained-model').to('cuda')
processor = Wav2Vec2Processor.from_pretrained("./sub1/wav2vec-wolof-trained-processor")

Special tokens have been added in the vocabulary, make sure the associated word embedding are fine-tuned or trained.


In [47]:
def map_to_result(batch):
    model.to("cuda")
    input_values = processor(
      batch["speech"], 
      sampling_rate=batch["sampling_rate"], 
      return_tensors="pt"
    ).input_values.to("cuda")

    with torch.no_grad():
        logits = model(input_values).logits

        pred_ids = torch.argmax(logits, dim=-1)
        batch["pred_str"] = processor.batch_decode(pred_ids)[0]

    return batch

In [48]:
results = wolof["test"].map(map_to_result)

HBox(children=(FloatProgress(value=0.0, max=251.0), HTML(value='')))




In [49]:
print("Test WER: {:.3f}".format(wer_metric.compute(predictions=results["pred_str"], references=results["target_text"])))

Test WER: 0.145


In [50]:
show_random_elements(results.remove_columns(["speech", "sampling_rate"]))

Unnamed: 0,pred_str,target_text
0,diakay peulh fouta,diaakay peulh fouta
1,croisement camberéne,croisement camberéne
2,dama bëgg dem lymodac,dama bëgg dem lymodac
3,ecole bachir,ecole bachir
4,tigo almadies,tigo almadies
5,avenue peytavin,avenue peytavin
6,cité djily m'baye,cité djily m'baye
7,marché yeumbeul laa bëgg dem,marché yeumbeul laa bëgg dem
8,sicap mbao,sicap mbao
9,ban oto mooy dem saveurs d'asie,ban oto mooy dem saveurs d'asie


# Make Test set AM model predictions
- Load in wav2vec2 model and processor
- Do all the same data prep we did for training

In [4]:
from transformers import Wav2Vec2Processor
from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained('./wav2vec-wolof-trained-model-large').to('cuda')
processor = Wav2Vec2Processor.from_pretrained("./wav2vec-wolof-trained-processor-large")

Special tokens have been added in the vocabulary, make sure the associated word embedding are fine-tuned or trained.


In [None]:
wolof_test = pd.read_csv('Test.csv')
wolof_test = wolof_test.drop(["up_votes", "down_votes", "age", "gender"], axis = 1)

In [41]:
wolof_test['file'] = [paths[id] for id in list(wolof_test['ID'])]

In [54]:
wolof = DatasetDict({'test': Dataset.from_pandas(wolof_test)})

In [56]:
import librosa
import warnings

warnings.filterwarnings("ignore")
def speech_file_to_array_fn(batch):
    speech_array, sampling_rate = librosa.load(batch["file"], sr = 16000)
    batch["speech"] = speech_array.astype('float16')
    batch["sampling_rate"] = sampling_rate
    return batch

In [None]:
wolof = wolof.map(speech_file_to_array_fn, num_proc=1)

HBox(children=(FloatProgress(value=0.0, max=1564.0), HTML(value='')))

In [6]:
import pickle
infile = open('wolof_test.pkl','rb')
wolof = pickle.load(infile)

In [7]:
def map_to_result(batch):
    model.to("cuda")
    input_values = processor(
      batch["speech"], 
      sampling_rate=batch["sampling_rate"], 
      return_tensors="pt"
    ).input_values.to("cuda")

    with torch.no_grad():
        logits = model(input_values).logits

        pred_ids = torch.argmax(logits, dim=-1)
        batch["pred_str"] = processor.batch_decode(pred_ids)[0]
  
    return batch

In [8]:
import torch
results = wolof["test"].map(map_to_result)

HBox(children=(FloatProgress(value=0.0, max=1564.0), HTML(value='')))




In [12]:
show_random_elements(results.remove_columns(["speech", "sampling_rate", "file", "ID"]))

Unnamed: 0,pred_str
0,pont mariste
1,centre socio culturel wakhinane nimzatt
2,pharmacie sedina issa laye cambéréne
3,rond point les grands mounues de dakar
4,grand yoff
5,vdn
6,marché zingu
7,parmerie pharmacie akhaya ouakam
8,mairiede parcelle assainies
9,dakar nave


In [13]:
sub = results.remove_columns(["speech", "sampling_rate", "file"])

In [16]:
sub.rename_column("pred_str", "Transcription")

Dataset({
    features: ['ID', 'Transcription'],
    num_rows: 1564
})

In [40]:
sub_df = sub.to_pandas()

In [43]:
# This ID wasn't in our Test.csv so we had to add it manually. It was a blank audio file anyway.
sub_df = sub_df.append({"ID":"e3a74a8998f03c320f5a4923272247485832b1cd803528f5eb5a50aef3d29a78b436b3ea37c47763e9b9be8b3ee53435b51d3466345217ce5d6fcb9b48a53c63",  "pred_str":" "}, ignore_index=True)

In [44]:
sub_df

Unnamed: 0,ID,pred_str
0,00416cff4f818d3dfd99c9178ff0e268e7575500c8baa5...,africatel avs
1,00891ba561e80e135f9d12b9fa1347f0a2560998f7ea16...,nan laay def ngir dem tally bou bess
2,00a508027ed4edf0bd3db79f45f4ed6e1b89fba6482c10...,africatel avs
3,00ac13cd0d93e35c1ff672cc106ad94d1ea9b93fcf049a...,mosquée de cambérène
4,00c2d5baf4719bf01b990a8924e99bda043cd462147193...,cité safco tivaoune peulh
...,...,...
1560,ff1808218a15fa576c405314e4de4bda56c44f849ff1b5...,tigo almadies
1561,ff5b9a45d60600e875e0a031b1d7076c9cbdeb1c48c09c...,gouy gui grand mbao
1562,ff98e108ec61d3bd485734b83f21be77820549dab1cac1...,pharmacie rokaa ouakam
1563,ffb6873f183e8995e50d1079f60f8e9d1018092e421578...,orabank


In [54]:
sub_df.to_csv("AM_model_preds.csv", index=False)

## Sentence-level autocorect
TODO