## Introduction

Key steps:

1. Set data paths
2. PyTorch setting
3. Determine the device (i.e., cpu or cuda)
4. Encoding data
5. Split training, validation, and evaluation sets
6. Building and training model
7. Inference

1. Set data paths

In [None]:
test_version = "test_v4.0"


# data file path
input_filename = '<your source data file>.json.gz'
project_data_path = r"<your project path>" # 
input_data_path = f"{project_data_path}/data/out/shared_across_models/{input_filename}"

tag_pool_file_path = f'<your path to the tag pool file>.json'

# Set the output directory where processed data will be saved
output_path = f"{project_data_path}/data/out/bi_lstm/{test_version}/"
output_encoded_path = f"{output_path}/encoded"
output_models_path = f'{output_path}/models'


2. PyTorch setting

In [None]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"]= ":16:8"
import torch

torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True


3. Determine the device (i.e., cpu or cuda)

In [None]:
if torch.cuda.is_available():
    # Check how many GPUs are available
    num_gpus = torch.cuda.device_count()
    print(f"Number of GPUs available: {num_gpus}")
    
    # Find the GPU with the largest VRAM
    max_vram = 0
    selected_gpu = 0
    for i in range(num_gpus):
        vram = torch.cuda.get_device_properties(i).total_memory / (1024 ** 3)  # Convert to GB
        print(f"GPU {i}: {vram:.2f} GB of VRAM")
        if vram > max_vram:
            max_vram = vram
            selected_gpu = i

    # Select the GPU with the largest VRAM
    device = torch.device(f"cuda:{selected_gpu}")
    print(f"Selected GPU {selected_gpu} with {max_vram:.2f} GB of VRAM")
else:
    device = torch.device("cpu")
    print("CUDA is not available. Using CPU.")

4. Encoding data

In [None]:
import os
import json
# import torch
import tarfile
import shutil
from typing import Dict
from tqdm import tqdm
from collections import Counter


