In [5]:
import math
def generate_target_matrix(labels, batch_size, seq_len, num_classes, device):
    multi_hot = torch.zeros((batch_size, seq_len, num_classes), device=device)
    
    for batch_idx in range(batch_size):
        for action in labels[batch_idx]:
            # 使用浮点数计算后取整
            start_f = action[0].item() * seq_len
            end_f = action[1].item() * seq_len
            
            # 新的边界计算方式
            start = int(math.floor(start_f))
            end = int(math.ceil(end_f))
            
            # 边界保护
            start = max(0, min(start, seq_len))
            end = max(0, min(end, seq_len))
            if start >= end: continue
            
            # 标记区间（左闭右开）
            multi_hot[batch_idx, start:end, int(action[2].item())] = 1

    flattened = multi_hot.view(-1, num_classes)
    return (flattened @ flattened.T > 0).float()
    

In [3]:
def generate_enhanced_target_matrix(labels, batch_size, seq_len, num_classes, device):
    state_types = 3
    multi_hot = torch.zeros((batch_size, seq_len, num_classes * state_types), device=device)
    
    # 填充multi_hot（保持原有逻辑不变）
    for batch_idx in range(batch_size):
        for action in labels[batch_idx]:
            start_time = action[0].item()
            end_time = action[1].item()
            label = int(action[2].item())
            
            start_frame = math.floor(start_time * seq_len)
            end_frame = math.ceil(end_time * seq_len)
            
            start = max(0, min(start_frame, seq_len))
            end = max(0, min(end_frame, seq_len))
            if start >= end: continue
            
            base = label * state_types
            multi_hot[batch_idx, start, base+1] = 1  # 开始状态
            if end - start > 1:
                multi_hot[batch_idx, end-1, base+2] = 1  # 结束状态
                multi_hot[batch_idx, start+1:end-1, base] = 1  # 进行中

    # 展平并计算相似性矩阵
    flattened = multi_hot.view(-1, num_classes * state_types)
    similarity = (flattened @ flattened.T) > 0  # [B*seq_len, B*seq_len]

    # 核心修正：动态生成状态兼容性掩码
    # --------------------------------------------------------
    # 步骤1：提取每个位置的类别和状态类型
    active = (flattened.sum(dim=1) > 0)  # 有效位置掩码
    channel_idx = torch.argmax(flattened, dim=1)  # 每个位置的最大激活通道
    
    # 类别 = 通道索引 // 3
    class_ids = (channel_idx // state_types) * active.long()  # 无效位置设为0
    # 状态类型 = 通道索引 % 3
    state_idx = (channel_idx % state_types) * active.long()   # 无效位置设为0

    # 步骤2：构建状态兼容性规则
    compatibility = torch.tensor(
        [[1,1,0], [1,1,1], [0,1,1]],  # ing/start/end的兼容规则
        device=device
    )
    
    # 步骤3：生成状态兼容性矩阵
    state_mask = compatibility[state_idx][:, state_idx]  # [B*seq_len, B*seq_len]
    
    # 步骤4：生成类别匹配矩阵
    class_match = (class_ids.unsqueeze(1) == class_ids.unsqueeze(0))  # [B*seq_len, B*seq_len]
    
    # 步骤5：组合最终掩码
    final_mask = (state_mask & class_match & active.unsqueeze(1) & active.unsqueeze(0))
    
    # 应用掩码
    return (similarity.float() * final_mask.float())

In [9]:
def normalize_target(target_matrix):
    # 添加极小值避免全零行
    target_matrix += 1e-8
    # 行归一化
    return target_matrix / target_matrix.sum(dim=1, keepdim=True)



In [10]:
import torch
# 假设参数
batch_size = 2
seq_len = 3
num_classes = 5  # 需要根据实际标签范围确定
device = "cuda"

# 模拟输入数据（两个样本，每个包含两个动作）
labels = [
    [torch.tensor([0.0, 0.3, 2]), torch.tensor([0.5, 0.7, 3])],  # 样本1
    [torch.tensor([0.2, 0.4, 1]), torch.tensor([0.6, 0.9, 3])]   # 样本2
]

# 生成目标矩阵
targets = generate_target_matrix(
    labels = labels,
    batch_size = batch_size,
    seq_len = seq_len,
    num_classes = num_classes,
    device = device
)

print(targets)  # 输出: torch.Size([20, 20]) (假设seq_len=10)

tensor([[1., 0., 0., 0., 0., 0.],
        [0., 1., 1., 0., 1., 1.],
        [0., 1., 1., 0., 1., 1.],
        [0., 0., 0., 1., 1., 0.],
        [0., 1., 1., 1., 1., 1.],
        [0., 1., 1., 0., 1., 1.]], device='cuda:0')


In [11]:
targets_nor = normalize_target(targets)
print(targets_nor)

tensor([[1.0000e+00, 1.0000e-08, 1.0000e-08, 1.0000e-08, 1.0000e-08, 1.0000e-08],
        [2.5000e-09, 2.5000e-01, 2.5000e-01, 2.5000e-09, 2.5000e-01, 2.5000e-01],
        [2.5000e-09, 2.5000e-01, 2.5000e-01, 2.5000e-09, 2.5000e-01, 2.5000e-01],
        [5.0000e-09, 5.0000e-09, 5.0000e-09, 5.0000e-01, 5.0000e-01, 5.0000e-09],
        [2.0000e-09, 2.0000e-01, 2.0000e-01, 2.0000e-01, 2.0000e-01, 2.0000e-01],
        [2.5000e-09, 2.5000e-01, 2.5000e-01, 2.5000e-09, 2.5000e-01, 2.5000e-01]],
       device='cuda:0')
