In [None]:
# -*- coding: utf-8 -*-
"""Copy of Untitled0.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1UBuX-QMznTW6z556hJGx0nqgG-twaUJa

下面的python代码定义了牌的花色、牌值和牌型（完整代码见github）：
"""

import numpy as np
from enum import Enum

SUITS = Enum("suit", {"clubs": "♣", "diamonds": "♦", "hearts": "♥", "spades": "♠"})
RANKS = Enum("rank", {"2": 0, "3": 1, "4": 2, "5": 3, "6": 4, "7": 5, "8": 6, "9": 7, "10":8,
                      "J":9, "Q":10, "K":11, "A":12})
HANDS = Enum("hand", {"straight flush": 0, "four of a kind": 1, "full house": 2, "flush": 3,
                      "three of a kind": 4, "straight": 5, "two pair": 6, "one pair": 7,
                      "high card": 8})
HAND_NAMES = [r for r in HANDS] # list of rank enums to support value-to-name lookup
RANK_NAMES = [r for r in RANKS]

"""下面定义牌的类（Card）和生成一手牌的的工具函数（make_card，make_hand）等。"""

class Card:
  def __init__(self, rank, suit):
    self.rank, self.suit = rank, suit

  def __repr__(self):
    return "{}{}".format(self.rank.name, self.suit.value)

class Deck:
  def __init__(self):
    self.cards_ = [Card(rank, suit) for rank in RANKS for suit in SUITS]

  def draw(self):
    return self.cards_.pop(np.random.randint(len(self.cards_))) # 随机数生成

  def __len__(self):
    return len(self.cards_) # 一段字符串里有多少个字符，或者它站多少个字节 len()


def make_card(shorthand):
  assert len(shorthand) == 2 or len(shorthand) == 3
  suit_dict = {"C": SUITS.clubs, "D": SUITS.diamonds, "H": SUITS.hearts, "S": SUITS.spades}
  return  Card(RANKS[shorthand[:-1].upper()], suit_dict[shorthand[-1].upper()]) #upper()转大写，

def make_hand(shorthand):
  # 如果要判断两个类型是否相同推荐使用 isinstance()
  if isinstance(shorthand, str): # isinstance() 函数来判断一个对象是否是一个已知的类型，类似 type()
    shorthand = shorthand.split(" ")
    # 在 Python 中，assert 语句用于断言某个条件是真的。如果条件为 False，则会触发一个 AssertionError。
    assert len(shorthand) == 5
    return [make_card(card_shorthand) for card_shorthand in shorthand]

"""有了这些工具，我们就可以很容易写出一个基于规则的牌型分类函数[1]，作为训练机器学习模型的“裁判”。"""

def is_in_pair(hand):
  return [len([c2 for c2 in hand if c2.rank == c1.rank]) == 2 for c1 in hand]

def is_in_toak(hand):
  return [len([c2 for c2 in hand if c2.rank == c1.rank]) == 3 for c1 in hand]

def is_in_foak(hand):
  return [len([c2 for c2 in hand if c2.rank == c1.rank]) == 4 for c1 in hand]

def next_rank(rank):
  return RANK_NAMES[rank.value + 1] if rank != RANKS.A else None

def next_n_rank(rank, n):
  for i in range(n):
    rank = next_rank(rank) if rank is not None else None
  return rank

def is_start_of_straight(hand):  # note that this does not consider valid straights like "5432A".
# all() 函数的作用是检查可迭代对象中所有的元素是否都为真（即每个元素在布尔上下文中都为True）。如果所有元素都为真，则all() 返回True；否则，返回False。
  return [all([len([c2 for c2 in hand if c2.rank == next_n_rank(c1.rank, n)]) > 0 for n in range(1, 5)]) for c1 in hand]

def is_in_flush(hand):
  return [len([c2 for c2 in hand if c2.suit == c1.suit]) == 5 for c1 in hand]

