In [114]:
import numpy as np
from scipy.spatial.distance import cdist
from sklearn.metrics.pairwise import manhattan_distances
from pyxdameraulevenshtein import damerau_levenshtein_distance_seqs
from tqdm import tqdm
data = []
queries = []

with open('dict.txt', 'r') as file:
    data = [line.strip() for line in file]

with open('queries.txt', 'r') as file:
    queries = [line.strip() for line in file]
print(len(data), len(queries))

alphabet = ''.join(sorted(set(''.join(data))))
def get_vector(word,alphabet):
    result = [0]*len(alphabet)
    for char in word:
        result[alphabet.index(char)] += 1
    return result

# Convert to numpy array for faster processing
vectorized_data = np.array([get_vector(word, alphabet) for word in data])

# Optimized version using NumPy
def get_words_with_distance(query_vector, vectorized_data, threshold):
    # Convert query_vector to numpy array if it's not already
    query_vector = np.array(query_vector).reshape(1, -1)
    
    # Calculate Manhattan distances between query and all words at once
    distances = manhattan_distances(query_vector, vectorized_data)[0]
    
    # Return indices where distance is less than or equal to threshold
    return np.where(distances <= threshold)[0].tolist()

# Alternative version using scipy's cdist
def get_words_with_distance_scipy(query_vector, vectorized_data, threshold):
    query_vector = np.array(query_vector).reshape(1, -1)
    # Calculate Manhattan distances
    distances = cdist(query_vector, vectorized_data, metric='cityblock')[0]
    # Return indices where distance is less than or equal to threshold
    return np.where(distances <= threshold)[0].tolist()


62027 100000


In [None]:
def find_correction_path(word1, word2, max_dist):
    if word1 == word2:
        return f"{word1} 0"
    
    from collections import deque
    
    queue = deque()
    queue.append((word1, [word1], max_dist))
    visited = set([word1])
    
    while queue:
        current, path, remaining = queue.popleft()
        
        if current == word2:
            ans = f"{word1} {max_dist} {' '.join(path[1:])}"
            return ans
        if remaining == 0:
            continue
        
        length = len(current)
        # Generate neighbors (all possible 1-step corrections)
        neighbors = set()
        
        # Deletions
        for i in range(length):
            new_word = current[:i] + current[i+1:]
            neighbors.add(new_word)
        
        # Insertions (only try inserting letters from target word)
        for i in range(length + 1):
            if i < len(word2):
                new_word = current[:i] + word2[i] + current[i:]
                neighbors.add(new_word)
        
        # Substitutions (only try substituting with target letters)
        for i in range(min(length, len(word2))):
            if current[i] != word2[i]:
                new_word = current[:i] + word2[i] + current[i+1:]
                neighbors.add(new_word)
        
        # Transpositions (only adjacent swaps that match target)
        for i in range(length - 1):
            if (i + 1 < len(word2) and 
                current[i] == word2[i+1] and 
                current[i+1] == word2[i]):
                new_word = current[:i] + current[i+1] + current[i] + current[i+2:]
                neighbors.add(new_word)
        
        for neighbor in neighbors:
            if neighbor not in visited:
                visited.add(neighbor)
                queue.append((neighbor, path + [neighbor], remaining - 1))
    
    return f"{word1} {max_dist}+ "

In [None]:
with open('results.txt', 'w') as file:
    for query in tqdm(queries):
        vectorized_query = get_vector(query, alphabet)
        candidates = get_words_with_distance_scipy(vectorized_query, vectorized_data, 4)
        if len(candidates) > 0:
            candidates = [data[i] for i in candidates]
            distances = damerau_levenshtein_distance_seqs(query, candidates)
            match, dist = candidates[np.argmin(distances)], min(distances)
            correct_path = find_correction_path(query, match, dist)
        else:
            file.write(f'{query} 5+\n')

100%|██████████| 100000/100000 [06:02<00:00, 275.91it/s]
