In [1]:
import torch

from torch.distributions import (
    Bernoulli, 
    Categorical, 
    Independent, 
    Multinomial, 
    Normal,
)

from longcapital.rl.utils.distributions import MultivariateHypergeometric
from longcapital.utils.constant import NEG_INF

In [2]:
class MyMultinomial(Multinomial):
    def entropy(self):
        return torch.zeros(self.batch_shape)
    

def test_dist(dist, return_sample=False):
    sample = dist.sample()
    log_prob = dist.log_prob(sample)
    entropy = dist.entropy()
    print(f"sample ({sample.shape}): {sample}")
    print(f"log_prob ({log_prob.shape}): {log_prob}")
    print(f"entropy ({entropy.shape}): {entropy}")
    if return_sample:
        return sample

In [3]:
batch_size = 4
stock_num = 5
topk = 2

In [4]:
# continuous.MetaPPO/MetaDDPG/MetaTD3/MetaSAC
# given a list of stocks, assign each stock a value for ranking (TopkDropoutStrategy) or weighting (WeightStrategy)

loc = torch.randn(batch_size, stock_num)
scale = torch.ones_like(loc)
dist = Independent(Normal(loc=loc, scale=scale), 1)
test_dist(dist)

sample (torch.Size([4, 5])): tensor([[-0.6264,  1.0204, -0.3040,  1.1052, -1.6539],
        [-1.1005,  1.3430,  0.2871, -0.8482, -1.9938],
        [-0.5263, -1.9366, -2.1826,  1.7727,  1.9515],
        [-0.4482, -1.4156, -0.3358, -2.0101, -0.0604]])
log_prob (torch.Size([4])): tensor([-7.4220, -6.4291, -6.7273, -8.9786])
entropy (torch.Size([4])): tensor([7.0947, 7.0947, 7.0947, 7.0947])


In [5]:
# discrete.PPO
# given the state and a set of params candidates (n_drop), choose one param for trading

n_drop_list = list(range(topk))
probs = torch.rand(batch_size, len(n_drop_list))
dist = Categorical(probs)
test_dist(dist)

sample (torch.Size([4])): tensor([0, 0, 0, 0])
log_prob (torch.Size([4])): tensor([-0.2423, -0.8840, -0.5941, -0.4371])
entropy (torch.Size([4])): tensor([0.5207, 0.6780, 0.6877, 0.6500])


In [6]:
# discrete.MultiBinaryMetaPPO
# given a list of stocks, for each stock decide whether buy or not (each stock is independent)

probs = torch.rand(batch_size, stock_num)
dist = Independent(Bernoulli(probs), 1)
test_dist(dist)

sample (torch.Size([4, 5])): tensor([[1., 1., 1., 0., 1.],
        [1., 0., 0., 0., 0.],
        [0., 1., 1., 0., 0.],
        [0., 1., 1., 1., 1.]])
log_prob (torch.Size([4])): tensor([-3.0781, -3.9245, -2.2733, -3.1525])
entropy (torch.Size([4])): tensor([2.7842, 2.1866, 2.9176, 2.8146])


In [7]:
# discrete.StepByStepMetaPPO
# given a list of stocks, repeat topk times to select stock one by one to buy WITHOUT replacement
# if the state is not changed, this is equivalent to TopkMetaPPO

logits = torch.randn(batch_size, stock_num)
for i in range(topk):
    dist = Categorical(logits=logits)
    sample = test_dist(dist, return_sample=True)
    # mask out selected ones
    logits.scatter_(1, sample.unsqueeze(1), NEG_INF)

sample (torch.Size([4])): tensor([3, 2, 2, 1])
log_prob (torch.Size([4])): tensor([-0.2373, -0.3380, -0.1331, -1.1716])
entropy (torch.Size([4])): tensor([0.6888, 0.9824, 0.5290, 1.5170])
sample (torch.Size([4])): tensor([1, 0, 0, 4])
log_prob (torch.Size([4])): tensor([-2.5913, -1.4515, -1.3618, -0.8278])
entropy (torch.Size([4])): tensor([0.8199, 1.3361, 1.2264, 1.3012])


In [8]:
# discrete.TopkMetaPPO
# given a list of stocks, repeat topk times to select which stock to buy WITHOUT replacement

logits = torch.randn(batch_size, stock_num)
dist = MultivariateHypergeometric(logits=logits, topk=topk)
test_dist(dist)

sample (torch.Size([4, 2])): tensor([[4, 1],
        [3, 4],
        [4, 3],
        [4, 3]])
log_prob (torch.Size([4])): tensor([-1.1390, -4.6722, -3.4064, -1.6144])
entropy (torch.Size([4])): tensor([0.6774, 1.1958, 1.5731, 1.3330])


In [9]:
# discrete.WeightMetaPPO
# given a list of stocks, repeat topk times to select which stock to buy WITH replacement
# or given a budget (e.g., total_count=topk), chose one stock to buy each time

logits = torch.randn(batch_size, stock_num)
dist = MyMultinomial(logits=logits, total_count=topk)
test_dist(dist)

sample (torch.Size([4, 5])): tensor([[0., 1., 0., 0., 1.],
        [0., 0., 0., 2., 0.],
        [0., 0., 1., 0., 1.],
        [0., 0., 1., 1., 0.]])
log_prob (torch.Size([4])): tensor([-1.1677, -2.3930, -2.0860, -1.9851])
entropy (torch.Size([4])): tensor([0., 0., 0., 0.])
