In [12]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from torchvision.models import alexnet
from tqdm import tqdm

# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [13]:

# Load CIFAR-10 dataset
transform_train = transforms.Compose([
    transforms.Resize(224),  # Resize images to 224x224 for AlexNet
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(224, padding=16),  # Adjust padding to fit the larger size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])  # CIFAR-10 normalization
])

transform_test = transforms.Compose([
    transforms.Resize(224),  # Resize test images to 224x224
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])  # CIFAR-10 normalization
])


trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)


Files already downloaded and verified
Files already downloaded and verified


In [14]:
# Load pre-defined AlexNet model and modify for CIFAR-10
model = alexnet(pretrained=False)
num_classes = 10
model.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)
model = model.to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)




In [15]:
# Training function
def train(model, trainloader, testloader, epochs=10):
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(tqdm(trainloader)):
            inputs, labels = inputs.to(device), labels.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)

            # Compute the loss
            loss = criterion(outputs, labels)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        # Log training loss
        epoch_loss = running_loss / len(trainloader)
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss:.4f}")

        # Evaluate on the test set after each epoch
        test_loss, test_accuracy = evaluate(model, testloader)
        print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")

# Evaluation function for validation loop
def evaluate(model, testloader):
    model.eval()
    total_loss = 0.0
    correct = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item()

            # Calculate accuracy
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
    total_loss /= len(testloader)
    accuracy = 100.0 * correct / len(testloader.dataset)
    return total_loss, accuracy



In [16]:
# Train the model for specified number of epochs
epochs = 50
train(model, trainloader, testloader, epochs)

# Save model checkpoint
torch.save(model.state_dict(), 'alexnet_cifar10.pth')
print("Model checkpoint saved.")


100%|██████████| 782/782 [00:51<00:00, 15.18it/s]

Epoch [1/50], Loss: 1.7563





Test Loss: 1.4225, Test Accuracy: 46.81%


100%|██████████| 782/782 [00:50<00:00, 15.33it/s]

Epoch [2/50], Loss: 1.4799





Test Loss: 1.3179, Test Accuracy: 53.01%


100%|██████████| 782/782 [00:51<00:00, 15.23it/s]

Epoch [3/50], Loss: 1.3559





Test Loss: 1.2370, Test Accuracy: 54.56%


100%|██████████| 782/782 [00:51<00:00, 15.18it/s]

Epoch [4/50], Loss: 1.2735





Test Loss: 1.1759, Test Accuracy: 58.47%


100%|██████████| 782/782 [00:51<00:00, 15.17it/s]


Epoch [5/50], Loss: 1.2310
Test Loss: 1.1468, Test Accuracy: 59.92%


100%|██████████| 782/782 [00:51<00:00, 15.17it/s]

Epoch [6/50], Loss: 1.1845





Test Loss: 1.1119, Test Accuracy: 60.46%


100%|██████████| 782/782 [00:51<00:00, 15.31it/s]

Epoch [7/50], Loss: 1.1463





Test Loss: 1.0416, Test Accuracy: 63.89%


100%|██████████| 782/782 [00:51<00:00, 15.19it/s]


Epoch [8/50], Loss: 1.1241
Test Loss: 0.9557, Test Accuracy: 66.80%


100%|██████████| 782/782 [00:51<00:00, 15.14it/s]


Epoch [9/50], Loss: 1.1021
Test Loss: 0.9800, Test Accuracy: 66.03%


100%|██████████| 782/782 [00:51<00:00, 15.18it/s]

Epoch [10/50], Loss: 1.0836





Test Loss: 0.9799, Test Accuracy: 66.09%


100%|██████████| 782/782 [00:51<00:00, 15.19it/s]

Epoch [11/50], Loss: 1.0670





Test Loss: 0.9315, Test Accuracy: 68.01%


100%|██████████| 782/782 [00:51<00:00, 15.24it/s]

Epoch [12/50], Loss: 1.0546





Test Loss: 0.9419, Test Accuracy: 68.22%


100%|██████████| 782/782 [00:51<00:00, 15.05it/s]

Epoch [13/50], Loss: 1.0397





Test Loss: 0.9553, Test Accuracy: 67.01%


100%|██████████| 782/782 [00:51<00:00, 15.14it/s]


Epoch [14/50], Loss: 1.0246
Test Loss: 0.9441, Test Accuracy: 67.22%


100%|██████████| 782/782 [00:51<00:00, 15.17it/s]

Epoch [15/50], Loss: 1.0115





Test Loss: 0.8750, Test Accuracy: 69.68%


100%|██████████| 782/782 [00:51<00:00, 15.08it/s]

Epoch [16/50], Loss: 0.9983





Test Loss: 0.8711, Test Accuracy: 70.13%


100%|██████████| 782/782 [00:51<00:00, 15.19it/s]

Epoch [17/50], Loss: 0.9935





Test Loss: 0.8869, Test Accuracy: 69.33%


100%|██████████| 782/782 [00:51<00:00, 15.07it/s]

Epoch [18/50], Loss: 0.9768





Test Loss: 0.8708, Test Accuracy: 70.02%


100%|██████████| 782/782 [00:52<00:00, 15.02it/s]

Epoch [19/50], Loss: 0.9660





Test Loss: 0.8907, Test Accuracy: 69.05%


100%|██████████| 782/782 [00:51<00:00, 15.13it/s]

Epoch [20/50], Loss: 0.9524





Test Loss: 0.8731, Test Accuracy: 70.16%


100%|██████████| 782/782 [00:51<00:00, 15.07it/s]


