In [None]:
import random

from hanabi_learning_environment.rl_env import Agent
import numpy as np
from enum import Enum, auto
import dataclasses


from typing import TypedDict, Literal, Union, List, TypeVar, Tuple, Callable, Iterable, Any, Dict, Iterator

rng: np.random.Generator = default_rng()

Color = Literal["B", "G", "R", "W", "Y"]
CardColor = Literal[None, Color]
Rank = Literal[0, 1, 2, 3, 4]
CardRank = Literal[-1, Rank]
ActionPD = Literal["PLAY", "DISCARD"]
ActionColor = Literal["REVEAL_COLOR"]
ActionRank = Literal["REVEAL_RANK"]

class BaseActionDict(TypedDict):
    pass

class ActionPDDict(BaseActionDict):
    action_type: ActionPD
    card_index: int

class BaseActionRevealDict(BaseActionDict):
    target_offset: int

class ActionColorDict(BaseActionRevealDict):
    action_type: ActionColor
    color: Color

class ActionRankDict(BaseActionRevealDict):
    action_type: ActionRank
    rank: Rank
    
def actplay(ind: int) -> ActionPDDict:
    """
    Helper method to create the appropriate dict for playing a card
    :param ind: index of card to play
    :return: an actionPDDict that plays that specified card
    """
    return {
        "action_type": "PLAY",
        "card_index": ind
    }

def actdiscard(ind: int) -> ActionPDDict:
    """
    Helper method for discarding
    :param ind: index of card being discarded
    :return: the appropriate actionPDDict
    """
    return {
        "action_type": "DISCARD",
        "card_index": ind
    }

def actcolor(offset: Union[int, "OtherPlayerData"], colour: Color) -> ActionColorDict:
    """
    Helper method for revealing colour
    :param offset: player to target
    :param colour: colour to reveal
    :return: the appropriate ActionColorDict
    """
    if isinstance(offset, OtherPlayerData):
        offset = offset.offset
    return {
        "action_type": "REVEAL_COLOR",
        "color": colour,
        "target_offset": offset
    }

class ActionRankDict(BaseActionRevealDict):
    action_type: ActionRank
    rank: Rank

def actrank(offset: Union[int, "OtherPlayerData"], rank: Rank) -> ActionRankDict:
    """
    Helper method for revealing rank
    :param offset: player to target
    :param rank: rank to reveal
    :return: the appropriate ActionRankDict
    """
    if isinstance(offset, OtherPlayerData):
        offset = offset.offset
    return {
        "action_type": "REVEAL_RANK",
        "rank": rank,
        "target_offset": offset
    }

ActionDict = Union[ActionPDDict, ActionColorDict, ActionRankDict]
ActionType = Literal[ActionPD, ActionColor, ActionRank]

class HandCard(TypedDict):
    color: CardColor
    rank: CardRank

OwnHand = List[HandCard]

class KnownCard(TypedDict):
    color: Color
    rank: Rank

KnownHand = List[KnownCard]

Card = Union[HandCard, KnownCard]

@dataclass(init=True, repr=True, eq=True, frozen=True)
class CardData:
    color: CardColor
    rank: CardRank

    @property
    def r_known(self) -> bool:
        return self.rank != -1

    @classmethod
    def make(cls, c: Card) -> "CardData":
        return cls(c["color"],c["rank"])

    @property
    def iter_higher(self) -> Iterator["CardData"]:
        "iterate through all of the cards higher than this card in same colour"
        return (CardData(self.color, rnk) for rnk in range(self.rank +1, 5))

    @property
    def is_known(self) -> bool:
        "returns true if all the data on this card is known"
        return self.color is not None and self.rank != -1

    def match_potential_other_known(self, crd: "CardData") -> bool:
        assert crd.is_known
        if self.is_known:
            return self == crd
        else:
            if self.color:
                return self.color == crd.color
            return self.rank == -1 or self.rank == crd.rank


    def potential_matches(self, other_cards: Iterable["CardData"]) -> Iterator["CardData"]:
        if self.is_known:
            return (crd for crd in other_cards if crd == self)
        return (crd for crd in other_cards if self.match_potential_other_known(crd))
    
