In [1]:
import pygame
import numpy as np
import torch		
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

pygame 2.0.0.dev6 (SDL 2.0.10, python 3.8.16)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
class Net(nn.Module):
    def __init__(self):
        # get init from parent class of Net
        super(Net, self).__init__()
        self.conv1 = nn.Conv1d(19, 100, 10) # 3: 3 channels, 6: num of filters, 5: dem of filters
        self.banorm1 = nn.BatchNorm1d(10)
        self.pool = nn.MaxPool1d(2,stride=2)
        self.conv2 = nn.Conv1d(100, 265, 3)
        self.banorm2 = nn.BatchNorm1d(3)
        self.convP = nn.Conv1d(265,1,1)
        self.convV = nn.Conv1d(265,1,1)
        self.banorm = nn.BatchNorm1d(1)
        self.fcp = nn.Linear(1, 19*19)
        self.fcv0 = nn.Linear(1, 500)
        self.fcv1 = nn.Linear(500, 300)
        self.fcv2 = nn.Linear(300, 150)
        self.fcv3 = nn.Linear(300, 100)
        self.fcv4 = nn.Linear(100, 50)
        self.fcv5 = nn.Linear(50, 1)

    def forward(self, x):
        x = torch.tensor(x)
        x = x.reshape(1,19*19)
        x = x.reshape(19,19)
        x = self.conv1(x)
        x = self.pool(F.relu(self.banorm1(x)))
        x = self.conv2(x)
        x = self.pool(F.relu(self.banorm2(x)))
        p = self.convP(x)
        p = F.softmax(self.fcp(p))
        p = p.reshape(19*19)
        v = self.convV(x)
        v = F.relu(self.fcv0(v))
        v = F.relu(self.fcv1(v))
        v = F.relu(self.fcv3(v))
        v = F.relu(self.fcv4(v))
        v = F.sigmoid(self.fcv5(v))
        v = v.reshape(1)
        return [p,v]

In [3]:
PATH = './T100.pth'
net = Net()
net.load_state_dict(torch.load(PATH))

<All keys matched successfully>

In [4]:
# COLOR (RED, GREEN, BLUE)
WHITE = (255,255,255)
BLACK = (0,0,0)
BOTTOM = (100,100,100)
BG_COLOR = (238,154,73)
# WIDOWS AND SCREEN
BOARD_DEM = 19-1
DIST = 50
GRID = 50
WindowLen = BOARD_DEM*GRID+2*DIST 

gap = WindowLen/3
dist = WindowLen/6
center = WindowLen/2 +50

# functions
def show_first_page(window):
    window.fill(BG_COLOR)
    gap = WindowLen/3
    dist = WindowLen/6
    center = WindowLen/2 +50
    # defining a font
    smallfont = pygame.font.SysFont('Corbel',40)
    version1 = smallfont.render('1 PLAYER' , True , WHITE)
    version2 = smallfont.render('2 PLAYER' , True , WHITE)
    Q = smallfont.render('Quit' , True , WHITE)
    pygame.draw.rect(window,BOTTOM,[center-150,dist,300,dist])
    pygame.draw.rect(window,BOTTOM,[center-150,2*dist+gap,300,dist])
    pygame.draw.rect(window,BOTTOM,[10,10,80,50])
    window.blit(version1, (center,1.5*dist))
    window.blit(version2, (center,2.5*dist+gap))
    window.blit(Q, (30,30))
    pygame.display.update()
    clicked  = False
    while not clicked:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
        x,y = pygame.mouse.get_pos()
        key = pygame.mouse.get_pressed()
        if key[0]:
            if center-150 < x < center+150 and dist < y < 2*dist:
                show_board(window)
                return 1
            if center-150 < x < center+150 and 2*dist+gap < y < 3*dist+gap:
                show_board(window)
                return 2
            if 10 < x < 90 and 10 < y < 60:
                pygame.quit()

