In [25]:
import torch
import numpy as np
from itertools import product
import string

model_path = "ner_trained_model_window_size_30.pt"
model = torch.load(model_path, map_location=torch.device('cpu'), weights_only=False)

filters = model.char_conv.weight.data.cpu().numpy()
out_channels, in_channels, kernel_size = filters.shape

filter_variances = np.var(filters, axis=(1,2))
top_indices = np.argsort(filter_variances)[-5:][::-1]

vocab_size = model.char_embeddings.weight.shape[0]
all_possible_chars = list(string.ascii_lowercase) + ['-', '.', "'"]
char_set = all_possible_chars[:vocab_size]
char_to_idx = {c: i for i, c in enumerate(char_set)}
char_embedding_weights = model.char_embeddings.weight.data.cpu().numpy()

def get_trigram_response(trigram, char_to_idx, char_embedding_weights, filter_weights):
    idxs = [char_to_idx.get(c, 0) for c in trigram]
    emb = char_embedding_weights[idxs].T
    return np.sum(emb * filter_weights)

all_trigrams = list(product(char_set, repeat=3))
for idx in top_indices:
    filter_weights = filters[idx]
    responses = []
    for trigram in all_trigrams:
        resp = get_trigram_response(trigram, char_to_idx, char_embedding_weights, filter_weights)
        responses.append((trigram, resp))
    responses.sort(key=lambda x: -abs(x[1]))
    print(f"\nTop 5 trigrams for filter {idx}:")
    for trigram, resp in responses[:5]:
        print(''.join(trigram), f"Response: {resp:.2f}")


Top 5 trigrams for filter 7:
ovx Response: -4.06
'vx Response: -3.92
tvx Response: -3.71
ovr Response: -3.67
onx Response: -3.63

Top 5 trigrams for filter 10:
er' Response: -4.43
es' Response: -4.42
gr' Response: -4.12
gs' Response: -4.10
wr' Response: -4.10

Top 5 trigrams for filter 0:
t'i Response: -4.36
t'. Response: -4.29
tji Response: -4.08
t'z Response: -4.05
tj. Response: -4.01

Top 5 trigrams for filter 9:
vzo Response: -5.32
zzo Response: -5.09
-zo Response: -4.95
bzo Response: -4.88
.zo Response: -4.79

Top 5 trigrams for filter 25:
oxg Response: 3.81
ovg Response: 3.74
ohg Response: 3.66
vyq Response: -3.55
odg Response: 3.49