UNKNOWN_CARD: CardData = CardData(None, -1)


def card_to_dc(c: Card) -> CardData:
    return CardData(c["color"], c["rank"])

def cardlist_to_dc(clist: Iterable[Card]) -> List[CardData]:
    return [CardData.make(c) for c in clist]

TCard = TypeVar("TCard", bound=Card)

class FireworksDict(TypedDict):
    B: int
    G: int
    R: int
    W: int
    Y: int

class ObservationDict(TypedDict):
    current_player: int
    current_player_offset: int
    deck_size: int
    discard_pile: List[KnownCard]
    fireworks: FireworksDict
    information_tokens: int
    legal_moves: List[ActionDict]
    life_tokens: int
    card_knowledge: List[OwnHand]
    observed_hands: List[Union[OwnHand, KnownHand]]
    num_players: int
    vectorized: List[Literal[0, 1]]
    
    
T = TypeVar("T")

def sany(it: Iterable[T]) -> Optional[T]:
    """
    Attempts to return an arbitrary item from an iterable.
    :param it: the iterable we want an arbitrary item from
    :return: the first item from that iterable, or null if it's empty
    """
    an_iter: Iterator[T] = it.__iter__()
    try:
        return an_iter.__next__()
    except StopIteration:
        return None
    
    
class RulesEnum(Enum):

    PLAY_MOST_PLAYABLE_CARD = 0
    PLAY_MOST_PLAYABLE_CARD_THRESHOLD_HIGH = 1
    PLAY_MOST_PLAYABLE_CARD_THRESHOLD_LOW = 2
    PLAY_MOST_DEFINITELY_PLAYABLE_CARD = 3
    TELL_PLAYER_ABOUT_PLAYABLE_CARD_RANK = 4
    TELL_PLAYER_ABOUT_PLAYABLE_CARD_COLOUR = 5
    TELL_NEXT_PLAYER_ABOUT_PLAYABLE_CARD_RANK = 6
    TELL_NEXT_PLAYER_ABOUT_PLAYABLE_CARD_COLOUR = 7
    TELL_PLAYER_ABOUT_UNPLAYABLE_RANK = 8
    TELL_PLAYER_ABOUT_UNPLAYABLE_COLOR = 9
    TELL_NEXT_PLAYER_ABOUT_UNPLAYABLE_RANK = 10
    TELL_NEXT_PLAYER_ABOUT_UNPLAYABLE_COLOR = 11
    TELL_PLAYER_WITH_MOST_PLAYABLE_CARDS_ABOUT_PLAYABLE_CARDS_RANKS = 12
    TELL_PLAYER_WITH_MOST_PLAYABLE_CARDS_ABOUT_PLAYABLE_CARDS_COLOURS = 13
    DISCARD_OLDEST_UNPLAYABLE_CARD = 14
    DISCARD_OLDEST_UNKNOWN_CARD = 15
    TELL_PLAYER_ABOUT_ONES = 16
    TELL_PLAYER_ABOUT_FIVES = 17
    TELL_PLAYER_ABOUT_MOST_COMMON_COLOR = 18
    TELL_PLAYER_ABOUT_LEAST_COMMON_COLOR = 19
    TELL_PLAYER_ABOUT_MOST_PLAYED_COLOR = 20
    TELL_PLAYER_ABOUT_LEAST_PLAYED_COLOR = 211
    TELL_PLAYER_WITH_MOST_USELESS_CARDS_ABOUT_USELESS_RANKS = 22
    TELL_PLAYER_WITH_MOST_USELESS_CARDS_ABOUT_USELESS_COLORS = 23
    