def classify_hand(hand):
  is_flush = all(is_in_flush(hand))
  is_straight = any(is_start_of_straight(hand))
  is_foak = any(is_in_foak(hand))
  is_toak = any(is_in_toak(hand))
  num_pairs = int(sum(is_in_pair(hand)) / 2) # 两两成对，is in pair是对单个牌进行检测的，所以最后除以2

  if is_flush and is_straight:
    hand_type = HANDS["straight flush"]
  elif is_foak:
    hand_type = HANDS["four of a kind"]
  elif is_toak and num_pairs == 1:
    hand_type = HANDS["full house"]
  elif is_flush:
    hand_type = HANDS["flush"]
  elif is_straight:
    hand_type = HANDS["straight"]
  elif is_toak:
    hand_type = HANDS["three of a kind"]
  elif num_pairs == 2:
    hand_type = HANDS["two pair"]
  elif num_pairs == 1:
    hand_type = HANDS["one pair"]
  else:
    hand_type = HANDS["high card"]

  details = {
        "hand_type": hand_type,
        "is_flush": is_flush,
        "is_straight": is_straight,
        "is_foak": is_foak,
        "is_toak": is_toak,
        "is_one_pair": num_pairs == 1,
        "is_two_pair": num_pairs == 2,
        "is_in_flush": is_in_flush(hand),
        "is_start_of_straight": is_start_of_straight(hand),
        "is_in_foak": is_in_foak(hand),
        "is_in_toak": is_in_toak(hand),
        "is_in_pair": is_in_pair(hand),
  }

  return hand_type, details

"""对以上judgement进行测试"""

def make_hand_and_classify(shorthand):
  hand = make_hand(shorthand)
  print(hand, classify_hand(hand)[0])

# run test
make_hand_and_classify("jc 10c 9c 8c 7c")
make_hand_and_classify("5c 5d 5h 5s 2d")
make_hand_and_classify("6s 6h 6d ks kh")
make_hand_and_classify("jh 9h 8h 4h 3h")
make_hand_and_classify("10d 9s 8h 7d 6c")
make_hand_and_classify("qc qs qh 9h 2s")
make_hand_and_classify("jh js 3c 3s 2h")
make_hand_and_classify("10s 10h 8s 7h 4c")
make_hand_and_classify("kd qd 7s 4s 3h")

"""生成数据集"""

from tqdm import tqdm
import torch
# qdm是一个快速、可扩展的Python进度条，可以在Python长循环中添加一个进度提示信息，用户只需要封装任意的迭代器tqdm(iterator)。
# 它可以帮助我们监测程序运行的进度，估计运行的时长，甚至可以协助debug。

def random_hand():
  deck = Deck()
  return [deck.draw() for _ in range(5)]

def make_sample(hand):
  hand_details = classify_hand(hand)[1]
  return {"hand": hand, **hand_details}

def generate_dataset(num=10000):
    dataset = []
    with tqdm(range(num)) as t:
        for i in t:
            hand = random_hand()
            dataset.append(make_sample(hand))
    return dataset

ds5 = generate_dataset(5)
print(ds5)

ds1k = generate_dataset(10000)

ds10k = generate_dataset(100000)

# 对于一维数组或者列表，np.unique() 函数 去除其中重复的元素 ，并按元素 由小到大 返回一个新的无元素重复的元组或者列表
np.unique([d["hand_type"].name for d in ds1k], return_counts=True)

def print_dataset_distribution(dataset, sort_by_count=False):
    hand_types, counts = np.unique([d["hand_type"].name for d in dataset], return_counts=True)
    hands = [(t, c) for t, c in zip(hand_types, counts)]
    if sort_by_count:
        hands = sorted(hands, key=lambda item: item[1])
    else:
        hands = sorted(hands, key=lambda item: HANDS[item[0]].value)
    for t, c in hands:
        print(f"{c / len(dataset) * 100:.4f}%: {t}")

print_dataset_distribution(ds1k)
print("\n")
print_dataset_distribution(ds10k)

def is_hand_valid(hand):
  if len(hand) != 5:
    return False
  for i in range(5):
    for j in range(i + 1, 5):
      if hand[i].rank == hand[j].rank and hand[i].suit == hand[j].suit:
        return False

  return True

# test valid
print(is_hand_valid(make_hand("4c 4c 5c 6c 10c")))
print(is_hand_valid(make_hand("3c 4c 5c 6c 7c")))

def permute_hand(hand):
    return [hand[i] for i in np.random.permutation(len(hand))]

