In [1]:
import torch
import matplotlib.pyplot as plt
from transformers import StoppingCriteria, StoppingCriteriaList

from m2_utilities.load_data import load_trajectories
from m2_utilities.preprocessor import Preprocessor, batch_trim_sequence
from m2_utilities.qwen import load_qwen
from m2_utilities.metrics import forecast_points, compute_mae

%load_ext autoreload
%autoreload 2

In [2]:
model, tokenizer = load_qwen()

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


In [3]:
trajectories = load_trajectories("data/lotka_volterra_data.h5")[:, :20, :]
preprocessor = Preprocessor(decimals=2)
input_ids = preprocessor.encode(trajectories)

print(input_ids.shape)

torch.Size([1000, 199])


In [4]:
class MaxSemicolonCriteria(StoppingCriteria):
    def __init__(self, n_points, input_ids):
        self.max_semicolons = n_points + 1
        self.n_semicolons = torch.zeros(len(input_ids), dtype=torch.int)

    def __call__(self, input_ids, scores, **kwargs):
        SEMICOLON_ID = 26
        # Count semicolons
        for i in range(len(input_ids)):
            last_id = input_ids[i][-1].item()
            if last_id == SEMICOLON_ID:
                self.n_semicolons[i] += 1

        return torch.all(self.n_semicolons >= self.max_semicolons)

In [5]:
stopping_criteria = StoppingCriteriaList([MaxSemicolonCriteria(n_points=10, input_ids=input_ids[:1])])

In [6]:
output_ids = model.generate(input_ids[:1], max_length=100000, stopping_criteria=stopping_criteria, do_sample=False)

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


In [17]:
output_trimmed = batch_trim_sequence(output_ids, 25)

all_trajectories = preprocessor.decode(output_trimmed)


In [18]:
print(all_trajectories.shape)

torch.Size([1, 25, 2])
