In [1]:
import torch

from torch.distributions import (
    Bernoulli, 
    Categorical, 
    Independent, 
    Multinomial, 
    Normal,
)

from longcapital.rl.utils.distributions import MultivariateHypergeometric
from longcapital.utils.constant import NEG_INF

In [2]:
class MyMultinomial(Multinomial):
    def entropy(self):
        return torch.zeros(self.batch_shape)
    

def test_dist(dist):
    sample = dist.sample()
    log_prob = dist.log_prob(sample)
    entropy = dist.entropy()
    print(f"sample ({sample.shape}): {sample}")
    print(f"log_prob ({log_prob.shape}): {log_prob}")
    print(f"entropy ({entropy.shape}): {entropy}")

In [3]:
batch_size = 4
stock_num = 5
topk = 2

In [4]:
# continuous.MetaPPO/MetaDDPG/MetaTD3/MetaSAC
# given a list of stocks, assign each stock a value for ranking (TopkDropoutStrategy) or weighting (WeightStrategy)

loc = torch.randn(batch_size, stock_num)
scale = torch.ones_like(loc)
dist = Independent(Normal(loc=loc, scale=scale), 1)
test_dist(dist)

sample (torch.Size([4, 5])): tensor([[ 0.7195, -2.0932,  0.3654,  0.5980, -0.6570],
        [-0.9277,  2.1504, -1.9409, -4.8942, -0.4706],
        [ 1.8688,  2.6778, -0.0615, -1.4287, -0.8648],
        [ 2.1370, -0.8389,  1.8719,  0.7652, -2.9274]])
log_prob (torch.Size([4])): tensor([-6.2783, -9.4490, -7.3731, -7.8984])
entropy (torch.Size([4])): tensor([7.0947, 7.0947, 7.0947, 7.0947])


In [5]:
# discrete.PPO
# given the state and a set of params candidates (n_drop), choose one param for trading

n_drop_list = list(range(topk))
probs = torch.rand(batch_size, len(n_drop_list))
dist = Categorical(probs)
test_dist(dist)

sample (torch.Size([4])): tensor([1, 1, 1, 0])
log_prob (torch.Size([4])): tensor([-0.4659, -0.6305, -0.3924, -0.1561])
entropy (torch.Size([4])): tensor([0.6602, 0.6911, 0.6303, 0.4130])


In [6]:
# discrete.MultiBinaryMetaPPO
# given a list of stocks, for each stock decide whether buy or not (each stock is independent)

probs = torch.rand(batch_size, stock_num)
dist = Independent(Bernoulli(probs), 1)
test_dist(dist)

sample (torch.Size([4, 5])): tensor([[1., 1., 0., 1., 0.],
        [0., 1., 1., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 1., 0., 1., 0.]])
log_prob (torch.Size([4])): tensor([-1.3788, -1.7273, -3.1913, -2.8856])
entropy (torch.Size([4])): tensor([2.4236, 2.3921, 2.7697, 2.4109])


In [7]:
# discrete.StepByStepMetaPPO
# given a list of stocks, repeat topk times to select stock one by one to buy WITHOUT replacement
# if the state is not changed, this is equivalent to TopkMetaPPO

logits = torch.randn(batch_size, stock_num)
selected = torch.zeros_like(logits)
for i in range(topk):
    logits = (1 - selected) * logits + selected * NEG_INF
    dist = Categorical(logits=logits)
    test_dist(dist)
    selected.scatter_(1, logits.argmax(-1).unsqueeze(1), 1)

sample (torch.Size([4])): tensor([4, 4, 4, 3])
log_prob (torch.Size([4])): tensor([-1.0513, -0.9424, -2.2990, -0.4879])
entropy (torch.Size([4])): tensor([1.4960, 1.3942, 0.9058, 1.0789])
sample (torch.Size([4])): tensor([3, 3, 4, 2])
log_prob (torch.Size([4])): tensor([-1.2981, -0.9102, -0.9513, -0.5726])
entropy (torch.Size([4])): tensor([1.3049, 1.1890, 1.2812, 1.0670])


In [8]:
# discrete.TopkMetaPPO
# given a list of stocks, repeat topk times to select which stock to buy WITHOUT replacement

logits = torch.randn(batch_size, stock_num)
dist = MultivariateHypergeometric(logits=logits, topk=topk)
test_dist(dist)

sample (torch.Size([4, 2])): tensor([[3, 1],
        [2, 1],
        [0, 4],
        [0, 2]])
log_prob (torch.Size([4])): tensor([-3.9768, -2.6084, -2.0306, -1.5059])
entropy (torch.Size([4])): tensor([1.3623, 1.2967, 1.1532, 1.2774])


In [9]:
# discrete.WeightMetaPPO
# given a list of stocks, repeat topk times to select which stock to buy WITH replacement
# or given a budget (e.g., total_count=topk), chose one stock to buy each time

logits = torch.randn(batch_size, stock_num)
dist = MyMultinomial(logits=logits, total_count=topk)
test_dist(dist)

sample (torch.Size([4, 5])): tensor([[0., 1., 0., 0., 1.],
        [0., 0., 2., 0., 0.],
        [1., 0., 1., 0., 0.],
        [0., 0., 0., 1., 1.]])
log_prob (torch.Size([4])): tensor([-1.8317, -1.0850, -1.5075, -3.2836])
entropy (torch.Size([4])): tensor([0., 0., 0., 0.])
