In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

# Check if GPU is available, else use CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
num_epochs = 5
batch_size = 4
learning_rate = 0.001

# Define transformations: Normalize images and convert to tensors
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

# Download and prepare CIFAR-10 dataset
train_data = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_data = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Create data loaders for training and testing
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)

# Class names in CIFAR-10
class_labels = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# Function to display images
def show_images(img):
    img = img / 2 + 0.5  # Undo normalization
    np_image = img.numpy()
    plt.imshow(np.transpose(np_image, (1, 2, 0)))
    plt.show()

# Load a batch of training images for visualization
data_iter = iter(train_loader)
sample_images, sample_labels = next(data_iter)

# Display images with grid
show_images(torchvision.utils.make_grid(sample_images))

# Define Convolutional Neural Network (CNN)
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # First convolution + pooling
        x = self.pool(F.relu(self.conv2(x)))  # Second convolution + pooling
        x = x.view(-1, 16 * 5 * 5)           # Flatten for fully connected layers
        x = F.relu(self.fc1(x))              # First fully connected layer
        x = F.relu(self.fc2(x))              # Second fully connected layer
        x = self.fc3(x)                      # Output layer
        return x

# Initialize model, loss function, and optimizer
model = CNNModel().to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

# Training the model
for epoch in range(num_epochs):
    for batch_idx, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        predictions = model(images)
        loss = loss_function(predictions, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Print loss every 2000 steps
        if (batch_idx + 1) % 2000 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

print('Training complete.')

# Save the trained model
model_path = './cnn_model.pth'
torch.save(model.state_dict(), model_path)

# Evaluate the model on test data
model.eval()
with torch.no_grad():
    total_correct = 0
    total_samples = 0
    correct_per_class = [0] * 10
    samples_per_class = [0] * 10

    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predictions = torch.max(outputs, 1)

        # Total correct predictions
        total_samples += labels.size(0)
        total_correct += (predictions == labels).sum().item()

        # Per-class accuracy
        for i in range(batch_size):
            label = labels[i]
            pred = predictions[i]
            if label == pred:
                correct_per_class[label] += 1
            samples_per_class[label] += 1

    # Overall accuracy
    overall_accuracy = 100.0 * total_correct / total_samples
    print(f'Overall model accuracy: {overall_accuracy:.2f}%')

    # Per-class accuracy
    for i in range(10):
        class_accuracy = 100.0 * correct_per_class[i] / samples_per_class[i]
        print(f'Accuracy for {class_labels[i]}: {class_accuracy:.2f}%')
