In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, ConcatDataset
import time
import gc
import matplotlib.pyplot as plt
import matplotlib.gridspec as grid

# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Training variables
BATCH_SIZE = 32
EPOCHS = 60 
LEARNING_RATE = 5e-5

# Allowed transformations for the train, validation, and test datasets
train_transform = transforms.Compose([
    transforms.Resize(256, antialias=True),
    transforms.RandomCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_test_transform = transforms.Compose([
    transforms.Resize((256, 256), antialias=True),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load the train, validation, and test datasets
train_dataset = datasets.Flowers102(root="./data", split="train", transform=train_transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_dataset = datasets.Flowers102(root="./data", split="val", transform=val_test_transform, download=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_dataset = datasets.Flowers102(root="./data", split="test", transform=val_test_transform, download=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

train_acc = []
val_acc = []

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
                                   nn.BatchNorm2d(16),
                                   nn.ReLU(),
                                   nn.MaxPool2d(2, 2)) # 224 -> 112
        self.conv2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
                                   nn.BatchNorm2d(32),
                                   nn.ReLU(),
                                   nn.MaxPool2d(2, 2)) # 112 -> 56
        self.conv3 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
                                   nn.BatchNorm2d(64),
                                   nn.ReLU(),
                                   nn.MaxPool2d(2, 2)) # 56 -> 28
        self.conv4 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
                                   nn.BatchNorm2d(128),
                                   nn.ReLU(),
                                   nn.MaxPool2d(2, 2)) # 28 -> 14
        self.conv5 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
                                   nn.BatchNorm2d(256),
                                   nn.ReLU(),
                                   nn.MaxPool2d(2, 2)) # 14 -> 7
        self.conv6 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
                                   nn.BatchNorm2d(512),
                                   nn.ReLU(),
                                   nn.MaxPool2d(2, 2)) # 7 -> 3
        self.conv7 = nn.Sequential(nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=1),
                                   nn.BatchNorm2d(1024),
                                   nn.ReLU(),
                                   nn.MaxPool2d(2, 2)) # 3 -> 1

        self.fc1 = nn.Sequential(nn.Linear(1024, 512),
                                 nn.ReLU())
        self.fc2 = nn.Sequential(nn.Linear(512, 256),
                                 nn.ReLU())
        self.fc3 = nn.Sequential(nn.Linear(256, 102))

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.conv6(x)
        x = self.conv7(x)
        x = x.view(-1, 1024)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x


# Clear CUDA cache
torch.cuda.empty_cache()

# Create an instance of the CNN and move it to the device
cnn = CNN().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(cnn.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.1, verbose=True)

import os
def save_network(network, epoch_label):
    save_filename = 'net_%s.pth' % epoch_label
    save_path = os.path.join('./savedModels', save_filename)
    torch.save(network.cpu().state_dict(), save_path)
    if torch.cuda.is_available():
       network.cuda()

start_time = time.time()
best_epoch = 0
best_acc = 0

# Train the CNN
for epoch in range(EPOCHS):
    print(f'--------- Epoch {epoch + 1} ----------')
    cnn.train()
    train_loss = 0
    train_correct = 0
    train_total = 0
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        outputs = cnn(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        train_total += labels.size(0)
        train_correct += (predicted == labels).sum().item()

    train_loss /= len(train_loader)
    train_accuracy = 100 * train_correct / train_total
    print(f'  Training Accuracy: {train_accuracy:.2f}%')
    train_acc.append(train_accuracy)

    # Evaluate model after each training epoch
    cnn.eval()
    val_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = cnn(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss /= len(val_loader)
    accuracy = 100 * correct / total
    print(f'  Validation Accuracy: {accuracy:.2f}%')
    val_acc.append(accuracy)
    
    # Update the learning rate scheduler
    scheduler.step(val_loss)

    if accuracy > best_acc:
      best_acc = accuracy
      torch.save(cnn.state_dict(), 'best-model-parameters.pt')
      best_epoch = epoch + 1
      print(" **Better model found. Updated best model.")

end_time = time.time()
print("Training complete in: " + time.strftime("%Hh %Mm %Ss", time.gmtime(end_time - start_time)))

# Test the CNN on the test set
best_model = torch.load("best-model-parameters.pt")
cnn.load_state_dict(best_model)
print("---------- Best Model ----------")
print(f" Epoch: {best_epoch}")
print(f" Validation accuracy: {best_acc:.3f}%")

cnn.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = cnn(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_accuracy = 100 * correct / total
print('----------- Testing ------------')
print(f'  Test Accuracy: {test_accuracy:.2f}%')

# Plot
fig = plt.figure(tight_layout=True)
gs = grid.GridSpec(nrows=2, ncols=1)

ax = fig.add_subplot(gs[0,0])
ax.plot(train_acc, label="Training", color="green")
ax.plot(val_acc, label="Validation", color="red")
ax.set_xlabel("Epoch")
ax.set_ylabel("Accuracy")

fig.align_labels()
plt.legend()
plt.show()