In [None]:


import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from tqdm.notebook import tqdm

# CIFAR-10 Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_data = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
val_data = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_data, batch_size=64, shuffle=False)

# Modified CNN Model
class ModifiedCNN(nn.Module):
    def __init__(self):
        super(ModifiedCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.25)
        self.fc1 = nn.Linear(256 * 4 * 4, 512)  # Adjust input dimensions based on CIFAR-10 image size
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.pool(nn.ReLU()(self.conv1(x)))
        x = self.pool(nn.ReLU()(self.conv2(x)))
        x = self.pool(nn.ReLU()(self.conv3(x)))
        x = x.view(-1, 256 * 4 * 4)
        x = self.dropout(nn.ReLU()(self.fc1(x)))
        x = self.fc2(x)
        return x

model = ModifiedCNN()

# Loss Function and Optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training and Validation Functions
def train_one_epoch(dataloader, model, loss_fn, optimizer, progress_bar):
    model.train()
    for batch in dataloader:
        images, labels = batch
        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        progress_bar.update(1)

def evaluate(dataloader, model, loss_fn, progress_bar):
    model.eval()
    total_loss, correct = 0, 0
    with torch.no_grad():
        for batch in dataloader:
            images, labels = batch
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            total_loss += loss.item()
            correct += (outputs.argmax(dim=1) == labels).sum().item()
            progress_bar.update(1)
    return correct / len(dataloader.dataset), total_loss / len(dataloader)

# Training Loop
epochs = 15
training_losses, validation_losses = [], []
training_accuracies, validation_accuracies = [], []

for j in range(epochs):
    with tqdm(total=len(train_dataloader), position=0, leave=True, desc=f"Train Epoch {j}") as train_bar:
        train_one_epoch(train_dataloader, model, loss_fn, optimizer, train_bar)

    if j % 5 == 0:
        with tqdm(total=len(train_dataloader), position=0, leave=True, desc=f"Validate (train) Epoch {j}") as train_eval:
            acc, loss = evaluate(train_dataloader, model, loss_fn, train_eval)
            print(f"Epoch {j}: training loss: {loss:.3f}, accuracy: {acc:.3f}")
            training_losses.append(loss)
            training_accuracies.append(acc)

    with tqdm(total=len(val_dataloader), position=0, leave=True, desc=f"Validate Epoch {j}") as val_bar:
        acc_val, loss_val = evaluate(val_dataloader, model, loss_fn, val_bar)
        print(f"Epoch {j}: validation loss: {loss_val:.3f}, accuracy: {acc_val:.3f}")
        validation_losses.append(loss_val)
        validation_accuracies.append(acc_val)

# Plotting Results
import matplotlib.pyplot as plt

epochs_range = range(epochs)
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(epochs_range[::5], training_losses, label="Training Loss")
plt.plot(epochs_range, validation_losses, label="Validation Loss")
plt.legend()
plt.title("Loss Over Epochs")

plt.subplot(1, 2, 2)
plt.plot(epochs_range[::5], training_accuracies, label="Training Accuracy")
plt.plot(epochs_range, validation_accuracies, label="Validation Accuracy")
plt.legend()
plt.title("Accuracy Over Epochs")
plt.show()


Files already downloaded and verified
Files already downloaded and verified


Train Epoch 0:   0%|          | 0/782 [00:00<?, ?it/s]

Validate (train) Epoch 0:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 0: training loss: 0.958, accuracy: 0.657


Validate Epoch 0:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 0: validation loss: 1.007, accuracy: 0.638


Train Epoch 1:   0%|          | 0/782 [00:00<?, ?it/s]

Validate Epoch 1:   0%|          | 0/157 [00:00<?, ?it/s]

Epoch 1: validation loss: 0.796, accuracy: 0.722


Train Epoch 2:   0%|          | 0/782 [00:00<?, ?it/s]