In [39]:
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig, AutoConfig, PreTrainedModel
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

class TimeLLMConfig(PretrainedConfig):
    """
    HF-Compatible Configuration for TimeLLM Model
    (Non-dataclass version with full PretrainedConfig integration)
    """
    model_type = "time_llm"
    def __init__(
        self,
        tokenizer_kwargs: Dict[str, Any] = None,
        prediction_length: int = 24,
        n_tokens: int = 4096,
        query_len: int = 36,
        **kwargs
    ):
        # 必须调用父类初始化（处理HF标准参数）
        super().__init__(**kwargs)

        # 核心自定义参数
        self.tokenizer_kwargs = tokenizer_kwargs or {}
        self.prediction_length = prediction_length
        self.n_tokens = n_tokens
        self.query_len = query_len

    def create_tokenizer(self) -> 'TimeLLMTokenizer':
        
        return MeanScaleQuantileBins(**self.tokenizer_kwargs, config=self)

class TimeLLMTokenizer:
    """Base class for time series tokenizers"""
    def context_input_transform(self, context: torch.Tensor) -> Tuple:
        raise NotImplementedError()
    
    def label_input_transform(self, label: torch.Tensor, tokenizer_state: Any) -> Tuple:
        raise NotImplementedError()
    
    def output_transform(self, samples: torch.Tensor, tokenizer_state: Any) -> torch.Tensor:
        raise NotImplementedError()

class MeanScaleQuantileBins(TimeLLMTokenizer):
    """Quantile-based binning tokenizer for time series"""
    def __init__(self, low_limit: float, high_limit: float, config: TimeLLMConfig):
        self.config = config
        self.centers = torch.linspace(
            low_limit, high_limit,
            config.n_tokens - 1,
        )
        self.boundaries = torch.concat([
            torch.tensor([-1e20]),
            (self.centers[1:] + self.centers[:-1]) / 2,
            torch.tensor([1e20])
        ])

    def _input_transform(self, context: torch.Tensor, scale: Optional[torch.Tensor] = None):
        context = context.float()
        attention_mask = ~torch.isnan(context)

        if scale is None:
            scale = torch.nansum(torch.abs(context) * attention_mask, dim=-1) / \
                   torch.nansum(attention_mask, dim=-1)
            scale[~(scale > 0)] = 1.0

        scaled_context = context / scale.unsqueeze(-1)
        token_ids = torch.bucketize(scaled_context, self.boundaries,right=True).clamp(0, self.config.n_tokens - 1)  # 直接使用完整token空间
        
        return token_ids, attention_mask, scale

    def context_input_transform(self, context: torch.Tensor):
        # if context.shape[-1] > self.config.context_length:
        #     context = context[..., -self.config.context_length:]
            
        token_ids, attention_mask, scale = self._input_transform(context)
            
        return token_ids, attention_mask, scale

    def output_transform(self, samples: torch.Tensor, scale: torch.Tensor):
        """将模型输出的token索引转换为实际数值"""
        indices = torch.clamp(
            samples,  # 直接使用原始token索引
            min=0,
            max=len(self.centers)-1
        )
        return self.centers[indices] * scale.unsqueeze(-1).unsqueeze(-1)


class QueryAttention(nn.Module):
    
    def __init__(self, embed_dim: int, latent_len: int, num_heads: int = 8):
        super().__init__()
        self.latent_len = latent_len
        # 可学习的Query矩阵 (L×D)
        self.query = nn.Parameter(torch.randn(latent_len, embed_dim))
        # 多头注意力
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)

    def forward(self, x: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None):
        """
        输入: 
          x: [B, T, D] 
          key_padding_mask: [B, T]（可选）
        输出: 
          [B, L, D]
        """
        # 扩展Query为[B, L, D]
        queries = self.query.unsqueeze(0).expand(x.size(0), -1, -1)
        
        # 注意力计算
        attn_out, _ = self.attn(
            query=queries,  # [B, L, D]
            key=x,          # [B, T, D]
            value=x,        # [B, T, D]
            key_padding_mask=key_padding_mask  # 忽略padding部分
        )
        return attn_out  # [B, L, D]

