In [1]:
import numpy as np
import os
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

# from itertools import starmap
from utils import save_model, save_plots
from CNN_execution import plot_roc_curve, ect_train_validate, report_trained_model, find_numpy_files

In [2]:
# Parameters required to define the model. 
# Will remain same throught the excerise.
 
NUM_EPOCHS = 50 # number of epochs to train the network for; type=int
LEARNING_RATE = 1e-3 # learning rate for training; type=float
# loss function
lossfcn = nn.CrossEntropyLoss()

# Number of workers for dataloader
num_workers = int( os.environ.get('SLURM_CPUS_PER_TASK', default=0) )

# device
device = ('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
num_dirs = 4
num_thresh = 4

data_dir = '../../data'
classes = [
    i
        for i in os.listdir(data_dir)
        if os.path.isdir(os.path.join(data_dir, i))
]
class_items = {
    i: find_numpy_files(os.path.join(data_dir, i))
        for i in classes
}
class_items.pop('Transect')
class_items.pop('Leafsnap')

num_data_to_use_for_training = min( [len(class_items[i]) for i in class_items] )
# num_data_to_use_for_training = 30
print(f"Using {num_data_to_use_for_training} data for training")

batch_size = num_data_to_use_for_training // 11

class_items = {
    class_name: np.random.choice( file_paths, num_data_to_use_for_training, replace=False)
        for class_name, file_paths in class_items.items()
}

Using 865 data for training


In [4]:
help(ect_train_validate)

# trained_outputs = ect_train_validate(
#     num_dirs=num_dirs,
#     num_thresh=num_thresh,
#     input_path=class_items,
#     output_ect_path='example_data/outputs',
#     output_model_path='example_data/best_model.pth',
#     log_level='INFO'
# )

Help on function ect_train_validate in module CNN_execution:

ect_train_validate(num_dirs, num_thresh, input_path=None, output_ect_path='example_data/ect_output', in_memory=False, output_model_path='outputs/best_model.pth', num_epochs=50, learning_rate=0.001, lossfcn=CrossEntropyLoss(), batch_size=4, valid_split=0.2, num_workers=0, device=device(type='cuda'), recompute_ect=True, log_level='INFO')
    Function to train and validate the CNN model using the ECT dataset.
    Usage:
        ect_train_validate(
            num_dirs, num_thresh, input_path=None,
            output_ect_path="example_data/ect_output", in_memory=False,
            output_model_path="outputs/best_model.pth",
            num_epochs=50, learning_rate=1e-3, lossfcn=nn.CrossEntropyLoss(),
            batch_size=4, valid_split=0.2, num_workers=0,
            device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
            recompute_ect=True, log_level='INFO'
        )
    Parameters:
        num_dirs: 

In [5]:
help(save_model)
# save_model(
#     epochs=trained_outputs["num_epochs"],
#     model=trained_outputs["model"],
#     optimizer=trained_outputs["optimizer"],
#     criterion=trained_outputs["lossfcn"],
#     output_model_path='example_data/best_model.pth',
# )


Help on function save_model in module utils:

save_model(epochs, model, optimizer, criterion, output_model_path='outputs/best_model.pth')
    Function to save the trained model.
    Adapted from https://debuggercafe.com/saving-and-loading-the-best-model-in-pytorch/



In [6]:
help(save_plots)
# loss, acc = plt.figure(figsize=(9,5)).subplots(1, 2)
# save_plots(
#     train_acc= trained_outputs["train_acc"],
#     valid_acc= trained_outputs["valid_acc"],
#     train_loss= trained_outputs["train_loss"],
#     valid_loss= trained_outputs["valid_loss"],
#     loss=loss,
#     accuracy=acc,
#     accuracy_path='example_data/accuracy.png',
#     loss_path='example_data/loss.png'
# )


Help on function save_plots in module utils:

save_plots(train_acc, valid_acc, train_loss, valid_loss, accuracy=None, loss=None, fig_size=(10, 7), dpi=300, accuracy_path='outputs/accuracy.png', loss_path='outputs/loss.png')
    Function to save the loss and accuracy plots.
    Usage:
        save_plots(
            train_acc, valid_acc, train_loss,valid_loss,
            accuracy = None, loss = None,
            fig_size=(10, 7), dpi=300,
            accuracy_path = 'outputs/accuracy.png', loss_path = 'outputs/loss.png'
        )
    Parameters:
        train_acc: list of training accuracy values
        valid_acc: list of validation accuracy values
        train_loss: list of training loss values
        valid_loss: list of validation loss values
        accuracy: matplotlib axis to plot accuracy. If None, a new figure is created.
        loss: matplotlib axis to plot loss. If None, a new figure is created.
        fig_size: tuple, size of the figure. Default is (10, 7)
        dpi: i

In [7]:
help(report_trained_model)
# report_trained_model(
#     num_dirs=num_dirs,
#     num_thresh=num_thresh,
#     train_dataset=trained_outputs["train_dataset"],
#     train_loader=trained_outputs["train_loader"],
#     test_loader=trained_outputs["test_loader"],
#     test_dataset=trained_outputs["test_dataset"],
#     model_path='example_data/best_model.pth',
#     output_cf='example_data/confusion_matrix.png',
#     output_report='example_data/accuracy.txt',
#     log_level='INFO'
# )

Help on function report_trained_model in module CNN_execution:

report_trained_model(num_dirs, num_thresh, train_dataset, train_loader, test_loader, test_dataset, device=device(type='cuda'), model_path='outputs/best_model.pth', ax=None, output_cf='outputs/confusion_matrix.png', output_report='outputs/outputCLFreport.csv', log_level='INFO')
    Function to report the trained model.
    Usage:
        report_trained_model(
            num_dirs, num_thresh,
            train_dataset, train_loader, test_loader, test_dataset,
            device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
            model_path= 'outputs/best_model.pth',
            output_cf='outputs/confusion_matrix.png',
            output_report='outputs/outputCLFreport.csv',
            log_level='INFO'
        )
    Parameters:
        num_dirs: int, number of directions for ECT calculation.
        num_thresh: int, number of thresholds for ECT calculation.
        train_dataset: torch.utils.data.Data

In [8]:
help(plot_roc_curve)
# plot_roc_curve(
#     model=trained_outputs["model"],
#     test_loader=trained_outputs["test_loader"],
#     test_dataset=trained_outputs["test_dataset"],
#     output_path='example_data/roc_curve.png'
# )

Help on function plot_roc_curve in module CNN_execution:

plot_roc_curve(model, test_loader, test_dataset, device=device(type='cuda'), axis=None, output_path='outputs/roc_curve.png')
    Function to plot the ROC curve for the trained model.
    Usage:
        plot_roc_curve(model, test_loader, test_dataset, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
    Parameters:
        model: torch.nn model, trained model.
        test_loader: torch.utils.data.DataLoader, test data loader.
        test_dataset: torch.utils.data.Dataset, test dataset.
        device: torch.device, device to run the model. Optional, default is 'cuda' if available else 'cpu'.



In [None]:
for i in 2** np.linspace(2,11,10):
    directions , thresholds = int(i), int(i)
    trained_model = ect_train_validate(
        num_dirs=directions,
        num_thresh=thresholds,
        input_path=class_items,
        output_ect_path=f'outputs/output_{directions}_{thresholds}/ect',
        output_model_path=f'outputs/output_{directions}_{thresholds}/best_model.pth',
        num_workers=num_workers,
        batch_size=batch_size,
        num_epochs=10,
        recompute_ect=False,
        log_level='INFO'
    )
    save_model(
        epochs=trained_model["num_epochs"],
        model=trained_model["model"],
        optimizer=trained_model["optimizer"],
        criterion=trained_model["lossfcn"],
        output_model_path=f'outputs/output_{directions}_{thresholds}/best_model.pth',
    )
    loss, acc = plt.figure(figsize=(9,5)).subplots(1, 2)
    save_plots(
        train_acc= trained_model["train_acc"],
        valid_acc= trained_model["valid_acc"],
        train_loss= trained_model["train_loss"],
        valid_loss= trained_model["valid_loss"],
        loss=loss,
        accuracy=acc,
        accuracy_path=f'outputs/output_{directions}_{thresholds}/accuracy_loss.png',
        loss_path=f'outputs/output_{directions}_{thresholds}/accuracy_loss.png'
    )
    ax = plt.figure( figsize=(24,24), dpi=300 ).add_subplot(111)
    report_trained_model(
        num_dirs=directions,
        num_thresh=thresholds,
        train_dataset=trained_model["train_dataset"],
        train_loader=trained_model["train_loader"],
        test_loader=trained_model["test_loader"],
        test_dataset=trained_model["test_dataset"],
        ax=ax,
        model_path=f'outputs/output_{directions}_{thresholds}/best_model.pth',
        output_cf=f'outputs/output_{directions}_{thresholds}/confusion_matrix.png',
        output_report=f'outputs/output_{directions}_{thresholds}/accuracy.txt',
        log_level='None'
    )
    plot_roc_curve(
        model=trained_model["model"],
        test_loader=trained_model["test_loader"],
        test_dataset=trained_model["test_dataset"],
        output_path=f'outputs/output_{directions}_{thresholds}/roc_curve.png'
    )
    print(f"Completed training for {directions} directions and {thresholds} thresholds")

100%|██████████| 205/205 [00:03<00:00, 63.98it/s] 

Validation





Training loss: 2.623, training acc: 8.387
Validation loss: 2.620, validation acc: 8.549

Best validation loss: 2.6197767441089335

Saving best model for epoch: 2

--------------------------------------------------
[INFO]: Epoch 3 of 10
Training


100%|██████████| 205/205 [00:03<00:00, 56.36it/s] 

Validation





Training loss: 2.623, training acc: 8.469
Validation loss: 2.620, validation acc: 8.549
--------------------------------------------------
[INFO]: Epoch 4 of 10
Training


100%|██████████| 205/205 [00:03<00:00, 63.44it/s] 

Validation





Training loss: 2.622, training acc: 8.343
Validation loss: 2.620, validation acc: 8.549

Best validation loss: 2.6196577961628256

Saving best model for epoch: 4

--------------------------------------------------
[INFO]: Epoch 5 of 10
Training


100%|██████████| 205/205 [00:03<00:00, 63.83it/s] 

Validation





Training loss: 2.622, training acc: 8.262
Validation loss: 2.621, validation acc: 8.549
--------------------------------------------------
[INFO]: Epoch 6 of 10
Training


100%|██████████| 205/205 [00:03<00:00, 65.27it/s] 

Validation





Training loss: 2.623, training acc: 8.644
Validation loss: 2.622, validation acc: 8.925
--------------------------------------------------
[INFO]: Epoch 7 of 10
Training


100%|██████████| 205/205 [00:03<00:00, 63.96it/s] 

Validation





Training loss: 2.622, training acc: 8.531
Validation loss: 2.619, validation acc: 8.925

Best validation loss: 2.619488014624669

Saving best model for epoch: 7

--------------------------------------------------
[INFO]: Epoch 8 of 10
Training


100%|██████████| 205/205 [00:03<00:00, 61.14it/s] 

Validation





Training loss: 2.622, training acc: 8.770
Validation loss: 2.621, validation acc: 8.549
--------------------------------------------------
[INFO]: Epoch 9 of 10
Training


100%|██████████| 205/205 [00:03<00:00, 64.39it/s] 

Validation





Training loss: 2.622, training acc: 8.556
Validation loss: 2.620, validation acc: 8.047
--------------------------------------------------
[INFO]: Epoch 10 of 10
Training


100%|██████████| 205/205 [00:03<00:00, 64.25it/s] 

Validation





Training loss: 2.622, training acc: 8.538
Validation loss: 2.621, validation acc: 8.549
--------------------------------------------------
Saving final model...


  state_dict = torch.load(model_path)['model_state_dict']


Using validation to compute ROC curve


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
100%|██████████| 52/52 [00:00<00:00, 170.07it/s]


Completed training for 32 directions and 32 thresholds
num_classes= 14
ECT data; using only normalize, rotation transforms on training data
[INFO]: Epoch 1 of 10
Training


100%|██████████| 145/145 [00:03<00:00, 44.31it/s] 

Validation





Training loss: 2.619, training acc: 11.357
Validation loss: 2.609, validation acc: 12.136

Best validation loss: 2.6089653453311406

Saving best model for epoch: 1

--------------------------------------------------
[INFO]: Epoch 2 of 10
Training


100%|██████████| 145/145 [00:03<00:00, 47.05it/s] 

Validation





Training loss: 2.608, training acc: 12.155
Validation loss: 2.603, validation acc: 12.136

Best validation loss: 2.6027658758936703

Saving best model for epoch: 2

--------------------------------------------------
[INFO]: Epoch 3 of 10
Training


100%|██████████| 145/145 [00:03<00:00, 46.09it/s] 

Validation





Training loss: 2.606, training acc: 12.155
Validation loss: 2.604, validation acc: 12.136
--------------------------------------------------
[INFO]: Epoch 4 of 10
Training


100%|██████████| 145/145 [00:03<00:00, 42.92it/s] 

Validation





Training loss: 2.607, training acc: 12.155
Validation loss: 2.605, validation acc: 12.136
--------------------------------------------------
[INFO]: Epoch 5 of 10
Training


100%|██████████| 145/145 [00:03<00:00, 48.21it/s] 

Validation





Training loss: 2.606, training acc: 12.155
Validation loss: 2.609, validation acc: 12.136
--------------------------------------------------
[INFO]: Epoch 6 of 10
Training


100%|██████████| 145/145 [00:03<00:00, 46.98it/s]

Validation





Training loss: 2.606, training acc: 12.155
Validation loss: 2.605, validation acc: 12.136
--------------------------------------------------
[INFO]: Epoch 7 of 10
Training


100%|██████████| 145/145 [00:03<00:00, 42.46it/s] 

Validation





Training loss: 2.606, training acc: 12.155
Validation loss: 2.606, validation acc: 12.136
--------------------------------------------------
[INFO]: Epoch 8 of 10
Training


100%|██████████| 145/145 [00:03<00:00, 41.52it/s]

Validation





Training loss: 2.606, training acc: 12.155
Validation loss: 2.606, validation acc: 12.136
--------------------------------------------------
[INFO]: Epoch 9 of 10
Training


100%|██████████| 145/145 [00:03<00:00, 44.46it/s] 

Validation





Training loss: 2.606, training acc: 12.155
Validation loss: 2.606, validation acc: 12.136
--------------------------------------------------
[INFO]: Epoch 10 of 10
Training


100%|██████████| 145/145 [00:03<00:00, 45.32it/s] 

Validation





Training loss: 2.606, training acc: 12.155
Validation loss: 2.605, validation acc: 12.136
--------------------------------------------------
Saving final model...


  state_dict = torch.load(model_path)['model_state_dict']


Using validation to compute ROC curve


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
100%|██████████| 37/37 [00:00<00:00, 127.78it/s]


Completed training for 64 directions and 64 thresholds
num_classes= 14
ECT data; using only normalize, rotation transforms on training data
[INFO]: Epoch 1 of 10
Training


100%|██████████| 125/125 [00:03<00:00, 33.04it/s]

Validation





Training loss: 2.755, training acc: 7.265
Validation loss: 2.640, validation acc: 6.074

Best validation loss: 2.6400466188788414

Saving best model for epoch: 1

--------------------------------------------------
[INFO]: Epoch 2 of 10
Training


100%|██████████| 125/125 [00:03<00:00, 32.25it/s]

Validation





Training loss: 2.639, training acc: 7.254
Validation loss: 2.640, validation acc: 6.074

Best validation loss: 2.639979511499405

Saving best model for epoch: 2

--------------------------------------------------
[INFO]: Epoch 3 of 10
Training


100%|██████████| 125/125 [00:04<00:00, 30.95it/s]

Validation





Training loss: 2.639, training acc: 7.420
Validation loss: 2.640, validation acc: 6.074

Best validation loss: 2.6397384107112885

Saving best model for epoch: 3

--------------------------------------------------
[INFO]: Epoch 4 of 10
Training


100%|██████████| 125/125 [00:04<00:00, 30.01it/s]

Validation





Training loss: 2.639, training acc: 7.389
Validation loss: 2.640, validation acc: 6.074
--------------------------------------------------
[INFO]: Epoch 5 of 10
Training


100%|██████████| 125/125 [00:03<00:00, 33.93it/s]

Validation





Training loss: 2.639, training acc: 7.420
Validation loss: 2.641, validation acc: 6.074
--------------------------------------------------
[INFO]: Epoch 6 of 10
Training


100%|██████████| 125/125 [00:03<00:00, 33.31it/s]

Validation





Training loss: 2.639, training acc: 7.182
Validation loss: 2.640, validation acc: 6.074
--------------------------------------------------
[INFO]: Epoch 7 of 10
Training


100%|██████████| 125/125 [00:04<00:00, 29.59it/s]

Validation





Training loss: 2.639, training acc: 7.420
Validation loss: 2.641, validation acc: 6.074
--------------------------------------------------
[INFO]: Epoch 8 of 10
Training


100%|██████████| 125/125 [00:03<00:00, 33.04it/s]

Validation





Training loss: 2.639, training acc: 7.420
Validation loss: 2.640, validation acc: 6.074
--------------------------------------------------
[INFO]: Epoch 9 of 10
Training


100%|██████████| 125/125 [00:03<00:00, 33.86it/s]

Validation





Training loss: 2.639, training acc: 7.079
Validation loss: 2.641, validation acc: 6.074
--------------------------------------------------
[INFO]: Epoch 10 of 10
Training


100%|██████████| 125/125 [00:03<00:00, 33.55it/s]

Validation





Training loss: 2.639, training acc: 7.420
Validation loss: 2.641, validation acc: 6.074
--------------------------------------------------
Saving final model...


  state_dict = torch.load(model_path)['model_state_dict']


Using validation to compute ROC curve


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
100%|██████████| 32/32 [00:00<00:00, 85.07it/s]


Completed training for 128 directions and 128 thresholds
num_classes= 14
ECT data; using only normalize, rotation transforms on training data
[INFO]: Epoch 1 of 10
Training


100%|██████████| 125/125 [00:05<00:00, 21.37it/s]

Validation





Training loss: 3.095, training acc: 6.996
Validation loss: 2.639, validation acc: 6.736

Best validation loss: 2.638704888522625

Saving best model for epoch: 1

--------------------------------------------------
[INFO]: Epoch 2 of 10
Training


100%|██████████| 125/125 [00:06<00:00, 19.67it/s]

Validation





Training loss: 2.640, training acc: 6.727
Validation loss: 2.639, validation acc: 7.438
--------------------------------------------------
[INFO]: Epoch 3 of 10
Training


100%|██████████| 125/125 [00:05<00:00, 22.11it/s]

Validation





Training loss: 2.640, training acc: 6.779
Validation loss: 2.639, validation acc: 6.736
--------------------------------------------------
[INFO]: Epoch 4 of 10
Training


100%|██████████| 125/125 [00:05<00:00, 22.31it/s]

Validation





Training loss: 2.639, training acc: 6.407
Validation loss: 2.639, validation acc: 7.438
--------------------------------------------------
[INFO]: Epoch 5 of 10
Training


100%|██████████| 125/125 [00:05<00:00, 21.72it/s]

Validation





Training loss: 2.639, training acc: 6.934
Validation loss: 2.639, validation acc: 6.736
--------------------------------------------------
[INFO]: Epoch 6 of 10
Training


100%|██████████| 125/125 [00:05<00:00, 21.65it/s]

Validation





Training loss: 2.639, training acc: 7.058
Validation loss: 2.640, validation acc: 6.281
--------------------------------------------------
[INFO]: Epoch 7 of 10
Training


100%|██████████| 125/125 [00:05<00:00, 21.09it/s]

Validation





Training loss: 2.639, training acc: 7.099
Validation loss: 2.640, validation acc: 6.281
--------------------------------------------------
[INFO]: Epoch 8 of 10
Training


100%|██████████| 125/125 [00:05<00:00, 22.32it/s]

Validation





Training loss: 2.639, training acc: 7.213
Validation loss: 2.639, validation acc: 6.281
--------------------------------------------------
[INFO]: Epoch 9 of 10
Training


100%|██████████| 125/125 [00:05<00:00, 22.32it/s]

Validation





Training loss: 2.639, training acc: 7.368
Validation loss: 2.640, validation acc: 6.281
--------------------------------------------------
[INFO]: Epoch 10 of 10
Training


100%|██████████| 125/125 [00:06<00:00, 20.11it/s]

Validation





Training loss: 2.639, training acc: 7.368
Validation loss: 2.640, validation acc: 6.281
--------------------------------------------------
Saving final model...


  state_dict = torch.load(model_path)['model_state_dict']


Using validation to compute ROC curve


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  ax = plt.figure().add_subplot(111)
100%|██████████| 32/32 [00:00<00:00, 57.99it/s] 


Completed training for 256 directions and 256 thresholds
num_classes= 14
