In [1]:
import yaml
import random
import numpy as np
import torch
from transformers import AutoTokenizer, AutoProcessor
from dataset.dataset import TsQaDataset, DataCollator
from models.TimeLanguageModel import TLM, TLMConfig

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [4]:
with open('yaml/infer.yaml', 'r') as f:
    config = yaml.safe_load(f)
print("üõ†Ô∏è Âä†ËΩΩÈÖçÁΩÆÊñá‰ª∂ÂÜÖÂÆπ:")
for k, v in config.items():
    print(f"  {k}: {v}")

class ConfigObject:
    """ÈÄíÂΩíÂ∞ÜÂ≠óÂÖ∏ËΩ¨Êç¢‰∏∫ÂµåÂ•óÂØπË±°"""
    def __init__(self, data):
        for key, value in data.items():
            # ÈÄíÂΩíÂ§ÑÁêÜÂµåÂ•óÂ≠óÂÖ∏
            if isinstance(value, dict):
                setattr(self, key, ConfigObject(value))
            # Áõ¥Êé•ËµãÂÄºÈùûÂ≠óÂÖ∏Á±ªÂûã
            else:
                setattr(self, key, value)
args=ConfigObject(config)

üõ†Ô∏è Âä†ËΩΩÈÖçÁΩÆÊñá‰ª∂ÂÜÖÂÆπ:
  model: TimeSeriesEncoder
  d_model: 512
  n_heads: 8
  e_layers: 4
  patch_len: 60
  stride: 60
  input_len: 600
  dropout: 0.1
  tt_d_model: 896
  tt_n_heads: 16
  tt_layers: 2
  tt_dropout: 0.1
  prefix_num: 25
  pretrain: False
  min_mask_ratio: 0.7
  max_mask_ratio: 0.8
  ts_path_test: dataset/dataset_processing/data_merged_new.h5
  qa_path_test: dataset/dataset_processing/test_sw3000.jsonl
  fp16: True
  dataloader_pin_memory: True
  dataloader_num_workers: 4


In [5]:
print(type(args))

<class '__main__.ConfigObject'>


In [6]:
tlmconfig = TLMConfig(
    ts_pad_num=config['prefix_num']
)

In [7]:
tlmconfig

TLMConfig {
  "model_type": "vlm_model",
  "transformers_version": "4.47.1",
  "ts_pad_num": 25
}

In [13]:
model = TLM.from_pretrained('checkpoints/Qwen-0.5B', config=tlmconfig, ts_config=args)

In [15]:
model._initialize_llm_components('checkpoints/Qwen-0.5B')
model = model.cuda()

ValueError: The checkpoint you are trying to load has model type `vlm_model` but Transformers does not recognize this architecture. This could be because of an issue with the checkpoint, or because your version of Transformers is out of date.

In [10]:
model.eval()

