Import

In [1]:
!git clone https://github.com/AndreasHammerKU/CardiacCTAnalysis.git
%cd CardiacCTAnalysis

Cloning into 'CardiacCTAnalysis'...
remote: Enumerating objects: 263, done.[K
remote: Counting objects: 100% (40/40), done.[K
remote: Compressing objects: 100% (34/34), done.[K
remote: Total 263 (delta 9), reused 15 (delta 6), pack-reused 223 (from 1)[K
Receiving objects: 100% (263/263), 1.66 MiB | 4.45 MiB/s, done.
Resolving deltas: 100% (135/135), done.
/content/CardiacCTAnalysis


Imports from Github Repository

In [2]:
!pip install dash
import numpy as np

# Custom Imports
import utils.io_utils as io
import utils.logger as logs
from baseline.BaseEnvironment import MedicalImageEnvironment
from utils.io_utils import DataLoader
from baseline.BaseAgent import DQNAgent

# Mount Drive
from google.colab import drive
drive.mount('/content/drive')

dataset_folder = '/content/drive/MyDrive/Data'

Collecting dash
  Downloading dash-2.18.2-py3-none-any.whl.metadata (10 kB)
Collecting Flask<3.1,>=1.0.4 (from dash)
  Downloading flask-3.0.3-py3-none-any.whl.metadata (3.2 kB)
Collecting Werkzeug<3.1 (from dash)
  Downloading werkzeug-3.0.6-py3-none-any.whl.metadata (3.7 kB)
Collecting dash-html-components==2.0.0 (from dash)
  Downloading dash_html_components-2.0.0-py3-none-any.whl.metadata (3.8 kB)
Collecting dash-core-components==2.0.0 (from dash)
  Downloading dash_core_components-2.0.0-py3-none-any.whl.metadata (2.9 kB)
Collecting dash-table==5.0.0 (from dash)
  Downloading dash_table-5.0.0-py3-none-any.whl.metadata (2.4 kB)
Collecting retrying (from dash)
  Downloading retrying-1.3.4-py3-none-any.whl.metadata (6.9 kB)
Downloading dash-2.18.2-py3-none-any.whl (7.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m69.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dash_core_components-2.0.0-py3-none-any.whl (3.8 kB)
Downloading dash_html_compo

Hyperparameters

In [6]:
MAX_STEPS = 500
EPISODES = 40
IMAGE_INTERVAL = 2
AGENTS = 6
N_SAMPLE_POINTS = 5
EVALUATION_STEPS = 30
DECAY = 200 # EPSILON = MIN_EPS - (MAX_EPS - MIN_EPS) * e^(-1 * current_step / decay)

In [7]:
debug = False

# Colab not enough RAM
preload_images = False
logger = logs.setup_logger(debug)

dataLoader = DataLoader(dataset_folder)

Training

In [8]:
# Initialize training environment
train_env = MedicalImageEnvironment(logger=logger,
                              dataLoader=dataLoader,
                              image_list=['n1', 'n2', 'n3', 'n4', 'n5', 'n6', 'n7', 'n8', 'n9', 'n10', 'n11', 'n12', 'n13', 'n14', 'n15', 'n16', 'n17', 'n18', 'n19', 'n20', 'n21', 'n22', 'n23', 'n24', 'n25', 'n26', 'n27', 'n28', 'n29', 'n30'],
                              agents=AGENTS,
                              n_sample_points=N_SAMPLE_POINTS,
                              preload_images=preload_images)
eval_env = MedicalImageEnvironment(logger=logger,
                              task="eval",
                              dataLoader=dataLoader,
                              image_list=['n31', 'n32', 'n33', 'n34', 'n35', 'n36', 'n37', 'n38', 'n39', 'n40'],
                              agents=AGENTS,
                              n_sample_points=N_SAMPLE_POINTS)
agent = DQNAgent(train_environment=train_env,
                 eval_environment=eval_env,
                 task="train",
                 logger=logger,
                 state_dim=train_env.state_size,
                 action_dim=train_env.n_actions,
                 agents=AGENTS,
                 max_steps=MAX_STEPS,
                 episodes=EPISODES,
                 decay=DECAY,
                 image_interval=IMAGE_INTERVAL,
                 evaluation_steps=EVALUATION_STEPS)

agent.train_dqn()

INFO:Logger:Episode 1: Total Reward = -37.18 | Final Avg Distance 67.05 | All Reached Goal False | Avg Closest Point = 15.57 | Avg Furthest Point = 68.24
INFO:Logger:Episode 2: Total Reward = -13.37 | Final Avg Distance 43.25 | All Reached Goal False | Avg Closest Point = 7.41 | Avg Furthest Point = 56.27
INFO:Logger:Episode 3: Total Reward = -48.58 | Final Avg Distance 75.42 | All Reached Goal False | Avg Closest Point = 7.20 | Avg Furthest Point = 82.29
INFO:Logger:Episode 4: Total Reward = 11.27 | Final Avg Distance 15.57 | All Reached Goal False | Avg Closest Point = 4.70 | Avg Furthest Point = 36.54
INFO:Logger:Episode 5: Total Reward = 9.36 | Final Avg Distance 20.52 | All Reached Goal False | Avg Closest Point = 5.43 | Avg Furthest Point = 34.11
INFO:Logger:Episode 6: Total Reward = 25.52 | Final Avg Distance 4.36 | All Reached Goal False | Avg Closest Point = 0.50 | Avg Furthest Point = 32.13
INFO:Logger:Episode 7: Total Reward = 1.26 | Final Avg Distance 26.76 | All Reached Go

In [None]:
train_env.visualize_current_state()

In [None]:
eval_env.visualize_current_state()

Evaluate

In [None]:
test_env = MedicalImageEnvironment(logger=logger,
                              task="test",
                              dataLoader=dataLoader,
                              image_list=['n41', 'n42', 'n43', 'n44', 'n45', 'n46', 'n47', 'n48', 'n49', 'n50'],
                              agents=AGENTS,
                              n_sample_points=N_SAMPLE_POINTS)
agent = DQNAgent(train_environment=train_env,
                 eval_environment=eval_env,
                 test_environment=test_env,
                 task="test",
                 logger=logger,
                 state_dim=test_env.state_size,
                 action_dim=test_env.n_actions,
                 agents=AGENTS,
                 model_path="latest-model.pt",
                 max_steps=MAX_STEPS,
                 episodes=EPISODES,
                 evaluation_steps=25
                 )

agent.test_dqn()

In [None]:
test_env.visualize_current_state()