In [2]:
import torch
import torch.nn as nn

class Conv2D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=1, pool_size=2, pool_stride=2):
        super(Conv2D, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
        self.stride = stride
        self.padding = padding
        self.pool_size = pool_size
        self.pool_stride = pool_stride
        
        # Initialize weights and biases using nn.Parameter to track gradients
        self.weights = nn.Parameter(torch.randn(out_channels, in_channels, *self.kernel_size) * 0.01)
        self.biases = nn.Parameter(torch.zeros(out_channels, 1))

    def add_padding(self, input_matrices: torch.tensor):
        batch_size, in_channels, height, width = input_matrices.shape
        padded_matrices = torch.zeros(batch_size, in_channels, height + 2 * self.padding, width + 2 * self.padding)
        for b in range(batch_size):
            for c in range(in_channels):
                input_matrix = input_matrices[b, c]
                horizontal_pads = torch.zeros(height, self.padding)
                vertical_pads = torch.zeros(self.padding, width + 2 * self.padding)
                padded_matrix = torch.cat((horizontal_pads, input_matrix), dim=1)  
                padded_matrix = torch.cat((padded_matrix, horizontal_pads), dim=1)
                padded_matrix = torch.cat((vertical_pads, padded_matrix), dim=0)  
                padded_matrix = torch.cat((padded_matrix, vertical_pads), dim=0)
                padded_matrices[b, c] = padded_matrix
        return padded_matrices
    
    def forward(self, input_matrix):
        B, C, H, W = input_matrix.shape
        input_matrix = self.add_padding(input_matrix)
        kernel_height, kernel_width = self.kernel_size
        stride = self.stride
        padding = self.padding
        OH = round(((W + 2*padding - kernel_height)/stride)) + 1
        OW = round(((W + 2*padding - kernel_width)/stride)) + 1

        output = torch.zeros(B, self.out_channels, OH, OW)

        for i in range(OH):
            for j in range(OW):
                height_start = i * stride
                width_start = j * stride
                height_end = height_start + kernel_height
                width_end = width_start + kernel_width
                input_region = input_matrix[:, :, height_start:height_end, width_start:width_end]
                for k in range(self.out_channels):
                    output[:, k, i, j] = torch.sum(input_region * self.weights[k], dim=(1, 2, 3)) + self.biases[k]

        return output
    
    def pooling(self, input_matrix, pool_type='max'):
        B, C, H, W = input_matrix.shape
        OH = (H - self.pool_size) // self.pool_stride + 1
        OW = (W - self.pool_size) // self.pool_stride + 1

        output = torch.zeros(B, C, OH, OW)

        for b in range(B):
            for c in range(C):
                for i in range(OH):
                    for j in range(OW):
                        height_start = i * self.pool_stride
                        width_start = j * self.pool_stride
                        height_end = height_start + self.pool_size
                        width_end = width_start + self.pool_size
                        input_region = input_matrix[b, c, height_start:height_end, width_start:width_end]

                        if pool_type == 'max':
                            output[b, c, i, j] = torch.max(input_region)
                        elif pool_type == 'average':
                            output[b, c, i, j] = torch.mean(input_region)
                        else:
                            raise ValueError(f"Unsupported pool_type: {pool_type}")

        return output

    def mse_loss(self, output, target):
        return torch.mean((output - target) ** 2)

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()

        # Layers
        self.conv1 = Conv2D(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.conv2 = Conv2D(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.fc1 = torch.nn.Linear(64 * 7 * 7, 128)  # Fully connected layer after flattening
        self.fc2 = torch.nn.Linear(128, 10)  # Output layer (for classification of 10 classes)

        # CrossEntropyLoss for classification
        self.criterion = torch.nn.CrossEntropyLoss()

        # Store the weights manually for custom gradient updates
        self.weights_list = [self.conv1.weights, self.conv1.biases, self.conv2.weights, self.conv2.biases, self.fc1.weight, self.fc1.bias, self.fc2.weight, self.fc2.bias]

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv1.pooling(x, pool_type='max')
        
        x = self.conv2(x)
        x = self.conv2.pooling(x, pool_type='max')

        x = x.view(x.size(0), -1)  # Flatten the tensor before feeding it to the fully connected layer
        
        x = torch.relu(self.fc1(x))  # Apply ReLU activation
        x = self.fc2(x)  # Output layer
        
        return x

    def backward(self, output, target):
        loss = self.criterion(output, target)
        loss.backward()
        return loss

    def update_weights(self, learning_rate=0.001):
        with torch.no_grad():
            # Manually update weights for each layer
            for param in self.weights_list:
                param -= learning_rate * param.grad

            # Zero out gradients
            self.zero_grad()

    def zero_grad(self):
        # Zero the gradients of each parameter in the network
        for param in self.weights_list:
            param.grad.zero_()

# Example usage:
model = CNN()

# Define an optimizer (Adam in this case) for regular updates, though we use manual weight updates here
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Forward pass
input_matrix = torch.randn(1, 1, 28, 28)  # Example batch of images (batch size = 1, 28x28 grayscale image)
output = model(input_matrix)

# Dummy target labels for the example (assuming batch size of 1)
target = torch.tensor([3])  # Suppose the target class is '3'

# Calculate loss
loss = model.backward(output, target)

# Update the weights using manual update
model.update_weights(learning_rate=0.001)

# Zero the gradients after updating
model.zero_grad()
