In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# Define transformations for data augmentation and normalization
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to [-1, 1]
])

# Load CIFAR-10 dataset
batch_size = 64
trainset_cifar10 = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset_cifar10 = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
trainloader_cifar10 = DataLoader(trainset_cifar10, batch_size=batch_size, shuffle=True)
testloader_cifar10 = DataLoader(testset_cifar10, batch_size=batch_size, shuffle=False)

# Classes in CIFAR-10
classes_cifar10 = trainset_cifar10.classes
print(f"Classes: {classes_cifar10}")

# Define the CNN model
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 8 * 8, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)  # 10 classes for CIFAR-10
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.pool(nn.ReLU()(self.conv1(x)))
        x = self.pool(nn.ReLU()(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)  # Flatten
        x = nn.ReLU()(self.fc1(x))
        x = self.dropout(nn.ReLU()(self.fc2(x)))
        x = self.fc3(x)
        return x

# Initialize model, loss function, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training function
def train_model(model, trainloader, criterion, optimizer, epochs=10):
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in trainloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {running_loss / len(trainloader):.4f}")

# Train the model on CIFAR-10
train_model(model, trainloader_cifar10, criterion, optimizer, epochs=10)

# Testing function
def evaluate_model(model, testloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

# Evaluate CIFAR-10 model
accuracy_cifar10 = evaluate_model(model, testloader_cifar10)
print(f"CIFAR-10 Test Accuracy: {accuracy_cifar10 * 100:.2f}%")

# Load CIFAR-100 dataset
trainset_cifar100 = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
testset_cifar100 = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
trainloader_cifar100 = DataLoader(trainset_cifar100, batch_size=batch_size, shuffle=True)
testloader_cifar100 = DataLoader(testset_cifar100, batch_size=batch_size, shuffle=False)

# Classes in CIFAR-100
classes_cifar100 = trainset_cifar100.classes
print(f"Number of classes in CIFAR-100: {len(classes_cifar100)}")

# Modify the final layer for 100 classes
model.fc3 = nn.Linear(128, 100).to(device)

# Reinitialize the optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the model on CIFAR-100
train_model(model, trainloader_cifar100, criterion, optimizer, epochs=10)

# Evaluate CIFAR-100 model
accuracy_cifar100 = evaluate_model(model, testloader_cifar100)
print(f"CIFAR-100 Test Accuracy: {accuracy_cifar100 * 100:.2f}%")

# Visualize some predictions
def visualize_predictions(model, testloader, classes):
    model.eval()
    inputs, labels = next(iter(testloader))
    inputs, labels = inputs.to(device), labels.to(device)
    outputs = model(inputs)
    _, predicted = torch.max(outputs, 1)
    # Display images and predictions
    inputs = inputs.cpu().numpy()
    plt.figure(figsize=(12, 6))
    for i in range(8):
        plt.subplot(2, 4, i + 1)
        plt.imshow(np.transpose(inputs[i], (1, 2, 0)) / 2 + 0.5)  # Denormalize
        plt.title(f"Predicted: {classes[predicted[i]]}\nTrue: {classes[labels[i]]}")
        plt.axis('off')
    plt.show()

# Visualize predictions for CIFAR-10
visualize_predictions(model, testloader_cifar10, classes_cifar100)

# Visualize predictions for CIFAR-100
visualize_predictions(model, testloader_cifar100, classes_cifar100)
