<a href="https://colab.research.google.com/github/Nikhil-Khetani/Dino-RL/blob/main/train_GPU.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install pygame

Collecting pygame
[?25l  Downloading https://files.pythonhosted.org/packages/01/da/4ff439558641a26dd29b04c25947e6c0ace041f56b2aa2ef1134edab06b8/pygame-2.0.1-cp36-cp36m-manylinux1_x86_64.whl (11.8MB)
[K     |████████████████████████████████| 11.8MB 271kB/s 
[?25hInstalling collected packages: pygame
Successfully installed pygame-2.0.1


In [6]:
import cv2
import time 
import os, sys
import game
import math
import random
import numpy as np
from collections import namedtuple
from itertools import count
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
import pygame
import time

In [3]:
os.environ["SDL_VIDEODRIVER"] = "dummy"

In [7]:



DISPLAY_HEIGHT=400
DISPLAY_WIDTH=400
STATE_HEIGHT = DISPLAY_HEIGHT-1
STATE_WIDTH = DISPLAY_WIDTH-1

pygame.init()

image_size=84
batch_size=32
lr=1e-6
gamma=0.99
initial_epsilon=0.1
final_epsilon=1e-4
num_iters=2000000
replay_memory_size=50000

In [8]:

Transition = namedtuple('Transition',('state', 'action', 'next_state', 'reward'))

def pre_processing(image, width, height):
    image = cv2.cvtColor(cv2.resize(image, (width, height)), cv2.COLOR_BGR2GRAY)
    _, image = cv2.threshold(image, 1, 255, cv2.THRESH_BINARY)
    return image[None, :, :].astype(np.float32)


In [9]:

class DeepQNetwork(nn.Module):
    def __init__(self):
        super(DeepQNetwork, self).__init__()

        self.conv1 = nn.Sequential(nn.Conv2d(4, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True))
        self.conv2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True))
        self.conv3 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True))

        self.fc1 = nn.Sequential(nn.Linear(7 * 7 * 64, 512), nn.ReLU(inplace=True))
        self.fc2 = nn.Linear(512, 2)
        self._create_weights()

    def _create_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.uniform(m.weight, -0.01, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, input):
        output = self.conv1(input)
        output = self.conv2(output)
        output = self.conv3(output)
        output = output.view(output.size(0), -1)
        output = self.fc1(output)
        output = self.fc2(output)

        return output

In [11]:

