In [1]:
import torch
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [49]:
import wandb

LEARNING_RATE = 0.001
BATCH_SIZE = 256
HIDDEN_SIZE = 16
NUM_EPOCHS = 10
CONTEXT_SIZE = 24

EXTRA_MODEL_PARAMS = dict(hidden_size=HIDDEN_SIZE, num_heads=4, embedding_size=32, num_outputs=2)

experiment_config = {
    "learning_rate": LEARNING_RATE,
    "architecture": "transformer",
    "dataset": "synthetic-wiki-one-meelyun-sentences",
    "epochs": NUM_EPOCHS,
    "batch_size": BATCH_SIZE,
    "hidden_size": HIDDEN_SIZE,
    "context_size": CONTEXT_SIZE,
    "extra_model_params": EXTRA_MODEL_PARAMS
}

wandb.init(
    project="tiny_sentence_tokenizer",  # https://wandb.ai/josephc/tiny_sentence_tokenizer
    config=experiment_config
)

# Datasets:

In [3]:
import bisect
import bz2
import os
import random
from typing import List, Optional

from torch.utils.data import Dataset

class SentenceSplitDataset:
    def __init__(self, path_to_sentences: os.PathLike, context_size: int, sentence_breaks: List[str] = [" ", "  ", "\n", "\r\n", "\t", "", "\n\n\n\n", "    "]):
        self.context_size = context_size
        self.characters_read = 0
        self.sentence_offsets = list()
        self.sentences = list()
        self.sentence_breaks = sentence_breaks
        with bz2.open(path_to_sentences, 'rt') as fin:
            for line in fin:
                self.sentence_offsets.append(self.characters_read)
                line = line.strip()
                self.sentences.append(line)
                self.characters_read += len(line)+1
    def __len__(self):
        #return len(self.sentences)
        return self.characters_read
    def get_sentence(self, idx: int) -> str:
        return self.sentences[idx]
    def get_sentence_idx_with_character(self, idx: int) -> int:
        return bisect.bisect_right(self.sentence_offsets, idx)-1
    def get_sentence_with_character(self, idx: int, context: Optional[int] = None) -> str:
        # Map the index to the sentence.
        sentence_idx = self.get_sentence_idx_with_character(idx)
        sentence = self.sentences[sentence_idx]
        position_in_sentence = idx - self.sentence_offsets[sentence_idx]
        assert position_in_sentence >= 0
        if context is None:
            return sentence
        return sentence[max(0, position_in_sentence-context):position_in_sentence], sentence[position_in_sentence:position_in_sentence+context]
    def __getitem__(self, idx: int):
        prefix, suffix = self.get_sentence_with_character(idx, context=self.context_size)
        end_of_sentence = len(suffix) == 0
        suffix += random.choice(self.sentence_breaks) + self.get_sentence(self.get_sentence_idx_with_character(idx+1))
        prefix = prefix.rjust(self.context_size)[-self.context_size:]
        suffix = suffix.ljust(self.context_size)[:self.context_size]
        #suffix = suffix.ljust(self.context_size)
        return prefix, suffix, end_of_sentence

In [14]:
ds = SentenceSplitDataset(path_to_sentences="./one_meelyun_sentences.bz2", context_size=CONTEXT_SIZE)

In [15]:
print(ds[0])
print(ds[5001])
print(ds[5036])
print(ds[5037])
print(ds[5038])
for i in range(0, 2000):
    a, b, c = ds[i]
    if c:
        print(f"{i}: {ds[i]}")

