In [12]:
import numpy as np
from numpy.linalg import matrix_power
from scipy.stats import rv_discrete
from scipy.special import factorial
from dataclasses import dataclass
from functools import cache

leaf_data = {
    'plasma': {
        'rarity': 21,
        'craft_level': 13,
        'craft_bonus': 896,
        'prop_count': 7
    },
    'hematite': {
        'rarity': 20,
        'craft_level': 13,
        'craft_bonus': 768,
        'prop_count': 7
    },
    'malachite': {
        'rarity': 19,
        'craft_level': 12,
        'craft_bonus': 640,
        'prop_count': 7
    },
    'biotite': {
        'rarity': 18,
        'craft_level': 11,
        'craft_bonus': 512,
        'prop_count': 7
    },
    'sacred': {
        'rarity': 17,
        'craft_level': 10,
        'craft_bonus': 384,
        'prop_count': 7
    },
    'ancient': {
        'rarity': 16,
        'craft_level': 9,
        'craft_bonus': 256,
        'prop_count': 7
    },
    'sand': {
        'rarity': 15,
        'craft_level': 9,
        'craft_bonus': 164,
        'prop_count': 6
    },
    'moonstone': {
        'rarity': 14,
        'craft_level': 8,
        'craft_bonus': 132,
        'prop_count': 6
    },
    'benitoite': {
        'rarity': 13,
        'craft_level': 7,
        'craft_bonus': 112,
        'prop_count': 6
    },
    'silicon': {
        'rarity': 12,
        'craft_level': 6,
        'craft_bonus': 96,
        'prop_count': 6
    },
    'obsidian': {
        'rarity': 11,
        'craft_level': 5,
        'craft_bonus': 80,
        'prop_count': 6
    },
    'ice': {
        'rarity': 10,
        'craft_level': 5,
        'craft_bonus': 64,
        'prop_count': 5
    },
    'lava': {
        'rarity': 9,
        'craft_level': 4,
        'craft_bonus': 48,
        'prop_count': 5
    },
    'mythical': {
        'rarity': 8,
        'craft_level': 4,
        'craft_bonus': 32,
        'prop_count': 5
    },
    'celestial': {
        'rarity': 7,
        'craft_level': 3,
        'craft_bonus': 20,
        'prop_count': 5
    },
    'exotic': {
        'rarity': 6,
        'craft_level': 3,
        'craft_bonus': 16,
        'prop_count': 4
    },
    'void': {
        'rarity': 5,
        'craft_level': 3,
        'craft_bonus': 12,
        'prop_count': 4
    },
    'cosmic': {
        'rarity': 4,
        'craft_level': 2,
        'craft_bonus': 8,
        'prop_count': 3
    },
    'bismuth': {
        'rarity': 3,
        'craft_level': 1,
        'craft_bonus': 6,
        'prop_count': 3
    },
    'platinum': {
        'rarity': 2,
        'craft_level': 1,
        'craft_bonus': 4,
        'prop_count': 2
    },
    'gold': {
        'rarity': 1,
        'craft_level': 1,
        'craft_bonus': 2,
        'prop_count': 2
    },
    'basic': {
        'rarity': 0,
        'craft_level': 0,
        'craft_bonus': 1,
        'prop_count': 1
    },
}

min_craft_levels = [0,8,10,0,3,1,8,1,0,9,9,0,7,12,5,10,10,11,11,11,11,5,7,0,2,10,0,4,10,12,3,8,9,1,2,9,9]
craft_levels = list(range(13 + 1))
possible_properties_per_craft_level = { level: len([cl for cl in min_craft_levels if cl <= level]) for level in craft_levels }

def probability_of_transformation_success(leaf_type, number_wanted_properties):
    return number_wanted_properties / (possible_properties_per_craft_level[leaf_data[leaf_type]['craft_level']] - leaf_data[leaf_type]['prop_count'] + 1)

def transition_matrix(transformations):
    all_transitions = [
        [leaf, number_wanted_properties]
        for count, leaf, num_total_props
        in transformations
        for i in range(count)
        for number_wanted_properties in range(num_total_props, 0, -1)
    ]
    
    transition_matrix = np.zeros((len(all_transitions) + 1, len(all_transitions) + 1), dtype=float)
    initial_state = np.zeros(len(all_transitions) + 1, dtype=float)
    
    initial_state[0] = 1
    transition_matrix[-1,-1] = 1
    
    for i, (leaf, number_wanted_properties) in enumerate(all_transitions):
        p = probability_of_transformation_success(leaf, number_wanted_properties)
        transition_matrix[i, i] = 1 - p
        transition_matrix[i, i + 1] = p
    
    return Args(initial_state=initial_state, transition_matrix=transition_matrix)

@dataclass
class Args:
    initial_state: np.array
    transition_matrix: np.array
    
    def item(self):
        return self

@cache    
def stirling2(n, k):
    if n == 0 and k == 0:
        return 1
    
    if n == 0 or k == 0:
        return 0
    
    return k * stirling2(n - 1, k) + stirling2(n - 1, k - 1)
    
class DiscretePhaseTypeDistribution(rv_discrete):
    "Number of transformations distribution"
    
    def _argcheck(self, args):
        return True
    
    def _pmf(self, n, args):
        def impl(n, args):
            return args.initial_state[:-1] @ matrix_power(args.transition_matrix[:-1, :-1], int(n)-1) @ args.transition_matrix[:-1, -1:]
        
        return np.vectorize(impl)(n, args)
    
    def _cdf(self, n, args):
        return np.vectorize(lambda n, args: 1 - np.sum(args.initial_state[:-1] @ matrix_power(args.transition_matrix[:-1, :-1], int(n))))(n, args)
    
    def _get_support(self, args):
        return np.vectorize(lambda args: (len(args.initial_state) - 1, np.inf))(args)
    
    def factorial_moment(self, n, args):
        return np.sum(factorial(n) * args.initial_state[:-1] @ matrix_power(np.eye(args.transition_matrix[:-1, :-1].shape[0]) - args.transition_matrix[:-1, :-1], -n) @ matrix_power(args.transition_matrix[:-1, :-1], n - 1))
    
    def _munp(self, n, args):
        def impl(n, args):
            total = 0
            for i in range(0, n + 1):
                total += stirling2(n, i) * self.factorial_moment(i, args)
            return total
        return np.vectorize(impl)(n, args)
    
dist = DiscretePhaseTypeDistribution()


def transform_shard_cost(leaf_type):
    return 1 + (leaf_data[leaf_type]['rarity'] - 15) * 2

print(dist.ppf(0.7, transition_matrix([[16, 'hematite', 2]])) * transform_shard_cost('hematite'))
print(dist.mean(transition_matrix([[16, 'hematite', 2]])) * transform_shard_cost('hematite'))
print(dist.ppf(0.5, transition_matrix([[16, 'hematite', 2]])) * transform_shard_cost('hematite'))
print(dist.ppf(0.99, transition_matrix([[16, 'hematite', 2]])) * transform_shard_cost('hematite'))
print(dist.ppf(0.01, transition_matrix([[16, 'hematite', 2]])) * transform_shard_cost('hematite'))

8888.0
8184.00000000012
8085.0
12100.0
5148.0
