# Problem: Implement a CNN for CIFAR-10 in PyTorch

### Problem Statement
You are tasked with implementing a **Convolutional Neural Network (CNN)** for image classification on the **CIFAR-10** dataset using PyTorch. The model should contain convolutional layers for feature extraction, pooling layers for downsampling, and fully connected layers for classification. Your goal is to complete the CNN model by defining the necessary layers and implementing the forward pass.

### Requirements
1. **Define the CNN Model**:
   - Add **convolutional layers** for feature extraction.
   - Add **pooling layers** to reduce the spatial dimensions.
   - Add **fully connected layers** to output class predictions.
   - The model should be capable of processing input images of size `(32x32x3)` as in the CIFAR-10 dataset.

### Constraints
- The CNN should be designed with multiple convolutional and pooling layers followed by fully connected layers.
- Ensure the model is compatible with the CIFAR-10 dataset, which contains 10 classes.


<details>
  <summary>💡 Hint</summary>
  Add the convolutional (conv1, conv2), pooling (pool), and fully connected layers (fc1, fc2) in CNNModel.__init__.
  <br>
  Implement the forward pass to process inputs through these layers.
</details>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

In [None]:
# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

In [None]:
# Define the CNN Model
# TODO: Add convolutional, pooling, and fully connected layers
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 6, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(6, 12, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.block3 = nn.Sequential(
            nn.Conv2d(12, 24, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.linear = torch.nn.Linear(24*4*4, 10)


    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = x.reshape(-1, 24*4*4)
        x = self.linear(x)
        x = torch.softmax(x, dim=-1)
        return x

In [None]:
# Initialize the model, loss function, and optimizer
model = CNNModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
epochs = 10
for epoch in range(epochs):
    for images, labels in train_loader:
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}")

In [None]:
# Evaluate on the test set
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Test Accuracy: {100 * correct / total:.2f}%")

In [None]:
# Evaluate on the test set using baseline
baseline_model = CNNModel()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = baseline_model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Test Accuracy: {100 * correct / total:.2f}%")