class BiLSTMAddressEncoder:
    def __init__(self, config: Dict):
        self.input_data_path = config["input_data_path"]
        self.output_encoded_path = config["output_encoded_path"]

        os.makedirs(self.output_encoded_path, exist_ok=True)
        self.temp_dir = os.path.join(self.output_encoded_path, "temp_dir")
        os.makedirs(self.temp_dir, exist_ok=True)

        self.batch_size = config.get("batch_size", 10000)
        self.max_token_sequence_len = 0
        self.tag_vocab_size = 0

        self.maybe_tag_set = config.get('maybe_tag_set', set([
                "unit_designator", "unit_number", "unit_number_suffix", 
                "civic_number", "civic_number_suffix", 
                "street_name", "street_type", "street_direction", "street_qualifier", 
                "locality_name", "locality_type", 
                "province_code"
            ]))

        self.token_vocab = {}
        self.lemma_vocab = {}
        self.tag_vocab = []
        self.tag_to_id = {}
        self.data = []

    def load_data(self):
        open_fn = open
        if self.input_data_path.endswith(".gz"):
            import gzip
            open_fn = gzip.open
        with open_fn(self.input_data_path, "rt", encoding="utf-8") as f:
            self.data = json.load(f)

    def normalize(self, text):
        if text.isdigit():
            return "__num__"
        elif any(c.isalpha() for c in text) and any(c.isdigit() for c in text):
            return "__alnum__"
        return text

    def build_vocabs(self):
        tokens_set, lemmas_set, tags_set = set(), set(), set()

        for item in self.data:
            text = item["text"]
            self.max_token_sequence_len = max(self.max_token_sequence_len, len(item["tokens"]))
            for tok in item["tokens"]:
                word = self.normalize(text[tok["start"]:tok["end"]].lower())
                lemma = self.normalize(tok.get("_lemma", word).lower())
                tokens_set.add(word)
                lemmas_set.add(lemma)

                std_tags = [t for t in tok.get("tags_std", []) if t != "none"]
                maybe_tags = [t for t in tok.get("tags_maybe", []) if t != "none"]
                tags_set.update(std_tags)
                tags_set.update(maybe_tags)

        self.token_vocab = {tok: i + 1 for i, tok in enumerate(sorted(tokens_set))}
        self.lemma_vocab = {lem: i + 1 for i, lem in enumerate(sorted(lemmas_set))}
        self.tag_vocab = sorted(tags_set)
        self.tag_to_id = {tag: i for i, tag in enumerate(self.tag_vocab)}
        self.tag_vocab_size = len(self.tag_vocab)

    def compute_num_prev_maybe_tags(self, tokens):

        result = []
        count = 0
        for tok in tokens:
            result.append(count)
            if any(tag in self.maybe_tag_set for tag in tok.get("tags_maybe", []) if tag != "none"):
                count += 1
        return result

    def compute_num_next_maybe_tags(self, tokens):
        result = []
        count = 0
        for i in reversed(range(len(tokens))):
            result.insert(0, count)
            if any(tag in self.maybe_tag_set for tag in tokens[i].get("tags_maybe", []) if tag != "none"):
                count += 1
        return result
    
    def compute_distance_to_prev_maybe_tag(self, tokens):
        result = []
        last_index = -1
        for i, tok in enumerate(tokens):
            if last_index == -1:
                result.append(-1)
            else:
                result.append(i - last_index)
            if any(tag in self.maybe_tag_set for tag in tok.get("tags_maybe", []) if tag != "none"):
                last_index = i
        return result

    def compute_distance_to_next_maybe_tag(self, tokens):
        result = [-1] * len(tokens)
        next_index = -1
        for i in reversed(range(len(tokens))):
            if next_index == -1:
                result[i] = -1
            else:
                result[i] = next_index - i
            if any(tag in self.maybe_tag_set for tag in tokens[i].get("tags_maybe", []) if tag != "none"):
                next_index = i
        return result


    def encode_and_save_dataset(self):
        num_batches = (len(self.data) + self.batch_size - 1) // self.batch_size
        dataset = []

        for batch_idx in tqdm(range(num_batches), desc="Encoding batches"):
            batch = self.data[batch_idx * self.batch_size:(batch_idx + 1) * self.batch_size]

            for i, item in enumerate(batch):
                global_index = batch_idx * self.batch_size + i

                text = item["text"]
                tokens = item["tokens"]
                s, l, m, le = [], [], [], []

                prev_maybe_counts = self.compute_num_prev_maybe_tags(tokens)
                next_maybe_counts = self.compute_num_next_maybe_tags(tokens)
                dist_to_prev = self.compute_distance_to_prev_maybe_tag(tokens)
                dist_to_next = self.compute_distance_to_next_maybe_tag(tokens)

                for j, tok in enumerate(tokens):
                    word_raw = text[tok["start"]:tok["end"]].lower()
                    lemma_raw = tok.get("_lemma", word_raw).lower()

                    word = self.normalize(word_raw)
                    lemma = self.normalize(lemma_raw)

                    s.append(self.token_vocab.get(word, 0))
                    le.append(self.lemma_vocab.get(lemma, 0))

                    std_tags = [t for t in tok.get("tags_std", []) if t != "none"]
                    maybe_tags = [t for t in tok.get("tags_maybe", []) if t != "none"]

                    label = self.tag_to_id.get(std_tags[0], -100) if std_tags else -100
                    l.append(label)

                    tag_counts = Counter(maybe_tags)
                    total = sum(tag_counts.values())
                    maybe_vec = [
                        tag_counts.get(tag, 0) / total if total > 0 else 0.0
                        for tag in self.tag_vocab
                    ]
                    m.append(maybe_vec)

                pad_len = self.max_token_sequence_len - len(s)
                record = {
                    "sequences": s + [0] * pad_len,
                    "labels": l + [-100] * pad_len,
                    "tags_maybe_features": m + [[0.0] * self.tag_vocab_size for _ in range(pad_len)],
                    "lemma_features": le + [0] * pad_len,
                    "prev_maybe_counts": prev_maybe_counts + [-1] * pad_len,
                    "next_maybe_counts": next_maybe_counts + [-1] * pad_len,
                    "distance_to_prev_maybe": dist_to_prev + [-1] * pad_len,
                    "distance_to_next_maybe": dist_to_next + [-1] * pad_len,
                    "original_index": global_index
                }
                dataset.append(record)

        dataset_pt_path = os.path.join(self.temp_dir, "dataset.pt")
        torch.save(dataset, dataset_pt_path)

        tar_path = os.path.join(self.output_encoded_path, "dataset.tar.gz")
        with tarfile.open(tar_path, "w:gz") as tar:
            tar.add(dataset_pt_path, arcname="dataset.pt")


    def save_metadata(self):
        def dump_json(filename, obj):
            path = os.path.join(self.output_encoded_path, filename)
            with open(path, "w", encoding="utf-8") as f:
                json.dump(obj, f, ensure_ascii=False)

        dump_json("tag_vocab.json", self.tag_vocab)
        dump_json("token_vocab.json", self.token_vocab)
        dump_json("lemma_vocab.json", self.lemma_vocab)
        dump_json("length_metadata.json", {
            "max_token_sequence_len": self.max_token_sequence_len,
            "tag_vocab_size": self.tag_vocab_size,
            "token_vocab_size": len(self.token_vocab),
            "lemma_vocab_size": len(self.lemma_vocab)
        })

    def cleanup_temp_dir(self):
        if os.path.exists(self.temp_dir):
            shutil.rmtree(self.temp_dir)

    def run(self):
        self.load_data()
        self.build_vocabs()
        self.encode_and_save_dataset()
        self.save_metadata()
        self.cleanup_temp_dir()


