In [1]:
import torch

from torch.distributions import (
    Bernoulli, 
    Categorical, 
    Independent, 
    Multinomial, 
    Normal,
)

from longcapital.rl.utils.distributions import MultivariateHypergeometric
from longcapital.utils.constant import NEG_INF

In [2]:
class MyMultinomial(Multinomial):
    def entropy(self):
        return torch.zeros(self.batch_shape)
    

def test_dist(dist, return_sample=False):
    sample = dist.sample()
    log_prob = dist.log_prob(sample)
    entropy = dist.entropy()
    print(f"sample ({sample.shape}): {sample}")
    print(f"log_prob ({log_prob.shape}): {log_prob}")
    print(f"entropy ({entropy.shape}): {entropy}")
    if return_sample:
        return sample

In [3]:
batch_size = 4
stock_num = 5
topk = 2

In [4]:
# continuous.MetaPPO/MetaDDPG/MetaTD3/MetaSAC
# given a list of stocks, assign each stock a value for ranking (TopkDropoutStrategy) or weighting (WeightStrategy)

loc = torch.randn(batch_size, stock_num)
scale = torch.ones_like(loc)
dist = Independent(Normal(loc=loc, scale=scale), 1)
test_dist(dist)

sample (torch.Size([4, 5])): tensor([[-7.5870e-01,  8.8834e-01,  4.3071e-01,  1.5514e+00,  1.6033e+00],
        [ 7.1720e-01, -2.0396e-01,  5.1662e-01, -3.3107e+00, -2.9578e-02],
        [ 8.1710e-02, -6.7831e-01,  5.1750e-01, -8.1944e-01, -1.0100e+00],
        [-2.2577e+00, -2.9829e-03, -2.3051e+00, -1.1247e+00,  8.3585e-01]])
log_prob (torch.Size([4])): tensor([-7.4591, -8.3586, -5.7100, -6.3486])
entropy (torch.Size([4])): tensor([7.0947, 7.0947, 7.0947, 7.0947])


In [5]:
# discrete.PPO
# given the state and a set of params candidates (n_drop), choose one param for trading

n_drop_list = list(range(topk))
probs = torch.rand(batch_size, len(n_drop_list))
dist = Categorical(probs)
test_dist(dist)

sample (torch.Size([4])): tensor([1, 1, 0, 0])
log_prob (torch.Size([4])): tensor([-0.3730, -0.6471, -0.1849, -0.5225])
entropy (torch.Size([4])): tensor([0.6202, 0.6920, 0.4540, 0.6757])


In [6]:
# discrete.MultiBinaryMetaPPO
# given a list of stocks, for each stock decide whether buy or not (each stock is independent)

probs = torch.rand(batch_size, stock_num)
dist = Independent(Bernoulli(probs), 1)
test_dist(dist)

sample (torch.Size([4, 5])): tensor([[0., 0., 0., 1., 1.],
        [0., 1., 0., 1., 0.],
        [1., 1., 0., 1., 0.],
        [1., 1., 1., 1., 1.]])
log_prob (torch.Size([4])): tensor([-2.9593, -2.5406, -2.2767, -2.0295])
entropy (torch.Size([4])): tensor([2.9821, 2.6975, 3.0601, 2.4700])


In [7]:
# discrete.StepByStepMetaPPO
# given a list of stocks, repeat topk times to select stock one by one to buy WITHOUT replacement
# if the state is not changed, this is equivalent to TopkMetaPPO which selects topk all in once

logits = torch.randn(batch_size, stock_num)
for i in range(topk):
    dist = Categorical(logits=logits)
    sample = test_dist(dist, return_sample=True)
    # mask out selected ones
    logits.scatter_(1, sample.unsqueeze(1), NEG_INF)

sample (torch.Size([4])): tensor([1, 4, 4, 1])
log_prob (torch.Size([4])): tensor([-0.9068, -2.2671, -0.7272, -2.6500])
entropy (torch.Size([4])): tensor([1.4198, 1.3411, 1.2733, 1.3611])
sample (torch.Size([4])): tensor([2, 0, 3, 0])
log_prob (torch.Size([4])): tensor([-0.7142, -1.6009, -0.6231, -0.9216])
entropy (torch.Size([4])): tensor([1.2501, 1.1246, 1.1237, 1.1899])


In [8]:
# discrete.TopkMetaPPO
# given a list of stocks, repeat topk times to select which stock to buy WITHOUT replacement

logits = torch.randn(batch_size, stock_num)
dist = MultivariateHypergeometric(logits=logits, topk=topk)
test_dist(dist)

sample (torch.Size([4, 2])): tensor([[0, 4],
        [4, 3],
        [4, 2],
        [4, 0]])
log_prob (torch.Size([4])): tensor([-3.1807, -2.2466, -2.7735, -1.7747])
entropy (torch.Size([4])): tensor([1.4737, 1.3995, 1.5030, 1.2700])


In [9]:
# discrete.WeightMetaPPO
# given a list of stocks, repeat topk times to select which stock to buy WITH replacement
# or given a budget (e.g., total_count=topk), chose one stock to buy each time

logits = torch.randn(batch_size, stock_num)
dist = MyMultinomial(logits=logits, total_count=topk)
test_dist(dist)

sample (torch.Size([4, 5])): tensor([[0., 1., 0., 0., 1.],
        [0., 2., 0., 0., 0.],
        [1., 0., 0., 1., 0.],
        [0., 1., 0., 0., 1.]])
log_prob (torch.Size([4])): tensor([-2.2070, -1.5353, -2.2308, -1.2772])
entropy (torch.Size([4])): tensor([0., 0., 0., 0.])
