# Training CNN model for tree classification based on images of their bark
In this notebook a CNN model will be developed to classify bark according to tree species. The CNN will be trained to classify the original images. After the training the CNN will be evaluated with some post-hoc model analysis methods like LIME and SHAP.

## Organizing the data structure (only done once, after downloading the dataset)

First, the dataset (https://www.kaggle.com/datasets/saurabhshahane/barkvn50) should be downloaded to directory: "./data/BarkVN-50/" and unzipped. You should then have the following structure:
- data
    - BarkVN-50
        - BarkVN-50_mendeley
            - Acacia
            - Adenanthera microsperma
            - Adenieum species
            - Anacardium occidentale
            - ...

Since this is not ideal for this CNN, a subset of the data is selected and split into training data using the code in the next cell:

In [None]:
# import helpers.split
# helpers.split.train_test_split()

Note: this cell only needs to be executed once (this is why it is commented out by default).

After execution the new data structure looks like this:
- data
    - BarkVN-50
        - BarkVN-50_mendeley
            - Acacia
            - Adenanthera microsperma
            - Adenieum species
            - Anacardium occidentale
            - ...
        - Test
            - Adenanthera microsperma
            - Cananga odorata
            - Cedrus
            - Cocos nucifera
            - Dalvergia oliveri
        - Train
            - Adenanthera microsperma
            - Cananga odorata
            - Cedrus
            - Cocos nucifera
            - Dalvergia oliveri

Note: the directory "./data/BarkVN-50/BarkVN-50_mendeley" may be deleted after this step.

## Loading the Dataset and creating DataLoaders
Since the used dataset is a custom one, we need to first create a custom Dataset for loading, transforming and delivering datapoints.

In [None]:
from helpers.dataset import BarkVN50Dataset
from torch.utils.data import DataLoader
from torch import device
from torch.cuda import is_available

# recognizing device
DEVICE = device("cuda" if is_available() else "cpu")

# load train dataset and create DataLoaders that automatically create minibatches and shuffle the data
train_dataset = BarkVN50Dataset(train=True, device=DEVICE)
test_dataset = BarkVN50Dataset(train=False, device=DEVICE)

train_dataloader = DataLoader(train_dataset, batch_size=39, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

## Training the CNN model

### Initialization
Now that the data is ready to be used, we can load the CNN model (or resume training):

In [None]:
from helpers.cnn import ConvolutionalNeuralNetwork

model = ConvolutionalNeuralNetwork()
model.to(device=DEVICE)

Initialize the hyperaparameters, optimizer and the criterion (loss function):

In [None]:
from torch.optim import Adam
from torch.nn import CrossEntropyLoss

# hyperparameters
num_epochs = 50
learning_rate = 3e-4

# optimizer and loss function
model.train()
optimizer = Adam(model.parameters(), lr=learning_rate)
criterion = CrossEntropyLoss()

If an already existing model should be trained, it can be loaded from disk:

In [None]:
# from torch import load

# checkpoint = load("models/checkpoint-2024-11-04-18-14-59.tar", weights_only=True)
# model.load_state_dict(checkpoint["model_state_dict"])
# optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
# epoch = checkpoint["epoch"]
# loss = checkpoint["loss"]

### Training the CNN model
And finally train the model:

In [None]:
from helpers.train import train_cnn

train_cnn(
    num_epochs=num_epochs,
    model=model,
    criterion=criterion,
    dataloader=train_dataloader,
    optimizer=optimizer,
)

This trained model can be evaluated before we save it:

In [None]:
from helpers.evaluate import evaluate_cnn

evaluate_cnn(
    criterion=criterion,
    test_dataloader=test_dataloader,
    train_dataloader=train_dataloader,
    model=model,
)

If the model should be trained again later on, it can be saved using the .tar (PyTorch convention for model checkpoints) format:

In [None]:
from torch import save
from datetime import datetime

time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
save(
    {
        "epoch": num_epochs,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": 123,
    },
    f"models/checkpoint-{num_epochs}ep-{time}.tar",
)

And if it shouldn't be trainable, but nonetheless be evaluated, only the model's state_dictionary can be saved with the .pt format (PyTorch convention for finished models):

In [None]:
from torch import save
from datetime import datetime

time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
save(model.state_dict(), f"models/eval-model-{num_epochs}ep-{time}.pt")