In [11]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from generation import MarkedIntensityHomogenuosPoisson, generate_samples_marked

class PointProcessDataset(Dataset):
    def __init__(self, sequences):
        self.sequences = sequences
        self.max_len = max(len(seq) for seq in sequences)
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        seq = self.sequences[idx]
        padded_seq = np.zeros((self.max_len, 2))
        padded_seq[:len(seq)] = np.array(seq)
        seq_len = len(seq)
        
        return {
            'sequence': torch.FloatTensor(padded_seq),
            'length': torch.LongTensor([seq_len]),
            'time_series': torch.ones(self.max_len, NUM_STEPS_TIMESERIES, TIMESERIES_FEATURE)
        }

DIM_SIZE = 7
mi = MarkedIntensityHomogenuosPoisson(DIM_SIZE)
for u in range(DIM_SIZE):
    mi.initialize(1.0, u)
simulated_sequences = generate_samples_marked(mi, 15.0, 1000)
dataset = PointProcessDataset(simulated_sequences)


In [20]:
len(simulated_sequences[3])

135

In [104]:
class PointProcessDataset(Dataset):
    def __init__(self, sequences):
        # 直接将序列转换为tensor格式存储
        self.sequences = [
            torch.tensor([[event[0], event[1]] for event in seq], dtype=torch.float32)
            for seq in sequences
        ]
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        # 直接返回原始序列，不进行padding
        return self.sequences[idx]
def collate_fn(batch):
    return batch
DIM_SIZE = 7
BATCH_SIZE=256
mi = MarkedIntensityHomogenuosPoisson(DIM_SIZE)
for u in range(DIM_SIZE):
    mi.initialize(1.0, u)
simulated_sequences = generate_samples_marked(mi, 15.0, 1000)
dataset = PointProcessDataset(simulated_sequences)


# 在创建DataLoader时使用自定义的collate_fn
dataloader = DataLoader(
    dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    collate_fn=collate_fn
)


In [120]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleGRU(nn.Module):
    def __init__(self, num_classes, hidden_size_event, hidden_size_time, reg=0.1):
        super(SimpleGRU, self).__init__()
        self.num_classes = num_classes
        self.reg = reg
        self.event_embedding = nn.Embedding(num_classes, num_classes)
        self.event_gru = nn.GRU(num_classes, hidden_size_event, batch_first=True)
        self.time_gru = nn.GRU(1, hidden_size_time, batch_first=True)
        combined_size = hidden_size_event + hidden_size_time
        self.time_output = nn.Linear(combined_size, 1)
        self.mark_output = nn.Linear(combined_size, num_classes)

    def forward(self, event_sequence):
        marks = event_sequence[..., 0].long()
        times = event_sequence[..., 1].unsqueeze(-1)
        
        # 修正: 使用event_embedding处理marks
        mark_embedded = self.event_embedding(marks)
        event_output, _ = self.event_gru(mark_embedded)
        time_output, _ = self.time_gru(times)
        
        combined_output = torch.cat([event_output, time_output], dim=-1)
        time_pred = self.time_output(combined_output)
        mark_logits = self.mark_output(combined_output)
        return time_pred, mark_logits

    def compute_loss(self, time_pred, mark_logits, targets):
        true_times = targets[..., 1].unsqueeze(-1)
        true_marks = targets[..., 0].long()
        time_loss = F.mse_loss(time_pred, true_times)
        mark_logits_flat = mark_logits.view(-1, self.num_classes)
        true_marks_flat = true_marks.view(-1)
        mark_loss = F.cross_entropy(mark_logits_flat, true_marks_flat)
        
        total_loss = mark_loss + self.reg * time_loss
        
        return total_loss, mark_loss, time_loss


In [121]:
def train_model(model, train_loader, num_epochs=10, learning_rate=0.001, device='cuda' if torch.cuda.is_available() else 'cpu'):
    """
    训练模型的函数
    
    Args:
        model: SimpleGRU模型实例
        train_loader: 训练数据的DataLoader
        num_epochs: 训练轮数
        learning_rate: 学习率
        device: 训练设备（'cuda'或'cpu'）
    """
    print(f"Training on {device}")
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    for epoch in range(num_epochs):
        model.train()
        train_total_loss = 0
        train_mark_loss_sum = 0
        train_time_loss_sum = 0
        num_batches = 0
        
        epoch_start_time = time.time()
        train_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
        
        for batch in train_bar:
            optimizer.zero_grad()
            batch_loss = 0
            batch_mark_loss = 0
            batch_time_loss = 0
            
            for sequence in batch:
                sequence = sequence.to(device)
                sequence = sequence.unsqueeze(0)
                time_pred, mark_logits = model(sequence)
                loss, mark_loss, time_loss = model.compute_loss(time_pred, mark_logits, sequence)
                batch_loss += loss
                batch_mark_loss += mark_loss
                batch_time_loss += time_loss
            batch_size = len(batch)
            batch_loss /= batch_size
            batch_mark_loss /= batch_size
            batch_time_loss /= batch_size
            batch_loss.backward()
            optimizer.step()
            train_total_loss += batch_loss.item()
            train_mark_loss_sum += batch_mark_loss.item()
            train_time_loss_sum += batch_time_loss.item()
            num_batches += 1
            
            train_bar.set_postfix({
                'loss': f'{batch_loss.item():.4f}',
                'mark_loss': f'{batch_mark_loss.item():.4f}',
                'time_loss': f'{batch_time_loss.item():.4f}'
            })
        
        # 计算并打印epoch的平均损失
        avg_train_loss = train_total_loss / num_batches
        avg_train_mark_loss = train_mark_loss_sum / num_batches
        avg_train_time_loss = train_time_loss_sum / num_batches
        
        epoch_time = time.time() - epoch_start_time
        print(f'\nEpoch {epoch+1}/{num_epochs} - {epoch_time:.2f}s')
        print(f'Train Loss: {avg_train_loss:.4f} (Mark: {avg_train_mark_loss:.4f}, Time: {avg_train_time_loss:.4f})')
        print('-' * 80)

