In [52]:
import torch
from torch import nn

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = self.relu(out)
        return out
    
class ResNet(nn.Module):
    def __init__(self, block, inital_channel = 7, num_blocks=20, num_classes=15):
        super(ResNet, self).__init__()
        self.in_channels = 256
        self.value_num = 256
        self.policy_num = 2
        self.classes = num_classes**2

        #initial block
        self.conv1 = nn.Conv2d(inital_channel, self.in_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.relu = nn.ReLU()

        #RNN Blocks
        self.layers = []
        for _ in range(num_blocks):
            self.layers.append(block(self.in_channels, self.in_channels))
        self.layers = nn.Sequential(*self.layers)

        # Policy Block
        self.policy_conv = nn.Conv2d(self.in_channels, self.policy_num, kernel_size=1, stride=1, padding=1, bias = False)
        self.policy_bn = nn.BatchNorm2d(self.policy_num)
        self.policy_relu = nn.ReLU()
        self.policy_linear = nn.Linear((self.policy_num+num_classes)**2*self.policy_num, self.classes)  # 수정된 부분

        # Value Block
        self.value_conv = nn.Conv2d(self.in_channels, 1, kernel_size=1, stride=1, padding = 1, bias=False)
        self.value_bn1 = nn.BatchNorm2d(1)
        self.value_linear = nn.Linear((self.policy_num+num_classes)**2, 1)  # 수정된 부분
        self.value_relu = nn.ReLU()
        self.value_output = nn.Tanh()


    def forward(self, x):
        # Initial block
        out = self.relu(self.bn1(self.conv1(x)))

        # Residual blocks
        out = self.layers(out)

        # Policy head
        policy = self.policy_relu(self.policy_bn(self.policy_conv(out)))
        policy = policy.view(policy.size(0), -1)  # 평탄화
        policy = self.policy_linear(policy)  # 최종 정책 출력

        # Value head
        value = self.value_relu(self.value_bn1(self.value_conv(out)))
        value = value.view(value.size(0), -1)  # 평탄화
        value = self.value_linear(value)
        value = self.value_relu(value)
        value = self.value_output(value)  # 최종 값 출력

        return policy, value

    

In [53]:
import numpy as np

model = ResNet(ResidualBlock)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

test_board = np.array([[[0 for _ in range (15)] for _ in range (15)] for _ in range (7)])
board_tensor = torch.tensor(test_board, dtype=torch.float32)


# 모델, 테스트 데이터를 GPU로 이동
board_tensor = board_tensor.to(device)
board_tensor = board_tensor.unsqueeze(0) 
model.to(device)

ResNet(
  (conv1): Conv2d(7, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU()
  (layers): Sequential(
    (0): ResidualBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): ResidualBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1,

In [54]:
from torchsummary import summary
summary(model, input_size=(7, 15, 15))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 256, 15, 15]          16,128
       BatchNorm2d-2          [-1, 256, 15, 15]             512
              ReLU-3          [-1, 256, 15, 15]               0
            Conv2d-4          [-1, 256, 15, 15]         589,824
       BatchNorm2d-5          [-1, 256, 15, 15]             512
              ReLU-6          [-1, 256, 15, 15]               0
            Conv2d-7          [-1, 256, 15, 15]         589,824
       BatchNorm2d-8          [-1, 256, 15, 15]             512
              ReLU-9          [-1, 256, 15, 15]               0
    ResidualBlock-10          [-1, 256, 15, 15]               0
           Conv2d-11          [-1, 256, 15, 15]         589,824
      BatchNorm2d-12          [-1, 256, 15, 15]             512
             ReLU-13          [-1, 256, 15, 15]               0
           Conv2d-14          [-1, 256,

In [55]:
# 모델의 첫 번째 파라미터가 위치한 디바이스 확인
model_device = next(model.parameters()).device
print(model_device)

# board_tensor의 디바이스 확인
print(board_tensor.device)


cuda:0
cuda:0


In [56]:
model.eval()
with torch.no_grad():  # Gradient 계산 비활성화
    policy, value = model(board_tensor)

In [57]:
print(policy)
print(len(policy[0]))

tensor([[-1.7608e-03,  8.5580e-03, -4.0488e-03, -3.4897e-02,  5.1077e-02,
          3.1379e-02, -1.6601e-02,  2.1862e-02, -1.9335e-02, -3.0400e-02,
          5.4215e-02, -6.3216e-03, -7.9829e-03,  5.2556e-03,  2.8475e-02,
          2.8338e-02,  1.9801e-02,  5.4724e-03,  5.1513e-02, -2.1454e-02,
         -3.4921e-02, -3.8315e-02,  3.9937e-02, -5.0859e-02,  3.4419e-02,
         -3.2661e-02, -2.3469e-02, -3.6093e-02,  2.9734e-02, -1.9743e-02,
          4.1360e-02,  1.7285e-02, -1.0399e-02, -2.4012e-02, -1.0489e-02,
         -1.3591e-02, -4.9252e-02, -3.0738e-02, -3.7061e-02,  1.3450e-02,
         -2.6896e-02,  6.3587e-03, -3.1419e-02, -2.5669e-02, -2.8090e-02,
         -2.4290e-02,  1.8134e-02,  2.2544e-02, -5.7966e-02,  3.7540e-02,
          4.3163e-02, -2.7966e-02,  8.6479e-03,  3.5397e-02, -3.1726e-02,
         -2.2607e-02, -3.1928e-02, -2.8477e-02,  2.6253e-02,  2.6936e-02,
          2.8879e-02,  2.1345e-02, -4.0308e-02,  2.1084e-02, -6.0633e-02,
          1.9520e-02,  1.9328e-02, -1.

In [58]:
print(value)
print(len(value[0]))

tensor([[0.0027]], device='cuda:0')
1
