In [1]:
import torch

from torch.distributions import Bernoulli, Categorical, Independent, Multinomial, Normal
from longcapital.rl.utils.distributions import MultivariateHypergeometric
from longcapital.utils.constant import NEG_INF

In [2]:
class MyMultinomial(Multinomial):
    def entropy(self):
        return torch.zeros(self.batch_shape)
    

def test_dist(dist):
    sample = dist.sample()
    log_prob = dist.log_prob(sample)
    entropy = dist.entropy()
    print(f"sample ({sample.shape}): {sample}")
    print(f"log_prob ({log_prob.shape}): {log_prob}")
    print(f"entropy ({entropy.shape}): {entropy}")

In [3]:
batch_size = 4
stock_num = 5
topk = 2

In [4]:
# continuous.MetaPPO/MetaDDPG/MetaTD3/MetaSAC
# given a list of stocks, assign each stock a value for ranking (TopkDropoutStrategy) or weighting (WeightStrategy)

loc = torch.randn(batch_size, stock_num)
scale = torch.ones_like(loc)
dist = Independent(Normal(loc=loc, scale=scale), 1)
test_dist(dist)

sample (torch.Size([4, 5])): tensor([[-1.1518,  1.8908,  2.3198, -0.6100, -0.4247],
        [-1.3236, -1.3905, -1.6903,  1.7480, -1.2054],
        [-0.7742,  0.8872,  2.0958,  0.5362,  0.1253],
        [ 0.2307,  1.7311, -0.2698, -1.6820, -0.1353]])
log_prob (torch.Size([4])): tensor([-6.5114, -8.0955, -6.7380, -7.8099])
entropy (torch.Size([4])): tensor([7.0947, 7.0947, 7.0947, 7.0947])


In [5]:
# discrete.PPO
# given the state and a set of params candidates (n_drop), choose one param for trading

n_drop_list = list(range(topk))
probs = torch.rand(batch_size, len(n_drop_list))
dist = Categorical(probs)
test_dist(dist)

sample (torch.Size([4])): tensor([0, 1, 0, 1])
log_prob (torch.Size([4])): tensor([-0.7213, -0.3690, -0.1278, -0.6006])
entropy (torch.Size([4])): tensor([0.6928, 0.6180, 0.3668, 0.6884])


In [6]:
# discrete.MultiBinaryMetaPPO
# given a list of stocks, for each stock decide whether buy or not (each stock is independent)

probs = torch.rand(batch_size, stock_num)
dist = Independent(Bernoulli(probs), 1)
test_dist(dist)

sample (torch.Size([4, 5])): tensor([[0., 1., 0., 1., 1.],
        [1., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 1., 1., 0., 0.]])
log_prob (torch.Size([4])): tensor([-2.9925, -2.0541, -2.1748, -3.8611])
entropy (torch.Size([4])): tensor([2.4892, 2.0643, 2.7707, 2.0659])


In [7]:
# discrete.StepByStepMetaPPO
# given a list of stocks, repeat topk times to select stock one by one to buy WITHOUT replacement
# if the state is not changed, this is equivalent to TopkMetaPPO

logits = torch.randn(batch_size, stock_num)
selected = torch.zeros_like(logits)
for i in range(topk):
    logits = (1 - selected) * logits + selected * NEG_INF
    dist = Categorical(logits=logits)
    test_dist(dist)
    selected.scatter_(1, logits.argmax(-1).unsqueeze(1), 1)

sample (torch.Size([4])): tensor([1, 2, 3, 0])
log_prob (torch.Size([4])): tensor([-1.2132, -1.6683, -2.2203, -1.3199])
entropy (torch.Size([4])): tensor([1.4314, 1.4539, 1.3485, 1.1749])
sample (torch.Size([4])): tensor([4, 4, 3, 0])
log_prob (torch.Size([4])): tensor([-1.0265, -1.8034, -1.5004, -0.5820])
entropy (torch.Size([4])): tensor([1.1869, 1.2424, 1.3469, 1.0097])


In [8]:
# discrete.TopkMetaPPO
# given a list of stocks, repeat topk times to select which stock to buy WITHOUT replacement

logits = torch.randn(batch_size, stock_num)
dist = MultivariateHypergeometric(logits=logits, topk=topk)
test_dist(dist)

sample (torch.Size([4, 2])): tensor([[0, 2],
        [2, 1],
        [4, 2],
        [2, 0]])
log_prob (torch.Size([4])): tensor([-2.6579, -2.7395, -2.1988, -1.9295])
entropy (torch.Size([4])): tensor([1.5391, 1.4907, 1.4681, 1.3701])


In [9]:
# discrete.WeightMetaPPO
# given a list of stocks, repeat topk times to select which stock to buy WITH replacement
# or given a budget (e.g., total_count=topk), chose one stock to buy each time

logits = torch.randn(batch_size, stock_num)
dist = MyMultinomial(logits=logits, total_count=topk)
test_dist(dist)

sample (torch.Size([4, 5])): tensor([[0., 0., 0., 1., 1.],
        [0., 0., 0., 2., 0.],
        [0., 2., 0., 0., 0.],
        [0., 0., 0., 2., 0.]])
log_prob (torch.Size([4])): tensor([-2.0916, -1.4665, -1.1537, -2.4307])
entropy (torch.Size([4])): tensor([0., 0., 0., 0.])