def generate_random_hand_with_at_least_type(hand_type):
    # Note that this could generate a higher type than hand_type, albeit not likely in general.
    # e.g. when hand_type=HANDS["straight"], it could potentially generate a straight flush.
    if hand_type == HANDS["straight flush"]:
        start_rank = np.random.choice([r for r in RANKS if r not in [RANKS.J, RANKS.Q, RANKS.K, RANKS.A]])
        suit = np.random.choice(SUITS)
        hand = [Card(next_n_rank(start_rank, i), suit) for i in range(5)]
    elif hand_type == HANDS["four of a kind"]:
        ranks = np.random.choice(RANKS, 2, replace=False)
        hand = [Card(ranks[0], suit) for suit in SUITS]
        hand.append(Card(ranks[1], np.random.choice(SUITS)))
    elif hand_type == HANDS["full house"]:
        ranks = np.random.choice(RANKS, 2, replace=False)
        suits3 = np.random.choice(SUITS, 3, replace=False)
        suits2 = np.random.choice(SUITS, 2, replace=False)
        hand = [Card(ranks[0], s) for s in suits3]
        hand.extend([Card(ranks[1], s) for s in suits2])
    elif hand_type == HANDS["flush"]:
        suit = np.random.choice(SUITS)
        ranks = np.random.choice(RANKS, 5, replace=False)
        hand = [Card(r, suit) for r in ranks]
    elif hand_type == HANDS["straight"]:
        start_rank = np.random.choice([r for r in RANKS if r not in [RANKS.J, RANKS.Q, RANKS.K, RANKS.A]])
        hand = [Card(next_n_rank(start_rank, i), np.random.choice(SUITS)) for i in range(5)]
    elif hand_type == HANDS["three of a kind"]:
        ranks = np.random.choice(RANKS, 3, replace=False)
        suits3 = np.random.choice(SUITS, 3, replace=False)
        hand = [Card(ranks[0], s) for s in suits3]
        hand.extend([Card(r, np.random.choice(SUITS)) for r in ranks[1:]])
    elif hand_type == HANDS["two pair"]:
        ranks = np.random.choice(RANKS, 3, replace=False)
        suits1 = np.random.choice(SUITS, 2, replace=False)
        suits2 = np.random.choice(SUITS, 2, replace=False)
        hand = [Card(ranks[0], s) for s in suits1]
        hand.extend([Card(ranks[1], s) for s in suits2])
        hand.append(Card(ranks[2], np.random.choice(SUITS)))
    elif hand_type == HANDS["one pair"]:
        ranks = np.random.choice(RANKS, 4, replace=False)
        suits = np.random.choice(SUITS, 2, replace=False)
        hand = [Card(ranks[0], s) for s in suits]
        hand.extend([Card(r, np.random.choice(SUITS)) for r in ranks[1:]])
    elif hand_type == HANDS["high card"]:
        ranks = np.random.choice(RANKS, 5, replace=False)
        hand = [Card(r, np.random.choice(SUITS)) for r in ranks]
    hand = permute_hand(hand)
    assert is_hand_valid(hand)
    return hand

for hand_type in HANDS:
    print(generate_random_hand_with_at_least_type(hand_type), hand_type)

def generate_mined_dataset(num=1000):
    dataset = []
    type_probs = {HANDS["straight flush"]:0.01, HANDS["four of a kind"]:0.1, HANDS["full house"]:0.1,
                  HANDS["flush"]:0.1, HANDS["straight"]:0.1, HANDS["three of a kind"]:0.1,
                  HANDS["two pair"]:0.1, HANDS["one pair"]:0.1, HANDS["high card"]:0.1, None:0.19}
    with tqdm(range(num)) as t:
        for i in t:
            target_type = np.random.choice(list(type_probs.keys()), p=list(type_probs.values()))
            if target_type == None:
                hand = random_hand()
            else:
                hand = generate_random_hand_with_at_least_type(target_type)
            dataset.append(make_sample(hand))
    return dataset

mds1k = generate_mined_dataset(10000)

mds10k = generate_mined_dataset(100000)

def serialize_dataset(dataset):
    return [{**{k:v for k, v in d.items()}, "hand":[(c.rank.name, c.suit.name) for c in d["hand"]],
             "hand_type":d["hand_type"].name} for d in dataset]

def deserialize_dataset(dataset):
    return [{**{k:v for k, v in d.items()}, "hand":[Card(RANKS[r], SUITS[s]) for r, s in d["hand"]],
             "hand_type":HANDS[d["hand_type"]]} for d in dataset]

print(ds1k[:2])
print(serialize_dataset(ds1k[:2]))
print(deserialize_dataset(serialize_dataset(ds1k[:2])))

