In [1]:
import math
import os
import h5py
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
from constants import KEY_LM_HIDDEN_STATES, KEY_LM_INPUT_IDS, KEY_LM_LABELS
from utils import load_config


class HDF5Dataset(Dataset):
    def __init__(self, h5_files_dir, split):
        """
        Args:
            h5_files_dir (str): 
            split (str): train/validation/test
        """
        h5_file_path = os.path.join(h5_files_dir, split+'.h5')
        self.h5_file = h5py.File(h5_file_path, 'r')
        
        self.hidden_states = self.h5_file[KEY_LM_HIDDEN_STATES]
        self.input_ids = self.h5_file[KEY_LM_INPUT_IDS]
        self.labels = self.h5_file[KEY_LM_LABELS]
        self.total_samples = int(self.h5_file.attrs['total_samples'])

    def __len__(self):
        return self.total_samples

    def __getitem__(self, idx):
        input_ids = torch.tensor(self.input_ids[idx], dtype=torch.long)
        labels = torch.tensor(self.labels[idx], dtype=torch.long)
        hidden_states = torch.tensor(self.hidden_states[idx], dtype=torch.float)

        return {
            KEY_LM_INPUT_IDS: input_ids,
            KEY_LM_LABELS: labels,
            KEY_LM_HIDDEN_STATES: hidden_states
        }

    def close(self):
        self.h5_file.close()


class ChunkedHDF5Dataset(HDF5Dataset):
    def __init__(self, h5_files_dir, split, chunk_size:int):
        super().__init__(h5_files_dir, split)
        self.chunk_size = chunk_size
    
    def __getitem__(self, idx):
        item = super().__getitem__(idx)
        hidden_states = item['hidden_states']
        # last
        # item['hidden_states'] = hidden_states[self.chunk_size-1::self.chunk_size, :]

        # mean pooling
        assert hidden_states.shape[0] % self.chunk_size == 0, "行数不能被 chunk_size 整除"

        # 重塑张量为 (num_chunks, chunk_size, num_features)
        hidden_states_reshaped = hidden_states.view(-1, self.chunk_size, hidden_states.shape[1])

        # 对每个 chunk 计算均值，axis=1 表示按第二个维度（即每个块的行）计算均值
        pooled_hidden_states = hidden_states_reshaped.mean(dim=1)

        # 结果的形状是 (256, 768)
        item['hidden_states'] = pooled_hidden_states
        
        
        return item


def get_chunked_h5dataloader(config_path, split):
    config = load_config(config_path=config_path)
    num_workers = 1  # Set num workers to 0 to enable debugging
    shuffle = split == 'train'
    dataset = ChunkedHDF5Dataset(config['h5_file_path'], split, chunk_size=config['chunk_size'])
    # import pdb; pdb.set_trace()
    dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=shuffle, num_workers=num_workers)
    # import pdb; pdb.set_trace()
    return dataloader

dataloader = get_chunked_h5dataloader('conf/data/norm_layer10.yaml', 'test')

In [2]:
import torch

def calculate_dataset_stats(dataloader, ratio=0.7):
    total_loss = 0
    # 遍历整个数据集并收集所有的 hidden_states
    for i, batch in enumerate(dataloader):
        hidden_states = batch[KEY_LM_HIDDEN_STATES]
        mean = hidden_states.mean()
        var = hidden_states.var()
        # 创建一个形状为 (256,) 的掩码，标记出比例为 (1-ratio) 的部分
        mask = torch.rand(hidden_states.size(1)) < ratio  # mask的大小是 (256,)
        
        # 扩展掩码到 (128, 256, 768) 形状
        mask = mask.unsqueeze(0).unsqueeze(2).expand(hidden_states.size(0), hidden_states.size(1), hidden_states.size(2))
        
        # 生成一个形状与 hidden_states 相同的高斯分布张量
        gaussian_fill = mean + torch.randn_like(hidden_states)*var  # 生成均值为0，方差为1的高斯分布
        
        # 将被掩盖为零的部分用高斯分布填充
        masked_hidden_states = hidden_states * mask.float() + gaussian_fill * (1 - mask.float())
        
        # 计算重构损失
        rec_loss = (masked_hidden_states - hidden_states).abs().mean()
        total_loss += rec_loss.item()

    total_loss /= len(dataloader)
    return total_loss

In [3]:
ratios = np.arange(0.6, 1.02, 0.02)  # 1.02 是为了确保包含 1.0

for ratio in ratios:
    losses = []  # 用于存储每个 ratio 下的 5 次损失值
    
    # 运行 5 次计算损失
    for _ in range(5):
        loss = calculate_dataset_stats(dataloader, ratio)
        losses.append(loss)
    
    # 计算损失的均值
    avg_loss = np.mean(losses)
    
    # 输出均值损失
    print(f'Average Loss for ratio {ratio:.2f}: {avg_loss:.5f}')

Average Loss for ratio 0.60: 0.05291
Average Loss for ratio 0.62: 0.05001
Average Loss for ratio 0.64: 0.04793
Average Loss for ratio 0.66: 0.04343
Average Loss for ratio 0.68: 0.03823
Average Loss for ratio 0.70: 0.03833
Average Loss for ratio 0.72: 0.03672
Average Loss for ratio 0.74: 0.03375
Average Loss for ratio 0.76: 0.03272
Average Loss for ratio 0.78: 0.02834
Average Loss for ratio 0.80: 0.02459
Average Loss for ratio 0.82: 0.02221
Average Loss for ratio 0.84: 0.01989
Average Loss for ratio 0.86: 0.01952
Average Loss for ratio 0.88: 0.01600
Average Loss for ratio 0.90: 0.01327
Average Loss for ratio 0.92: 0.00984
Average Loss for ratio 0.94: 0.00751
Average Loss for ratio 0.96: 0.00468
Average Loss for ratio 0.98: 0.00284
Average Loss for ratio 1.00: 0.00000


In [20]:
loss[1].shape

torch.Size([18, 256, 768])