In [5]:
import torch

a = torch.randn(4, 5)
print(a)
b = torch.argmax(a, dim=1)

print(b)
print(b.dtype)


tensor([[-2.2573,  0.5230, -1.5216, -0.4470, -0.4594],
        [ 0.0504,  0.3612, -0.8903, -0.6420,  1.7113],
        [ 0.6373,  1.6904,  0.9238,  0.9839, -1.2449],
        [ 0.1692,  0.8605,  0.5169,  1.6001, -1.4548]])
tensor([1, 4, 1, 3])
torch.int64


In [6]:
import torch
from torch.distributions import Transform, Uniform
from torch.distributions.transforms import AffineTransform
from torch.distributions.utils import broadcast_all

class TruncatedLaplace(torch.distributions.Distribution):
    def __init__(self, loc, scale, low, high, validate_args=None):
        """
        截断拉普拉斯分布的初始化

        参数:
        loc : tensor
            分布的位置参数（均值）
        scale : tensor
            分布的尺度参数
        low : tensor
            截断的下界
        high : tensor
            截断的上界
        validate_args : bool
            是否验证输入
        """
        super(TruncatedLaplace, self).__init__(validate_args=validate_args)
        self.loc, self.scale = broadcast_all(loc, scale)
        self.low, self.high = broadcast_all(low, high)
        if validate_args:
            assert torch.all(self.scale > 0)
            assert torch.all(self.low < self.high)

        # 计算标准化常数，用于后续的概率密度函数和抽样
        self._standardized_constant = 2 * (torch.exp(-torch.abs((self.low - self.loc) / self.scale)) 
                                           - torch.exp(-torch.abs((self.high - self.loc) / self.scale)))

    def cdf(self, value):
        """
        计算累积分布函数（CDF）
        """
        # 由于截断，CDF 在 low 和 high 之间是线性的
        return torch.clamp((value - self.low) / (self.high - self.low), 0, 1)

    def icdf(self, value):
        """
        计算逆累积分布函数（CDF）
        """
        # 由于截断，逆 CDF 在 [0, 1] 范围内是线性的
        return value * (self.high - self.low) + self.low

    def log_prob(self, value):
        """
        计算对数概率密度函数（log PDF）
        """
        if self._validate_args:
            self._validate_sample(value)
        return -torch.abs(value - self.loc) / self.scale - torch.log(self._standardized_constant / (2 * self.scale))

    def sample(self, sample_shape=torch.Size()):
        """
        从截断拉普拉斯分布中抽取样本
        """
        # 使用逆变换抽样方法
        with torch.no_grad():
            uniforms = torch.distributions.Uniform(self.low, self.high).sample(sample_shape)
            return uniforms

# 示例：创建截断拉普拉斯分布实例
loc = torch.tensor(0.0)
scale = torch.tensor(1.0)
low = torch.tensor(-2.0)
high = torch.tensor(2.0)
truncated_laplace = TruncatedLaplace(loc, scale, low, high)

# 抽样
samples = truncated_laplace.sample((1000,))

# 计算概率密度
prob = truncated_laplace.log_prob(samples)




In [12]:
import numpy as np
from scipy.stats import rv_continuous

class truncated_laplace_gen(rv_continuous):
    def __init__(self, loc=0, scale=1, lower=-np.inf, upper=np.inf):
        super().__init__(a=lower, b=upper)
        # self.x = x
        # self.p = p
        self.loc = loc, 
        self.scale = scale
        self.lower = lower
        self.upper = upper

    def _pdf(self, x):

        # no probability for generating value outside the bound
        if x <= self.lower or x >= self.upper:
            return 0.0

        # Standard Laplace PDF
        laplace_pdf = 1/(2*self.scale) * np.exp(-np.abs(x - self.loc)/self.scale)
        
        # Calculate the normalization constant (CDF of the upper bound - CDF of the lower bound)
        normalization_constant = np.exp(-np.abs(self.lower - self.loc)/self.scale) - np.exp(-np.abs(self.upper - self.loc) / self.scale)
        
        return laplace_pdf / normalization_constant

# Define parameters
loc = 0       # Mean of the Laplace distribution
scale = 1     # Scale parameter (similar to standard deviation)
lower = -2    # Lower truncation bound
upper = 2     # Upper truncation bound

# Create an instance of the distribution without shape defaults
truncated_laplace = truncated_laplace_gen(loc=loc, scale=scale, lower=lower, upper=upper)

# Generate random samples
samples = truncated_laplace.rvs(size=1000)

# Print some samples
print(samples[:10])


TypeError: unsupported operand type(s) for -: 'float' and 'tuple'