# 使用示例：
if __name__ == "__main__":
    # 创建模型
    model = SimpleGRU(
        num_classes=DIM_SIZE,
        hidden_size_event=16,
        hidden_size_time=32
    )
    
    # 创建训练数据加载
    
    # 训练模型
    train_model(
        model=model,
        train_loader=dataloader,
        num_epochs=10,
        learning_rate=0.001
    )

Training on cpu


Epoch 1/10: 100%|██████████| 4/4 [00:22<00:00,  5.52s/it, loss=8.7964, mark_loss=1.9855, time_loss=68.1082]



Epoch 1/10 - 22.09s
Train Loss: 8.8937 (Mark: 1.9917, Time: 69.0204)
--------------------------------------------------------------------------------


Epoch 2/10: 100%|██████████| 4/4 [00:20<00:00,  5.13s/it, loss=8.4259, mark_loss=1.9702, time_loss=64.5570]



Epoch 2/10 - 20.53s
Train Loss: 8.5696 (Mark: 1.9736, Time: 65.9599)
--------------------------------------------------------------------------------


Epoch 3/10: 100%|██████████| 4/4 [00:21<00:00,  5.34s/it, loss=8.1609, mark_loss=1.9536, time_loss=62.0733]



Epoch 3/10 - 21.36s
Train Loss: 8.2578 (Mark: 1.9582, Time: 62.9961)
--------------------------------------------------------------------------------


Epoch 4/10: 100%|██████████| 4/4 [00:20<00:00,  5.16s/it, loss=7.8092, mark_loss=1.9403, time_loss=58.6896]



Epoch 4/10 - 20.63s
Train Loss: 7.9512 (Mark: 1.9448, Time: 60.0643)
--------------------------------------------------------------------------------


Epoch 5/10: 100%|██████████| 4/4 [00:20<00:00,  5.18s/it, loss=7.5525, mark_loss=1.9286, time_loss=56.2385]



Epoch 5/10 - 20.70s
Train Loss: 7.6516 (Mark: 1.9326, Time: 57.1898)
--------------------------------------------------------------------------------


Epoch 6/10: 100%|██████████| 4/4 [00:21<00:00,  5.33s/it, loss=7.2936, mark_loss=1.9172, time_loss=53.7640]



Epoch 6/10 - 21.34s
Train Loss: 7.3569 (Mark: 1.9214, Time: 54.3549)
--------------------------------------------------------------------------------


Epoch 7/10: 100%|██████████| 4/4 [00:22<00:00,  5.58s/it, loss=6.9781, mark_loss=1.9073, time_loss=50.7074]



Epoch 7/10 - 22.31s
Train Loss: 7.0648 (Mark: 1.9109, Time: 51.5392)
--------------------------------------------------------------------------------


Epoch 8/10: 100%|██████████| 4/4 [00:21<00:00,  5.41s/it, loss=6.6628, mark_loss=1.8971, time_loss=47.6572]



Epoch 8/10 - 21.64s
Train Loss: 6.7755 (Mark: 1.9010, Time: 48.7452)
--------------------------------------------------------------------------------


Epoch 9/10: 100%|██████████| 4/4 [00:20<00:00,  5.09s/it, loss=6.3574, mark_loss=1.8879, time_loss=44.6946]



Epoch 9/10 - 20.37s
Train Loss: 6.4891 (Mark: 1.8917, Time: 45.9741)
--------------------------------------------------------------------------------


Epoch 10/10: 100%|██████████| 4/4 [00:20<00:00,  5.15s/it, loss=6.0773, mark_loss=1.8790, time_loss=41.9826]



Epoch 10/10 - 20.59s
Train Loss: 6.2075 (Mark: 1.8824, Time: 43.2509)
--------------------------------------------------------------------------------
