In [25]:
!nvidia-smi

Wed Nov  8 17:17:35 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 546.01                 Driver Version: 546.01       CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                     TCC/WDDM  | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 3060 ...  WDDM  | 00000000:01:00.0 Off |                  N/A |
| N/A   60C    P0              15W /  95W |    396MiB /  6144MiB |     15%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

# seed

In [26]:
import random
import numpy as np
import torch


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


seed = 11032006
set_seed(seed)


In [27]:
import torch

print(torch.cuda.is_available())

True


# implement game logic

In [28]:
import torch


class GoGame:
    """
    Go game class.
    This class implements the Go game logic to be used for training the neural network.
    """

    def __init__(self, board_size=19) -> None:
        """
        Initializes the Go game with the given board size.
        Args:
            board_size (int): Size of the Go board (default is 19).
        """
        self.BLACK_DIM = 0
        self.WHITE_DIM = 1

        self.board_size = board_size
        self.board = torch.zeros((2, board_size, board_size), dtype=torch.float32)

    def place_stone(self, x, y, dim) -> None:
        """
        Places a stone of the specified color at the given position (x, y) on the board.
        Args:
            x (float): X-coordinate of the position.
            y (float): Y-coordinate of the position.
            dim (float): Color of the stone (0 for black, 1 for white).
        """
        self.board[dim][x][y] = 1
        self.__remove_dead_stones(x, y, dim)

    def __find_group_to_be_remove(self, x, y, dim) -> set:
        """
        Finds a group of stones of the specified color to be removed from the board.
        Args:
            x (float): X-coordinate of the position.
            y (float): Y-coordinate of the position.
            dim (float): Color of the stone (0 for black, 1 for white).
        Returns:
            set: Set of positions of stones in the group.
        """
        # Create a set to store positions of stones in the group
        group = set()

        # Create a set to store positions of stones to be checked
        to_be_check = set([(x, y)])

        # if the position has no stone, return an empty group
        if self.board[dim][x][y] == 0:
            return group

        # Iterate through the positions to be checked
        while len(to_be_check) > 0:
            # Pop the position to be checked
            x, y = to_be_check.pop()

            # Check if the position is already in the group
            if (x, y) in group:
                continue

            # if the position is opponent's stone, continue
            if self.board[-dim + 1][x][y] == 1:
                continue

            # if the position is empty, it has liberty, so this group is not dead
            if self.board[-dim + 1][x][y] == 0 and self.board[dim][x][y] == 0:
                return set()

            # Add the position to the group
            group.add((x, y))

            # Add the neighbors to the positions to be checked
            for dx, dy in [(1, 0), (-1, 0), (0, 1), (0, -1)]:
                nx, ny = x + dx, y + dy
                if nx < 0 or nx >= self.board_size or ny < 0 or ny >= self.board_size:
                    continue
                to_be_check.add((nx, ny))

        return group

    def __remove_dead_stones(self, x, y, dim) -> None:
        """
        Removes dead stones of the specified color from the board.

        Args:
            x (float): X-coordinate of the position.
            y (float): Y-coordinate of the position.
            dim (float): Color of the stone (0 for black, 1 for white).
        """

        # Determine the opponent's dimension
        opponent_dim = -dim + 1

        # Find groups of opponent stones to be removed
        groups = []
        for dx, dy in [(1, 0), (-1, 0), (0, 1), (0, -1)]:
            # find neighbor stones of the same color
            nx, ny = x + dx, y + dy
            if nx < 0 or nx >= self.board_size or ny < 0 or ny >= self.board_size:
                continue
            if self.board[opponent_dim][nx][ny] == 0:
                continue
            # find group of neighbor stones
            groups.append(self.__find_group_to_be_remove(nx, ny, opponent_dim))

        # Create a set to store positions of dead stones
        dead_stones = set()

        # Remove
        for group in groups:
            for x, y in group:
                dead_stones.add((x, y))

        # Remove dead stones
        for x, y in dead_stones:
            self.board[opponent_dim][x][y] = 0

    def get_board(self) -> torch.Tensor:
        """
        Returns the current game board.
        Returns:
            torch.Tensor: Current game board.
        """
        return self.board

    def reset(self) -> None:
        """
        Resets the game board to the initial state.
        """
        self.board = torch.zeros(
            (2, self.board_size, self.board_size), dtype=torch.float32
        )

# define dataset for model

In [29]:
import csv
import numpy as np
from torch.utils.data import Dataset
import time


