<a href="https://colab.research.google.com/github/JaeHyeok98/-/blob/main/iap_2024_Lab5_AlexNet_CIFAR10_with_TensorBoard.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torchvision
import numpy as np
import time

# Load Model

In [None]:
# Load the AlexNet model
MyModel = torchvision.models.alexnet() # try 'torchvision.models.resnet18()' and 'torchvision.models.mobilenet_v2()'

# Show network structure
print(MyModel)

In [None]:
# Modify output feature dimension of the last FC layer (10 classes in CIFAR-10 dataset)
num_ftrs = MyModel.classifier[6].in_features
MyModel.classifier[6] = torch.nn.Linear(num_ftrs, 10)

In [None]:
# Set the device on which the model is to run
MyDevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Move tensors and models to a specified device
MyModel.to(MyDevice)

# Load Dataset

In [None]:
# Define the data transformation pipeline
MyTransform = torchvision.transforms.Compose([torchvision.transforms.Resize(224),
                                              torchvision.transforms.RandomHorizontalFlip(),
                                              torchvision.transforms.ToTensor(),
                                              torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [None]:
# Load the CIFAR-10 dataset
CIFAR10_TrainSet = torchvision.datasets.CIFAR10(root='./data',
                                                train=True,
                                                download=True,
                                                transform=MyTransform)
CIFAR10_TestSet = torchvision.datasets.CIFAR10(root='./data',
                                               train=False,
                                               download=True,
                                               transform=MyTransform)

In [None]:
# Generate a list of indices ranging from 0 to 49999
indices = list(range(50000))

# Shuffle the list of indices randomly
shuffled_indices = np.random.permutation(indices)

In [None]:
# Create a validation set by selecting a subset of samples from the CIFAR10_TrainSet
CIFAR10_ValSet = torch.utils.data.Subset(CIFAR10_TrainSet, shuffled_indices[0:5000])

# Update the CIFAR10_TrainSet to exclude the samples used for validation
CIFAR10_TrainSet = torch.utils.data.Subset(CIFAR10_TrainSet, shuffled_indices[5000:50000])

# Prepare for training

In [None]:
# Define the hyperparameters
lr = 0.001
bs = 128
num_epochs = 5

In [None]:
# Create DataLoader for training and validation sets
CIFAR10_TrainLoader = torch.utils.data.DataLoader(dataset=CIFAR10_TrainSet,
                                                  batch_size=bs,
                                                  shuffle=True,
                                                  num_workers=2)
CIFAR10_ValLoader = torch.utils.data.DataLoader(dataset=CIFAR10_ValSet,
                                                batch_size=bs,
                                                shuffle=False,
                                                num_workers=2)

In [None]:
# Define Optimizer
optimizer = torch.optim.Adam(MyModel.parameters(), lr=lr)

In [None]:
# Define Loss Function
criterion = torch.nn.CrossEntropyLoss()

In [None]:
import torch.utils.tensorboard

# Create an instance of the SummaryWriter class, which is used to log information for visualization in TensorBoard
writer = torch.utils.tensorboard.SummaryWriter('runs/AlexNet_CIFAR10')

# Load the TensorBoard extension, allowing you to use the %tensorboard magic command
%load_ext tensorboard

# Launche TensorBoard within the Jupyter Notebook interface
%tensorboard --logdir runs/AlexNet_CIFAR10 --port=6006

# Train

In [None]:
# Define the 'train' function
def train(model, device, train_set, train_loader, optimizer, criterion, writer, epoch):

    model.train()

    total_loss = 0
    num_total_preds = 0
    num_total_correct_preds = 0

    train_size = len(train_set)
    num_iterations_per_epoch = int((train_size + (bs - 1)) / bs)

    for batch_idx, (data, target_label) in enumerate(train_loader):

        data, target_label = data.to(device), target_label.to(device)

        predicted_scores = model(data)

        loss = criterion(predicted_scores, target_label)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        num_preds = target_label.size(0)

        _, predicted_label = predicted_scores.max(1)

        num_correct_preds = predicted_label.eq(target_label).sum().item()

        print(f"\ttrain  |  batch: {batch_idx + 1:>3}  |  num_preds: {num_preds:>3}  |  num_correct_preds: {num_correct_preds:>3}  |  batch_accuracy: {(num_correct_preds / num_preds) * 100:>4.1f}")

        total_loss += loss.item() * num_preds
        num_total_preds += num_preds
        num_total_correct_preds += num_correct_preds

        # Log the epoch training results on tensorboard
        iteration = epoch * num_iterations_per_epoch + (batch_idx + 1)
        writer.add_scalar('Training Loss', loss.item(), iteration)
        writer.add_scalar('Training Acc', (num_correct_preds / num_preds) * 100, iteration)

    training_loss = total_loss / num_total_preds
    training_accuracy = (num_total_correct_preds / num_total_preds) * 100

    print(f"\tTraining Loss: {training_loss:.3f}")
    print(f"\tTraining Accuracy: {training_accuracy:.1f}")

In [None]:
# Define the 'val' function
def val(model, device, val_loader, criterion):

    model.eval()

    total_loss = 0
    num_total_preds = 0
    num_total_correct_preds = 0

    with torch.no_grad():

        for batch_idx, (data, target_label) in enumerate(val_loader):

            data, target_label = data.to(device), target_label.to(device)

            predicted_scores = model(data)

            loss = criterion(predicted_scores, target_label)

            num_preds = target_label.size(0)

            _, predicted_label= predicted_scores.max(1)
            num_correct_preds = predicted_label.eq(target_label).sum().item()

            total_loss += loss.item() * num_preds
            num_total_preds += num_preds
            num_total_correct_preds += num_correct_preds

        val_loss = total_loss / num_total_preds
        val_accuracy = (num_total_correct_preds / num_total_preds) * 100

        print(f"\tVal Loss: {val_loss:.3f}")
        print(f"\tVal Accuracy: {val_accuracy:.1f}")

In [None]:
# Train the model
for epoch in range(num_epochs):

    print('Epoch:', epoch + 1)

    start = time.time()

    train(model=MyModel,
          device=MyDevice,
          train_set=CIFAR10_TrainSet,
          train_loader=CIFAR10_TrainLoader,
          optimizer=optimizer,
          criterion=criterion,
          writer=writer,
          epoch=epoch)

    val(model=MyModel,
        device=MyDevice,
        val_loader=CIFAR10_ValLoader,
        criterion=criterion)

    end = time.time()

    print(f"\tTime spent training this epoch: {int(end - start)} sec\n")

print('Training End!')

# Evaluate

In [None]:
# Create DataLoader test set
CIFAR10_TestLoader = torch.utils.data.DataLoader(dataset=CIFAR10_TestSet,
                                                 batch_size=bs,
                                                 shuffle=False,
                                                 num_workers=2)

In [None]:
# Define the 'test' function
def test(model, device, test_loader):

    model.eval()

    num_total_preds = 0
    num_total_correct_preds = 0

    with torch.no_grad():

        for batch_idx, (data, target_label) in enumerate(test_loader):

            data, target_label = data.to(device), target_label.to(device)

            predicted_scores = model(data)

            num_preds = target_label.size(0)

            _, predicted_label= predicted_scores.max(1)
            num_correct_preds = predicted_label.eq(target_label).sum().item()

            num_total_preds += num_preds
            num_total_correct_preds += num_correct_preds

        test_accuracy = (num_total_correct_preds / num_total_preds) * 100

        print(f"Model Accuracy: {test_accuracy:.1f}")

In [None]:
# Test the model
test(model=MyModel,
     device=MyDevice,
     test_loader=CIFAR10_TestLoader)

In [None]:
'''Measure FLOPs (floating-point operations) and the number of parameters
!pip install thop
import thop

random_input = torch.randn(1, 3, 224, 224).to(MyDevice)
FLOPS, params = thop.profile(MyModel, inputs=(random_input, ))

print(f"FLOPS: {int(FLOPS)} | params: {int(params)}")