In [5]:
import numpy as np
import time

In [6]:
class EE21B093_Q4:
    def __init__(self, num_balls, num_matches):
        self.num_balls = num_balls
        self.num_wickets = 4
        self.num_matches = num_matches

        self.q_values = np.zeros(
            (self.num_balls, self.num_wickets, 6), dtype=np.float32
        )
        self.balls = 0
        self.wickets = 0

        self.alpha = 0.1
        self.gamma = 0.9
        self.epsilon = 0.1

        self.policy = "e-greedy"
        self.algo = "q-learning"

        self.last_action = None
        self.next_action = None

    def policy_action(self):
        q_values = self.q_values[self.balls][self.wickets]

        if self.policy == "e-greedy":
            p = np.random.rand()
            if p < self.epsilon:
                return np.random.randint(0, 6)
            else:
                return np.argmax(q_values)
        elif self.policy == "softmax":
            vals = np.exp(q_values)

            if np.sum(vals) == 0:
                return np.random.randint(0, 6)

            probs = vals / np.sum(vals)
            return np.random.choice(range(6), p=probs)

    def get_action(self, wicket, runs_scored):
        if self.balls == 0:
            self.last_action = self.policy_action()
            self.balls += 1
            return self.last_action

        self.balls += 1
        self.wickets += wicket

        if self.balls == self.num_balls or self.wickets == self.num_wickets:
            self.balls = 0
            self.wickets = 0
            self.last_action = None
            self.next_action = None

            self.last_action = self.policy_action()
            self.balls += 1
            return self.last_action

        q_values_old = self.q_values[self.balls - 1][self.wickets - wicket]
        q_values_new = self.q_values[self.balls][self.wickets]

        new_q_values = 0
        if self.algo == "q-learning":
            new_q_values = (
                runs_scored
                + self.gamma * np.max(q_values_new)
                - q_values_old[self.last_action]
                - 0.5
            )
        elif self.algo == "sarsa":
            self.next_action = self.policy_action()
            new_q_values = (
                runs_scored
                + self.gamma * q_values_new[self.next_action]
                - q_values_old[self.last_action]
                - 0.5
            )

        q_values_old[self.last_action] += self.alpha * new_q_values

        if self.algo == "q-learning":
            self.next_action = self.policy_action()

        self.last_action = self.next_action
        return self.last_action

In [7]:
wickets = 4


class Environment:
    def __init__(self, num_balls, agent):
        self.num_balls = num_balls
        self.agent = agent
        self.__run_time = 0
        self.__total_runs = 0
        self.__total_wickets = 0
        self.__runs_scored = 0
        self.__start_time = 0
        self.__end_time = 0
        self.__p_out = np.array([0.001, 0.01, 0.02, 0.03, 0.1, 0.3])
        self.__p_run = np.array([1, 0.9, 0.85, 0.8, 0.75, 0.7])
        self.__action_runs_map = np.array([0, 1, 2, 3, 4, 6])
        self.__wickets_left = wickets
        self.__wicket = 0
        self.__runs_scored = 0
        self.__start_time = 0
        self.__end_time = 0
        self.__batting_order = np.array([0, 1, 2, 3])

    def __get_action(self):
        self.__start_time = time.time()
        action = self.agent.get_action(self.__wicket, self.__runs_scored)
        self.__end_time = time.time()
        self.__run_time = self.__run_time + self.__end_time - self.__start_time
        return action

    def __get_outcome(self, action):
        pout = self.__p_out[action]
        prun = self.__p_run[action]
        wicket = np.random.choice(2, 1, p=[1 - pout, pout])[0]
        runs = 0
        if wicket == 0:
            runs = (
                self.__action_runs_map[action]
                * np.random.choice(2, 1, p=[1 - prun, prun])[0]
            )
        return wicket, runs

    def innings(self):
        self.__wickets_left = wickets
        self.__runs_scored = 0
        self.__total_runs = 0
        self.__total_wickets = 0
        self.__run_time = 0
        self.__start_time = 0
        self.__end_time = 0

        for ball in range(self.num_balls):
            if self.__wickets_left > 0:
                action = self.__get_action()
                self.__wicket, self.__runs_scored = self.__get_outcome(action)
                self.__total_runs = self.__total_runs + self.__runs_scored
                if self.__wicket > 0:
                    self.__wickets_left = self.__wickets_left - 1
                self.__total_wickets = self.__total_wickets + self.__wicket
                if self.__wickets_left == 0:
                    self.__get_action()
        return self.__total_runs, self.__total_wickets, self.__run_time

In [8]:
num_matches = 1000
num_balls = 60
agent = EE21B093_Q4(num_balls, num_matches)
environment = Environment(num_balls, agent)
score = np.zeros((num_matches, 1))
run_time = np.zeros((num_matches, 1))
wicket = np.zeros((num_matches, 1))

last_100_avgs = []
for i in range(num_matches):
    score[i], wicket[i], run_time[i] = environment.innings()
    last_100_avgs.append(score[i])
    if (i + 1) % 100 == 0:
        print("Match: ", i + 1, "Average: ", np.mean(last_100_avgs))
        last_100_avgs = []

Match:  100 Average:  76.3
Match:  200 Average:  78.92
Match:  300 Average:  80.72
Match:  400 Average:  79.89
Match:  500 Average:  79.9
Match:  600 Average:  76.59
Match:  700 Average:  76.07
Match:  800 Average:  77.1
Match:  900 Average:  80.68
Match:  1000 Average:  78.7
