In [63]:
import pickle
import json
import argparse
from tqdm import tqdm
from copy import deepcopy
from queue import PriorityQueue
import torch
from transformers import WhisperForConditionalGeneration, WhisperProcessor

if torch.cuda.is_available():
    DEVICE = 'cuda'
else:
    DEVICE = 'cpu'


In [64]:
with open('data/data.pkl', 'rb') as fp:
        data = pickle.load(fp)
with open('data/phoneme_table.json', 'r') as fp:
        phenome_table = json.load(fp)
with open('data/vocabulary.json', 'r') as fp:
        vocabulary = json.load(fp)

In [68]:
class CostModel(object):
    def __init__(self) -> None:
        # Load Whisper model and processor
        self.__processor = WhisperProcessor.from_pretrained("openai/whisper-small.en")
        self.__model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small.en").to(DEVICE)
        self.__audio_inputs = None

    def set_audio(self, audio, sampling_rate):
        self.__audio_inputs = self.__processor(
            audio, sampling_rate=sampling_rate, return_tensors="pt"
        ).input_features.to(DEVICE)

    def get_loss(self, text):
        # Prepare the target text input IDs
        target = self.__processor(
            text=text, return_tensors="pt", padding=True
        ).input_ids.to(DEVICE)

        # Make sure to set the decoder input IDs
        with torch.no_grad():
            outputs = self.__model(input_features=self.__audio_inputs, labels=target)

        return outputs.loss.item()


class Environment(object):
    def __init__(self, init_state, cost_function, phoneme_table) -> None:
        self.init_state = init_state
        self.phoneme_table = deepcopy(phoneme_table)
        self.__cost_function = cost_function

    def compute_cost(self, text):
        # try:
        cost = self.__cost_function(text)
        # except:
        #     cost = 1e6
        return cost
replacement_lens = []
matrix = {}
for char in phenome_table:
	for replacement in phenome_table[char]:
		if(replacement not in matrix):
			matrix[replacement] = []
		matrix[replacement].append(char)

for rep in matrix:
    if(len(rep) not in replacement_lens):
        replacement_lens.append(len(rep))



In [69]:
audio = None
sr = None
text = None
pred = None
corrected_texts = []
cost_model = CostModel()
for sample in tqdm(data[:3]):
        audio = sample['audio']['array']
        sr = sample['audio']['sampling_rate']
        text = sample['text']
        cost_model.set_audio(audio, sr)
        # print(text)
        # environment = Environment(text, cost_model.get_loss, phenome_table)

        # # try:
        # agent.asr_corrector(environment)
        # pred = agent.best_state
        # except:
        #     pred = None
        corrected_texts.append(text)
        # break


100%|██████████| 3/3 [00:00<00:00, 154.03it/s]


In [71]:
cost_model.set_audio(audio,sr)
environment = Environment(text, cost_model.get_loss, phenome_table)

def cost(text):
    return environment.compute_cost(text)