class TimeLLMModel(PreTrainedModel):
    config_class = TimeLLMConfig
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        
        # 2. 文本处理模块
        self.llm_config = AutoConfig.from_pretrained(config.llm_name, trust_remote_code=True)
        self.tokenizer = AutoTokenizer.from_pretrained(config.llm_name, trust_remote_code=True)
        self.llm = AutoModelForCausalLM.from_pretrained(
            config.llm_name,
            config=self.llm_config,
            # device_map='auto', # 必须用关键字参数
            trust_remote_code=True
        )
        self.llm_dim = self.llm_config.hidden_size  # 模型的隐藏维度

        # 1. 时序处理模块
        self.bin_embed = nn.Embedding(config.n_tokens, self.llm_dim)
        self.query_attn = QueryAttention(self.llm_dim, config.query_len)  # Query向量
        self.alignment = nn.Sequential(
            nn.Linear(self.llm_dim, self.llm_dim),
            nn.GELU(),
            nn.Linear(self.llm_dim, self.llm_dim)
        )
        
    def process_ts(self, bin_id):
        """时序数据编码：分箱 -> Embedding -> 注意力筛选 -> 投影"""
        # ts_data: [batch, context_length]
        
        bin_embedding = self.bin_embed(bin_id)  # [batch, ctx_len, dim]
        bin_feat = self.query_attn(bin_embedding)
        
        return self.alignment(bin_feat)

    def forward(self, input_ids, attention_mask, bin_ids, labels=None, scales=None):
        """
        关键设计：
        - 输入格式：[时序Token][文本Token][预测值Token]
        - 训练时：通过错位labels实现自回归
        """
        # 1. 处理时序数据
        ts_emb = self.process_ts(bin_ids) 
        
        # 2. 获取文本嵌入
        text_emb = self.llm.get_input_embeddings()(input_ids)  
        
        # 3. 拼接输入 [时序][文本]
        inputs_embeds = torch.cat([ts_emb, text_emb], dim=1)  
        
        # 5. 通过LLM生成（自回归）
        outputs = self.llm(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            labels=labels  
        )
        return outputs

    def generate(self, ts_data, text_input, max_new_tokens=48):
        # 初始化输入
        input_ids = self.tokenizer(text_input, return_tensors='pt').input_ids
        ts_emb = self.process_ts(ts_data)
        
        # 自回归循环
        for _ in range(max_new_tokens):
            text_emb = self.llm.get_input_embeddings()(input_ids)
            inputs_embeds = torch.cat([ts_emb, text_emb], dim=1)
            
            outputs = self.llm(inputs_embeds=inputs_embeds)
            next_token = outputs.logits[:, -1, :].argmax(-1)
            input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=-1)
            
        return input_ids[:, -max_new_tokens:]  

    
class TimeLLMPipeline:
    def __init__(self, config, model_path=None):
        self.config = config
        self.tokenizer = config.create_tokenizer()
        self.model = TimeLLMModel(config)
        
        if model_path:
            self.model.load_state_dict(torch.load(model_path))

    def preprocess(self, raw_ts: List[float], text: str) -> Dict[str, torch.Tensor]:
        """将原始数据转换为模型输入"""
        # 1. 时序分箱
        bin_ids, _, scale = self.tokenizer.context_input_transform(
            torch.tensor([raw_ts])
        )
        
        # 2. 文本token化
        text_enc = self.model.tokenizer(
            text, 
            return_tensors="pt",
            padding="max_length",
            # max_length=self.config.context_length
        )
        
        return {
            "bin_ids": bin_ids,
            "input_ids": text_enc.input_ids,
            "attention_mask": text_enc.attention_mask,
            "scale": scale  # 保留用于逆变换
        }

    def postprocess(self, pred_tokens: torch.Tensor, scale: float) -> List[float]:
        """将模型输出转换为原始数值"""
        return self.tokenizer.output_transform(pred_tokens, scale).tolist()

    def predict(self, raw_ts: List[float], text: str) -> List[float]:
        """端到端预测"""
        inputs = self.preprocess(raw_ts, text)
        pred_tokens = self.model.generate(
            inputs["bin_ids"], 
            text_input=text,
            max_new_tokens=self.config.prediction_length
        )
        return self.postprocess(pred_tokens, inputs["scale"])




In [40]:
config = TimeLLMConfig(
        tokenizer_class="MeanScaleQuantileBins",
        tokenizer_kwargs={"low_limit": -15.0, "high_limit": 15.0},
        prediction_length=48,
        n_tokens=4096,
        query_len=36,
        llm_name="Qwen/Qwen2.5-1.5B-Instruct"
    )

In [42]:
from datasets import load_dataset
dataset_test = load_dataset('json', data_files='/opt/tiger/dyf/data/AULF_test_data_2021.json', split='train')
dataset_train = load_dataset('json', data_files='/opt/tiger/dyf/data/AULF_train_data_2019-2020.json', split='train')
dataset_test, dataset_train

(Dataset({
     features: ['instruction', 'input', 'output'],
     num_rows: 100
 }),
 Dataset({
     features: ['instruction', 'input', 'output'],
     num_rows: 3655
 }))

In [63]:
d = dataset_test[0]
print(d)