TLM(
  (ts_encoder): Model(
    (patchfy): Patchfy()
    (layers): ModuleList(
      (0-3): 4 x BasicBlock(
        (seq_att_block): SeqAttBlock(
          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (attn_seq): SeqAttention(
            (qkv): Linear(in_features=512, out_features=1536, bias=True)
            (q_norm): Identity()
            (k_norm): Identity()
            (attn_drop): Dropout(p=0.1, inplace=False)
            (proj): Linear(in_features=512, out_features=512, bias=True)
            (proj_drop): Dropout(p=0.1, inplace=False)
          )
          (drop_path1): Identity()
        )
        (var_att_block): VarAttBlock(
          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (attn_var): VarAttention(
            (qkv): Linear(in_features=512, out_features=1536, bias=True)
            (q_norm): Identity()
            (k_norm): Identity()
            (attn_drop): Dropout(p=0.1, inplace=False)
            (proj): Linea

In [11]:
print(f"‚úÖ Ê®°ÂûãÂä†ËΩΩÂÆåÊàê! ÂèÇÊï∞Èáè: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")

‚úÖ Ê®°ÂûãÂä†ËΩΩÂÆåÊàê! ÂèÇÊï∞Èáè: 64.78M


In [11]:
tokenizer = AutoTokenizer.from_pretrained(tlmconfig.llm_model_path)
tokenizer.padding_side = 'left'
processor = AutoProcessor.from_pretrained(tlmconfig.llm_model_path)

In [12]:
test_dataset = TsQaDataset(
    config['ts_path_test'],
    config['qa_path_test'],
    tokenizer,
    processor,
    tlmconfig
)

üìä Vocab size: 151665
üîç È™åËØÅÁâπÊÆätoken:
‚úÖ pad_token_id = 151643
‚úÖ eos_token_id = 151645


In [13]:
print(len(test_dataset))

42477


In [24]:
random_index = random.randint(1, len(test_dataset)) 
sample = test_dataset[random_index]
sample


{'form': 'open',
 'stage': 1,
 'query_ids': [641,
  279,
  2266,
  315,
  279,
  3897,
  4712,
  8286,
  11,
  1128,
  374,
  279,
  23560,
  13042,
  323,
  24586,
  25361,
  315,
  279,
  2297,
  304,
  10657,
  37022,
  518,
  279,
  12041,
  12,
  68269,
  1198,
  56220,
  320,
  43,
  4872,
  8,
  26389,
  13166,
  2337,
  264,
  3175,
  10775,
  5267,
  366,
  2576,
  29],
 'input_ids': [151644,
  8948,
  198,
  2610,
  525,
  264,
  10950,
  17847,
  13,
  151645,
  198,
  151644,
  872,
  198,
  641,
  279,
  2266,
  315,
  279,
  3897,
  4712,
  8286,
  11,
  1128,
  374,
  279,
  23560,
  13042,
  323,
  24586,
  25361,
  315,
  279,
  2297,
  304,
  10657,
  37022,
  518,
  279,
  12041,
  12,
  68269,
  1198,
  56220,
  320,
  43,
  4872,
  8,
  26389,
  13166,
  2337,
  264,
  3175,
  10775,
  5267,
  220,
  151655,
  151655,
  151655,
  151655,
  151655,
  151655,
  151655,
  151655,
  151655,
  151655,
  151655,
  151655,
  151655,
  151655,
  151655,
  151655,
  151655,

In [25]:
for key, value in sample.items():
    if torch.is_tensor(value):
        print(f"  {key}: ÂΩ¢Áä∂ {value.shape}, dtype {value.dtype}")
    else:
        print(f"  {key}: {value}")

  form: open
  stage: 1
  query_ids: [641, 279, 2266, 315, 279, 3897, 4712, 8286, 11, 1128, 374, 279, 23560, 13042, 323, 24586, 25361, 315, 279, 2297, 304, 10657, 37022, 518, 279, 12041, 12, 68269, 1198, 56220, 320, 43, 4872, 8, 26389, 13166, 2337, 264, 3175, 10775, 5267, 366, 2576, 29]
  input_ids: [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 641, 279, 2266, 315, 279, 3897, 4712, 8286, 11, 1128, 374, 279, 23560, 13042, 323, 24586, 25361, 315, 279, 2297, 304, 10657, 37022, 518, 279, 12041, 12, 68269, 1198, 56220, 320, 43, 4872, 8, 26389, 13166, 2337, 264, 3175, 10775, 5267, 220, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151645, 198, 151644, 77091, 198]
  labels: [785, 10657, 37022, 518, 279, 49075, 26389, 8458, 20699, 15175, 13, 1096, 19753, 77764, 264, 14720, 323, 24020, 58877, 5068, 11, 44

In [26]:
data_collator = DataCollator(tokenizer=tokenizer)
collated = data_collator([sample])
print(f"Collated keys: {list(collated.keys())}")

Collated keys: ['input_ids', 'attention_mask', 'labels', 'ts_values', 'stage', 'index', 'query_ids']


In [27]:
input_ids = collated['input_ids'].cuda()
ts_values = collated['ts_values'].cuda()
attention_mask = collated['attention_mask'].cuda()
query_ids = collated['query_ids'].cuda()
stages = collated['stage'].cuda()

In [28]:
type(collated['ts_values'])

torch.Tensor

In [29]:
raw_text = tokenizer.decode(input_ids[0].cpu().numpy(), skip_special_tokens=True)
raw_text

'system\nYou are a helpful assistant.\nuser\nIn the context of the provided engine signal, what is the precise representation and operational significance of the change in Total Temperature at the Low-Pressure Compressor (LPC) outlet observed during a single cycle?\n \nassistant\n'

In [30]:
with torch.no_grad():
    generated_ids = model.generate(
        input_ids=input_ids,
        query_ids=query_ids,
        ts_values=ts_values,
        stage=stages,
        attention_mask=attention_mask,
        max_new_tokens=128,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
        do_sample=True,
        temperature=0.7,
        num_beams=1
    )

In [31]:
decoded_full = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
prediction = decoded_full.split('assistant\n')[-1] if 'assistant\n' in decoded_full else decoded_full
print(decoded_full)
print(prediction)

system
You are a helpful assistant.
user
In the context of the provided engine signal, what is the precise representation and operational significance of the change in Total Temperature at the Low-Pressure Compressor (LPC) outlet observed during a single cycle?
 
assistant
The Total Temperature at the LPC outlet remains constant throughout the cycle, indicating stable operational conditions.
-threats_to_the_environment
The Total Temperature at the LPC outlet remains constant throughout the cycle, indicating stable operational conditions.
-threats_to_the_environment


In [32]:
label_text = tokenizer.batch_decode(collated['labels'], skip_special_tokens=True)[0]
label_text

'The Total Temperature at the LPC outlet remains consistently stable. This stability signifies a reliable and steady compressor performance, demonstrating the absence of significant temperature fluctuations throughout the cycle.'