In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.init as init
from torchsummary import summary

# Data

In [None]:
transform = transforms.Compose([transforms.ToTensor(),
                               transforms.Normalize((0.5, ),(0.5, ))])

trainset = torchvision.datasets.CIFAR10(root='Data_CIFAF10',
                                        train=True,
                                        download=True,
                                        transform=transform)
trainloader = torch.utils.data.DataLoader(trainset,
                                          batch_size=1024,
                                          num_workers=10,
                                          shuffle=True,
                                          drop_last = True)

testset = torchvision.datasets.CIFAR10(root='Data_CIFAR10',
                                       train=False,
                                       download=True,
                                       transform=transform)
testloader = torch.utils.data.DataLoader(testset,
                                         batch_size=1024,
                                         num_workers=10,
                                         shuffle=False)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Function to display the images
def imshow(img):
    img = img*0.5 + 0.5
    np_img = img.numpy()
    plt.imshow(np.transpose(np_img, (1, 2, 0)))
    plt.show()

for i, (images, labels) in enumerate(trainloader, 0):
    # Plot some images
    imshow(torchvision.utils.make_grid(images[:8]))  # Display 8 images from the batch
    break

# Model

In [None]:
model = nn.Sequential(nn.Flatten(),
                      nn.Linear(3*32*32, 256),
                      nn.ReLU(),
                      nn.Linear(256,10))

summary(model, (3, 32, 32))

# Loss, Optimizer and Evaluate Function

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),
                       lr=0.01)

In [None]:
def evalute(model, testloader, loss_fn):
    model.eval()
    test_loss = 0.0
    running_correct = 0.0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:

            outputs = model(images)
            loss = loss_fn(outputs, labels)
            test_loss += loss.item()

            _, predicted = torch.max(outputs.data, dim=1)
            total += labels.size(0)
            running_correct += (predicted == labels).sum().item()
    
    accuracy = running_correct *100 / total
    loss = test_loss / len(testloader)
    return loss, accuracy

# Train

In [None]:
train_losses = []
train_accuracies = []

test_losses = []
test_accuracies = []

max_epoch = 20

In [None]:
for epoch in  range(max_epoch):
    model.train()
    running_loss = 0.0
    running_correct = 0.0
    total = 0

    for i, (images, labels ) in enumerate(trainloader, 0):

        outputs = model(images)
        loss = loss_fn(outputs, labels)
        running_loss += loss.item()

        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        running_correct += (predicted == labels).sum().item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    epoch_accuracy = running_correct * 100 / total
    epoch_loss = running_loss / (i+1)
    train_accuracies.append(epoch_accuracy)
    train_losses.append(epoch_loss)

    test_loss, test_accuracy = evalute(model, testloader, loss_fn)
    test_accuracies.append(test_accuracy)
    test_losses.append(test_loss)
    print(f"Epoch [{epoch + 1}/{max_epoch}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")

In [None]:
import matplotlib.pyplot as plt

plt.plot(train_losses, label='train_losses')
plt.plot(test_losses, label='test_losses')
plt.legend()

In [None]:
import matplotlib.pyplot as plt

plt.plot(train_accuracies, label='train_accuracy')
plt.plot(test_accuracies, label='test_accuracy')
plt.legend()