In [1]:
import sys
sys.path.append('../')

**Import Module**

In [2]:
import torch.nn as nn
from torchvision import models
import timm

from models.trainer import Trainer
from datasets.data_manager import DataManager
from models.transform_manager import TransformManager
from models.model_manager import ModelManager
from models.training_manager import TrainingManager
from models.test_manager import TestManager

from results.metrics_visualizer import MetricsVisualizer
from results.metrics_calculator import MetricsCalculator
from results.plot_visualizer import PlotVisualizer
from results.heatmap_generator import HeatmapGenerator

**Define Model Class**

In [3]:
class EfficientNetB0MultiLabel(nn.Module):

    def __init__(self, num_classes):
        super(EfficientNetB0MultiLabel, self).__init__()
        self.model = models.efficientnet_b0(pretrained=True)

        num_ftrs = self.model.classifier[1].in_features
        self.model.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(num_ftrs, num_classes)
        )

    def forward(self, x):
        return self.model(x)

**Define Model Name & Initialize Model Class**

In [4]:
modelName = "EfficientNetB0"
MODEL = EfficientNetB0MultiLabel

**Constants**

In [5]:
BATCH_SIZE = 64
LEARNING_RATE = 1e-4
SIZE = 224
FOLD_SPLITS=10
EPOCHS = 30
DEVICE = "cuda"
WEIGHT_DECAY = 1e-5
OPTIMIZER_TYPE = "RAdam"

MAIN_CSV_FILE = "../final_label.csv"
LABEL_COLS = ["Fracture", "CalvarialFracture", "OtherFracture"]
DATASET_PATH = "../Dataset_PNG"
TEST_CSV_FILE = "../test_fold.csv"

MODEL_SAVE_PATH = f"./models/{modelName}"
TRAINING_VALIDATION_PLOT_SAVE_PATH = f"./results/{modelName}"
PLOT_SAVE_PATH = f"./results/{modelName}"
GRADCAM_HEATMAP_SAVE_PATH = f"./results/{modelName}"
CAM_OUTPUT_SIZE = 512

***Initialize necessary modules***

In [6]:
data_manager = DataManager(MAIN_CSV_FILE, LABEL_COLS, FOLD_SPLITS)

transform_manager = TransformManager(DATASET_PATH, SIZE, BATCH_SIZE)

model_manager = ModelManager(LABEL_COLS, DEVICE, MODEL_SAVE_PATH, MODEL, LEARNING_RATE, WEIGHT_DECAY, OPTIMIZER_TYPE)

training_manager = TrainingManager(data_manager, transform_manager, model_manager, EPOCHS)

***Begin training***

In [None]:
training_manager.run_training()

***Show Training vs Validation Plot***

In [None]:
visualizer = MetricsVisualizer(training_manager.train_losses, training_manager.val_losses, training_manager.train_accuracies, training_manager.val_accuracies, TRAINING_VALIDATION_PLOT_SAVE_PATH)
visualizer.plot_metrics()


***Begin Inference***

In [None]:
model_predictor = TestManager(DATASET_PATH, TEST_CSV_FILE, LABEL_COLS, SIZE, MODEL, DEVICE, MODEL_SAVE_PATH, FOLD_SPLITS)

predictions, true_labels, probabilities = model_predictor.make_predictions()

***Compute and Display Metrics***

In [None]:
metrics_calculator = MetricsCalculator(LABEL_COLS)

aggregated_metrics, metrics_per_class, classification_reports, multi_label_report = metrics_calculator.compute_metrics(predictions[0], true_labels[0])
metrics_calculator.display_metrics(aggregated_metrics, metrics_per_class, classification_reports, multi_label_report)


***Display Plots***

In [None]:
plot_visualizer = PlotVisualizer(LABEL_COLS, PLOT_SAVE_PATH)
plot_visualizer.plot_multilabel_confusion_matrix(true_labels, predictions)
plot_visualizer.plot_roc_curve(true_labels, probabilities)
plot_visualizer.plot_precision_recall_curve(true_labels, probabilities)

***GRADCAM HeatMap Generation***

In [44]:
# # Sample reload example


# import results.heatmap_generator
# import importlib
# importlib.reload(results.heatmap_generator)
# from results.heatmap_generator import HeatmapGenerator
# generator = HeatmapGenerator(model_predictor, TEST_CSV_FILE, MODEL, SIZE, DATASET_PATH, CAM_OUTPUT_SIZE, GRADCAM_HEATMAP_SAVE_PATH, modelName)

# generator.plot_heatmaps()


In [None]:
generator = HeatmapGenerator(model_predictor, TEST_CSV_FILE, MODEL, SIZE, DATASET_PATH, CAM_OUTPUT_SIZE, GRADCAM_HEATMAP_SAVE_PATH, modelName)

generator.plot_heatmaps()