torch.save(serialize_dataset(ds1k), "/tmp/ds1k.pt")
torch.save(serialize_dataset(ds10k), "/tmp/ds10k.pt")
torch.save(serialize_dataset(mds1k), "/tmp/mds1k.pt")
torch.save(serialize_dataset(mds10k), "/tmp/mds10k.pt")

ds1k = deserialize_dataset(torch.load("/tmp/ds1k.pt"))
ds10k = deserialize_dataset(torch.load("/tmp/ds10k.pt"))
mds1k = deserialize_dataset(torch.load("/tmp/mds1k.pt"))
mds10k = deserialize_dataset(torch.load("/tmp/mds10k.pt"))

print_dataset_distribution(mds10k)

"""MODE!!!!!! 训模型"""

import torch
import torch.nn as nn

NC = 17
DEVICE = "cuda"

def encode_hand(hand):
  suit_map = {SUITS["clubs"]:0, SUITS["diamonds"]:1, SUITS["hearts"]:2, SUITS["spades"]:3}
  x = torch.zeros([5, NC]) # 返回一个由标量0填充的张量，它的形状由size决定，
  for i, c in enumerate(hand):
    x[i, c.rank.value] = 1
    x[i, len(RANKS) + suit_map[c.suit]] = 1

  return x

# torch.nn 是 PyTorch 中用于构建神经网络的模块。它提供了一组类和函数，用于定义、训练和评估神经网络模型。
class MultiHeadedAttentionBackbone(nn.Module):
  def __init__(self, num_heads =8, feature_size =32, kdim=32, vdim=32, hidden_size=32):
    # 返回一个代理对象，它会将方法调用委托给 type 的父类或兄弟类。
    super(MultiHeadedAttentionBackbone, self).__init__()
    self.num_heads = num_heads
    self.feature_size = feature_size
    self.kdim = kdim
    self.vdim = vdim
    self.softmax = torch.softmax
    self.proj_key = nn.Linear(NC, kdim) # nn.Linear定义一个神经网络的线性层，
    self.proj_query = nn.Linear(NC, kdim)
    self.proj_value = nn.Linear(NC, vdim)
    self.Linear1 = nn.Linear(vdim, hidden_size)
    self.Linear2 = nn.Linear(hidden_size, feature_size)
    self.act = nn.ReLU()

  def forward(self, x): # [B, 5, NC]
    # PyTorch中的.view()函数是一个用于改变张量形状的方法。
    # 它类似于NumPy中的.reshape()函数，可以通过重新排列张量的维度来改变其形状，而不改变张量的数据。
    # 在PyTorch中，.view()函数可以接受一个特殊的参数 -1，用于自动计算张量在该维度上的大小
    x = x.view(-1, NC) # [B*5, NC]
    # [B * 5, kdim/vdim]
    key, query, value = self.proj_key(x), self.proj_query(x), self.proj_value(x)

    ## 多头切片了
    # 在Python中/表示浮点整除法，返回浮点结果，也就是结果为浮点数;
    # 而//在Python中表示整数除法，返回大于结果的一个最大的整数，意思就是除法结果向下取整
    # .transpose(): transpose()函数的作用就是调换数组的行列值的索引值，类似于求矩阵的转置：
    # 最终输出 # [B, nh, 5, kdim / nh]
    key = key.view(-1, 5, self.num_heads, self.kdim // self.num_heads).transpose(1, 2) #??????
    # [B, nh, 5, kdim / nh]
    query = query.view(-1, 5, self.num_heads, self.kdim // self.num_heads).transpose(1, 2)
    # [B, nh, 5, vdim / nh]
    value = value.view(-1, 5, self.num_heads, self.vdim // self.num_heads).transpose(1, 2)
    # [B, nh, 5, 5]
    # torch.matmul()函数用于矩阵乘法，要求输入张量满足矩阵乘法规则，
    # 而mul()函数执行逐元素相乘，需要输入张量形状完全一致。
    # mul()不支持广播机制，而matmul()在某些情况下可以自动处理形状不匹配的问题。
    attention = torch.matmul(query, key.transpose(-1, -2)) # Q * K

    x = torch.matmul(attention, value) # [B, nh, 5, vdim/nh]
    x = x.transpose(1, 2).reshape(-1, 5, self.vdim) # [B, 5, vdim]

    return self.Linear2(self.act(self.Linear2(x))) # ????????

class BinaryPredictor(nn.Module):
  def __init__(self, feature_size=32, hidden_size = 16):
    super(BinaryPredictor, self).__init__()
    self.feature_size = feature_size
    self.hidden_size = hidden_size
    self.Linear1 = nn.Linear(feature_size, hidden_size)
    self.Linear2 = nn.Linear(hidden_size, 1)
    self.act = nn.ReLU()

  def forward(self, x): # [B*5, feature_size] -> [B, 5]
  # 当给定dim时，那么挤压操作只在给定维度上。例如，输入形状为: (A×1×B),
  # squeeze(input, 0) 将会保持张量不变，只有用 squeeze(input, 1)，形状会变成 (A×B)。
  # 此处，只会删除倒数第一个维度， 且倒数第一个维度的维度数必须为1
    return self.Linear2(self.act(self.Linear1(x).squeeze(dim = -1)))

class TypePredictor(nn.Module):
    def __init__(self, hidden_size=32):
        super(TypePredictor, self).__init__()
        self.hidden_size = hidden_size
        self.linear1 = nn.Linear(25, hidden_size)
        self.linear2 = nn.Linear(hidden_size, 9)
        self.act = nn.ReLU()

    def forward(self, x):  # [B, 5, 5] -> [B, 9]
        return self.linear2(self.act(self.linear1(x.view(-1, 25))))  # [B, 25] -> [B, 9]

class HandClassifier(nn.Module):
  def __init__(self, backbone = MultiHeadedAttentionBackbone,
               binary_predictor = BinaryPredictor,
               type_predictor = TypePredictor,
  # nn.BCEWithLogitsLoss() 是 PyTorch 中一个用于二元分类问题的损失函数，它结合了 Sigmoid 层
  #（将输出映射到 [0,1] 范围内）和 Binary Cross Entropy（BCE）损失。这可以避免在正向和反向传播
  # 过程中可能出现梯度爆炸或梯度消失的问题。
               binary_loss = nn.BCEWithLogitsLoss,
               ## 交叉熵损失函数
               type_loss = nn.CrossEntropyLoss,
               loss_weights = None):
    super(HandClassifier, self).__init__()
    if loss_weights is not None:
      self.loss_weights = loss_weights
    else:
      self.loss_weights = {k:1.0 for k in ["pair", "toak", "foak", "sofs", "flush", "type"]}

    self.backbone = backbone() # ???????
    self.c = self.backbone.feature_size
    self.pred_pair = binary_predictor(feature_size=self.c)
    self.pred_toak = binary_predictor(feature_size=self.c)
    self.pred_foak = binary_predictor(feature_size=self.c)
    self.pred_sofs = binary_predictor(feature_size=self.c)
    self.pred_flush = binary_predictor(feature_size=self.c)
    self.pred_type = type_predictor()
    self.binary_loss = binary_loss()
    self.type_loss = type_loss()


  def forward(self, batch):
    batch_size = len(batch)
    # stack()官方解释：沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状。
# 浅显说法：把多个2维的张量凑成一个3维的张量；多个3维的凑成一个4维的张量…以此类推，也就是在增加新的维度进行堆叠。

# # 将张量转换为GPU可用的格式
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# x = x.to(device)
    #??????????
    x = torch.stack([encode_hand(h["hand"]) for h in batch]).to(DEVICE) # [B, 5, NC]
    labels = {k:torch.stack([torch.tensor(h[k], dtype=torch.float) for h in batch], dim = 0).to(DEVICE)
      for k in batch[0].keys() if k != "hand" and k != "hand_type"}  # [B, 5]
# 但是torch.tensor(注意这里是小写)仅仅是python的函数,函数原型是
# torch.tensor(data, dtype=None, device=None, requires_grad=False)
# 其中data可以是:list,tuple,NumPy,ndarray等其他类型,torch.tensor会从data中的数据部分做拷贝(而不是直接引用),
# 根据原始数据类型生成相应类型的torch.Tensor。
    #??????????
    labels["hand_type"] = torch.stack([torch.tensor(h["hand_type"].value) for h in batch]).to(DEVICE) # [B]

    features = self.backbone(x) # [B, 5, C]
    features = features.view([batch_size * 5, -1]) # [B * 5, C]
    y_pair = self.pred_pair(features).view([batch_size, 5]) # [B, 5]
    y_toak = self.pred_toak(features).view([batch_size, 5])  # [B, 5]
    y_foak = self.pred_foak(features).view([batch_size, 5])  # [B, 5]
    y_sofs = self.pred_sofs(features).view([batch_size, 5])  # [B, 5]
    y_flush = self.pred_flush(features).view([batch_size, 5])  # [B, 5]
# torch.where(condition, x, y)
# 根据条件，也就是condiction，返回从x或y中选择的元素的张量（这里会创建一个新的张量，新张量的元素就是从x或y中选的，形状要符合x和y的广播条件）。
# Parameters解释如下：
# 1、condition (bool型张量) ：当condition为真，返回x的值，否则返回y的值
# 2、x (张量或标量)：当condition=True时选x的值
# 2、y (张量或标量)：当condition=False时选y的值
    y_type = self.pred_type(torch.where(
        torch.stack([y_pair, y_toak, y_foak, y_sofs, y_flush], dim=-1) > 0.0,
        1.0, -1.0)) # [B, 9] for logits of 9 types
    preds = {"pair":y_pair, "toak":y_toak, "foak":y_foak, "sofs":y_sofs, "flush":y_flush, "type":y_type}
# flatten()是对多维数据的降维函数。
# flatten(),默认缺省参数为0，也就是说flatten()和flatte(0)效果一样。
# python里的flatten(dim)表示，从第dim个维度开始展开，将后面的维度转化为一维.也就是说，
# 只保留dim之前的维度，其他维度的数据全都挤在dim这一维。

    l_pair = self.binary_loss(y_pair.flatten(), labels["is_in_pair"].flatten()) # [B * 5], [B * 5]
    l_toak = self.binary_loss(y_toak.flatten(), labels["is_in_toak"].flatten()) # [B * 5], [B * 5]
    l_foak = self.binary_loss(y_foak.flatten(), labels["is_in_foak"].flatten())  # [B * 5], [B * 5]
    l_sofs = self.binary_loss(y_sofs.flatten(), labels["is_start_of_straight"].flatten())  # [B * 5], [B * 5]
    l_flush = self.binary_loss(y_flush.flatten(), labels["is_in_flush"].flatten())  # [B * 5], [B * 5]
    l_type = self.type_loss(y_type, labels["hand_type"]) # [B, 9], [B]  (input, output)

    losses = {"pair":l_pair, "toak":l_toak, "foak":l_foak,
              "sofs":l_sofs, "flush":l_flush, "type":l_type}

    loss = sum((losses[k] * self.loss_weights[k]) for k in losses)

    return preds, loss, losses

"""build:  Train and Evaluate"""

# 实现毫秒级别的计时
from timeit import default_timer as timer
from torch.utils.data import DataLoader

def collate_losses(losses_history, loss_history):
  # np.array 创建数组
  # ???????? losses_history[0] and losses_history what??
  losses = {k:np.array([l[k] for l in losses_history]) for k in losses_history[0]}
  losses["total"] = np.array(loss_history)
  return losses

def collate_data(batch):
  return batch

def train_epoch(model, optimizer, dataset, batch_size = 16):
  # 在PyTorch中，train()方法是用于在训练神经网络时启用dropout、batch
  # normalization和其他特定于训练的操作的函数。这个方法会通知模型进行反向传播，并更新模型的权重和偏差。
  model.train()
  losses_history = []
  loss_history = []
  # collate_fn 另一种方法是动态填充数据。 当选择该批的样本时，我们只将它们填充到最长的样本。 如果我们
  # 另外按长度对数据进行排序，则填充将是最小的。 如果有一些非常长的序列，它们只会影响它们的批次，而不是整个数据集。
  train_dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_data)

  with tqdm(train_dataloader) as t:
    for batch in t:
      pred, loss, losses = model(batch)

      # 在进行反向传播之前，必须要用zero_grad()清空梯度
      optimizer.zero_grad()
      loss.backward()

      # 所有的optimizer都实现了step()方法，这个方法会更新所有的参数。(step 参数更新)
      optimizer.step()
      # .item(): 在浮点数结果上使用 .item() 函数可以提高显示精度，所以我们在求 loss 或者
      # accuracy 时，一般使用 x[1,1].item() 而不是单纯使用 x[1,1]。

      # .items(): Python 字典 items() 函数作用：以列表返回可遍历的(键, 值) 元组数组。
      losses_history.append({k:v.item() for k, v in losses.items()})
      loss_history.append(loss.item())

  return collate_losses(losses_history, loss_history)

def evaluate(model, dataset):
  # model.eval()的作用是 不启用 Batch Normalization 和 Dropout。 训练后的模型评估
  model.eval()
  losses_history = []
  loss_history = []
  val_dataloader = DataLoader(dataset, batch_size = 16, collate_fn=collate_data)

# with 语句适用于对资源进行访问的场合，确保不管使用过程中是否发生异常都会执行必要的“清理”操作，
# 释放资源，比如文件使用后自动关闭／线程中锁的自动获取和释放等。
# with torch.no_grad的作用
# 在该模块下，所有计算得出的tensor的requires_grad都自动设置为False。
  with torch.no_grad():
    with tqdm(val_dataloader) as t:
      for batch in val_dataloader:
        pred, loss, losses = model(batch)
        losses_history.append({k:v.item() for k, v in losses.items()})
        loss_history.append(loss.item())

  return collate_losses(losses_history, loss_history)

"""Run Training"""

model = HandClassifier().to(DEVICE)
# ????????? betas???
# params（必须参数）: 这是一个包含了需要优化的参数（张量）的迭代器，通常是模型的参数 model.parameters()。

# lr（默认值为 0.001）: 学习率（learning rate）。它是一个正数，控制每次参数更新的步长。

# betas（默认值为 (0.9, 0.999)）: 用于计算梯度的一阶和二阶矩的指数衰减因子。betas 是一个长度为 2 的元组，
#分别对应于一阶矩（平均梯度）和二阶矩（梯度平方的平均值）。通常情况下，这些值保持在接近 1 的范围内。

# eps（默认值为 1e-8）: 为了数值稳定性而添加到分母中的小常数。防止除零错误。

# weight_decay（默认值为 0）: 权重衰减，也称为 L2 正则化项。它用于控制参数的幅度，以防止过拟合。
#通常设置为一个小的正数。

# amsgrad（默认值为 False）: 是否使用 AMSGrad 变种。当设置为 True 时，AMSGrad 变种保留了梯度的历史信息，
#这有助于一些情况下防止学习率过早下降。

optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-8)

EPOCHS = 10 # 学习次数

for epoch in range(EPOCHS):
  t0 = timer()
  train_losses = train_epoch(model, optimizer, mds10k, batch_size=256)
  t1 = timer()
  eval_losses = evaluate(model, mds1k)
  t2 = timer()

  print("Epoch {}: train_loss: {}, time: {:.3f}s; eval_time:{:.3f}s".format(
      epoch, train_losses["total"].mean(), t1-t0, t2-t1
  ))
  print("Eval losses:")
  for k, v in eval_losses.items():
    print(f"  {k}: {v.mean()}")

# torch.nn.Module模块中的state_dict变量存放训练过程中需要学习的权重和偏执系数，state_dict作
# 为python的字典对象将每一层的参数映射成tensor张量，需要注意的是torch.nn.Module模块中的
# state_dict只包含卷积层和全连接层的参数，当网络中存在batchnorm时，例如vgg网络结构，
# torch.nn.Module模块中的state_dict也会存放batchnorm's running_mean
torch.save(model.state_dict(), "/tmp/model.pt") # full_model_rc1.pt

"""Run Evaluation"""

model.load_state_dict(torch.load("/tmp/model.pt"))
model = model.eval()

# ???? [:10]
batch = mds1k[:10]
with torch.no_grad():
  pred, loss, losses = model(batch)
  for i in range(len(batch)):
    print("hand: ", batch[i]["hand"])
    # .cpu()将数据的处理设备从其他设备（如.cuda()拿到cpu上），不会改变变量类型，转换后仍然是Tensor变量。
    print("pred: ", {k:v[i].cpu() for k, v, in pred.items()})
    print("pred_results: ")
    for k in ["pair", "toak", "foak", "sofs", "flush"]:
      #t.numpy()将Tensor变量转换为ndarray变量，其中t是一个Tensor变量，可
      #以是标量，也可以是向量，转换后dtype与Tensor的dtype一致。
      #关于下面三种sigmoid的使用方法
      # torch.sigmoid()
      # torch.nn.functional.sigmoid()
      # torch.Sigmoid
      # 相同点都是将值映射到0-1之间，没用区别
      print("  {}: {}".format(k, torch.nn.functional.sigmoid(
          pred[k][i]).cpu().numpy()))
    print("  hand type: {}".format(HAND_NAMES[torch.argmax(pred["type"][i], dim=-1)]))


# model.load_state_dict(torch.load("/tmp/model.pt"))
# model = model.eval()

# batch = mds1k[:10]
# with torch.no_grad():
#     pred, loss, losses = model(batch)
#     for i in range(len(batch)):
#         print("hand: ", batch[i]["hand"])
#         print("pred: ", {k:v[i].cpu() for k, v, in pred.items()})
#         print("pred results:")
#         for k in ["pair", "toak", "foak", "sofs", "flush"]:
#             print("  {}: {}".format(k, torch.nn.functional.sigmoid(pred[k][i]).cpu().numpy()))
#         print("  hand type: {}".format(HAND_NAMES[torch.argmax(pred["type"][i], dim=-1)]))

# Confusion matrix: [[TP, FN], [FP, TN]]
def binary_confusion(preds, labels):
  # 此处一定要将其转为cpu方法调用，才能正常转为np.array
    return np.array([[torch.logical_and(preds > 0.0, labels).sum().cpu().numpy(),
                      torch.logical_and(preds <= 0.0, labels).sum().cpu().numpy()],
                     [torch.logical_and(preds > 0.0, torch.logical_not(labels)).sum().cpu().numpy(),
                      torch.logical_and(preds <= 0.0, torch.logical_not(labels)).sum().cpu().numpy()]], dtype=int)

def type_confusion(preds, labels):
    confusion = np.zeros([9, 9], dtype=int)
    preds, labels = torch.argmax(preds, dim=-1).flatten(), torch.flatten(labels)
    combinations, counts = torch.unique(torch.stack([preds, labels], dim=1), dim=0, return_counts=True)
    confusion[combinations[:, 1].cpu(), combinations[:, 0].cpu()] = counts.cpu()
    return confusion

def eval_accuracy(model, dataset):
    val_dataloader = DataLoader(dataset, batch_size=16, collate_fn=collate_data)
    confusion = {"pair":np.zeros([2, 2], dtype=int), "toak":np.zeros([2, 2], dtype=int),
                 "foak":np.zeros([2, 2], dtype=int), "sofs":np.zeros([2, 2], dtype=int),
                 "flush":np.zeros([2, 2], dtype=int), "type":np.zeros([9, 9], dtype=int)}
    correct_counts = {"pair":0, "toak":0, "foak":0, "sofs":0, "flush":0, "type":0}
    with torch.no_grad():
        with tqdm(val_dataloader) as t:
            for batch in t:
                pred, loss, losses = model(batch)
                labels = {k:torch.stack([torch.tensor(h[k], dtype=torch.bool) for h in batch], dim=0).to(DEVICE)
                    for k in batch[0].keys() if k != "hand" and k != "hand_type"}  # [B, 5]
                # labels = labels.cpu().numpy()
                labels["hand_type"] = torch.stack([torch.tensor(h["hand_type"].value) for h in batch]).to(DEVICE)  # [B]
                confusion["pair"] += binary_confusion(pred["pair"], labels["is_in_pair"])
                confusion["toak"] += binary_confusion(pred["toak"], labels["is_in_toak"])
                confusion["foak"] += binary_confusion(pred["foak"], labels["is_in_foak"])
                confusion["sofs"] += binary_confusion(pred["sofs"], labels["is_start_of_straight"])
                confusion["flush"] += binary_confusion(pred["flush"], labels["is_in_flush"])
                confusion["type"] += type_confusion(pred["type"], labels["hand_type"])

    for k, c in confusion.items():
        if k != "type":
            print("{}: precision {:.2f}% recall {:.2f}% accuracy {:.2f}% positive labels {:.2f}%".format(
                k, c[0, 0] / (c[0, 0] + c[1, 0]) * 100, c[0, 0] / (c[0, 0] + c[0, 1]) * 100,
                (c[0, 0] + c[1, 1]) / c.sum() * 100, (c[0, 0] + c[0, 1]) / c.sum() * 100))
        else:
            print("{}: accuracy {:.2f}%".format(k, np.diag(c).sum() / c.sum() * 100))
            for i, hand_type in enumerate(HANDS):
                print("    {:15s}: accuracy {:.2f}% labels {:.2f}%".format(
                    hand_type.name, c[i, i] / c[i].sum() * 100, c[i].sum() / c.sum() * 100))

eval_accuracy(model, mds1k)
