In [1]:
import torch

class TemperatureSampler:
    '''Sampler with temperature, without using other library functions'''
    
    def __init__(self, temperature: float = 1.0):
        self.temperature = temperature

    def __call__(self, logits: torch.Tensor):
        # 调整 logits 加入温度
        logits /= self.temperature
        
        # 通过 softmax 将 logits 转换为概率分布
        probs = torch.softmax(logits, dim=-1)
        
        # 从概率分布中采样（返回一个索引，表示选择的类别）
        sample = torch.multinomial(probs, 1)
        
        return sample.squeeze()  # 返回标量索引
    

In [7]:
logits = torch.tensor([1.0, 2.0, 1.5, 0.5, 0.8, 2.5, 1.2, 1.3, 0.9, 1.8])

sampler = TemperatureSampler(temperature=1.0)  # 采样
sampler_idx = sampler(logits)
print(f"sampled index at {sampler_idx.item()}")

sampled index at 5