def show_board(window,VisitedPos=[]):
    # Draw the board
    window.fill(BG_COLOR)
    LINE_COLOR = (0,0,0) # black
    smallfont = pygame.font.SysFont('Corbel',40)
    Q = smallfont.render('Quit' , True , WHITE)
    window.blit(Q, (WindowLen+30,10))
    pygame.draw.rect(window,BOTTOM,[WindowLen+10,10,50,50])
    pygame.draw.line(window,LINE_COLOR,[DIST,DIST],[WindowLen-DIST,DIST],4)
    pygame.draw.line(window,LINE_COLOR,[DIST,DIST],[DIST,WindowLen-DIST],4)
    pygame.draw.line(window,LINE_COLOR,[DIST,WindowLen-DIST],[WindowLen-DIST,WindowLen-DIST],4)
    pygame.draw.line(window,LINE_COLOR,[WindowLen-DIST,WindowLen-DIST],[WindowLen-DIST,DIST],4)
    for i in range (1,18):
        pygame.draw.line(window,LINE_COLOR,[DIST,DIST+i*GRID],[WindowLen-DIST,DIST+i*GRID],2)
    for j in range (1,18):
        pygame.draw.line(window,LINE_COLOR,[DIST+j*GRID,DIST],[DIST+j*GRID,WindowLen-DIST],2)
    pygame.draw.circle(window,LINE_COLOR,[DIST+GRID*(BOARD_DEM/2),DIST+GRID*(BOARD_DEM/2)],8)
    if len(VisitedPos) != 0:
        for ii in range (len(VisitedPos)):
            if (ii+1)%2 != 0: # black
                pygame.draw.circle(window,BLACK,VisitedPos[ii],20)
            else:
                pygame.draw.circle(window,WHITE,VisitedPos[ii],20)
    pygame.display.update()
    

def find_legal_position(x,y):
    BoardXPosition = x - DIST
    BoardYPosition = y - DIST
    Row = int(BoardXPosition/GRID)
    RowR = BoardXPosition%GRID
    Column = int(BoardYPosition/GRID)
    ColumnR = BoardYPosition%GRID
    if RowR > (GRID/2):
        Row = Row + 1
    if ColumnR > (GRID/2):
        Column = Column + 1
    return [GRID*Row+DIST,GRID*Column+DIST] # windows position

def win(player,VisitedPos):
    numElement = len(VisitedPos)
    if player%2 != 0: # black
        StoneList = [0]* (int(numElement/2)+1)
    else:
        StoneList = [0]* int(numElement/2)
    if player%2 != 0: #black
        for item in range (len(VisitedPos)):
            if item%2 == 0:
                StoneList[int(item/2)] = VisitedPos[item]
    else:
        for item in range (len(VisitedPos)):
            if item%2 != 0:
                StoneList[int(item/2)] = VisitedPos[item]
    
    for i in StoneList:
        xposition = i[0]
        yposition = i[1]
        # step1: check horizontally
        count = 0
        connection = 1
        x = xposition
        connectlist = [i]
        while x < DIST+GRID*BOARD_DEM and count < 4:
            count = count + 1
            if [x+GRID,yposition] in StoneList:
                connection = connection + 1
                connectlist.append([x+GRID,yposition])
            x = x + GRID
        if connection == 5:
            return [True,connectlist] # black win
        
        # step2: check vertically
        count = 0
        connection = 1
        y = yposition
        connectlist = [i]
        while y < DIST+GRID*BOARD_DEM and count < 4:
            count = count + 1
            if [xposition,y+50] in StoneList:
                connection = connection + 1
                connectlist.append([xposition,y+50])
            y = y + 50
        if connection == 5:
            return [True,connectlist]
        
        # step3: check diagnally
        count = 0
        connection = 1
        x = xposition
        y = yposition
        connectlist = [i]
        while y < DIST+GRID*BOARD_DEM and x < DIST+GRID*BOARD_DEM and count < 4:
            count = count + 1
            if [x+50,y+50] in StoneList:
                connection = connection + 1
                connectlist.append([x+50,y+50])
            y = y + 50
            x = x + 50
        if connection == 5:
            return [True,connectlist]
        
        # step3: check diagnally
        count = 0
        connection = 1
        x = xposition
        y = yposition
        connectlist = [i]
        while DIST <= x and y <= DIST+GRID*BOARD_DEM and count < 4:
            count = count + 1
            if [x-50,y+50] in StoneList:
                connection = connection + 1
                connectlist.append([x-50,y+50])
            y = y + 50
            x = x - 50
        if connection == 5:
            return [True,connectlist]
    return [False,[]]

