In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F

# Step 1: Generate Synthetic Data
image_size = 8  # Size of each image (smaller for demonstration)
num_images = 1000  # Number of images for training

# Generate random images
X1 = np.random.rand(num_images, image_size, image_size).astype(np.float32)
X2 = np.random.rand(num_images, image_size, image_size).astype(np.float32)

# Calculate element-wise product
y = X1 * X2  # The target output is the element-wise product of X1 and X2

# Convert data to PyTorch tensors
X1_tensor = torch.tensor(X1).unsqueeze(1)  # Add channel dimension
X2_tensor = torch.tensor(X2).unsqueeze(1)
y_tensor = torch.tensor(y).unsqueeze(1)

# Smaller channel size
hs = 16

# Step 2: Define CNN Architecture
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()

        # First convolution:
        # Input: (B, 1, 8, 8)
        # Operation: Apply hs (16) filters of size 3x3 → Output: (B, hs, 6, 6)
        self.conv1 = nn.Conv2d(1, hs, kernel_size=3)

        # First pooling:
        # Operation: MaxPool with 2x2 kernel → Downsample to (B, hs, 3, 3)
        self.pool1 = nn.MaxPool2d(2, 2)

        # Second convolution:
        # Operation: Apply hs filters of size 3x3 again → Output: (B, hs, 1, 1)
        self.conv2 = nn.Conv2d(hs, hs, kernel_size=3)

        # Second pooling:
        # Pool with size 1x1 (does nothing, just avoids error here)
        self.pool2 = nn.MaxPool2d(1, 1)

        # Fully Connected Layers:
        # After processing x1 and x2 → each is (B, hs*1*1)
        # Concatenated → (B, 2 * hs)
        self.fc1 = nn.Linear(hs * 1 * 1 * 2, hs * 2)

        # Output: Predict flattened 8×8 image → 64 values
        self.fc2 = nn.Linear(hs * 2, image_size * image_size)

    def forward(self, x1, x2):

        # ───── CONV → RELU → POOL: First Input ─────
        x1 = self.conv1(x1)    # conv1: (B, 1, 8, 8) → (B, hs, 6, 6)
        x1 = F.relu(x1)        # Apply nonlinearity
        x1 = self.pool1(x1)    # MaxPool: (B, hs, 6, 6) → (B, hs, 3, 3)

        x1 = self.conv2(x1)    # conv2: (B, hs, 3, 3) → (B, hs, 1, 1)
        x1 = F.relu(x1)        # Nonlinearity
        x1 = self.pool2(x1)    # Identity (1x1 pooling)
        x1 = torch.flatten(x1, 1)  # Flatten: (B, hs)

        # ───── CONV → RELU → POOL: Second Input ─────
        x2 = self.conv1(x2)    # Same conv1 weights reused
        x2 = F.relu(x2)
        x2 = self.pool1(x2)
        x2 = self.conv2(x2)
        x2 = F.relu(x2)
        x2 = self.pool2(x2)
        x2 = torch.flatten(x2, 1)

        # ───── CONCATENATE AND DENSE LAYERS ─────
        x = torch.cat((x1, x2), dim=1)  # (B, 2*hs)

        x = self.fc1(x)        # Dense: (B, 2*hs) → (B, 2*hs)
        x = F.relu(x)
        x = self.fc2(x)        # Dense: (B, 2*hs) → (B, 64)

        # ───── RESHAPE TO IMAGE FORMAT ─────
        x = torch.reshape(x, (-1, 1, 8, 8))  # (B, 1, 8, 8)

        return x
    
# Step 3: Create DataLoader
batch_size = 32  # Adjusted for practical batch size
dataset = TensorDataset(X1_tensor, X2_tensor, y_tensor)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Step 4: Define Training Loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 10

# Training loop
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, (inputs1, inputs2, labels) in enumerate(dataloader):
        inputs1, inputs2, labels = inputs1.to(device), inputs2.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs1, inputs2)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if (i+1) % 10 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], Loss: {running_loss/10:.4f}')
            running_loss = 0.0
            
print('Finished Training')
