In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# CIFAR-10 Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_data = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
val_data = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_data, batch_size=64, shuffle=False)

# Updated CNN Model
class UpdatedCNN(nn.Module):
    def __init__(self):
        super(UpdatedCNN, self).__init__()
        # Increased filters and added batch normalization
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.4)
        
        # Fully connected layers
        self.fc1 = nn.Linear(256 * 2 * 2, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.pool(nn.ReLU()(self.bn1(self.conv1(x))))
        x = self.pool(nn.ReLU()(self.bn2(self.conv2(x))))
        x = self.pool(nn.ReLU()(self.bn3(self.conv3(x))))
        x = self.pool(nn.ReLU()(self.bn4(self.conv4(x))))
        x = x.view(-1, 256 * 2 * 2)
        x = self.dropout(nn.ReLU()(self.fc1(x)))
        x = self.fc2(x)
        return x

# Initialize model, loss function, and optimizer
model = UpdatedCNN().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training and Evaluation Functions
def train_one_epoch(dataloader, model, loss_fn, optimizer):
    model.train()
    total_loss, correct = 0, 0
    for images, labels in tqdm(dataloader, desc="Training", leave=False):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        correct += (outputs.argmax(dim=1) == labels).sum().item()
    return correct / len(dataloader.dataset), total_loss / len(dataloader)

def evaluate(dataloader, model, loss_fn):
    model.eval()
    total_loss, correct = 0, 0
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Validation", leave=False):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            total_loss += loss.item()
            correct += (outputs.argmax(dim=1) == labels).sum().item()
    return correct / len(dataloader.dataset), total_loss / len(dataloader)

# Training Loop
epochs = 30
training_losses, validation_losses = [], []
training_accuracies, validation_accuracies = [], []

for j in range(epochs):
    train_acc, train_loss = train_one_epoch(train_dataloader, model, loss_fn, optimizer)
    training_losses.append(train_loss)
    training_accuracies.append(train_acc)
    
    val_acc, val_loss = evaluate(val_dataloader, model, loss_fn)
    validation_losses.append(val_loss)
    validation_accuracies.append(val_acc)
    
    print(f"Epoch {j + 1}/{epochs} -> Train Loss: {train_loss:.3f}, Train Acc: {train_acc:.3f}, "
          f"Val Loss: {val_loss:.3f}, Val Acc: {val_acc:.3f}")

# Plotting Results
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(range(1, epochs + 1), training_losses, label="Training Loss")
plt.plot(range(1, epochs + 1), validation_losses, label="Validation Loss")
plt.legend()
plt.title("Loss Over Epochs")

plt.subplot(1, 2, 2)
plt.plot(range(1, epochs + 1), training_accuracies, label="Training Accuracy")
plt.plot(range(1, epochs + 1), validation_accuracies, label="Validation Accuracy")
plt.legend()
plt.title("Accuracy Over Epochs")
plt.show()