class RuleAgentChromosome(Agent):
    """Agent that applies a simple heuristic."""

    colors: Tuple[Color] = ('Y', 'B', 'W', 'R', 'G')
    ranks: Tuple[int] = (0, 1, 2, 3, 4)
    individual_hanabi_cards: Tuple[KnownCard] = tuple({"color": c, "rank": r} for c in colors for r in
                                         [0, 1, 2, 3, 4])
    full_hanabi_deck: Tuple[KnownCard] = tuple({"color": c, "rank": r} for c in colors for r in
                                         [0, 0, 0, 1, 1, 2, 2, 3, 3, 4])
    individual_cards_and_quantities: Dict[str, int] = dict(
        ("{},{}".format(c["color"],c["rank"]), 3 if c["rank"] == 0 else 1 if c["rank"] == 4 else 2) for c in individual_hanabi_cards
    )

    def __init__(self, config, chromosome: List[RulesEnum]=None, *args, **kwargs):
        # TODO replace this default chromosome with something better, if possible.  Plus, Add new bespoke rules below if necessary.
        """Initialize the agent."""
        self.config = config
        if chromosome is None:
            self.chromosome: List[RulesEnum] = generate_premade_rules_order()
        else:
            self.chromosome: List[RulesEnum] = chromosome
        assert isinstance(self.chromosome, list)
        
        # Extract max info tokens or set default to 8.
        self.max_information_tokens = config.get('information_tokens', 8)

    def calculate_all_unseen_cards(self, discard_pile: List[KnownCard], player_hands: List[List[KnownCard]], fireworks: FireworksDict) -> List[KnownCard]:
        # All of the cards which we can't see are either in our own hand or in the deck.
        # The other cards must be in the discard pile (all cards of which we have seen and remembered) or in other player's hands.
        assert len(RuleAgentChromosome.full_hanabi_deck)==50 # full hanabi deck size.

        result: List[KnownCard] = list(RuleAgentChromosome.full_hanabi_deck)
        # subract off all cards that have been discarded...
        for card in discard_pile:
            if card in result:
                result.remove(card)
        
        # subract off all cards that we can see in the other players' hands...
        for hand in player_hands[1:]:
            for card in hand:
                if card in result:
                    result.remove(card)

        for (color, height) in fireworks.items():
            for rank in range(height):
                card: KnownCard = {"color":color, "rank":rank}
                if card in result:
                    result.remove(card)

        # Now we left with only the cards we have never seen before in the game
        # (so these are the cards in the deck UNION our own hand).
        return result             

    def filter_card_list_by_hint(self, card_list: List[KnownCard], hint: Card) -> List[Card]:
        # This could be enhanced by using negative hint information,
        # available from observation['pyhanabi'].card_knowledge()[player_offset][card_number]
        filtered_card_list: List[Card] = card_list.copy()
        if hint["color"] is not None:
            filtered_card_list = [c for c in filtered_card_list if c["color"] == hint["color"]]
        if hint["rank"] is not None:
            filtered_card_list = [c for c in filtered_card_list if c["rank"] == hint["rank"]]
        return filtered_card_list

    def filter_card_list_by_playability(self, card_list: List[KnownCard], fireworks: FireworksDict) -> List[KnownCard]:
        # find out which cards in card list would fit exactly onto next value of its colour's firework
        return [c for c in card_list if self.is_card_playable(c,fireworks)]

    @classmethod
    def get_unplayables_from_discard_pile(cls, discard_pile: List[KnownCard]) -> List[KnownCard]:

        undiscarded_counts: Dict[str, int] = cls.individual_cards_and_quantities.copy()
        for d in discard_pile:
            undiscarded_counts["{},{}".format(d["color"],d["rank"])] -= 1

        discard_unplayables: List[KnownCard] = []
        for card in (nd[0] for nd in undiscarded_counts.items() if nd[1] == 0):
            c_list = card.split(",")
            try:
                c_rank: int = int(c_list[1])
                if {"color":c_list[0], "rank":c_rank} in discard_unplayables:
                    continue
                elif c_rank < 4:
                    current_h = c_rank + 1
                    while current_h <= 4:
                        # noinspection PyTypeChecker
                        discard_unplayables.append({"color":c_list[0], "rank":current_h})
                        current_h += 1
            except ValueError:
                pass
        return discard_unplayables


    def filter_card_list_by_unplayable(self, card_list: List[KnownCard], fireworks: FireworksDict, discard_unplayable: List[KnownCard]) -> List[KnownCard]:
        # find out which cards in card list are always going to be unplayable on its colour's firework
        # This function could be improved by considering that we know a card of value 5 will never be playable
        # if all the 4s for that colour have been discarded.
        return [c for c in card_list if c["rank"] < fireworks[c["color"]] and c not in discard_unplayable]



    def is_card_playable(self, card: KnownCard, fireworks: FireworksDict) -> bool:
        return card['rank'] == fireworks[card['color']]

    def act(self, observation: ObservationDict) -> Union[ActionDict, None]:
        # this function is called for every player on every turn
        """Act based on an observation."""
        if observation['current_player_offset'] != 0:
            # but only the player with offset 0 is allowed to make an action.  The other players are just observing.
            return None
        
        fireworks: FireworksDict = observation['fireworks']
        card_hints: OwnHand = observation['card_knowledge'][0] # This [0] produces the card hints for OUR own hand (player offset 0)
        hand_size=len(card_hints)

        discarded: List[KnownCard] = observation['discard_pile']

        discard_unplayables: List[KnownCard] = RuleAgentChromosome.get_unplayables_from_discard_pile(discarded)

        # build some useful lists of information about what we hold in our hand and what team-mates know about their hands.
        all_unseen_cards: List[KnownCard] = self.calculate_all_unseen_cards(
            discarded, observation['observed_hands'], fireworks
        )
        possible_cards_by_hand: List[List[Card]] = [self.filter_card_list_by_hint(all_unseen_cards, h) for h in card_hints]
        playable_cards_by_hand: List[List[KnownCard]] =[self.filter_card_list_by_playability(posscards, fireworks) for posscards in possible_cards_by_hand]
        probability_cards_playable: List[float] =[len(playable_cards_by_hand[index])/len(possible_cards_by_hand[index]) for index in range(hand_size)]
        useless_cards_by_hand: List[List[KnownCard]] = [self.filter_card_list_by_unplayable(posscards, fireworks, discard_unplayables) for posscards in possible_cards_by_hand]
        probability_cards_useless: List[float] =[len(useless_cards_by_hand[index])/len(possible_cards_by_hand[index]) for index in range(hand_size)]

        other_player_info = TypedDict("other_player_info", {"playable": List[KnownCard],"useless": List[KnownCard],
                                                            "unknown ranks": List[KnownCard], "unknown colors": List[KnownCard]})
        others_info: Dict[int, other_player_info] = {}
        for i in range(1, observation['num_players']):
            other_cards: List[KnownCard] = observation['observed_hands'][i]
            other_hand: List[HandCard] = observation['card_knowledge'][i]
            others_info[i] = {
                "playable": self.filter_card_list_by_playability(other_cards, fireworks),
                "useless": self.filter_card_list_by_unplayable(other_cards, fireworks, discard_unplayables),
                "unknown ranks": [other_cards[i] for i in range(len(other_cards)) if other_hand[i]["rank"] is None],
                "unknown colors": [other_cards[i] for i in range(len(other_cards)) if other_hand[i]["color"] is None]
            }

        my_unknown_cards: List[Card] = [c for c in card_hints if c["rank"] is None or c["color"] is None]
        my_known_cards: List[Card] = [c for c in card_hints if c["rank"] is not None or c["color"] is not None]
        my_known_ranks: List[Card] = [c for c in my_known_cards if c["rank"]  is not None]
        my_known_colors:List[Card] = [c for c in my_known_cards if c["color"] is not None]

        # based on the above calculations, try a sequence of rules in turn and perform the first one that is applicable:

        can_discard: bool = observation['information_tokens'] < self.max_information_tokens

        can_inform: bool = observation['information_tokens'] > 0

        # noinspection PyTypeChecker
        most_played_colours: List[Color] = [kv[0] for kv in fireworks.items() if kv[1] == max(fireworks.values())]
        # noinspection PyTypeChecker
        least_played_colors: List[Color] = [kv[0] for kv in fireworks.items() if kv[1] == min(fireworks.values())]

        for rule in self.chromosome:

            if rule == RulesEnum.PLAY_MOST_PLAYABLE_CARD:
                return {'action_type': 'PLAY', 'card_index': argmax(probability_cards_playable)}
            elif rule == RulesEnum.PLAY_MOST_DEFINITELY_PLAYABLE_CARD:
                if max(probability_cards_playable) == 1:
                    return {'action_type': 'PLAY', 'card_index': argmax(probability_cards_playable)}
            elif rule == RulesEnum.PLAY_MOST_PLAYABLE_CARD_THRESHOLD_HIGH:
                if max(probability_cards_playable) > 0.8:
                    return {'action_type': 'PLAY', 'card_index': argmax(probability_cards_playable)}
            elif rule == RulesEnum.PLAY_MOST_PLAYABLE_CARD_THRESHOLD_LOW:
                if max(probability_cards_playable) > 0.5:
                    return {'action_type': 'PLAY', 'card_index': argmax(probability_cards_playable)}

            if can_discard:
                if rule == RulesEnum.DISCARD_OLDEST_UNPLAYABLE_CARD:
                    # TODO: this.
                    if max(probability_cards_useless) > 0.5:
                        return {'action_type': 'DISCARD', 'card_index': argmax(probability_cards_useless)}
                elif rule == RulesEnum.DISCARD_OLDEST_UNKNOWN_CARD:
                    if len(my_unknown_cards) > 0:
                        return {'action_type': 'DISCARD', 'card_index': card_hints.index(my_unknown_cards[0])}

            if can_inform:
                if rule == RulesEnum.TELL_PLAYER_ABOUT_ONES:
                    for i in range(1, observation['num_players']):
                        if any((others_info[i]["unknown ranks"][c]["rank"] == 0 for c in range(len(others_info[i]["unknown ranks"])))):
                            return {
                                'action_type': 'REVEAL_RANK',
                                'rank': 0,
                                'target_offset': i
                            }
                elif rule == RulesEnum.TELL_PLAYER_ABOUT_FIVES:
                    for i in range(1, observation['num_players']):
                        if any((others_info[i]["unknown ranks"][c]["rank"] == 4 for c in range(len(others_info[i]["unknown ranks"])))):
                            return {
                                'action_type': 'REVEAL_RANK',
                                'rank': 4,
                                'target_offset': i
                            }
                elif rule == RulesEnum.TELL_PLAYER_ABOUT_PLAYABLE_CARD_COLOUR or rule == RulesEnum.TELL_NEXT_PLAYER_ABOUT_PLAYABLE_CARD_COLOUR:
                    for i in range(1, observation['num_players']):
                        for c in others_info[i]["unknown colors"]:
                            if self.is_card_playable(c, fireworks):
                                return {
                                    'action_type': 'REVEAL_COLOR',
                                    'color': c["color"],
                                    'target_offset': i
                                }
                        if rule == RulesEnum.TELL_NEXT_PLAYER_ABOUT_PLAYABLE_CARD_COLOUR:
                            break
                elif rule == RulesEnum.TELL_PLAYER_ABOUT_PLAYABLE_CARD_RANK or rule == RulesEnum.TELL_NEXT_PLAYER_ABOUT_PLAYABLE_CARD_RANK:
                    for i in range(1, observation['num_players']):
                        for c in others_info[i]["unknown ranks"]:
                            if self.is_card_playable(c, fireworks):
                                return {
                                    'action_type': 'REVEAL_RANK',
                                    'rank': c["rank"],
                                    'target_offset': i
                                }
                        if rule == RulesEnum.TELL_PLAYER_ABOUT_PLAYABLE_CARD_RANK:
                            break
                elif rule == RulesEnum.TELL_PLAYER_ABOUT_UNPLAYABLE_COLOR or rule == RulesEnum.TELL_NEXT_PLAYER_ABOUT_UNPLAYABLE_COLOR:
                    for i in range(1, observation['num_players']):
                        for c in others_info[i]["unknown colors"]:
                            if not self.is_card_playable(c, fireworks):
                                return {
                                    'action_type': 'REVEAL_COLOR',
                                    'color': c["color"],
                                    'target_offset': i
                                }
                        if rule == RulesEnum.TELL_NEXT_PLAYER_ABOUT_UNPLAYABLE_COLOR:
                            break
                elif rule == RulesEnum.TELL_PLAYER_ABOUT_UNPLAYABLE_RANK or rule == RulesEnum.TELL_NEXT_PLAYER_ABOUT_UNPLAYABLE_RANK:
                    for i in range(1, observation['num_players']):
                        for c in others_info[i]["unknown ranks"]:
                            if self.is_card_playable(c, fireworks):
                                return {
                                    'action_type': 'REVEAL_RANK',
                                    'rank': c["rank"],
                                    'target_offset': i
                                }
                        if rule == RulesEnum.TELL_NEXT_PLAYER_ABOUT_UNPLAYABLE_RANK:
                            break
                elif rule == RulesEnum.TELL_PLAYER_ABOUT_MOST_PLAYED_COLOR:
                    for i in range(1, observation['num_players']):
                        for c in others_info[i]["unknown colors"]:
                            if c["color"] in most_played_colours:
                                return {
                                    'action_type': 'REVEAL_COLOR',
                                    'color': c["color"],
                                    'target_offset': i
                                }
                elif rule == RulesEnum.TELL_PLAYER_ABOUT_LEAST_PLAYED_COLOR:
                    for i in range(1, observation['num_players']):
                        for c in others_info[i]["unknown colors"]:
                            if c["color"] in least_played_colors:
                                return {
                                    'action_type': 'REVEAL_COLOR',
                                    'color': c["color"],
                                    'target_offset': i
                                }
                elif rule == RulesEnum.TELL_PLAYER_ABOUT_MOST_COMMON_COLOR or rule == RulesEnum.TELL_PLAYER_ABOUT_LEAST_COMMON_COLOR:
                    for i in range(1, observation['num_players']):
                        if len(others_info[i]["unknown colors"]) == 0:
                            continue
                        else:
                            col_counts: Dict[Color, int] = {}
                            for c in others_info[i]["unknown colors"]:
                                if c["color"] not in col_counts.keys():
                                    col_counts[c["color"]] = 0
                                col_counts[c["color"]] += 1
                            if rule == RulesEnum.TELL_PLAYER_ABOUT_MOST_COMMON_COLOR:
                                return {
                                    'action_type': 'REVEAL_COLOR',
                                    'color': max(col_counts.items(), key=lambda kv:kv[1])[0],
                                    'target_offset': i
                                }
                            else:
                                return {
                                    'action_type': 'REVEAL_COLOR',
                                    'color': min(col_counts.items(), key=lambda kv: kv[1])[0],
                                    'target_offset': i
                                }
                elif rule == RulesEnum.TELL_PLAYER_WITH_MOST_USELESS_CARDS_ABOUT_USELESS_COLORS or rule == RulesEnum.TELL_PLAYER_WITH_MOST_USELESS_CARDS_ABOUT_USELESS_RANKS:
                    sorted_by_useless: List[Tuple[int, other_player_info]] = [*others_info.items()]
                    sorted_by_useless.sort(key=lambda kv: len(kv[1]["useless"]))
                    if rule == RulesEnum.TELL_PLAYER_WITH_MOST_USELESS_CARDS_ABOUT_USELESS_COLORS:
                        for kv in sorted_by_useless:
                            unknown_useless_c = tuple(c for c in kv[1]["useless"] if c in kv[1]["unknown colors"])
                            if len(unknown_useless_c) > 0:
                                return {
                                    'action_type': 'REVEAL_COLOR',
                                    'color': unknown_useless_c[0]["color"],
                                    'target_offset': kv[0]
                                }
                    else:
                        for kv in sorted_by_useless:
                            unknown_useless_r = tuple(c for c in kv[1]["useless"] if c in kv[1]["unknown ranks"])
                            if len(unknown_useless_r) > 0:
                                return {
                                    'action_type': 'REVEAL_RANK',
                                    'rank': unknown_useless_r[0]["rank"],
                                    'target_offset': kv[0]
                                }

        if observation['information_tokens'] < self.max_information_tokens:
            return {'action_type': 'DISCARD', 'card_index': argmax(probability_cards_useless)}
        else:
            return {'action_type': 'PLAY', 'card_index': argmax(probability_cards_playable)}
            # the chromosome contains an unknown rule
            #raise Exception("Rule not defined: "+str(rule))
        # The chromosome needs to be defined so the program never gets to here.  
        # E.g. always include rules 5 and 6 in the chromosome somewhere to ensure this never happens..        
        #raise Exception("No rule fired for game situation - faulty rule set")