class GoDataset(Dataset):
    def __init__(self, path, length):
        """
        Initializes the GoDataset with the given CSV file path.
        Args:
            path (str): Path to the CSV file containing Go game data.
        """
        super().__init__()
        self.path = path
        self.length = length
        self.goGame = GoGame()
        self.char2idx = {c: i for i, c in enumerate("abcdefghijklmnopqrs")}

        # Load data from CSV file
        with open(self.path, newline="") as csvfile:
            reader = csv.reader(csvfile, delimiter=",")
            # Read row by row
            self.data = list(reader)  # dtype: list[str]

    def __rotate_point(self, x, y, n, max_val=18):
        # Perform the rotation n times
        for _ in range(n):
            x, y = y, max_val - x

        return x, y

    def __step(self, step):
        """
        Perform a step in the game based on the given input step.
        Args:
            step (str): A str containing player, x-coordinate, and y-coordinate information.
        """
        dim = 0 if step[0] == "B" else 1
        x = self.char2idx[step[2]]
        y = self.char2idx[step[3]]
        x, y = self.__rotate_point(x, y, self.rotate_times)
        self.goGame.place_stone(x, y, dim)

    def __transform(self, data):
        """
        Transform data from CSV data to boards and add other features.
        Args:
            data (list): List of steps in the game.
        Returns:
            to_model (torch.Tensor): Processed data sample.
            label (torch.Tensor): Label for the data sample.
        """
        transformed_data = []
        random_start = np.random.randint(2, len(data) - self.length - 1)

        for i in range(2, random_start + self.length):
            self.__step(data[i])
            if i >= random_start:
                transformed_data.append(self.goGame.get_board().clone())

        to_model = torch.stack(transformed_data)

        self.goGame.reset()
        self.__step(data[random_start + self.length])
        label = self.goGame.get_board().clone().reshape(-1)

        return to_model, label

    def __len__(self):
        """
        Returns the number of samples in the dataset.
        Returns:
            int: Number of samples in the dataset.
        """
        return len(self.data)

    def __getitem__(self, idx):
        """
        Get data at the given index.
        Args:
            idx (int): Index of the data sample.
        Returns:
            torch.Tensor: Processed and padded data sample.
        """
        # Get data at the given index
        row = self.data[idx]

        # Randomly rotate times
        self.rotate_times = np.random.randint(3)

        # Transform data into a board
        self.goGame.reset()
        processed_data, label = self.__transform(row)
        return processed_data, label


# goDataset = GoDataset("data/train/dan_train.csv", 5)
# print(f"data: {goDataset[0][0].shape}, dtype: {goDataset[0][0].dtype}")
# print(f"label: {goDataset[0][1].shape}, dtype: {goDataset[0][1].dtype}")
# # stop execution here
# raise Exception("Stop execution")

# visualize game

In [30]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
import time
from IPython.display import clear_output


def draw_board(board):
    """
    Draws the Go board with stones based on the provided board configuration.
    Args:
        board (numpy.ndarray): 2D array representing the Go board (-1 for black stones, 1 for white stones, 0 for empty).
    Returns:
        numpy.ndarray: RGB image of the Go board with stones and grid lines.
    """
    # Create an RGB image (3 channels) with a green background
    image = np.ones((20 * 20, 20 * 20, 3), dtype=np.uint8) * 173  # RGB value for green

    # Draw lines for the board grid
    for i in range(1, 20):
        cv2.line(
            image, (i * 20, 20), (i * 20, 20 * 20 - 20), color=(0, 0, 0), thickness=1
        )
        cv2.line(
            image, (20, i * 20), (20 * 20 - 20, i * 20), color=(0, 0, 0), thickness=1
        )

    black = (0, 0, 0)  # RGB for black
    white = (255, 255, 255)  # RGB for white
    # Draw stones on the board
    for row in range(19):
        for col in range(19):
            if board[0][row][col] == 1:  # Black stone
                cv2.circle(
                    image, (col * 20 + 20, row * 20 + 20), 8, black, -1
                )  # Draw a filled circle
            elif board[1][row][col] == 1:  # White stone
                cv2.circle(
                    image, (col * 20 + 20, row * 20 + 20), 8, white, -1
                )  # Draw a filled circle

    return image

# goDataset = GoDataset("data/train/dan_train.csv", 5)
# for _ in range(10):
#     boards, y = goDataset.__getitem__(0)
#     for i in range(0, len(boards)):
#         image = draw_board(boards[i].numpy())
#         plt.imshow(image)
#         plt.axis('off')
#         plt.show()
#         time.sleep(1)
#         clear_output(wait=True)

# # stop execution here
# raise Exception("Stop execution")

In [31]:
import cv2


def save_as_video(boards):
    """
    Saves a sequence of Go board states as a video file.
    Args:
        boards (list): List of 2D numpy arrays representing Go board states.
    """
    # Define the codec and create a VideoWriter object
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # Codec for video format (MP4)
    # VideoWriter(filename, codec, fps, frameSize)
    video = cv2.VideoWriter('test.mp4', fourcc, 1,
                            (20*20, 20*20))  # VideoWriter object

    # Iterate through the list of board states and save them as frames in the video
    for board in boards:
        image = draw_board(board)  # Convert board state to an RGB image
        video.write(image)  # Write the image as a frame in the video

    # Release the VideoWriter object, finalizing the video creation
    video.release()

# boards, _ = goDataset.__getitem__(0)
# save_as_video(boards)

# # stop execution here
# raise Exception("Stop execution")

# PredEncoder

In [32]:
import torch
import torch.nn as nn


