<a href="https://colab.research.google.com/github/OmarElbanna/Power-Window-Control-System/blob/main/Learning_Rate_Analysis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from Dataset_Loader import load_leafs_dataset
from Training import train_cnn
from Model_Evaluation import evaluate_model_with_outputs

from torch import nn, optim, manual_seed, flatten
from Util import apply_conv, apply_pool
import matplotlib.pyplot as plt
from time import time

In [None]:
batch_size = 20
test_split_size = 0.2

images_train, images_test, labels_train, labels_test = load_leafs_dataset(
    test_split_size,
    batch_size
)

In [None]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        width = 1633
        height = 1089
        self.pool1, width, height = apply_pool(width, height, kernel_size=2, stride=2)
        self.pool2, width, height = apply_pool(width, height, kernel_size=2, stride=2)
        self.pool3, width, height = apply_pool(width, height, kernel_size=2, stride=2)

        self.conv1, width, height = apply_conv(width, height, 1, 6, kernel_size=3, stride=1, padding=1)
        self.pool4, width, height = apply_pool(width, height, kernel_size=2, stride=2)

        self.conv2, width, height = apply_conv(width, height, 6, 16, kernel_size=3, stride=1, padding=1)
        self.pool5, width, height = apply_pool(width, height, kernel_size=2, stride=2)

        self.fc1 = nn.Linear(16*width*height, 1000)
        self.fc2 = nn.Linear(1000, 99)

    def forward(self, x):
        x = self.pool1(x)
        x = self.pool2(x)
        x = self.pool3(x)

        x = nn.functional.relu(self.conv1(x))
        x = self.pool4(x)

        x = nn.functional.relu(self.conv2(x))
        x = self.pool5(x)

        x = flatten(x, 1)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.log_softmax(self.fc2(x), dim=1)
        return x

# Static Learning Rate Analysis

In [None]:
lr_list = [
    0.0001, 0.001, 0.01, 0.1
]
fig = plt.figure(figsize=(19, 15))

for i in range(0, len(lr_list)):
    print(f'Analysing Learning Rate {lr_list[i]}...')
    manual_seed(1)
    model = CNN()
    epochs = 10
    initial_train_accuracy = evaluate_model_with_outputs(model, images_train, labels_train, batch_size, 0)
    initial_test_accuracy = evaluate_model_with_outputs(model, images_test, labels_test, batch_size, 1)

    trained_model, losses_train, accuracies_train, accuracies_test = train_cnn(
        model=model,
        images_train=images_train,
        labels_train=labels_train,
        images_test=images_test,
        labels_test=labels_test,
        epochs=epochs,
        batch_size=batch_size,
        lossFunction=nn.CrossEntropyLoss(),
        optimizer=optim.Adam(model.parameters(), lr_list[i], weight_decay=0.1),
        print_loss=True,
        calc_accuracy=True
    )


    accuracies_train.insert(0, initial_train_accuracy)
    accuracies_test.insert(0, initial_test_accuracy)

    rng_x = range(epochs+1)
    rng_y = range(0, 100, 5)

    ax = fig.add_subplot(231 + i)
    ax.set_title(f'Learning: {lr_list[i]}')
    plt.tight_layout()
    plt.plot(rng_x, accuracies_train, label='Training')
    plt.plot(rng_x, accuracies_test, label='Testing')
    plt.xlabel('Epoch')
    plt.xticks(rng_x, rng_x)
    plt.ylabel('Accuracy')
    plt.yticks(rng_y, rng_y)
    plt.grid()
    plt.legend()
plt.show()

# Learning Rate Scheduler Analysis

In [None]:
def train_cnn(
        model,
        images_train,
        labels_train,
        images_test,
        labels_test,
        epochs,
        batch_size,
        lossFunction,
        optimizer,
        print_loss=False,
        calc_accuracy=True,
):
    losses_train = []
    accuracies_train = []
    accuracies_test = []
    scheduler = StepLR(optimizer, step_size =2, gamma=0.1)

    for i in range(epochs):
        losses_batch = []
        for b in range(len(images_train)):
            y_pred = model.forward(images_train[b])

            loss_train = lossFunction(y_pred, labels_train[b])
            losses_batch.append(loss_train.detach().numpy())

            optimizer.zero_grad()
            loss_train.backward()
            optimizer.step()
        losses_train.append(losses_batch)

        if print_loss:
            print(f'Epoch: {i} / loss: {loss_train}')

        if calc_accuracy:
            accuracies_train.append(
                evaluate_model_with_outputs(model, images_train, labels_train, batch_size, 0) * 100
            )
            accuracies_test.append(
                evaluate_model_with_outputs(model, images_test, labels_test, batch_size, 1) * 100
            )
        scheduler.step()

    return model, losses_train, accuracies_train, accuracies_test

In [None]:
lr_list = [
    0.0001, 0.001, 0.01, 0.1
]
fig = plt.figure(figsize=(19, 15))

for i in range(0, len(lr_list)):
    print(f'Analysing Learning Rate {lr_list[i]}...')
    manual_seed(1)
    model = CNN()
    epochs = 10
    initial_train_accuracy = evaluate_model_with_outputs(model, images_train, labels_train, batch_size, 0)
    initial_test_accuracy = evaluate_model_with_outputs(model, images_test, labels_test, batch_size, 1)

    trained_model, losses_train, accuracies_train, accuracies_test = train_cnn(
        model=model,
        images_train=images_train,
        labels_train=labels_train,
        images_test=images_test,
        labels_test=labels_test,
        epochs=epochs,
        batch_size=batch_size,
        lossFunction=nn.CrossEntropyLoss(),
        optimizer=optim.Adam(model.parameters(), lr_list[i], weight_decay=0.1),
        print_loss=True,
        calc_accuracy=True
    )


    accuracies_train.insert(0, initial_train_accuracy)
    accuracies_test.insert(0, initial_test_accuracy)

    rng_x = range(epochs+1)
    rng_y = range(0, 100, 5)

    ax = fig.add_subplot(231 + i)
    ax.set_title(f'Learning: {lr_list[i]}')
    plt.tight_layout()
    plt.plot(rng_x, accuracies_train, label='Training')
    plt.plot(rng_x, accuracies_test, label='Testing')
    plt.xlabel('Epoch')
    plt.xticks(rng_x, rng_x)
    plt.ylabel('Accuracy')
    plt.yticks(rng_y, rng_y)
    plt.grid()
    plt.legend()
plt.show()