In [1]:
import numpy as np
import os
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

# from itertools import starmap
from utils import save_model, save_plots
from CNN_execution import plot_roc_curve, ect_train_validate, report_trained_model, find_numpy_files

In [2]:
# Parameters required to define the model. 
# Will remain same throught the excerise.
 
NUM_EPOCHS = 10 # number of epochs to train the network for; type=int
LEARNING_RATE = 1e-3 # learning rate for training; type=float
# loss function
lossfcn = nn.CrossEntropyLoss()

# Number of workers for dataloader
num_workers = int( os.environ.get('SLURM_CPUS_PER_TASK', default=0) ) - 2
num_workers = 16

# device
device = ('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
data_dir = '../../data'
classes = [
    i
        for i in os.listdir(data_dir)
        if os.path.isdir(os.path.join(data_dir, i))
]
class_items = {
    i: find_numpy_files(os.path.join(data_dir, i))
        for i in classes
}

class_items.pop('Transect')
class_items.pop('Leafsnap')

num_data_to_use_for_training = min( [len(class_items[i]) for i in class_items] )
num_data_to_use_for_training = 30
print(f"Using {num_data_to_use_for_training} data for training")

class_items = {
    class_name: np.random.choice( file_paths, num_data_to_use_for_training, replace=False)
        for class_name, file_paths in class_items.items()
}

In [None]:
for i in range(2,8):
    i = 2**i
    directions , thresholds = int(i), int(i)
    trained_model = ect_train_validate(
        num_dirs=directions,
        num_thresh=thresholds,
        input_path=class_items,
        output_ect_path=f'outputs/output_{directions}_{thresholds}/ect',
        output_model_path=f'outputs/output_{directions}_{thresholds}/best_model.pth',
        num_workers=num_workers,
        parallel=True,
        batch_size=4,
        num_epochs=NUM_EPOCHS,
        log_level='None'
    )
    save_model(
        epochs=trained_model["num_epochs"],
        model=trained_model["model"],
        optimizer=trained_model["optimizer"],
        criterion=trained_model["lossfcn"],
        output_model_path=f'outputs/output_{directions}_{thresholds}/best_model.pth',
    )
    loss, acc = plt.figure(figsize=(9,5)).subplots(1, 2)
    save_plots(
        train_acc= trained_model["train_acc"],
        valid_acc= trained_model["valid_acc"],
        train_loss= trained_model["train_loss"],
        valid_loss= trained_model["valid_loss"],
        loss=loss,
        accuracy=acc,
        accuracy_path=f'outputs/output_{directions}_{thresholds}/accuracy_loss.png',
        loss_path=f'outputs/output_{directions}_{thresholds}/accuracy_loss.png'
    )
    ax = plt.figure( figsize=(24,24), dpi=300 ).add_subplot(111)
    report_trained_model(
        num_dirs=directions,
        num_thresh=thresholds,
        train_dataset=trained_model["train_dataset"],
        train_loader=trained_model["train_loader"],
        test_loader=trained_model["test_loader"],
        test_dataset=trained_model["test_dataset"],
        ax=ax,
        model_path=f'outputs/output_{directions}_{thresholds}/best_model.pth',
        output_cf=f'outputs/output_{directions}_{thresholds}/confusion_matrix.png',
        output_report=f'outputs/output_{directions}_{thresholds}/accuracy.txt',
        log_level='None'
    )
    plot_roc_curve(
        model=trained_model["model"],
        test_loader=trained_model["test_loader"],
        test_dataset=trained_model["test_dataset"],
        output_path=f'outputs/output_{directions}_{thresholds}/roc_curve.png'
    )
    plt.close('all')
    print(f"Completed training for {directions} directions and {thresholds} thresholds")