In [1]:
import torch

In [12]:
device = torch.device("cuda") if torch.cuda.is_available() else (torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu"))
device

device(type='mps')

In [76]:
# Representation Neural Network
from torch.nn import Module, ModuleList
from torch.nn import Conv2d, Linear
from torch.nn import ReLU, BatchNorm2d

class ConvolutionalBlock(Module):

    def __init__(self, in_channels, n_channels = 256, kernel_size = (3,3), stride=(1,1), padding=(1,1)):
        super(ConvolutionalBlock, self).__init__()
        self.convo = Conv2d(in_channels=in_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, stride=stride)
        self.norm = BatchNorm2d(num_features=n_channels)
        self.act = ReLU()

    def forward(self, x):
        return self.act(self.norm(self.convo(x)))
    
x = torch.randn((1,1,3,4), device=device)
convo_block = ConvolutionalBlock(1).to(device=device)

print(x.shape, convo_block(x).shape)

torch.Size([1, 1, 3, 4]) torch.Size([1, 256, 3, 4])


In [66]:
class ResidualBlock(Module):

    def __init__(self, in_channels, n_channels = 256, kernel_size = (3,3), stride=(1,1)):
        super(ResidualBlock, self).__init__()
        self.convo1 = ConvolutionalBlock(in_channels=in_channels, n_channels=n_channels, kernel_size=kernel_size, stride=stride)
        self.convo2 = Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size, padding=(1,1), stride=stride)
        self.norm = BatchNorm2d(num_features=n_channels)
        self.act = ReLU()

    def forward(self, x):
        skip_conn = x
        x = self.convo1(x)
        x = self.norm(self.convo2(x))
        x = x + skip_conn
        x = self.act(x)
        return x

x = torch.randn((1,1,3,4), device=device)
convo_block = ResidualBlock(1).to(device=device)

print(x.shape, convo_block(x).shape)

torch.Size([1, 1, 3, 4]) torch.Size([1, 256, 3, 4])


In [67]:
class GenericResidualNetwork(Module):

    def __init__(self, in_channels, n_channels = 256, n_layers=10, kernel_size = (3,3)):
        super(GenericResidualNetwork, self).__init__()

        # For now, simple, no downscale
        self.input_layer = ResidualBlock(in_channels=in_channels, n_channels=n_channels, kernel_size=kernel_size)

        self.residuals = ModuleList([ResidualBlock(in_channels=n_channels, n_channels=n_channels, kernel_size=kernel_size) for _ in range(n_layers-1)])

    def forward(self, x):

        x = self.input_layer(x)
        for res_layer in self.residuals:
            x = res_layer(x)
        return x
    
x = torch.randn((1,1,6,7), device=device)
convo_block = GenericResidualNetwork(1).to(device=device)

x = convo_block(x)
x.shape

torch.Size([1, 256, 6, 7])

In [None]:
class PolicyPredictor(Module):

    def __init__(self, in_channels, num_conv = 2 ):
        
        self.conv1 = Conv2d() 

In [70]:
class ContinousValuePredictor(Module):

    def __init__(self, in_channels, board_total_slots, n_outputs, n_convs = 2):
        super(ContinousValuePredictor, self).__init__()

        self.convos = ModuleList()
        
        for _ in range(n_convs):
            out_channels = in_channels // 2
            convo = ConvolutionalBlock(in_channels=in_channels, n_channels=out_channels, kernel_size=(1,1), padding=(0,0))
            self.convos.append(convo)
            in_channels = out_channels

        self.output_size = board_total_slots * out_channels
        self.linear = Linear(self.output_size, n_outputs)

    def forward(self, x):

        for conv in self.convos:
            x = conv(x)

        x = x.view(-1, self.output_size)
        return self.linear(x)       

x = torch.randn((1,256,6,7), device=device)
convo_block = ContinousValuePredictor(256, 42, 601).to(device=device)

x = convo_block(x)
x.shape 

torch.Size([1, 601])

In [73]:
class DynamicsNetwork(Module):

    def __init__(self, in_channels, board_total_slots, n_convs = 2, n_channels = 256, n_residual_layers=10, kernel_size = (3,3)):
        super(DynamicsNetwork, self).__init__()

        self.first_net = GenericResidualNetwork(in_channels=in_channels, n_channels=n_channels, n_layers=n_residual_layers, kernel_size=kernel_size)
        self.reward_predictor = ContinousValuePredictor(in_channels=n_channels, board_total_slots=board_total_slots, n_outputs=601, n_convs=n_convs)

    def forward(self, x):

        return self.reward_predictor(self.first_net(x))
    

x = torch.randn((1,256,6,7), device=device)
convo_block = DynamicsNetwork(256, 42).to(device=device)

x = convo_block(x)
x.shape 

torch.Size([1, 601])

In [86]:
class PredictionNetwork(Module):

    def __init__(self, in_channels, board_total_slots, action_space_size, n_convs = 2, n_channels = 256, n_residual_layers=10, kernel_size = (3,3)):
        super(PredictionNetwork, self).__init__()

        self.first_net = GenericResidualNetwork(in_channels=in_channels, n_channels=n_channels, n_layers=n_residual_layers, kernel_size=kernel_size)
        self.value_predictor = ContinousValuePredictor(in_channels=n_channels, board_total_slots=board_total_slots, n_outputs=601, n_convs=n_convs)
        self.policy_predictor = ContinousValuePredictor(in_channels=n_channels, board_total_slots=board_total_slots, n_outputs=action_space_size, n_convs=n_convs)

    def forward(self, x):

        x = self.first_net(x)
        pp = self.policy_predictor(x)
        return self.value_predictor(x), torch.nn.functional.softmax(pp, dim=1)
    

x = torch.randn((3,256,6,7), device=device)
convo_block = PredictionNetwork(256, 42, 7).to(device=device)

x = convo_block(x)
x[1].shape

torch.Size([3, 7])

In [87]:
class RepresentationNetwork(Module):

    def __init__(self, in_channels, n_channels = 256, n_residual_layers=10, kernel_size = (3,3)):
        super(RepresentationNetwork, self).__init__()

        self.net = GenericResidualNetwork(in_channels=in_channels, n_channels=n_channels, n_layers=n_residual_layers, kernel_size=kernel_size)

    def forward(self, x):

        return self.net(x)
    
x = torch.randn((3,4,6,7), device=device)
convo_block = RepresentationNetwork(4).to(device=device)

x = convo_block(x)
x[1].shape

RuntimeError: The size of tensor a (256) must match the size of tensor b (4) at non-singleton dimension 1