In [None]:

config = {
    "input_data_path": input_data_path,
    "output_encoded_path": output_encoded_path,
    "batch_size": 10000,
}
encoder = BiLSTMAddressEncoder(config)
encoder.run()

5. Split training, validation, and evaluation sets

In [None]:
import os
import tarfile
# import torch
import random

def split_and_save_dataset(
    output_encoded_path,
    seed=42,
    split_ratio=(0.4, 0.3, 0.3),
    input_dataset_name="dataset",
    output_dataset_names=("train", "valid", "eval")
):
    assert sum(split_ratio) == 1.0, "Split ratios must sum to 1.0"
    assert len(output_dataset_names) == 3, "output_dataset_names must be a tuple of 3 strings"

    # Step 1: Extract dataset.pt from dataset.tar.gz
    tar_path = os.path.join(output_encoded_path, f"{input_dataset_name}.tar.gz")
    extract_dir = os.path.join(output_encoded_path, "temp_extract")
    os.makedirs(extract_dir, exist_ok=True)

    with tarfile.open(tar_path, "r:gz") as tar:
        tar.extract(f"{input_dataset_name}.pt", path=extract_dir)

    dataset_path = os.path.join(extract_dir, f"{input_dataset_name}.pt")
    dataset = torch.load(dataset_path)

    # Step 2: Shuffle and split indices
    total = len(dataset)
    indices = list(range(total))
    random.seed(seed)
    random.shuffle(indices)

    n_train = int(split_ratio[0] * total)
    n_eval = int(split_ratio[1] * total)
    n_test = total - n_train - n_eval

    split_indices = {
        output_dataset_names[0]: indices[:n_train],
        output_dataset_names[1]: indices[n_train:n_train + n_eval],
        output_dataset_names[2]: indices[n_train + n_eval:]
    }

    # Step 3: Save each split
    for name, index_list in split_indices.items():
        subset = [dataset[i] for i in index_list]
        temp_pt_path = os.path.join(output_encoded_path, f"{name}.pt")
        torch.save(subset, temp_pt_path)

        with tarfile.open(os.path.join(output_encoded_path, f"{name}.tar.gz"), "w:gz") as tar:
            tar.add(temp_pt_path, arcname=f"{input_dataset_name}.pt")

        os.remove(temp_pt_path)

    # Step 4: Cleanup
    os.remove(dataset_path)
    os.rmdir(extract_dir)