{'instruction': 'The historical load data is: 1189.4,1147.1,1136.2,1143.0,1148.7,1156.9,1161.1,1136.7,1128.3,1145.7,1131.3,1168.6,1225.8,1300.6,1394.4,1468.7,1540.0,1550.1,1488.8,1456.8,1422.4,1372.1,1342.9,1294.8,1247.7,1221.3,1189.8,1203.8,1196.2,1210.2,1230.6,1260.5,1311.1,1349.7,1423.7,1487.1,1549.6,1570.6,1565.4,1536.9,1500.6,1495.2,1409.5,1370.1,1345.4,1291.9,1254.8,1192.7', 'input': 'Based on the historical load data, please predict the load consumption in the next day. The region for prediction is TAS. The start date of historical data was on 2021-8-3 that is Weekday, and it is not a public holiday. The data frequency is 30 minutes per point. Historical data covers 1 day. The date of prediction is on 2021-8-4 that is Weekday, and it is not a public holiday. Weather of the start date: the minimum temperature is 279.71; the maximum temperature is 285.83; the humidity is 85.0; the pressure is 1003.0.  Weather of the prediction date: the minimum temperature is 280.54; the maximum t

In [44]:
bin_tokenizer = config.create_tokenizer()
llm_tokenizer = AutoTokenizer.from_pretrained(config.llm_name)
def process_example(example):
    # 时序数据分箱
    ts_values = [float(x) for x in example["instruction"].split(":")[1].strip().split(",")]
    bin_ids, _, scale = bin_tokenizer.context_input_transform(torch.tensor([ts_values]))
    
    # 构造输入输出文本
    input_text = f"{example['input']}\nanswer:"
    target_text = example["output"]

    # Tokenize（自动处理截断和填充）
    tokenized = llm_tokenizer(
        input_text + target_text,
        truncation=False,
        padding=False,
        return_tensors=None
    )
    
    # 计算loss mask
    input_len = len(llm_tokenizer(input_text, add_special_tokens=False)["input_ids"])
    labels = [-100] * config.query_len + tokenized["input_ids"].copy()  # 前16是时序部分
    labels[config.query_len : config.query_len + input_len] = [-100] * input_len  # 标记instruction


    return {
    "bin_ids": bin_ids[0].tolist(),
    "input_ids": tokenized["input_ids"],
    "attention_mask": [1] * (config.query_len + len(tokenized["input_ids"])),  # 时序+文本
    "labels": labels,
    "scale": scale.item()
    }

In [51]:
a = process_example(d)
for key in a:
    print(key, a[key])

bin_ids [2172, 2167, 2166, 2167, 2167, 2168, 2169, 2166, 2165, 2167, 2166, 2169, 2175, 2183, 2193, 2201, 2208, 2209, 2203, 2199, 2196, 2191, 2187, 2182, 2178, 2175, 2172, 2173, 2172, 2174, 2176, 2179, 2184, 2188, 2196, 2202, 2209, 2211, 2211, 2208, 2204, 2203, 2194, 2190, 2188, 2182, 2178, 2172]
input_ids [28715, 389, 279, 13656, 2795, 821, 11, 4486, 7023, 279, 2795, 15293, 304, 279, 1790, 1899, 13, 576, 5537, 369, 19639, 374, 91288, 13, 576, 1191, 2400, 315, 13656, 821, 572, 389, 220, 17, 15, 17, 16, 12, 23, 12, 18, 429, 374, 10348, 1292, 11, 323, 432, 374, 537, 264, 584, 13257, 13, 576, 821, 11639, 374, 220, 18, 15, 4420, 817, 1459, 13, 40043, 821, 14521, 220, 16, 1899, 13, 576, 2400, 315, 19639, 374, 389, 220, 17, 15, 17, 16, 12, 23, 12, 19, 429, 374, 10348, 1292, 11, 323, 432, 374, 537, 264, 584, 13257, 13, 22629, 315, 279, 1191, 2400, 25, 279, 8028, 9315, 374, 220, 17, 22, 24, 13, 22, 16, 26, 279, 7192, 9315, 374, 220, 17, 23, 20, 13, 23, 18, 26, 279, 37093, 374, 220, 23, 20, 13, 

In [60]:
print('inputid', len(a['input_ids']))
print('label:', len(a['labels']))
print('attention', len(a['attention_mask']))
print(len(a['bin_ids']))

inputid 539
label: 575
attention 575
48


In [66]:
deco = llm_tokenizer.decode([11])
print(deco)

,


In [67]:
processed_dataset = dataset_test.map(
        process_example,
        batched=False,
        remove_columns=dataset_test.column_names
    )

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [71]:
processed_dataset

Dataset({
    features: ['bin_ids', 'input_ids', 'attention_mask', 'labels', 'scale'],
    num_rows: 100
})

In [72]:
def print_sample(dataset, index):
    sample = dataset[index]
    for key, value in sample.items():
        print(f"{key}: {value}")
        if isinstance(value, list):
            print(f"Length of {key}: {len(value)}")
        elif isinstance(value, torch.Tensor):
            print(f"Shape of {key}: {value.shape}")

# 打印数据集中的前几个样本
for i in range(3):
    print(f"Sample {i}:")
    print_sample(processed_dataset, i)
    print("\n")

Sample 0:
bin_ids: [2172, 2167, 2166, 2167, 2167, 2168, 2169, 2166, 2165, 2167, 2166, 2169, 2175, 2183, 2193, 2201, 2208, 2209, 2203, 2199, 2196, 2191, 2187, 2182, 2178, 2175, 2172, 2173, 2172, 2174, 2176, 2179, 2184, 2188, 2196, 2202, 2209, 2211, 2211, 2208, 2204, 2203, 2194, 2190, 2188, 2182, 2178, 2172]
Length of bin_ids: 48
input_ids: [28715, 389, 279, 13656, 2795, 821, 11, 4486, 7023, 279, 2795, 15293, 304, 279, 1790, 1899, 13, 576, 5537, 369, 19639, 374, 91288, 13, 576, 1191, 2400, 315, 13656, 821, 572, 389, 220, 17, 15, 17, 16, 12, 23, 12, 18, 429, 374, 10348, 1292, 11, 323, 432, 374, 537, 264, 584, 13257, 13, 576, 821, 11639, 374, 220, 18, 15, 4420, 817, 1459, 13, 40043, 821, 14521, 220, 16, 1899, 13, 576, 2400, 315, 19639, 374, 389, 220, 17, 15, 17, 16, 12, 23, 12, 19, 429, 374, 10348, 1292, 11, 323, 432, 374, 537, 264, 584, 13257, 13, 22629, 315, 279, 1191, 2400, 25, 279, 8028, 9315, 374, 220, 17, 22, 24, 13, 22, 16, 26, 279, 7192, 9315, 374, 220, 17, 23, 20, 13, 23, 18, 26, 

In [73]:
def collate_fn(batch):
    max_text_len = max(len(x["input_ids"]) for x in batch)
    total_len = config.query_len + max_text_len
    
    def pad_field(values, pad_value):
        return torch.stack([
            torch.cat([
                torch.tensor(v, dtype=torch.long),
                torch.full((total_len - len(v),), pad_value)
            ]) for v in values
        ])
    
    return {
        "bin_ids": torch.stack([torch.tensor(x["bin_ids"]) for x in batch]),
        "input_ids": pad_field(
            [x["input_ids"] for x in batch], 
            llm_tokenizer.pad_token_id  # 使用传入的tokenizer
        )[:, config.query_len:],
        "attention_mask": pad_field([x["attention_mask"] for x in batch], 0),
        "labels": pad_field([x["labels"] for x in batch], -100),
        "scales": torch.tensor([x["scale"] for x in batch])
    }

In [75]:
# 假设你已经有一个预处理后的数据集 processed_dataset
# 选择前几个样本作为测试批量数据
batch = [processed_dataset[i] for i in range(3)]

# 调用 collate_fn 函数
collated_batch = collate_fn(batch)

# 打印 collated_batch 的结构和内容
for key, value in collated_batch.items():
    print(f"{key}: {value}")
    if isinstance(value, torch.Tensor):
        print(f"Shape of {key}: {value.shape}")

bin_ids: tensor([[2172, 2167, 2166, 2167, 2167, 2168, 2169, 2166, 2165, 2167, 2166, 2169,
         2175, 2183, 2193, 2201, 2208, 2209, 2203, 2199, 2196, 2191, 2187, 2182,
         2178, 2175, 2172, 2173, 2172, 2174, 2176, 2179, 2184, 2188, 2196, 2202,
         2209, 2211, 2211, 2208, 2204, 2203, 2194, 2190, 2188, 2182, 2178, 2172],
        [2186, 2184, 2182, 2180, 2177, 2172, 2168, 2165, 2164, 2164, 2166, 2170,
         2176, 2187, 2194, 2199, 2199, 2197, 2192, 2184, 2171, 2171, 2167, 2163,
         2159, 2159, 2159, 2158, 2159, 2160, 2164, 2169, 2175, 2183, 2195, 2206,
         2221, 2228, 2227, 2222, 2219, 2215, 2211, 2207, 2202, 2198, 2191, 2189],
        [2178, 2175, 2172, 2166, 2163, 2160, 2160, 2159, 2159, 2160, 2161, 2163,
         2164, 2167, 2168, 2171, 2172, 2174, 2176, 2176, 2175, 2176, 2177, 2178,
         2179, 2181, 2185, 2188, 2191, 2193, 2197, 2203, 2208, 2210, 2215, 2218,
         2219, 2218, 2217, 2216, 2212, 2207, 2201, 2197, 2191, 2189, 2186, 2181]])
Shape of bin_id