('                        ', 'County and municipal cou', False)
('            The Exchange', ' Building opened in 1854', False)
('pened in 1854, part of t', 'he building was later us', False)
('ened in 1854, part of th', 'e building was later use', False)
('ned in 1854, part of the', ' building was later used', False)
69: ('lected every four years.', '  The by-census indicate', True)
215: (' by more than 1 million.', " 'On the Marble Cliffs' ", True)
378: ("ion in Hitler's Germany.", '\tHomer brief description', True)
423: ("ion in the 'Iliad'Homer.", '\n\n\n\nPublic hearings were', True)
545: ('co, and Washington, D.C.', '\tOn July 10, both forces', True)
596: ('ced each other in Kyoto.', "    Monmouth's status as", True)
733: (' November 2002 election.', '\n\n\n\nOpiates are hypothes', True)
796: ('ate aggression and rage.', ' The town celebrated its', True)
840: (' its centennial in 2004.', '  In 1681 Anthony Ashley', True)
964: (' or recourse to a trial.', " Crater Lake's features 

In [160]:
import bisect
import bz2
import os
import random
from typing import List, Optional

from torch.utils.data import Dataset

class BalancedEOSDataset:
    def __init__(self, path_to_sentences: os.PathLike, context_size: int, sentence_breaks: List[str] = [" ", "  ", "\n", "\r\n", "\t", "", "\n\n\n\n", "    ",]):
        self.context_size = context_size
        self.sentences = list()
        self.sentence_breaks = sentence_breaks
        with bz2.open(path_to_sentences, 'rt') as fin:
            for line in fin:
                line = line.strip()
                self.sentences.append(line)
    def __len__(self):
        return (len(self.sentences)-2)*2  # Double since evens will be 'not end of sentence'.
    def __getitem__(self, idx: int):
        end_of_sentence = (idx%2 != 0)
        if not end_of_sentence:
            sentence = self.sentences[idx//2]
            split_point = random.randint(1, len(sentence)-1)
            prefix = sentence[:split_point]
            suffix = sentence[split_point:]
            prefix = prefix.rjust(self.context_size)[-self.context_size:]
            suffix = suffix.ljust(self.context_size)[:self.context_size]
            return prefix, suffix, False
        else:
            prefix = (self.sentences[idx//2].rjust(self.context_size+1))[-self.context_size:]
            suffix = (random.choice(self.sentence_breaks) + self.sentences[(idx//2)+1].ljust(self.context_size))[:self.context_size]
            return prefix, suffix, True

In [161]:
ds = BalancedEOSDataset(path_to_sentences="./one_meelyun_sentences.bz2", context_size=CONTEXT_SIZE)

In [162]:
print(ds[0])
print(ds[5001])
print(ds[5036])
print(ds[5037])
print(ds[5038])

('ty and municipal council', 's are popularly elected ', False)
('e President Dick Cheney.', '\n\n\n\nSabine Baring-Gould ', True)
('ere were 36 airports and', ' one heliport.          ', False)
('rports and one heliport.', '\n\n\n\nKennedy later said t', True)
('Kennedy later said that ', 'his four day-visit to Ir', False)


# Models:

In [None]:
from typing import Iterator, List, Union
import torch.nn as nn
import torch.nn.functional as F
from unidecode import unidecode


class RNN(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(RNN, self).__init__()

        self.hidden_size = hidden_size

        self.i2h = nn.Linear(256, hidden_size)
        self.h2h = nn.Linear(hidden_size, hidden_size)
        self.h2o = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    @staticmethod
    def strings_to_tensor(sentences: list) -> torch.Tensor:
        longest = max([len(sentence.encode("utf-8")) for sentence in sentences])
        out = torch.zeros(len(sentences), longest, 256)
        for batch_idx, s in enumerate(sentences):
            s = s.encode("utf-8")
            for byte_idx in range(len(s)):
                out[batch_idx, byte_idx, int(s[byte_idx])] = 1.0
        return out

    def forward(self, x, hidden = None):
        hidden = F.tanh(self.i2h(x) + self.h2h(hidden))
        output = self.h2o(hidden)
        output = self.softmax(output)
        return output, hidden

    def _init_hidden(self, height: int = 1):
        return torch.zeros(height, self.hidden_size)

    def split_paragraph_iter(self, p: str, min_threshold: Optional[float] = None) -> Iterator[str]:
        self.eval()
        h = self._init_hidden()
        i = torch.zeros(1, 256)
        last_sentence = ""
        for character in p:
            last_sentence += character
            # Convert character to a byte.
            b = character.encode("utf-8")
            for b_value in b:
                i[0, int(b_value)] = 1.0
                out, h = self.forward(i, h)
                i[0, int(b_value)] = 0.0
            out = out.cpu().numpy()
            if out[0,1] >= out[0,0] or (min_threshold is not None and out[0,1] > min_threshold):
                yield last_sentence
                last_sentence = ""
        yield last_sentence


def run_inference(m, prefix: List[str], suffix: Optional[List[str]]):
    prefix = RNN.strings_to_tensor(prefix).to(DEVICE)
    if suffix is not None:
        suffix = RNN.strings_to_tensor(suffix).to(DEVICE)
    hidden = m._init_hidden().to(DEVICE)
    for i in range(0, prefix.shape[1]):
        out, hidden = m(prefix[:,i,:], hidden)
    if suffix is not None:
        for i in range(0, suffix.shape[1]):
            out, hidden = m(suffix[:,i,:], hidden)
    return out


n_hidden = HIDDEN_SIZE
n_categories = 2
model = RNN(n_hidden, n_categories)
model = model.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_fn = nn.NLLLoss()

In [153]:
from typing import Iterator, List, Union
import torch.nn as nn
import torch.nn.functional as F
from unidecode import unidecode

class TFSentenceSplit(nn.Module):
    def __init__(self, hidden_size: int, num_heads: int, embedding_size: int, num_outputs: int):
        super().__init__()
        self.position_embedding = nn.Embedding(num_embeddings=256, embedding_dim=embedding_size)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=embedding_size, nhead=num_heads, batch_first=True)
        self.inference_head = nn.Linear(embedding_size, num_outputs)

    def forward(self, x):
        # Assumes x is batch first like a normal, sensible person.
        # out = self.encode_layer(torch.rand(batch_size, seq_length, embedding_size))
        # Use torch.LongTensor to encode.
        x = self.position_embedding(x)
        x = self.encoder_layer(x) # Out: [batch_size, seq_len, embedding_size]
        x = self.inference_head(x)
        x = F.softmax(x, dim=-1)
        return x[:,-1,:].squeeze(1)

    @staticmethod
    def strings_to_tensor(sentences: list) -> torch.Tensor:
        longest = max([len(sentence.encode("utf-8")) for sentence in sentences])
        out = torch.zeros((len(sentences), longest), dtype=torch.int64)  # torch.LongTensor
        for batch_idx, s in enumerate(sentences):
            s = s.rjust(longest).encode("utf-8")[-longest:]  # Pad the left with spaces so it's aligned, then convert to bytes and truncate.
            for byte_idx in range(len(s)):
                out[batch_idx, byte_idx] = int(s[byte_idx])
        return out

def run_inference(m, prefix: List[str], suffix: List[str]):
    prefix = TFSentenceSplit.strings_to_tensor(prefix).to(DEVICE)
    #suffix = TFSentenceSplit.strings_to_tensor(suffix).to(DEVICE)
    return m.forward(prefix)

model = TFSentenceSplit(**EXTRA_MODEL_PARAMS)
model = model.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_fn = nn.CrossEntropyLoss()

# Evaluation and Training Prep:

In [165]:
from torch.utils.data import DataLoader
total_examples = len(ds)  # Hack -- dunno the total example count.
train_size = int(total_examples*0.7)
validation_size = int(total_examples*0.1)
test_size = total_examples - (train_size + validation_size)
train_ds, validate_ds, test_ds = torch.utils.data.random_split(ds, [train_size, validation_size, test_size])

In [166]:
train_dataloader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
validate_dataloader = DataLoader(validate_ds, batch_size=BATCH_SIZE)
test_dataloader = DataLoader(test_ds, batch_size=BATCH_SIZE)

In [167]:
def compute_tpr_fpr_tnr_fnr(predictions_logits, ground_truth_labels):
    # Assume predictions are a tensor of shape [batch, 2], same with ground truth.
    # Assume predictions are NORMALIZED along axis=1 ([_,0], [_,1]).
    # Assume [:,1] means 'yes, this is a break'.
    with torch.no_grad():
        pred = predictions_logits.cpu().numpy()
        gt = ground_truth_labels.cpu().numpy()
        
        tpr = 0
        fpr = 0
        tnr = 0
        fnr = 0
        for idx in range(0, pred.shape[0]):
            if gt[idx] < 0.5 or gt[idx] == False: # GT: Negative
                if pred[idx,0] > pred[idx,1]: # Pred: Negative
                    tnr += 1
                else: # Pred: Positive
                    fpr += 1
            else: # GT: Positive
                if pred[idx,0] > pred[idx,1]: # Pred: Negative:
                    fnr += 1
                else: # Pred: Positive
                    tpr += 1
        return tpr, fpr, tnr, fnr
                

In [168]:
assert compute_tpr_fpr_tnr_fnr(torch.tensor([[0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [1.0, 0.0]]), torch.tensor([0, 1, 0, 1])) == (1, 1, 1, 1)
print(compute_tpr_fpr_tnr_fnr(torch.tensor([[0.1, 0.9], [0.2, 0.8], [0.3, 0.7], [0.4, 0.6], [0.6, 0.4]]), torch.tensor([0, 0, 0, 0, 0])))
print(compute_tpr_fpr_tnr_fnr(torch.tensor([[0.1, 0.9], [0.2, 0.8], [0.3, 0.7], [0.4, 0.6], [0.6, 0.4]]), torch.tensor([1, 1, 1, 1, 1])))
print(run_inference(model, ["Only a."], []))
print(compute_tpr_fpr_tnr_fnr(run_inference(model, ["Only a test", "Mostly a test", "Ignore me"], []), torch.tensor([0, 1, 0])))

(0, 4, 1, 0)
(4, 0, 0, 1)
tensor([[7.6257e-06, 9.9999e-01]], device='cuda:0', grad_fn=<SqueezeBackward1>)
(0, 0, 2, 1)


In [169]:
print(TFSentenceSplit.strings_to_tensor(["Only a test.", "Mostly a", "Ignore"]))
print(run_inference(model, ["Only a test."], []))
print(compute_tpr_fpr_tnr_fnr(run_inference(model, ["Only a test.", "Mostly a", "Ignore"], []), torch.tensor([1, 0, 0])))

tensor([[ 79, 110, 108, 121,  32,  97,  32, 116, 101, 115, 116,  46],
        [ 32,  32,  32,  32,  77, 111, 115, 116, 108, 121,  32,  97],
        [ 32,  32,  32,  32,  32,  32,  73, 103, 110, 111, 114, 101]])
tensor([[7.0584e-06, 9.9999e-01]], device='cuda:0', grad_fn=<SqueezeBackward1>)
(1, 0, 2, 0)


# Training:

In [170]:
#from tqdm import tqdm
from tqdm.notebook import trange, tqdm

lowest_loss = 1e10
best_counts = 0

for epoch in trange(NUM_EPOCHS):
    model.train()
    examples_seen = 0
    positives_seen = 0
    total_train_loss = 0.0
    running_train_loss = 0.0
    validation_loss = 0.0
    tpr, tnr, fpr, fnr = 0, 0, 0, 0
    for batch_idx, (pre, suf, label) in tqdm(enumerate(train_dataloader)):
        label = (torch.Tensor(label) * 1).to(DEVICE)
        out = run_inference(model, pre, suf)
        loss = loss_fn(out, label)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        examples_seen += BATCH_SIZE
        positives_seen += torch.sum(label).item()
        per_example_loss = loss.item()/BATCH_SIZE
        total_train_loss += per_example_loss
        running_train_loss = running_train_loss * 0.9 + per_example_loss * 0.1
        batch_tp, batch_fp, batch_tn, batch_fn = compute_tpr_fpr_tnr_fnr(out, label)
        tpr += batch_tp
        tnr += batch_tn
        fpr += batch_fp
        fnr += batch_fn
        if batch_idx % 100 == 0:
            rate = tpr+tnr+fpr+fnr
            wandb.log({
                "batch_loss": per_example_loss, 
                "false_positive_rate": fpr/rate, 
                "true_positive_rate": tpr/rate,
                "false_negative_rate": fnr/rate,
                "true_negative_rate": tnr/rate,
                "positives_seen": positives_seen,
            }, commit=(batch_idx%1000)==0)
            tpr, tnr, fpr, fnr = 0, 0, 0, 0
        if (batch_idx+1) % 2500 == 0:
            print(f"{epoch}: {batch_idx}: {running_train_loss}")
    print(f"END OF EPOCH {epoch}: {total_train_loss} train loss")
    model.eval()
    for batch_idx, (pre, suf, label) in enumerate(validate_dataloader):
        label = (torch.Tensor(label) * 1).to(DEVICE)
        out = run_inference(model, pre, suf)
        loss = loss_fn(out, label)
        validation_loss += loss.item()/BATCH_SIZE
    wandb.log({"validation_loss": validation_loss})
    print(f"END OF EPOCH {epoch}: {validation_loss} validation loss")
    torch.save(model, f"checkpoint_epoch_{epoch}.pt")
    if validation_loss < lowest_loss:
        lowest_loss = validation_loss
        torch.save(model, f"best_{best_counts}.pt")
        best_counts += 1
wandb.finish()

  0%|          | 0/10 [00:00<?, ?it/s]

0it [00:00, ?it/s]

0: 2499: 0.0012829395431837563
0: 4999: 0.0012842619629034535
END OF EPOCH 0: 7.023348012357019 train loss
END OF EPOCH 0: 1.003268651664257 validation loss


0it [00:00, ?it/s]

1: 2499: 0.0012785414394387651
1: 4999: 0.0012735982540847894
END OF EPOCH 1: 6.985950689762831 train loss
END OF EPOCH 1: 0.9949951990274712 validation loss


0it [00:00, ?it/s]

2: 2499: 0.0012766401339488929
2: 4999: 0.0012784149233055961
END OF EPOCH 2: 6.948916549794376 train loss
END OF EPOCH 2: 0.9929292330052704 validation loss


0it [00:00, ?it/s]

3: 2499: 0.00125973992400006
3: 4999: 0.0012639339528977417
END OF EPOCH 3: 6.940024675917812 train loss
END OF EPOCH 3: 0.9912436854792759 validation loss


0it [00:00, ?it/s]

4: 2499: 0.0012602235510088462
4: 4999: 0.0012700707688301776
END OF EPOCH 4: 6.931345159187913 train loss
END OF EPOCH 4: 0.9907174380496144 validation loss


0it [00:00, ?it/s]

5: 2499: 0.0012708762592163613
5: 4999: 0.0012662896689727863
END OF EPOCH 5: 6.92873289482668 train loss
END OF EPOCH 5: 0.9905290466267616 validation loss


0it [00:00, ?it/s]

6: 2499: 0.0012666972434059946
6: 4999: 0.0012735800933405099
END OF EPOCH 6: 6.928655483876355 train loss
END OF EPOCH 6: 0.9917419753037393 validation loss


0it [00:00, ?it/s]

7: 2499: 0.0012828623781695162
7: 4999: 0.0012689588803401516
END OF EPOCH 7: 6.927611568477005 train loss
END OF EPOCH 7: 0.9910406979033723 validation loss


0it [00:00, ?it/s]

8: 2499: 0.0012641329340382943
8: 4999: 0.0012601744025184038
END OF EPOCH 8: 6.9264642025809735 train loss
END OF EPOCH 8: 0.9905455211410299 validation loss


0it [00:00, ?it/s]

9: 2499: 0.0012674037941124432
9: 4999: 0.0012662072212573457
END OF EPOCH 9: 6.926161372219212 train loss
END OF EPOCH 9: 0.9894104808336124 validation loss


VBox(children=(Label(value='0.002 MB of 0.002 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
batch_loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁
false_negative_rate,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
false_positive_rate,▁▆▆▆▆▆▇▆▆▆▅▆▆▆▆▆▆▆▆▆▆▅▆▆▆█▆▆▆▆▆▆▆▆▆▆▇▆▆▆
positives_seen,▁▂▅▂▅▇▁▄▅▇▁▄▆█▂▅▆█▂▅▇▁▄▆▇▁▄▆█▂▅▇█▂▅▇▁▄▆█
true_negative_rate,█▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
true_positive_rate,▁████████████████████████▇██████████▇███
validation_loss,█▄▃▂▂▂▂▂▂▁

0,1
batch_loss,0.00123
false_negative_rate,0.00742
false_positive_rate,0.49816
positives_seen,691005.0
true_negative_rate,0.00355
true_positive_rate,0.49086
validation_loss,0.98941


Lifted from https://damdid2022.frccsc.ru/files/article/DAMDID_2022_paper_2646.pdf
Zavyalova, Martynyuk, and Samarev

"Testing will be performed on 5840 sentences from “The GUM Corpus” [16]."

|Tool Name                |tp  |fp | tn   |fn  |accuracy|error|precision|recall|f1   |
|---                      |--- |---|---   |--- |---     |---  |---      |---   |---  |
|Sentencize.jl            |6330|254|107813|1078|0.99    |0.01 |0.96     |0.85  |0.905|
|NLTK                     |6269|283|107787|1139|0.99    |0.01 |0.96     |0.85  |0.898|
|OpenNLP                  |6255|276|107791|1153|0.99    |0.01 |0.96     |0.84  |0.897|
|CoreNLP                  |6278|362|107786|1130|0.99    |0.01 |0.95     |0.85  |0.894|
|WordTokenizers.jl        |6140|264|107809|1268|0.99    |0.01 |0.96     |0.83  |0.889|
|Spacy (Dependency parser)|6631|934|107268|777 |0.99    |0.01 |0.88     |0.90  |0.886|
|Spacy (Rule-based)       |6183|994|107531|1225|0.98    |0.02 |0.86     |0.83  |0.848|
|SimpleSplitter           |5760|772|107847|1648|0.98    |0.02 |0.88     |0.78  |0.826|
|Julia split()            |5760|878|107780|1648|0.98    |0.02 |0.87     |0.78  |0.820|

In [203]:
import numpy
print("Raw:")
print(ds[0])
print("DL raw:")
for pre, suf, label in train_dataloader:
    print(pre[0])
    print(label[0])
    break
print("Processed:")
model.eval()
with torch.no_grad():
    for pre, suf, label in train_dataloader:
        label = (torch.Tensor(label) * 1).to(DEVICE)
        out = run_inference(model, pre, suf)
        tp, fp, tn, fn = compute_tpr_fpr_tnr_fnr(out, label)
        loss = loss_fn(out, label).cpu().numpy()
        out = out.cpu().numpy()
        confidence = numpy.abs(out[:, 0] - out[:, 1])
        print(f"TP: {tp}  TN: {tn}  FP: {fp}  FN: {fn}")
        errors = 0
        for idx in range(out.shape[0]):
            model_guess_eos = out[idx,1]>out[idx,0]
            gt_eos = label[idx]>0.5
            if model_guess_eos != gt_eos:
                print("!!!")
                errors += 1
                print(f"Sent: {pre[idx]}")
                print(f"Model guess: EOS: {model_guess_eos}")
                print(f"Truth: EOS: {gt_eos}")
                print()
        print(f"Total errors: {errors}")
        break

Raw:
('councils are popularly e', 'lected every four years.', False)
DL raw:
omemade drum as a child.
tensor(True)
Processed:
TP: 124  TN: 128  FP: 0  FN: 4
!!!
Sent: S YOU HAVE A REFERENCE!!
Model guess: EOS: False
Truth: EOS: True

!!!
Sent:  of the cross to be .178
Model guess: EOS: False
Truth: EOS: True

!!!
Sent: or 'Intervention Order '
Model guess: EOS: False
Truth: EOS: True

!!!
Sent:  do país' - Papo de Bola
Model guess: EOS: False
Truth: EOS: True

Total errors: 4


In [232]:
# Save model:
placeholder_x = torch.zeros([BATCH_SIZE, CONTEXT_SIZE], dtype=torch.int64)
model.eval().to('cpu')
out = model(placeholder_x)
torch.onnx.export(
    model, 
    placeholder_x,
    f"sentence_tokenizer_v6_{CONTEXT_SIZE}x256.onnx",
    export_params=True,        # store the trained parameter weights inside the model file
    opset_version=14,          # the ONNX version to export the model to
    do_constant_folding=True,  # whether to execute constant folding for optimization
    input_names = ['input'],   # the model's input names
    output_names = ['output'], # the model's output names
    dynamic_axes={'input' : {0 : 'batch_size'}, 'output' : {0 : 'batch_size'}}
)


In [228]:
placeholder_x = torch.zeros([1, CONTEXT_SIZE], dtype=torch.int64)
model.eval().to('cpu')
exporter = torch.onnx.dynamo_export(model, placeholder_x)
exporter.save(f"sentence_tokenizer_v6_{CONTEXT_SIZE}x256_dynamo.onnx")



OnnxExporterError: Failed to export the model to ONNX. Generating SARIF report at 'report_dynamo_export.sarif'. SARIF is a standard format for the output of static analysis tools. SARIF logs can be loaded in VS Code SARIF viewer extension, or SARIF web viewer (https://microsoft.github.io/sarif-web-component/). Please report a bug on PyTorch Github: https://github.com/pytorch/pytorch/issues