In [None]:
split_and_save_dataset(
    output_encoded_path=output_encoded_path,
    seed=123,
    # input_dataset_name = 'dataset',
    # split_ratio=(0.5, 0.25, 0.25),
    # dataset_names=("train_set", "validation_set", "test_set")
)

6. Building and training model

In [None]:
import os
import json
import tarfile
import shutil
import random

import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.utils.data import DataLoader

class BiLSTMMultiLabelTagger(nn.Module):
    def __init__(
        self,
        token_vocab_size,
        lemma_vocab_size,
        tag_maybe_feature_dim,
        tag_output_dim,
        embedding_dim=128,
        hidden_dim=256,
        dropout=0.1
    ):
        super().__init__()
        self.token_embedding = nn.Embedding(token_vocab_size + 1, embedding_dim, padding_idx=0)
        self.lemma_embedding = nn.Embedding(lemma_vocab_size + 1, embedding_dim, padding_idx=0)
        self.input_dim = embedding_dim * 2 + tag_maybe_feature_dim
        self.lstm = nn.LSTM(self.input_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(hidden_dim * 2, tag_output_dim)

    def forward(self, sequences, lemmas, tags_maybe_features):
        tok_emb = self.token_embedding(sequences)
        lem_emb = self.lemma_embedding(lemmas)
        concat = torch.cat([tok_emb, lem_emb, tags_maybe_features], dim=-1)
        lstm_out, _ = self.lstm(concat)
        logits = self.classifier(self.dropout(lstm_out))
        return logits


class BiLSTMMultiLabelTrainer:
    def __init__(self, config):
        self.config = config
        self.output_encoded_path = config["output_encoded_path"]
        self.train_tar = os.path.join(self.output_encoded_path, "train.tar.gz")
        self.valid_tar = os.path.join(self.output_encoded_path, "valid.tar.gz")
        self.output_models_path = config["output_models_path"]
        os.makedirs(self.output_models_path, exist_ok=True)

        self.device = config.get("device", "cpu")
        self.patience = config.get("early_stopping_patience", 2)
        self.min_delta = config.get("min_loss_delta", 0.001)
        self.min_train_loss = config.get("min_train_loss", 0.00005)

    def load_metadata(self):
        with open(os.path.join(self.output_encoded_path, "length_metadata.json")) as f:
            meta = json.load(f)
        with open(os.path.join(self.output_encoded_path, "tag_vocab.json")) as f:
            tag_vocab = json.load(f)

        self.tag_output_dim = len(tag_vocab)
        self.tag_maybe_feature_dim = self.tag_output_dim
        self.token_vocab_size = meta["token_vocab_size"]
        self.lemma_vocab_size = meta["lemma_vocab_size"]

    def _load_dataset(self, tar_path):
        extract_dir = os.path.join(self.output_encoded_path, "temp_data")
        os.makedirs(extract_dir, exist_ok=True)
        with tarfile.open(tar_path, "r:gz") as tar:
            tar.extract("dataset.pt", path=extract_dir)
        data = torch.load(os.path.join(extract_dir, "dataset.pt"))
        shutil.rmtree(extract_dir)
        return data

    def train(self, epochs=1000, batch_size=32, lr=1e-3, seed=42):
        self.load_metadata()

        torch.manual_seed(seed)
        random.seed(seed)
        torch.use_deterministic_algorithms(True)

        train_data = self._load_dataset(self.train_tar)
        valid_data = self._load_dataset(self.valid_tar)

        train_set = torch.utils.data.TensorDataset(
            torch.tensor([d["sequences"] for d in train_data], dtype=torch.long),
            torch.tensor([d["lemma_features"] for d in train_data], dtype=torch.long),
            torch.tensor([d["tags_maybe_features"] for d in train_data], dtype=torch.float),
            torch.stack([self._labels_to_multihot(d["labels"]) for d in train_data])
        )
        valid_set = torch.utils.data.TensorDataset(
            torch.tensor([d["sequences"] for d in valid_data], dtype=torch.long),
            torch.tensor([d["lemma_features"] for d in valid_data], dtype=torch.long),
            torch.tensor([d["tags_maybe_features"] for d in valid_data], dtype=torch.float),
            torch.stack([self._labels_to_multihot(d["labels"]) for d in valid_data])
        )

        train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
        valid_loader = DataLoader(valid_set, batch_size=batch_size)

        model = BiLSTMMultiLabelTagger(
            token_vocab_size=self.token_vocab_size,
            lemma_vocab_size=self.lemma_vocab_size,
            tag_maybe_feature_dim=self.tag_maybe_feature_dim,
            tag_output_dim=self.tag_output_dim
        ).to(self.device)

        loss_fn = nn.BCEWithLogitsLoss()
        optimizer = optim.Adam(model.parameters(), lr=lr)

        best_val_loss = float("inf")
        prev_train_loss = float("inf")
        patience_counter = 0

        min_train_loss = self.config.get("min_train_loss", 0.00005)
        min_delta = self.config.get("min_loss_delta", 0.001)
        min_train_delta = self.config.get("min_train_delta", 0.001)

        for epoch in range(epochs):
            model.train()
            total_train_loss = 0

            for seqs, lems, feats, lbls in tqdm(train_loader, desc=f"Epoch {epoch+1} [Train]"):
                seqs = seqs.to(self.device)
                lems = lems.to(self.device)
                feats = feats.to(self.device)
                lbls = lbls.to(self.device)

                logits = model(seqs, lems, feats)
                loss = loss_fn(logits, lbls)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                total_train_loss += loss.item()

            model.eval()
            total_val_loss = 0
            with torch.no_grad():
                for seqs, lems, feats, lbls in tqdm(valid_loader, desc=f"Epoch {epoch+1} [Valid]"):
                    seqs = seqs.to(self.device)
                    lems = lems.to(self.device)
                    feats = feats.to(self.device)
                    lbls = lbls.to(self.device)

                    logits = model(seqs, lems, feats)
                    loss = loss_fn(logits, lbls)
                    total_val_loss += loss.item()

            train_improvement = prev_train_loss - total_train_loss
            val_improvement = best_val_loss - total_val_loss

            print(f"Epoch {epoch+1} | Train Loss: {total_train_loss:.6f} | ΔTrain: {train_improvement:.6f} | Val Loss: {total_val_loss:.6f} | ΔVal: {val_improvement:.6f}")

            # Condition 1: train loss below minimum
            if total_train_loss < min_train_loss:
                print(f"Early stopping: total_train_loss < min_train_loss ({min_train_loss})")
                self.save_model(model)
                break

            # Condition 2: validation loss improvement
            if val_improvement >= min_delta:
                best_val_loss = total_val_loss
                patience_counter = 0
                self.save_model(model)
            else:
                patience_counter += 1
                print(f"Patience: {patience_counter}/{self.patience} (ΔVal < {min_delta})")
                if patience_counter >= self.patience:
                    print("Early stopping: validation loss not improving")
                    break

            # Condition 3: training loss improvement too small
            if train_improvement < min_train_delta:
                print(f"Early stopping: training improvement ΔTrain < min_train_delta ({min_train_delta})")
                break

            prev_train_loss = total_train_loss

    def _labels_to_multihot(self, label_indices):
        vec = torch.zeros((len(label_indices), self.tag_output_dim), dtype=torch.float)
        for i, label in enumerate(label_indices):
            if label != -100:
                vec[i][label] = 1.0
        return vec

    def save_model(self, model):
        model_path = os.path.join(self.output_models_path, "model_after_train.pt")
        torch.save(model.state_dict(), model_path)

        tar_path = os.path.join(self.output_models_path, "model_after_train.tar.gz")
        with tarfile.open(tar_path, "w:gz") as tar:
            tar.add(model_path, arcname="model_after_train.pt")

        os.remove(model_path)


In [None]:
config = {
    "output_encoded_path": output_encoded_path,
    "output_models_path": output_models_path,
    "device": device,
}

trainer = BiLSTMMultiLabelTrainer(config)
trainer.train(epochs=1000, batch_size=64, lr=1e-3)


7. Inference

In [None]:
import os
import json
import spacy
import csv
from torch.nn.functional import sigmoid
from spacy.tokenizer import Tokenizer
from spacy.util import compile_infix_regex
from rapidfuzz import process, fuzz
from collections import defaultdict, Counter


class BiLSTMAddressPredictor:
    def __init__(self, config):
        self.model = config["model"]
        self.tokenizer_path = config["tokenizer_path"]
        self.tag_vocab_path = config["tag_vocab_path"]
        self.sequence_length_path = config["sequence_length_path"]
        self.tag_pool_path = config["tag_pool_path"]
        self.device = config.get("device", "cpu")
        self.result_path = config["result_path"]
        self.use_fuzzy_tags = config.get("use_fuzzy_tags", True)
        self.threshold = config.get("threshold", 0.5)

        self._init_spacy()
        self._load_metadata()

    def _init_spacy(self):
        self.nlp = spacy.load("en_core_web_sm")

        def custom_tokenizer(nlp):
            infixes = [x for x in nlp.Defaults.infixes if "-" not in x and "'" not in x]
            infixes += [
                r"(?<![a-zA-Z])-+|-(?![a-zA-Z])|(?<=[a-zA-Z])-{2,}(?=[a-zA-Z])",
                r"(?<![a-zA-Z])'+|'(?![a-zA-Z])|(?<=[a-zA-Z])'{2,}(?=[a-zA-Z])",
                r"[,\.\?\(\)\[\]\+_#/&]"
            ]
            return Tokenizer(nlp.vocab, infix_finditer=compile_infix_regex(infixes).finditer, token_match=None)

        self.nlp.tokenizer = custom_tokenizer(self.nlp)

    def _load_metadata(self):
        base_path = os.path.dirname(self.tag_vocab_path)

        with open(self.tag_vocab_path) as f:
            self.tag_vocab = json.load(f)
        with open(self.sequence_length_path) as f:
            meta = json.load(f)
            self.max_token_sequence_len = meta["max_token_sequence_len"]

        with open(self.tag_pool_path) as f:
            raw_pool = json.load(f)
        self.tag_pool = defaultdict(Counter)
        for entry in raw_pool:
            tag = entry["tag"]
            if not tag.startswith("o_") and "_" in tag:
                clean_tag = tag.split("_", 1)[-1]
                self.tag_pool[entry["text"].lower()][clean_tag] += 1

        with open(os.path.join(base_path, "token_vocab.json")) as f:
            self.token_vocab = json.load(f)
        with open(os.path.join(base_path, "lemma_vocab.json")) as f:
            self.lemma_vocab = json.load(f)

        self.tag_to_id = {tag: i for i, tag in enumerate(self.tag_vocab)}
        self.id_to_tag = {i: tag for tag, i in self.tag_to_id.items()}
        self.tag_vocab_size = len(self.tag_vocab)

    def predict(self, address_list):
        self.address_list = address_list
        self._tokenize()
        self._add_tags_maybe()
        self._encode()
        self._predict_logits()
        self._decode()
        self.save_results()

    def _tokenize(self):
        self.tokenized = []
        for address in self.address_list:
            doc = self.nlp(address)
            self.tokenized.append({
                "text": address,
                "tokens": [
                    {
                        "_lemma": tok.lemma_.lower(),
                        "start": tok.idx,
                        "end": tok.idx + len(tok),
                    }
                    for tok in doc
                ]
            })

    def _add_tags_maybe(self):
        self.tagged = []
        for addr in self.tokenized:
            for tok in addr["tokens"]:
                lemma = tok["_lemma"]
                tag_counts = self.tag_pool.get(lemma, {})
                total = sum(tag_counts.values())
                tag_probs = {tag: count / total for tag, count in tag_counts.items()} if total > 0 else {}

                # Default tags for numeric/alphanumeric lemmas
                if lemma.isdigit() or (any(c.isdigit() for c in lemma) and any(c.isalpha() for c in lemma)):
                    for tag in ["civic_number", "unit_number", "street_name"]:
                        tag_probs[tag] = max(tag_probs.get(tag, 0.0), 1.0)

                # Fuzzy match if empty
                if not tag_probs and self.use_fuzzy_tags:
                    for t in self._fuzzy_match_tags(lemma):
                        tag_probs[t] = 1.0

                tok["tag_probs"] = tag_probs if tag_probs else {"none": 1.0}
            self.tagged.append(addr)

    def _fuzzy_match_tags(self, token, threshold=90):
        if not self.tag_pool:
            return []
        match, score, _ = process.extractOne(token, self.tag_pool.keys(), scorer=fuzz.ratio)
        if score >= threshold:
            return list(self.tag_pool[match].keys())
        return []

    def _encode(self):
        self.encoded = []
        for item in self.tagged:
            tokens = item["tokens"]
            seq, lem, maybes = [], [], []
            for tok in tokens:
                word = item['text'][tok["start"]:tok["end"]].lower()
                lemma = tok["_lemma"]

                if word.isdigit():
                    word_norm = "__num__"
                elif any(c.isdigit() for c in word) and any(c.isalpha() for c in word):
                    word_norm = "__alnum__"
                elif word in self.token_vocab:
                    word_norm = word
                else:
                    word_norm = "__unk__"

                if lemma.isdigit():
                    lemma_norm = "__num__"
                elif any(c.isdigit() for c in lemma) and any(c.isalpha() for c in lemma):
                    lemma_norm = "__alnum__"
                elif lemma in self.lemma_vocab:
                    lemma_norm = lemma
                else:
                    lemma_norm = "__unk__"

                seq.append(self.token_vocab.get(word_norm, 0))
                lem.append(self.lemma_vocab.get(lemma_norm, 0))

                maybe_vec = [0.0] * self.tag_vocab_size
                for tag, prob in tok.get("tag_probs", {}).items():
                    if tag in self.tag_to_id:
                        maybe_vec[self.tag_to_id[tag]] = float(prob)
                maybes.append(maybe_vec)

            pad_len = self.max_token_sequence_len - len(seq)
            self.encoded.append({
                "sequences": seq + [0] * pad_len,
                "lemmas": lem + [0] * pad_len,
                "tags_maybe_features": maybes + [[0.0] * self.tag_vocab_size] * pad_len,
                "valid_len": len(seq)
            })

    def _predict_logits(self):
        self.raw_logits = []
        self.model.to(self.device)
        self.model.eval()

        with torch.no_grad():
            for item in self.encoded:
                seqs = torch.tensor([item["sequences"]], dtype=torch.long).to(self.device)
                lems = torch.tensor([item["lemmas"]], dtype=torch.long).to(self.device)
                maybes = torch.tensor([item["tags_maybe_features"]], dtype=torch.float).to(self.device)

                logits = self.model(seqs, lems, maybes)
                probs = sigmoid(logits).squeeze(0).to(self.device)
                self.raw_logits.append(probs)

    def _decode(self):
        self.results = []
        self.token_level_outputs = []  # ⬅️ Store per-token prediction details

        for i, item in enumerate(self.tagged):
            tokens = item["tokens"]
            probs = self.raw_logits[i]
            tag_map = {}
            token_outputs = []

            for j, token_probs in enumerate(probs):
                if j >= len(tokens):
                    continue
                token_text = item['text'][tokens[j]["start"]:tokens[j]["end"]].lower()
                max_prob, tag_idx = torch.max(token_probs, dim=0)
                tag = self.id_to_tag.get(tag_idx.item())

                token_info = {
                    "text": token_text,
                    "start": tokens[j]["start"],
                    "end": tokens[j]["end"],
                    "predicted_tag": tag,
                    "probability": round(max_prob.item(), 4),
                    "all_probs": {
                        self.id_to_tag[k]: round(v.item(), 4)
                        for k, v in enumerate(token_probs)
                    }
                }
                token_outputs.append(token_info)

                if max_prob.item() >= self.threshold and tag:
                    tag_map.setdefault(tag, []).append(token_text)

            tag_map["text"] = item["text"]
            self.results.append(tag_map)
            self.token_level_outputs.append(token_outputs)

                
    def save_results(self):
        os.makedirs(self.result_path, exist_ok=True)

        json_path = os.path.join(self.result_path, "predicted_components.json")
        with open(json_path, "w", encoding="utf-8") as f:
            json.dump(self.results, f, ensure_ascii=False, indent=2)

        fixed_tags = [
            "unit_designator", "unit_number", "unit_number_suffix", 
            "civic_number", "civic_number_suffix", 
            "street_name", "street_type", "street_direction", "street_qualifier", 
            "locality_name", "locality_type", 
            "province_code"
        ]

        csv_path = os.path.join(self.result_path, "predicted_components.csv")
        with open(csv_path, "w", encoding="utf-8", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(["text"] + fixed_tags)
            for res in self.results:
                row = [res["text"]]
                for tag in fixed_tags:
                    row.append(" ".join(res.get(tag, [])))
                writer.writerow(row)



In [None]:
import os
import json
import tarfile
import torch
import shutil

def load_model_config_from_dataset_pt(output_encoded_path):
    tar_path = os.path.join(output_encoded_path, "dataset.tar.gz")
    extract_dir = os.path.join(output_encoded_path, "temp_model_extract")
    os.makedirs(extract_dir, exist_ok=True)

    with tarfile.open(tar_path, "r:gz") as tar:
        tar.extract("dataset.pt", path=extract_dir)

    dataset_path = os.path.join(extract_dir, "dataset.pt")
    dataset = torch.load(dataset_path)  # This is a list of dicts

    token_vocab_size = max(max(d["sequences"]) for d in dataset)
    lemma_vocab_size = max(max(d["lemma_features"]) for d in dataset)
    with open(os.path.join(output_encoded_path, "tag_vocab.json")) as f:
        tag_vocab = json.load(f)
    tag_maybe_feature_dim = len(tag_vocab)

    shutil.rmtree(extract_dir)

    return {
        "token_vocab_size": token_vocab_size,
        "lemma_vocab_size": lemma_vocab_size,
        "tag_maybe_feature_dim": tag_maybe_feature_dim,
        "tag_output_dim": tag_maybe_feature_dim  # same
    }

def load_trained_model(output_models_path, model_config, device="cpu"):
    tar_path = os.path.join(output_models_path, "model_after_train.tar.gz")
    extract_dir = os.path.join(output_models_path, "temp_model_load")
    os.makedirs(extract_dir, exist_ok=True)

    with tarfile.open(tar_path, "r:gz") as tar:
        tar.extract("model_after_train.pt", path=extract_dir)

    model = BiLSTMMultiLabelTagger(**model_config)
    model_path = os.path.join(extract_dir, "model_after_train.pt")
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()

    shutil.rmtree(extract_dir)
    return model



# 2. Define paths
output_encoded_path = os.path.join(output_path, "encoded")
output_models_path = os.path.join(output_path, "models")
result_path = os.path.join(output_path, "results")

# 3. Load model config from dataset.pt
model_config = load_model_config_from_dataset_pt(output_encoded_path)

# 4. Load trained model
model = load_trained_model(output_models_path, model_config, device=device)


In [None]:
# 1. Input address list
addresses=[
    "12040 Horseshoe Way, Richmond, BC",
    "108 East Hastings Street, Vancouver, BC"
]


# 2. Set up inference config
inference_config = {
    "model": model,
    "tokenizer_path": None,
    "tag_vocab_path": os.path.join(output_encoded_path, "tag_vocab.json"),
    "sequence_length_path": os.path.join(output_encoded_path, "length_metadata.json"),
    "tag_pool_path": tag_pool_file_path,
    "device": device,
    "result_path": result_path,
    "use_fuzzy_tags": True,
    "threshold": 0.5 ## threshold to determine whether to ignore the predicted component
}

# 3. Run prediction
predictor = BiLSTMAddressPredictor(inference_config)
predictor.predict(addresses)