def train(episodes):
    real_time_start = time.time()
    CPU_time_start = time.process_time()
    model = DeepQNetwork()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-6)
    criterion = torch.nn.MSELoss()
    current_game = game.DinoGame(None,400,400)
    image, reward, endgame = current_game.nextframe(0)
    
    image = pre_processing(image, 84,84)
    #image = torch.from_numpy(np.expand_dims(np.transpose(image,(2,0,1)),0))
    image = torch.from_numpy(image)
    if torch.cuda.is_available():
        model.cuda()
        image = image.cuda()
    #image=image.float()
    state=torch.cat(tuple(image for _ in range (4)))[None, :, :, :]
    #print(state.shape)
    replay_memory = []
    episode = 0
    checkpoint_real_time_elapsed = 0
    checkpoint_CPU_time_elapsed = 0


    for i in reversed(range(1000)):
      print(i)
      checkpoint_path = "checkpoint_{}.pt".format(i*1000)
      if os.path.exists(checkpoint_path):
          checkpoint = torch.load(checkpoint_path,map_location=torch.device('cpu'))
          model.load_state_dict(checkpoint['model_state_dict'])
          optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
          episode = checkpoint['checkpoint_episode']
          loss = checkpoint['loss']
          checkpoint_real_time_elapsed = checkpoint['real_time_elapsed']
          checkpoint_CPU_time_elapsed = checkpoint['CPU_time_elapsed']
          print("checkpoint loaded")
          print("checkpoint episode:" + str(episode))
          break
    
    while episode<episodes:
        pred = model(state)[0]
        epsilon = final_epsilon+((episodes-episode)*(initial_epsilon-final_epsilon)/episodes)
        take_random_action = random.random()<=epsilon
        if take_random_action:
            print('random')
            action = random.randint(0,1)
        else:
            action=torch.argmax(pred)

        next_image, reward, endgame = current_game.nextframe(action)
        #next_image = torch.from_numpy(np.expand_dims(np.transpose(next_image,(2,0,1)),0))
        #next_image = next_image.float()
        #next_state = torch.cat((state.squeeze(0)[3:,:,:], next_image))

        next_image = pre_processing(next_image, 84,84)

        #action = action.unsqueeze(0)
        #reward = torch.from_numpy(np.array([reward],dtype=np.float32)).unsqueeze(0)
        next_image = torch.from_numpy(next_image)
        if torch.cuda.is_available():
            next_image = next_image.cuda()
        next_state = torch.cat((state[0, 1:, :, :], next_image))[None, :, :, :]

        replay_memory.append([state, action, reward, next_state, endgame])
        if len(replay_memory) > replay_memory_size:
            del replay_memory[0]
        batch = random.sample(replay_memory, min(len(replay_memory), batch_size))
        state_batch, action_batch, reward_batch, next_state_batch, terminal_batch = zip(*batch)
        state_batch = torch.cat(tuple(state for state in state_batch))
        action_batch = torch.from_numpy(
            np.array([[1, 0] if action == 0 else [0, 1] for action in action_batch], dtype=np.float32))
        reward_batch = torch.from_numpy(np.array(reward_batch, dtype=np.float32)[:, None])
        next_state_batch = torch.cat(tuple(state for state in next_state_batch))
        if torch.cuda.is_available():
            state_batch = state_batch.cuda()
            action_batch = action_batch.cuda()
            reward_batch = reward_batch.cuda()
            next_state_batch = next_state_batch.cuda()
        
        current_prediction_batch = model(state_batch)
        next_prediction_batch = model(next_state_batch)

        y_batch = torch.cat(
            tuple(reward if terminal else reward + gamma * torch.max(prediction) for reward, terminal, prediction in
                  zip(reward_batch, terminal_batch, next_prediction_batch)))
        q_value = torch.sum(current_prediction_batch * action_batch, dim=1)
        optimizer.zero_grad()
        loss = criterion(q_value, y_batch)
        loss.backward()
        optimizer.step()
        state = next_state
        
        if episode %50 == 0:
            print("Episode: {}/{}, Action: {}, Loss: {}, Epsilon {}, Reward: {}, Q-value: {}".format(
            episode + 1, episodes, action, loss, epsilon, reward, torch.max(pred)))
        episode+=1

        if episode % 5000 == 0:
            real_time_elapsed = time.time()-real_time_start + checkpoint_real_time_elapsed
            CPU_time_elapsed =time.process_time()-CPU_time_start + checkpoint_CPU_time_elapsed
            print("Real time elapsed : {}, CPU time elapsed : {}".format(real_time_elapsed,CPU_time_elapsed))
            checkpoint_path = "checkpoint_{}.pt".format(episode)
            torch.save({
            'checkpoint_episode': episode,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss':loss,
            'real_time_elapsed' : real_time_elapsed,
            'CPU_time_elapsed' : CPU_time_elapsed
            }, checkpoint_path)


train(1000000)

999
998
997
996
995
994
993
992
991
990
989
988
987
986
985
984
983
982
981
980
979
978
977
976
975
974
973
972
971
970
969
968
967
966
965
964
963
962
961
960
959
958
957
956
955
954
953
952
951
950
949
948
947
946
945
944
943
942
941
940
939
938
937
936
935
934
933
932
931
930
929
928
927
926
925
924
923
922
921
920
919
918
917
916
915
914
913
912
911
910
909
908
907
906
905
904
903
902
901
900
899
898
897
896
895
894
893
892
891
890
889
888
887
886
885
884
883
882
881
880
879
878
877
876
875
874
873
872
871
870
869
868
867
866
865
864
863
862
861
860
859
858
857
856
855
854
853
852
851
850
849
848
847
846
845
844
843
842
841
840
839
838
837
836
835
834
833
832
831
830
829
828
827
826
825
824
823
822
821
820
819
818
817
816
815
814
813
812
811
810
809
808
807
806
805
804
803
802
801
800
799
798
797
796
795
794
793
792
791
790
789
788
787
786
785
784
783
782
781
780
779
778
777
776
775
774
773
772
771
770
769
768
767
766
765
764
763
762
761
760
759
758
757
756
755
754
753
752
751
750


error: display Surface quit