class PredEncoder(nn.Module):
    """
    PredNet encoder module.

    Args:
        input_dim (int): input dimension.
        num_channels (int): number of depthwise convolution layer input channels.
        depthwise_kernel_size (int): kernel size of depthwise convolution layer.
        dropout (float, optional): dropout probability. (Default: 0.0)
        bias (bool, optional): indicates whether to add bias term to each convolution layer. (Default: ``False``)
        use_group_norm (bool, optional): use GroupNorm rather than BatchNorm. (Default: ``False``)
    """

    def __init__(
        self,
        input_dim: int,
        num_channels: int,
        depthwise_kernel_size: int,
        dropout: float = 0.0,
        bias: bool = False,
        use_group_norm: bool = False,
    ) -> None:
        super().__init__()
        if (depthwise_kernel_size - 1) % 2 != 0:
            raise ValueError(
                "depthwise_kernel_size must be odd to achieve 'SAME' padding.")

        # Sequential layers: 1x1 Conv, GLU, Depthwise Conv, Normalization, Activation, 1x1 Conv, Dropout
        self.sequential = nn.Sequential(
            # 1x1 Convolutional layer with GLU activation
            nn.Conv1d(
                input_dim,
                2 * num_channels,
                1,
                stride=1,
                padding=0,
                bias=bias,
            ),
            nn.GLU(dim=1),  # Applying GLU activation along channel dimension
            # Depthwise Convolutional layer with specified kernel size and padding
            nn.Conv1d(
                num_channels,
                num_channels,
                depthwise_kernel_size,
                stride=1,
                padding=(depthwise_kernel_size - 1) // 2,
                groups=num_channels,  # Depthwise convolution with groups=num_channels
                bias=bias,
            ),
            # Normalization using GroupNorm or BatchNorm
            nn.GroupNorm(num_groups=1, num_channels=num_channels)
            if use_group_norm
            else nn.BatchNorm1d(num_channels),
            nn.SiLU(),  # Applying SiLU activation function
            # 1x1 Convolutional layer to map back to the original input dimension
            nn.Conv1d(
                num_channels,
                input_dim,
                kernel_size=1,
                stride=1,
                padding=0,
                bias=bias,
            ),
            # Dropout layer with specified dropout probability
            nn.Dropout(dropout),
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the PredNet encoder module.

        Args:
            input (torch.Tensor): Input tensor with shape `(B, D)`.
            B: Batch size, D: Input dimension

        Returns:
            torch.Tensor: Output tensor with shape `(B, D)`.
        """
        # input: (B, D) -> (B, D, 1)
        x = input.unsqueeze(-1)

        x = self.sequential(x)  # Applying sequential layers

        # Removing singleton dimension and returning the output tensor
        return x.squeeze(2)


# conformer

In [33]:
import torch
import torch.nn as nn


class ConvModule(nn.Module):
    """
    Conformer convolution module.

    Args:
        input_dim (int): input dimension.
        num_channels (int): number of depthwise convolution layer input channels.
        depthwise_kernel_size (int): kernel size of depthwise convolution layer.
        dropout (float, optional): dropout probability. (Default: 0.0)
        bias (bool, optional): indicates whether to add bias term to each convolution layer. (Default: ``False``)
        use_group_norm (bool, optional): use GroupNorm rather than BatchNorm. (Default: ``False``)
    """

    def __init__(
        self,
        input_dim: int,
        num_channels: int,
        depthwise_kernel_size: int,
        dropout: float = 0.0,
        bias: bool = False,
        use_group_norm: bool = False,
    ) -> None:
        super().__init__()
        if (depthwise_kernel_size - 1) % 2 != 0:
            raise ValueError(
                "depthwise_kernel_size must be odd to achieve 'SAME' padding."
            )

        # Layer normalization for input
        self.layer_norm = nn.LayerNorm(input_dim)

        # Sequential layers: 1x1 Conv, GLU, Depthwise Conv, Normalization, Activation, 1x1 Conv, Dropout
        self.sequential = nn.Sequential(
            nn.Conv1d(
                input_dim,
                2 * num_channels,
                1,
                stride=1,
                padding=0,
                bias=bias,
            ),
            nn.GLU(dim=1),
            nn.Conv1d(
                num_channels,
                num_channels,
                depthwise_kernel_size,
                stride=1,
                padding=(depthwise_kernel_size - 1) // 2,
                groups=num_channels,
                bias=bias,
            ),
            nn.GroupNorm(num_groups=1, num_channels=num_channels)
            if use_group_norm
            else nn.BatchNorm1d(num_channels),
            nn.SiLU(),  # SiLU activation function (Sigmoid Linear Unit)
            nn.Conv1d(
                num_channels,
                input_dim,
                kernel_size=1,
                stride=1,
                padding=0,
                bias=bias,
            ),
            nn.Dropout(dropout),
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the Conformer convolution module.

        Args:
            input (torch.Tensor): Input tensor with shape `(B, T, D)`.
            B: Batch size, T: Sequence length, D: Input dimension

        Returns:
            torch.Tensor: Output tensor with shape `(B, T, D)`.
        """
        x = self.layer_norm(input)
        # Transpose to shape `(B, D, T)` for 1D convolutions
        x = x.transpose(1, 2)
        x = self.sequential(x)  # Apply sequential layers
        return x.transpose(1, 2)  # Transpose back to shape `(B, T, D)`


class FeedForwardModule(nn.Module):
    """
    Feedforward module with Layer Normalization, Linear layers, SiLU activation, and Dropout.

    Args:
        input_dim (int): Input dimension.
        hidden_dim (int): Hidden layer dimension.
        dropout (float, optional): Dropout probability. (Default: 0.1)
    """

    def __init__(self, input_dim, hidden_dim, dropout=0.1):
        super(FeedForwardModule, self).__init__()
        self.module = nn.Sequential(
            nn.LayerNorm(input_dim),
            nn.Linear(input_dim, hidden_dim),
            nn.SiLU(),  # SiLU activation function (Sigmoid Linear Unit)
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, input_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        """
        Forward pass of the FeedForwardModule.

        Args:
            x (torch.Tensor): Input tensor with shape `(B, T, D)`.

        Returns:
            torch.Tensor: Output tensor with the same shape as the input tensor.
        """
        return self.module(x)


class ConformerBlock(nn.Module):
    """
    Conformer layer that constitutes Conformer.

    Args:
        input_dim (int): input dimension.
        ffn_dim (int): hidden layer dimension of the feedforward network.
        num_attention_heads (int): number of attention heads.
        depthwise_conv_kernel_size (int): kernel size of the depthwise convolution layer.
        dropout (float, optional): dropout probability. (Default: 0.1)
        use_group_norm (bool, optional): use ``GroupNorm`` rather than ``BatchNorm1d``
            in the convolution module. (Default: ``False``)
        convolution_first (bool, optional): apply the convolution module ahead of
            the attention module. (Default: ``False``)
    """

    def __init__(
        self,
        input_dim,
        ffn_dim,
        num_attention_heads,
        depthwise_conv_kernel_size,
        dropout=0.1,
        use_group_norm=False,
        convolution_first=False,
    ):
        super().__init__()
        self.ffn1 = FeedForwardModule(input_dim, ffn_dim, dropout)
        self.ffn2 = FeedForwardModule(input_dim, ffn_dim, dropout)
        self.conv = ConvModule(
            input_dim,
            input_dim,
            depthwise_conv_kernel_size,
            dropout,
            use_group_norm=use_group_norm,
        )
        self.self_attn = nn.MultiheadAttention(
            input_dim, num_attention_heads, dropout=dropout
        )
        self.self_attn_dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(input_dim)
        self.convolution_first = convolution_first

    def __apply_conv(self, x):
        """
        Apply the convolution module.

        Args:
            x (torch.Tensor): Input tensor with shape `(T, B, D)`.

        Returns:
            torch.Tensor: Output tensor after applying the convolution module.
        """
        residual = x
        # Transpose to shape `(B, T, D)` for 1D convolutions
        x = x.transpose(0, 1)
        x = self.conv(x)
        x = x.transpose(0, 1)  # Transpose back to shape `(T, B, D)`
        x = x + residual
        return x

    def forward(self, x):
        """
        Forward pass of the ConformerBlock.

        Args:
            x (torch.Tensor): Input tensor with shape `(T, B, D)`.

        Returns:
            torch.Tensor: Output tensor with the same shape as the input tensor.
        """
        residual = x
        x = self.ffn1(x)  # First feedforward module
        x = 0.5 * x + residual  # Residual connection and scaling

        if self.convolution_first:
            x = self.__apply_conv(x)  # Apply convolution module if specified

        residual = x
        x = self.layer_norm(x)  # Layer normalization
        x, _ = self.self_attn(x, x, x)  # Multihead self-attention
        x = self.self_attn_dropout(x)
        x = x + residual  # Residual connection

        if not self.convolution_first:
            x = self.__apply_conv(x)  # Apply convolution module if specified

        residual = x
        x = self.ffn2(x)  # Second feedforward module
        x = 0.5 * x + residual  # Residual connection and scaling
        x = self.layer_norm(x)  # Final layer normalization
        return x


class Conformer(nn.Module):
    """
    Args:
        input_dim (int): input dimension.
        num_heads (int): number of attention heads in each Conformer layer.
        ffn_dim (int): hidden layer dimension of feedforward networks.
        num_layers (int): number of Conformer layers to instantiate.
        depthwise_conv_kernel_size (int): kernel size of each Conformer layer's depthwise convolution layer.
        dropout (float, optional): dropout probability. (Default: 0.1)
        use_group_norm (bool, optional): use ``GroupNorm`` rather than ``BatchNorm1d``
            in the convolution module. (Default: ``False``)
        convolution_first (bool, optional): apply the convolution module ahead of
            the attention module. (Default: ``False``)
    """

    def __init__(
        self,
        input_dim,
        num_heads,
        ffn_dim,
        num_layers,
        depthwise_conv_kernel_size,
        dropout=0.1,
        use_group_norm=False,
        convolution_first=False,
    ):
        super().__init__()

        # Instantiate Conformer blocks
        self.conformer_blocks = nn.ModuleList(
            [
                ConformerBlock(
                    input_dim,
                    ffn_dim,
                    num_heads,
                    depthwise_conv_kernel_size,
                    dropout,
                    use_group_norm,
                    convolution_first,
                )
                for _ in range(num_layers)
            ]
        )

    def forward(self, x: torch.Tensor):
        """
        Forward pass of the Generator (Conformer model).

        Args:
            x (torch.Tensor): input with shape `(B, T, input_dim)`.

        Returns:
            torch.Tensor: output with shape `(B, T, input_dim)`.
        """
        batch_size, seq_length, _, _, _ = x.shape
        x = x.view(batch_size, seq_length, -1)  # Flatten input tensor

        x = x.transpose(0, 1)  # Transpose to shape `(T, B, input_dim)`

        # Pass input through Conformer blocks
        for layer in self.conformer_blocks:
            x = layer(x)

        x = x.transpose(0, 1)  # Transpose back to shape `(B, T, input_dim)`

        return x


# generator

In [34]:
import torch
import torch.nn as nn


class Generator(nn.Module):
    """
    Generator model using Conformer architecture.

    Args:
        input_dim (int): Input dimension.
        num_heads (int): Number of attention heads in each Conformer layer.
        ffn_dim (int): Hidden layer dimension of feedforward networks in Conformer layers.
        num_layers (int): Number of Conformer layers.
        depthwise_conv_kernel_size (int): Kernel size of depthwise convolution in Conformer layers.
        dropout (float, optional): Dropout probability. (Default: 0.1)
        use_group_norm (bool, optional): Use GroupNorm instead of BatchNorm1d in Conformer layers. (Default: False)
        convolution_first (bool, optional): Apply convolution module ahead of attention module. (Default: False)
    """

    def __init__(
        self,
        input_dim,
        num_heads,
        ffn_dim,
        num_layers,
        depthwise_conv_kernel_size,
        dropout=0.1,
        use_group_norm=False,
        convolution_first=False,
    ):
        super(Generator, self).__init__()

        # Instantiate the Conformer module
        self.conformer = Conformer(
            input_dim,
            num_heads,
            ffn_dim,
            num_layers,
            depthwise_conv_kernel_size,
            dropout,
            use_group_norm,
            convolution_first,
        )

        # Output layer: Linear + Softmax
        self.output_layer = nn.Sequential(
            nn.Linear(input_dim, input_dim),
            # TODO: try different activation functions
            nn.Softmax(dim=-1),
        )

    def forward(self, x: torch.Tensor):
        """
        Forward pass of the Generator (Conformer model).

        Args:
            x (torch.Tensor): Input tensor with shape `(B, T, input_dim)`.

        Returns:
            torch.Tensor: Output tensor with shape `(B, output_dim)`.
        """
        # Pass the input through the Conformer layers
        conformer_output = self.conformer(x)

        # truncate the output to the last time step
        output = conformer_output[:, -1, :]

        # Pass the output through the linear layer
        output = self.output_layer(output)

        return output, x  # Return the original input tensor without cloning


# discriminator

In [35]:
import torch
import torch.nn as nn


class Discriminator(nn.Module):
    """
    Discriminator model using Conformer and PredEncoder architectures.

    Args:
        input_dim (int): Input dimension for Conformer and PredEncoder.
        num_heads (int): Number of attention heads in each Conformer layer.
        ffn_dim (int): Hidden layer dimension of feedforward networks in Conformer.
        num_layers (int): Number of Conformer layers.
        depthwise_conv_kernel_size (int): Kernel size of depthwise convolution in Conformer.
        dropout (float, optional): Dropout probability. (Default: 0.1)
        use_group_norm (bool, optional): Use GroupNorm instead of BatchNorm1d in Conformer layers. (Default: False)
        convolution_first (bool, optional): Apply convolution module ahead of attention module. (Default: False)
    """

    def __init__(
        self,
        input_dim,
        num_heads,
        ffn_dim,
        num_layers,
        depthwise_conv_kernel_size,
        dropout=0.1,
        use_group_norm=False,
        convolution_first=False,
    ):
        super(Discriminator, self).__init__()

        # Instantiate the Generator (Conformer) module
        self.generator = Generator(
            input_dim,
            num_heads,
            ffn_dim,
            num_layers,
            depthwise_conv_kernel_size,
            dropout,
            use_group_norm,
            convolution_first,
        )

        # Instantiate the PredEncoder module
        self.pred_encoder = PredEncoder(
            input_dim=input_dim,
            num_channels=input_dim,
            depthwise_kernel_size=3,
            bias=False,
            use_group_norm=False,
        )

        # TODO: try different activation functions
        # Linear layers for final classification
        self.linear = nn.Sequential(
            # Concatenate Conformer output and PredEncoder output
            # Output dimension reduced by half
            nn.Linear(2 * input_dim, input_dim // 4),
            nn.LeakyReLU(0.2),
            nn.Linear(input_dim // 4, input_dim // 16),
            nn.LeakyReLU(0.2),
            # Output one-hot vector for binary classification (2 classes)
            nn.Linear(input_dim // 16, 1),
            nn.Tanh(),
        )

    def forward(self, x, y):
        """
        Forward pass of the Discriminator.

        Args:
            x (torch.Tensor): Input tensor with shape `(B, T, input_dim)` (for Conformer).
            y (torch.Tensor): Input tensor with shape `(B, height * width)` (for PredEncoder).

        Returns:
            torch.Tensor: Output tensor with shape `(B, 2)` (binary classification result).
        """
        # Pass the input through the Conformer (Generator) layers and get the input copy
        generator_output, input_copy = self.generator(x)

        # Pass the input through the PredEncoder
        pred_encoder_output = self.pred_encoder(y)

        # Concatenate Conformer output, PredEncoder output, and the original input tensor
        concatenated_input = torch.cat(
            (generator_output, pred_encoder_output, x), dim=1
        )

        # Apply linear transformation for final classification
        output = self.linear(concatenated_input)

        return output, input_copy

# trainer

In [36]:
from tqdm import tqdm
from torch.utils.data import DataLoader
import time


class Trainer:
    def __init__(
        self,
        config: dict,
        train_loader: DataLoader,
        val_loader: DataLoader,
        load_model: bool = False,
    ):
        self.config = config
        self.train_loader = train_loader
        self.val_loader = val_loader

        if load_model:
            self.gen = torch.load(config["gen_path"])
            self.dis = torch.load(config["dis_path"])
        else:
            self.gen = Generator(
                input_dim=config["input_dim"],
                num_heads=config["num_heads"],
                ffn_dim=config["ffn_dim"],
                num_layers=config["num_layers"],
                depthwise_conv_kernel_size=config["depthwise_conv_kernel_size"],
                dropout=config["dropout"],
                use_group_norm=config["use_group_norm"],
                convolution_first=config["convolution_first"],
            )
            self.dis = Discriminator(
                input_dim=config["input_dim"],
                num_heads=config["num_heads"],
                ffn_dim=config["ffn_dim"],
                num_layers=config["num_layers"],
                depthwise_conv_kernel_size=config["depthwise_conv_kernel_size"],
                dropout=config["dropout"],
                use_group_norm=config["use_group_norm"],
                convolution_first=config["convolution_first"],
            )
        self.gen.to(self.config["device"])
        self.dis.to(self.config["device"])

        self.G_optimizer = torch.optim.Adam(self.gen.parameters(), lr=config["lr"])
        self.D_optimizer = torch.optim.Adam(self.dis.parameters(), lr=config["lr"])
        self.G_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.G_optimizer)
        self.D_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.D_optimizer)

        self.criterion = nn.MSELoss()

        self.clip_value = config["clip_value"]

        self.early_count = 0
        self.best_val_loss = float("inf")

    def normal_evaluate_G(self, G_val_losses: list):
        self.gen.eval()
        for i, (x, y) in enumerate(tqdm(self.val_loader)):
            x = x.to(self.config["device"])
            y = y.to(self.config["device"])
            output, _ = self.gen(x)
            loss = self.criterion(output, y)
            G_val_losses.append(loss.item())

    def evaluate_G(self, G_val_losses: list, G_accs: list):
        print(f"Evaluating generator:")

        # Set the generator and discriminator in evaluation mode
        self.gen.eval()
        self.dis.eval()
        total_loss = 0
        total_correct = 0

        # Iterate through the validation loader
        for i, (x, y) in enumerate(tqdm(self.val_loader)):
            x = x.to(self.config["device"])
            y = y.to(self.config["device"])

            with torch.no_grad():
                # Generate fake data and conditioning information from the generator
                output, condition = self.gen(x)

                # Pass fake data and conditioning information through the discriminator
                fake_pred, _ = self.dis(condition, output)

            # Determine the predicted classes for fake and real samples
            fake_indices = torch.argmax(fake_pred, dim=1)
            real_indices = torch.argmax(y, dim=1)

            # Count correct predictions
            correct = torch.sum(fake_indices == real_indices)
            total_correct += correct

            # Compute generator loss for both the image output and the discriminator predictions
            loss = self.criterion(output, y) + -torch.mean(fake_pred)
            total_loss += loss.item()

        self.G_scheduler.step(total_loss)

        # Calculate and store the average generator validation loss
        average_loss = total_loss / len(self.val_loader)
        G_val_losses.append(average_loss)

        # Calculate and store the validation accuracy
        accuracy = total_correct / len(self.val_loader.dataset)
        G_accs.append(accuracy)
        print(f"G Validation accuracy: {accuracy}")

        if average_loss < self.best_val_loss:
            self.best_val_loss = average_loss
            torch.save(self.gen, self.config["gen_path"])
            torch.save(self.dis, self.config["dis_path"])
            self.early_count = 0
        else:
            self.early_count += 1

    def evaluate_D(self, D_val_losses: list, D_accs: list):
        print(f"Evaluating discriminator:")

        # Set the generator and discriminator in evaluation mode
        self.gen.eval()
        self.dis.eval()

        total_loss = 0
        total_correct = 0
        total_fake_loss = 0
        total_real_loss = 0

        # Iterate through the validation loader
        for i, (x, y) in enumerate(tqdm(self.val_loader)):
            x = x.to(self.config["device"])
            y = y.to(self.config["device"])

            # Generate fake data and conditioning information from the generator
            output, condition = self.gen(x)
            with torch.no_grad():
                # Pass fake data and conditioning information through the discriminator
                fake_pred, condition = self.dis(condition, output)
                real_pred, _ = self.dis(condition, y)

            total_correct += torch.sum(fake_pred < 0.5) + torch.sum(real_pred > 0.5)

            total_loss += -torch.mean(real_pred) + torch.mean(fake_pred)
            total_fake_loss += torch.mean(fake_pred)
            total_real_loss += -torch.mean(real_pred)

        self.D_scheduler.step(total_loss)

        # Calculate and store the average discriminator validation loss
        average_loss = total_loss / len(self.val_loader)
        D_val_losses.append(average_loss.item())
        print(f"Discriminator loss: {average_loss}")
        print(f"Fake loss: {total_fake_loss / len(self.val_loader)}")
        print(f"Real loss: {total_real_loss / len(self.val_loader)}")

        # Calculate and store the validation accuracy
        accuracy = total_correct / len(self.val_loader.dataset) / 2
        D_accs.append(accuracy.item())
        print(f"D Validation accuracy: {accuracy}")

        return accuracy

    def normal_train_G(self, G_losses: list):
        self.gen.train()
        total_loss = 0
        start = time.time()
        for i, (x, y) in enumerate(tqdm(self.train_loader)):
            data_time = time.time() - start
            print(f"Data time: {data_time}")
            self.G_optimizer.zero_grad()
            x = x.to(self.config["device"])
            y = y.to(self.config["device"])
            output, _ = self.gen(x)
            loss = self.criterion(output, y)
            loss.backward()
            self.G_optimizer.step()
            total_loss += loss.item()
            start = time.time()

        G_losses.append(total_loss / len(self.train_loader))
        print(f"Generator loss: {total_loss / len(self.train_loader)}")

    def train_G(self, x, y):
        x = x.to(self.config["device"])
        y = y.to(self.config["device"])

        # Generate fake data and conditioning information from the generator
        fake_output, condition = self.gen(x)

        # Pass fake data and conditioning information through the discriminator
        fake_pred, _ = self.dis(condition, fake_output)

        # Compute discriminator loss and normal loss using binary cross-entropy loss
        D_loss = -torch.mean(fake_pred)
        normal_loss = self.criterion(fake_output, y)

        # Total loss for the generator: discriminator loss + normal loss
        loss = D_loss + normal_loss

        # Backpropagation and optimization step
        self.G_optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.gen.parameters(), max_norm=self.clip_value)
        self.G_optimizer.step()

        return loss.item()

    def cal_gradient_penalty(self, real, fake, condition, lambda_gp=10):
        batch_size = real.size(0)
        alpha = torch.rand((batch_size, 1), dtype=real.dtype, device=real.device)

        # Interpolate between real and fake samples based on alpha
        interpolates = alpha * real + (1 - alpha) * fake
        interpolates = torch.autograd.Variable(interpolates, requires_grad=True)

        # Pass the interpolated samples through the discriminator
        disc_interpolates, _ = self.dis(condition, interpolates)

        # Compute gradients of the interpolated samples with respect to inputs
        gradients = torch.autograd.grad(
            outputs=disc_interpolates,
            inputs=interpolates,
            grad_outputs=torch.ones_like(disc_interpolates),
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]

        # Flatten and calculate the norm of the gradients for each sample in the batch
        gradients = gradients.view(batch_size, -1)
        gradient_norm = gradients.norm(2, dim=1)

        # Calculate gradient penalty based on the Lipschitz constraint formula
        gradient_penalty = ((gradient_norm - 1) ** 2).mean()

        # Scale the gradient penalty by lambda_gp and add it to the loss
        return lambda_gp * gradient_penalty

    def train_D(self, x, y):
        # Prevent original inputs from being changed
        x_copy = x.clone()
        y_copy = y.clone()

        # Move real data and labels to the specified device
        x = x.to(self.config["device"])
        y = y.to(self.config["device"])

        # Clone real data for the gradient penalty calculation
        real_output = torch.clone(y)

        # Generate fake data and conditioning information from the generator
        fake_output, condition = self.gen(x)

        # Pass fake data and conditioning information through the discriminator
        fake_pred, condition = self.dis(condition, torch.clone(fake_output))
        real_pred, condition = self.dis(condition, torch.clone(y))

        # Calculate the gradient penalty
        gradient_penalty = self.cal_gradient_penalty(
            real_output, fake_output, condition
        )

        # Calculate the total loss: -real + fake + gradient penalty
        loss = -torch.mean(real_pred) + torch.mean(fake_pred) + gradient_penalty

        # Backpropagation and optimization step
        self.D_optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.dis.parameters(), max_norm=self.clip_value)
        self.D_optimizer.step()

        # Restore original inputs
        x = x_copy
        y = y_copy

        return loss.item()

    def train(self):
        G_losses = []
        G_val_losses = []
        G_accs = []
        D_losses = []
        D_val_losses = []
        D_accs = []

        D_acc = 0
        for epoch in range(self.config["epochs"]):
            D_total_loss = 0
            G_total_loss = 0
            print(f'Epoch {epoch+1}/{self.config["epochs"]}')
            for i, (x, y) in enumerate(tqdm(self.train_loader)):
                if D_acc < 0.8:
                    D_loss = self.train_D(x, y)
                    D_total_loss += D_loss

                G_loss = self.train_G(x, y)

                G_total_loss += G_loss

            print(f"Discriminator loss: {D_total_loss / len(self.train_loader)}")
            D_losses.append(D_total_loss / len(self.train_loader))
            print(f"Generator loss: {G_total_loss / len(self.train_loader)}")
            G_losses.append(G_total_loss / len(self.train_loader))

            D_acc = self.evaluate_D(D_val_losses, D_accs)
            self.evaluate_G(G_val_losses, G_accs)

            if self.early_count >= self.config["early_stop"]:
                break

        return {
            "G_losses": G_losses,
            "G_val_losses": G_val_losses,
            "G_accs": G_accs,
            "D_losses": D_losses,
            "D_val_losses": D_val_losses,
            "D_accs": D_accs,
        }

# parameter finder

In [37]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader

# domain = {
#     "input_dim": [2 * 19 * 19],
#     "num_heads": [1, 2],
#     "ffn_dim": [64, 128, 256, 512, 1024],
#     "num_layers": [2, 4, 8],
#     "depthwise_conv_kernel_size": [3, 5, 7],
#     "dropout": [0, 0.1, 0.2, 0.3, 0.4],
#     "use_group_norm": [True, False],
#     "convolution_first": [True, False],
#     "lr": [0.0001, 0.001, 0.01],
#     "gen_path": ["data/models/gen.pth"],
#     "dis_path": ["data/models/dis.pth"],
#     "device": [torch.device("cuda" if torch.cuda.is_available() else "cpu")],
#     "batch_size": [512],
#     "clip_value": [1],
#     "data_len": [4, 8, 16, 32],
#     "epochs": [100],
#     "early_stop": [5],
# }

# for test
domain = {
    "input_dim": [2 * 19 * 19],
    "num_heads": [1],
    "ffn_dim": [64],
    "num_layers": [2],
    "depthwise_conv_kernel_size": [3],
    "dropout": [0],
    "use_group_norm": [True],
    "convolution_first": [True],
    "lr": [0.0001],
    "gen_path": ["data/models/gen.pth"],
    "dis_path": ["data/models/dis.pth"],
    "device": [torch.device("cuda" if torch.cuda.is_available() else "cpu")],
    "batch_size": [512],
    "clip_value": [1],
    "data_len": [4],
    "epochs": [100],
    "early_stop": [5],
}

# print domain count
count = 1
for key, value in domain.items():
    count *= len(value)
print(f"Total combinations: {count}")


class ParmFinder:
    def __init__(self, domain: dict) -> None:
        # Initialize parameters and data structures
        self.best_ratio = float("inf")
        self.best_params = None
        self.G_parms = []
        self.G_train_losses = []
        self.G_ratios = []
        self.D_parms = []
        self.domain = domain
        self.max_iter = 50
        self.max_epoch = 5
        self.G_history_path = "data/G_history.csv"

    def __random_sample(self):
        # Randomly sample parameters from the given domain
        params = {}
        for key, value in self.domain.items():
            params[key] = np.random.choice(value)

        print(f"Current params: {params}")

        goDataset = GoDataset("data/train/dan_train.csv", params["data_len"])
        train_len = int(0.8 * len(goDataset))
        val_len = len(goDataset) - train_len
        train_dataset, val_dataset = torch.utils.data.random_split(
            goDataset, [train_len, val_len]
        )
        self.train_loader = DataLoader(
            train_dataset, batch_size=int(params["batch_size"]), shuffle=True, pin_memory=True
        )
        self.val_loader = DataLoader(
            val_dataset, batch_size=int(params["batch_size"]), shuffle=False, pin_memory=True
        )

        return params

    def __save_G(self):
        # Save G_parms, G_train_losses, and G_ratios to a CSV file and best G model
        header = list(self.domain.keys()) + ["train_loss", "loss_ratio"]
        df = pd.DataFrame(self.G_parms, columns=header)
        df["train_loss"] = self.G_train_losses
        df["loss_ratio"] = self.G_ratios
        df.sort_values(by="loss_ratio", ascending=False, inplace=True)
        df.to_csv(self.G_history_path, index=False)

    def __evaluate_G(self, trainer: Trainer):
        # Evaluate generator performance over multiple epochs
        train_loss = 0
        loss_ratio = 0
        for epoch in range(self.max_epoch):
            G_losses = []
            G_val_losses = []
            trainer.normal_train_G(G_losses)
            trainer.normal_evaluate_G(G_val_losses)

            train_loss = np.mean(G_losses)
            val_loss = np.mean(G_val_losses)
            loss_ratio = train_loss / val_loss

            if loss_ratio < self.best_ratio:
                self.best_ratio = loss_ratio
                self.best_params = trainer.config
                self.best_G = trainer.gen
                self.best_D = trainer.dis

            print(
                f"Epoch {epoch+1}/{self.max_epoch}: Train Loss: {train_loss}, Val Loss: {val_loss}, Loss Ratio: {loss_ratio}"
            )

        self.G_parms.append(trainer.config)
        self.G_train_losses.append(train_loss)
        self.G_ratios.append(loss_ratio)
        self.__save_G()

    def find(self):
        # Iterate for a maximum number of iterations
        for _ in range(self.max_iter):
            params = self.__random_sample()
            trainer = Trainer(params, self.train_loader, self.val_loader)
            self.__evaluate_G(trainer)

        return self.best_params, self.train_loader, self.val_loader

Total combinations: 1


In [38]:
parmFinder = ParmFinder(domain)
parms, train_loader, val_loader = parmFinder.find()
print(f"Best params: {parms}")

Current params: {'input_dim': 722, 'num_heads': 1, 'ffn_dim': 64, 'num_layers': 2, 'depthwise_conv_kernel_size': 3, 'dropout': 0, 'use_group_norm': True, 'convolution_first': True, 'lr': 0.0001, 'gen_path': 'data/models/gen.pth', 'dis_path': 'data/models/dis.pth', 'device': device(type='cuda'), 'batch_size': 512, 'clip_value': 1, 'data_len': 4, 'epochs': 100, 'early_stop': 5}


  4%|▍         | 6/157 [00:53<22:18,  8.87s/it]


KeyboardInterrupt: 

# main

In [None]:
trainer = Trainer(parms, train_loader, val_loader)
statistic = trainer.train()

print(statistic)