In [None]:
# Import Libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# Load and Preprocess the Data
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert images to tensors
    transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Define the PixelCNN Model Architecture
class PixelCNN(nn.Module):
    def __init__(self):
        super(PixelCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=7, padding=3)  # Large receptive field
        self.conv2 = nn.Conv2d(64, 64, kernel_size=7, padding=3)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=1)  # 1x1 conv to mix features
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        out = self.relu(self.conv1(x))
        out = self.relu(self.conv2(out))
        out = self.sigmoid(self.conv3(out))  # Ensure output is in [0, 1]
        return out

model = PixelCNN()

# Define the Loss Function and Optimizer
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the Model
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    for images, _ in train_loader:
        images = (images + 1) / 2  # Transform images to [0, 1]

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, images)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

# Generate Samples
def generate_image(model):
    model.eval()
    with torch.no_grad():
        generated_image = torch.zeros(1, 1, 28, 28)  # Start with a blank image
        for i in range(28):
            for j in range(28):
                output = model(generated_image)
                generated_image[0, 0, i, j] = output[0, 0, i, j]
    generated_image = generated_image.squeeze().numpy()
    plt.imshow(generated_image, cmap='gray')
    plt.show()

generate_image(model)
