In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from google.colab import drive

In [2]:
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
# Define the CNN model for CIFAR-10 classification
class CNNClassifier(nn.Module):
    def __init__(self, embed_dim=512, num_classes=10):
        super(CNNClassifier, self).__init__()
        # Define the convolutional layers
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),  # (B, 3, 224, 224) -> (B, 32, 224, 224)
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # (B, 32, 224, 224) -> (B, 32, 112, 112)

            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  # (B, 32, 112, 112) -> (B, 64, 112, 112)
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # (B, 64, 112, 112) -> (B, 64, 56, 56)

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),  # (B, 64, 56, 56) -> (B, 128, 56, 56)
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # (B, 128, 56, 56) -> (B, 128, 28, 28)
        )
        self.fc = nn.Sequential(
            nn.Linear(128 * 28 * 28, embed_dim),  # Fully connected layer for embedding
            nn.ReLU(),
            nn.Linear(embed_dim, num_classes)  # Classification head for CIFAR-10 (10 classes)
        )

    def forward(self, x):
        x = self.conv_layers(x)  # Pass through convolutional layers
        x = x.view(x.size(0), -1)  # Flatten the feature map
        return self.fc(x)  # Output class scores


In [4]:
# Training function for CNN classifier
def train_cnn_classifier(model, train_loader, val_loader, epochs=100, lr=0.001):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    best_loss = float('inf')
    best_model_path = '/content/drive/MyDrive/best_standard_cnn_classifier_model.pth'

    for epoch in range(epochs):
        model.train()  # Set the model to training mode
        running_loss = 0.0
        correct = 0
        total = 0

        # Training loop
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            # Calculate accuracy
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        avg_loss = running_loss / len(train_loader)
        accuracy = 100 * correct / total

        print(f'Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%')

        # Validation loop
        model.eval()  # Set the model to evaluation mode
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for val_images, val_labels in val_loader:
                val_images, val_labels = val_images.to(device), val_labels.to(device)

                # Forward pass
                val_outputs = model(val_images)
                val_loss += criterion(val_outputs, val_labels).item()

                _, val_predicted = torch.max(val_outputs.data, 1)
                val_total += val_labels.size(0)
                val_correct += (val_predicted == val_labels).sum().item()

        val_loss /= len(val_loader)
        val_accuracy = 100 * val_correct / val_total

        print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%')

        # Save the best model
        if val_loss < best_loss:
            best_loss = val_loss
            torch.save(model.state_dict(), best_model_path)  # Save the model if validation loss improves
            print(f"Best model saved at epoch {epoch + 1} with validation loss {best_loss:.4f}")

    print("CNN classifier training complete!")


In [5]:

# DataLoader for CIFAR-10
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resizing for CNN
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

In [6]:

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(trainset, batch_size=64, shuffle=True)

valset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
val_loader = DataLoader(valset, batch_size=64, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:12<00:00, 13145206.54it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [7]:

# Initialize the CNN classifier
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cnn_classifier = CNNClassifier().to(device)

In [8]:
# Train the CNN classifier on CIFAR-10
train_cnn_classifier(cnn_classifier, train_loader, val_loader, epochs=100, lr=0.001)

Epoch 1/100, Loss: 1.4578, Accuracy: 47.87%
Validation Loss: 1.2105, Validation Accuracy: 56.23%
Best model saved at epoch 1 with validation loss 1.2105
Epoch 2/100, Loss: 1.0308, Accuracy: 63.47%
Validation Loss: 1.0968, Validation Accuracy: 60.61%
Best model saved at epoch 2 with validation loss 1.0968
Epoch 3/100, Loss: 0.7189, Accuracy: 74.67%
Validation Loss: 1.0529, Validation Accuracy: 65.17%
Best model saved at epoch 3 with validation loss 1.0529
Epoch 4/100, Loss: 0.3726, Accuracy: 87.06%
Validation Loss: 1.3173, Validation Accuracy: 64.23%
Epoch 5/100, Loss: 0.1539, Accuracy: 94.77%
Validation Loss: 1.7391, Validation Accuracy: 64.03%
Epoch 6/100, Loss: 0.0985, Accuracy: 96.65%
Validation Loss: 2.0545, Validation Accuracy: 64.40%
Epoch 7/100, Loss: 0.0754, Accuracy: 97.47%
Validation Loss: 2.2525, Validation Accuracy: 64.55%
Epoch 8/100, Loss: 0.0647, Accuracy: 97.81%
Validation Loss: 2.5047, Validation Accuracy: 63.01%
Epoch 9/100, Loss: 0.0675, Accuracy: 97.76%
Validation L

KeyboardInterrupt: 