Epoch [21/50], Loss: 0.9434
Test Loss: 0.8712, Test Accuracy: 70.46%


100%|██████████| 782/782 [00:51<00:00, 15.12it/s]

Epoch [22/50], Loss: 0.9402





Test Loss: 0.8408, Test Accuracy: 71.65%


100%|██████████| 782/782 [00:51<00:00, 15.07it/s]

Epoch [23/50], Loss: 0.9337





Test Loss: 0.8238, Test Accuracy: 71.81%


100%|██████████| 782/782 [00:51<00:00, 15.22it/s]

Epoch [24/50], Loss: 0.9164





Test Loss: 0.8346, Test Accuracy: 71.44%


100%|██████████| 782/782 [00:51<00:00, 15.22it/s]

Epoch [25/50], Loss: 0.9167





Test Loss: 0.8041, Test Accuracy: 72.14%


100%|██████████| 782/782 [00:51<00:00, 15.19it/s]

Epoch [26/50], Loss: 0.9065





Test Loss: 0.7918, Test Accuracy: 72.97%


100%|██████████| 782/782 [00:51<00:00, 15.26it/s]

Epoch [27/50], Loss: 0.8884





Test Loss: 0.8218, Test Accuracy: 71.73%


100%|██████████| 782/782 [00:51<00:00, 15.33it/s]

Epoch [28/50], Loss: 0.9009





Test Loss: 0.8016, Test Accuracy: 72.56%


100%|██████████| 782/782 [00:51<00:00, 15.26it/s]

Epoch [29/50], Loss: 0.8955





Test Loss: 0.8299, Test Accuracy: 71.17%


100%|██████████| 782/782 [00:51<00:00, 15.26it/s]

Epoch [30/50], Loss: 0.8758





Test Loss: 0.7812, Test Accuracy: 73.90%


100%|██████████| 782/782 [00:51<00:00, 15.15it/s]

Epoch [31/50], Loss: 0.8751





Test Loss: 0.7819, Test Accuracy: 72.94%


100%|██████████| 782/782 [00:51<00:00, 15.18it/s]

Epoch [32/50], Loss: 0.8644





Test Loss: 0.7643, Test Accuracy: 74.07%


100%|██████████| 782/782 [00:51<00:00, 15.24it/s]

Epoch [33/50], Loss: 0.8718





Test Loss: 0.7858, Test Accuracy: 73.65%


100%|██████████| 782/782 [00:51<00:00, 15.27it/s]

Epoch [34/50], Loss: 0.8601





Test Loss: 0.7958, Test Accuracy: 72.99%


100%|██████████| 782/782 [00:51<00:00, 15.25it/s]


Epoch [35/50], Loss: 0.8557
Test Loss: 0.7674, Test Accuracy: 74.58%


100%|██████████| 782/782 [00:51<00:00, 15.32it/s]

Epoch [36/50], Loss: 0.8450





Test Loss: 0.7579, Test Accuracy: 74.39%


100%|██████████| 782/782 [00:51<00:00, 15.20it/s]

Epoch [37/50], Loss: 0.8865





Test Loss: 0.8035, Test Accuracy: 72.97%


100%|██████████| 782/782 [00:51<00:00, 15.16it/s]

Epoch [38/50], Loss: 0.8418





Test Loss: 0.7396, Test Accuracy: 75.15%


100%|██████████| 782/782 [00:51<00:00, 15.13it/s]

Epoch [39/50], Loss: 0.8313





Test Loss: 0.7609, Test Accuracy: 74.75%


100%|██████████| 782/782 [00:51<00:00, 15.12it/s]

Epoch [40/50], Loss: 0.8299





Test Loss: 0.8620, Test Accuracy: 70.80%


100%|██████████| 782/782 [00:51<00:00, 15.22it/s]

Epoch [41/50], Loss: 0.8277





Test Loss: 0.7540, Test Accuracy: 74.59%


100%|██████████| 782/782 [00:51<00:00, 15.05it/s]

Epoch [42/50], Loss: 0.8221





Test Loss: 0.7859, Test Accuracy: 73.53%


100%|██████████| 782/782 [00:51<00:00, 15.19it/s]

Epoch [43/50], Loss: 0.8174





Test Loss: 0.7661, Test Accuracy: 73.97%


100%|██████████| 782/782 [00:51<00:00, 15.26it/s]

Epoch [44/50], Loss: 0.8134





Test Loss: 0.7270, Test Accuracy: 75.97%


100%|██████████| 782/782 [00:51<00:00, 15.24it/s]

Epoch [45/50], Loss: 0.8192





Test Loss: 0.7314, Test Accuracy: 75.28%


100%|██████████| 782/782 [00:51<00:00, 15.31it/s]

Epoch [46/50], Loss: 0.8156





Test Loss: 0.7273, Test Accuracy: 75.23%


100%|██████████| 782/782 [00:51<00:00, 15.26it/s]

Epoch [47/50], Loss: 0.8035





Test Loss: 0.7204, Test Accuracy: 76.23%


100%|██████████| 782/782 [00:51<00:00, 15.27it/s]

Epoch [48/50], Loss: 0.7974





Test Loss: 0.7227, Test Accuracy: 75.75%


100%|██████████| 782/782 [00:51<00:00, 15.09it/s]

Epoch [49/50], Loss: 0.7926





Test Loss: 0.7499, Test Accuracy: 74.37%


100%|██████████| 782/782 [00:51<00:00, 15.27it/s]

Epoch [50/50], Loss: 0.8008





Test Loss: 0.7191, Test Accuracy: 76.13%
Model checkpoint saved.