start_state = "I CAN'T ZAY WHEDER THERE IS A WILL OR NOT LET UZ TALK OF SOMEDING ELSE"
words = start_state.split()
beam = []
wbeams =[[word] for word in words ]
wpq =[PriorityQueue() for _ in words]
d = [{} for _ in words]
best_words = [None for _ in words]
def optimize_word(word,idx,epsilon,beam_size=30000,beam_depth=2,best_n=20,):
    best_word  = word
    global current_state
    best_cost = cost(current_state)
    initial_word = word
    inital_cost = best_cost
    wpq[idx].put((best_cost,best_word))
    d[idx][best_word] = best_cost
    f = 1 +epsilon*len(best_word)
    count = 0
    for depth in range(beam_depth):
        queue = PriorityQueue()
        new_beam = []
        for current_word in wbeams[idx]:
            for l in replacement_lens:
                for j in range(len(current_word)):
                    if(j+l>len(current_word)): continue
                    to_replace = current_word[j:j+l]
                    if to_replace in matrix:
                        for replacement in matrix[to_replace]:
                            new_word = current_word[:j]+replacement + current_word[j+l:]
                            words[idx] = new_word
                            new_sentence = ' '.join(words)
                            if(new_word not in d[idx]):
                                c = cost(new_sentence)
                                d[idx][new_word] = c
                                queue.put((c,new_word))
                                wpq[idx].put((c,new_word))
                                print((c,new_sentence))
                                count +=1
        next_beam_size = 0
        while((next_beam_size< beam_size) and not queue.empty()):
            word_cost,beam_word, = queue.get()
            if(word_cost<f*best_cost):
                new_beam.append(beam_word)
                next_beam_size+=1
                if(word_cost<best_cost):
                    best_cost = word_cost
                    best_word = beam_word
            else: break
        wbeams[idx] = new_beam
        print((count,initial_word,inital_cost,best_word,best_cost))
        if(count>=100 and best_word == initial_word):
            break

    ans = []
    a = wpq[idx].get()
    ans.append(a[1])
    u = []
    for _ in range(best_n):
        if not wpq[idx].empty():
            cw  = wpq[idx].get()
            if(a[0]*1.15> cw[0]):
                ans.append(cw[1])
            else: break
        else: break
    best_words[idx] = ans
    words[idx] = ans[0]
    current_state = ' '.join(words)
current_state = start_state 
for i,w in enumerate(words):
	optimize_word(w,i,beam_size=3000,beam_depth=3,epsilon=0.035,best_n=20)

# f = ' '.join(words)
# print(f)

(3.0743865966796875, "AE CAN'T ZAY WHEDER THERE IS A WILL OR NOT LET UZ TALK OF SOMEDING ELSE")
(3.0368447303771973, "AI CAN'T ZAY WHEDER THERE IS A WILL OR NOT LET UZ TALK OF SOMEDING ELSE")
(3.00717830657959, "IE CAN'T ZAY WHEDER THERE IS A WILL OR NOT LET UZ TALK OF SOMEDING ELSE")
(3, 'I', 2.9487361907958984, 'I', 2.9487361907958984)
(3.2747488021850586, "AEE CAN'T ZAY WHEDER THERE IS A WILL OR NOT LET UZ TALK OF SOMEDING ELSE")
(3.3451881408691406, "AIE CAN'T ZAY WHEDER THERE IS A WILL OR NOT LET UZ TALK OF SOMEDING ELSE")
(3.2717819213867188, "IEE CAN'T ZAY WHEDER THERE IS A WILL OR NOT LET UZ TALK OF SOMEDING ELSE")
(3.112644672393799, "IAE CAN'T ZAY WHEDER THERE IS A WILL OR NOT LET UZ TALK OF SOMEDING ELSE")
(2.8874752521514893, "IAI CAN'T ZAY WHEDER THERE IS A WILL OR NOT LET UZ TALK OF SOMEDING ELSE")
(2.996678590774536, "IEU CAN'T ZAY WHEDER THERE IS A WILL OR NOT LET UZ TALK OF SOMEDING ELSE")
(3.1863222122192383, "IIE CAN'T ZAY WHEDER THERE IS A WILL OR NOT LET UZ TALK OF

In [78]:
beam_depth= 40
beam_size = 
for _ in range(beam_depth):
	prq = PriorityQueue()
		for cs in beam:
			cs_words = cs.split(' ')
			for i,word in enumerate(cs_words):
				for pos_rep in best_words[i]:
					new_sol = ' '.join(cs_words[:i]+[pos_rep]+cs_words[i+1:])
					c = cost(new_sol)
					prq.put((c, new_sol))
		next_beam = []
		for _ in range(beam_size):
			possol = prq.get()
			next_beam.append(possol[1])
			if (possol[0] < self.cost_fn(self.best_state)):
				self.best_cost = possol[0]
				self.best_state = possol[1]
		self.beam = next_beam

1.6277482509613037