def get_board(VisitedPos):
  StoneList = [0] * (19*19)
  for item in range (len(VisitedPos)):
    pos = VisitedPos[item]
    position = widows_to_position(pos)
    if item%2 == 0:
      StoneList[position] = 1.
    else:
      StoneList[position] = -1.
  return StoneList

def position_to_widows(pos):
  pos = int(pos)
  return [DIST+(pos%19)*GRID,DIST+int((pos/19))*GRID]

def widows_to_position(pos):
    x = pos[0]
    y = pos[1]
    x = x - DIST
    y = y - DIST
    return int((y/50)*19 + (x/50))

pygame.init()
window = pygame.display.set_mode((WindowLen+100,WindowLen)) #length,width
pygame.display.set_caption('ConnectFive')
window.fill(BG_COLOR)
pygame.display.flip() # check new update
running = True # play a game
while running:
    # initial state
    NumOfPlayer = show_first_page(window)
    pygame.time.delay(500)
    VisitedPos = []
    # version 1: 1 player (TBD, Now is for 2-players only)
    if NumOfPlayer == 1:
        # start with black (odd player)
        player = 1 
        # when the game not end
        end = False
        while not end:
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    pygame.quit()
            # when no position has selected
            clicked = False
            while not clicked:
                for event in pygame.event.get():
                    if event.type == pygame.QUIT:
                        pygame.quit()
                # initialize the current state to show the squares
                # when there is a position in VisitedPosition show the stones
                show_board(window,VisitedPos)
                # get intensive position
                # current mouse's position, if not clicked will update as the movements as the mouse
                x,y = pygame.mouse.get_pos()
                key = pygame.mouse.get_pressed()
                if key[0] and WindowLen+10 < x < WindowLen+60  and 10 < y < 60:
                    pygame.quit()
                pos  = find_legal_position(x,y)
                if DIST <= pos[0] <= DIST+GRID*BOARD_DEM and DIST <= pos[1] <= DIST+GRID*BOARD_DEM:
                    if pos in VisitedPos:
                        pygame.draw.rect(window,(255,0,0),[pos[0]-25,pos[1]-25,50,50],2) # cannot put
                    # pygame.display.update()
                    else:
                        pygame.draw.rect(window,(0,0,255),[pos[0]-25,pos[1]-25,50,50],2) # can put
                else:
                    pygame.draw.rect(window,(255,0,0),[pos[0]-25,pos[1]-25,50,50],2)
                    # pygame.display.update()
                if key[0] == True: # clicked
                    if DIST <= pos[0] <= DIST+GRID*BOARD_DEM and DIST <= pos[1] <= DIST+GRID*BOARD_DEM:
                        if pos in VisitedPos: # cannot put here
                            pygame.draw.rect(window,(255,0,0),[pos[0]-25,pos[1]-25,50,50],2) # 1: empty
                            # pygame.display.update()
                        else: # can put here
                            clicked = True
                            VisitedPos.append(pos)
                            is_end = win(player,VisitedPos)
                            if is_end[0]:
                                show_board(window,VisitedPos)
                                connectlist = is_end[1]
                                for item in connectlist:
                                    pygame.draw.rect(window,(255,0,0),[item[0]-25,item[1]-25,50,50],2)
                                smallfont = pygame.font.SysFont('Corbel',50)
                                Q = smallfont.render('Player ' + str(player%2) + " wins!" , True , (123,245,112))
                                window.blit(Q, (WindowLen/2+50,WindowLen/2))
                                pygame.display.update()
                                pygame.time.delay(10000)
                                end = True
                            else:
                                show_board(window,VisitedPos)
                                board = get_board(VisitedPos)
                                p = net(board)[0]
                                NextMove = torch.argmax(p)
                                pos = position_to_widows(NextMove)
                                while pos in VisitedPos:
                                    p = torch.cat([p[0:NextMove], p[NextMove+1:]])
                                    NextMove = torch.argmax(p)
                                    pos = position_to_widows(NextMove)
                                VisitedPos.append(pos)
                                is_end = win(player+1,VisitedPos)
                                if is_end[0]:
                                  show_board(window,VisitedPos)
                                  connectlist = is_end[1]
                                  for item in connectlist:
                                      pygame.draw.rect(window,(255,0,0),[item[0]-25,item[1]-25,50,50],2)
                                  smallfont = pygame.font.SysFont('Corbel',50)
                                  Q = smallfont.render('Player ' + str(player%2) + " wins!" , True , (123,245,112))
                                  window.blit(Q, (WindowLen/2+50,WindowLen/2))
                                  pygame.display.update()
                                  pygame.time.delay(10000)
                                  end = True
                                else:
                                  player += 2
                    else:
                        pygame.draw.rect(window,(255,0,0),[pos[0]-25,pos[1]-25,50,50],2) # 1: empty
                pygame.display.update()
    if NumOfPlayer == 2:
        # start with black (odd player)
        player = 1 
        # when the game not end
        end = False
        while not end:
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    pygame.quit()
            # when no position has selected
            clicked = False
            while not clicked:
                for event in pygame.event.get():
                    if event.type == pygame.QUIT:
                        pygame.quit()
                # initialize the current state to show the squares
                # when there is a position in VisitedPosition show the stones
                show_board(window,VisitedPos)
                # get intensive position
                # current mouse's position, if not clicked will update as the movements as the mouse
                x,y = pygame.mouse.get_pos()
                key = pygame.mouse.get_pressed()
                if key[0] and WindowLen+10 < x < WindowLen+60  and 10 < y < 60:
                    pygame.quit()
                pos  = find_legal_position(x,y)
                if DIST <= pos[0] <= DIST+GRID*BOARD_DEM and DIST <= pos[1] <= DIST+GRID*BOARD_DEM:
                    if pos in VisitedPos:
                        pygame.draw.rect(window,(255,0,0),[pos[0]-25,pos[1]-25,50,50],2) # cannot put
                    # pygame.display.update()
                    else:
                        pygame.draw.rect(window,(0,0,255),[pos[0]-25,pos[1]-25,50,50],2) # can put
                else:
                    pygame.draw.rect(window,(255,0,0),[pos[0]-25,pos[1]-25,50,50],2)
                    # pygame.display.update()
                if key[0] == True: # clicked
                    if DIST <= pos[0] <= DIST+GRID*BOARD_DEM and DIST <= pos[1] <= DIST+GRID*BOARD_DEM:
                        if pos in VisitedPos: # cannot put here
                            pygame.draw.rect(window,(255,0,0),[pos[0]-25,pos[1]-25,50,50],2) # 1: empty
                            # pygame.display.update()
                        else: # can put here
                            clicked = True
                            VisitedPos.append(pos)
                            is_end = win(player,VisitedPos)
                            if is_end[0]:
                                show_board(window,VisitedPos)
                                connectlist = is_end[1]
                                for item in connectlist:
                                    pygame.draw.rect(window,(255,0,0),[item[0]-25,item[1]-25,50,50],2)
                                smallfont = pygame.font.SysFont('Corbel',50)
                                Q = smallfont.render('Player ' + str(player%2) + " wins!" , True , (123,245,112))
                                window.blit(Q, (WindowLen/2+50,WindowLen/2))
                                pygame.display.update()
                                pygame.time.delay(10000)
                                end = True
                            else:
                                player += 1
                    else:
                        pygame.draw.rect(window,(255,0,0),[pos[0]-25,pos[1]-25,50,50],2) # 1: empty
                pygame.display.update()

  pygame.draw.rect(window,BOTTOM,[center-150,dist,300,dist])
  pygame.draw.rect(window,BOTTOM,[center-150,2*dist+gap,300,dist])
  window.blit(version1, (center,1.5*dist))
  window.blit(version2, (center,2.5*dist+gap))
  pygame.draw.circle(window,LINE_COLOR,[DIST+GRID*(BOARD_DEM/2),DIST+GRID*(BOARD_DEM/2)],8)
  p = F.softmax(self.fcp(p))
  window.blit(Q, (WindowLen/2+50,WindowLen/2))
  window.blit(Q, (WindowLen/2+50,WindowLen/2))


error: video system not initialized