Skip to content

Commit

Permalink
added resnet model
Browse files Browse the repository at this point in the history
  • Loading branch information
adnaniazi committed Jul 10, 2024
1 parent fc35d4c commit 6780429
Show file tree
Hide file tree
Showing 10 changed files with 217 additions and 26 deletions.
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ repos:
language: system
types: [python]
pass_filenames: false
- id: ruff
name: ruff
entry: poetry run ruff check
language: system
types: [python]
- id: flake8
name: flake8
entry: poetry run flake8
Expand Down
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
### Added
- Added support for resnet model

## [0.1.7] - 2024-07-08
### Fixed
- Fixed encoder model hogging all available GPU memory and crashing

Expand Down Expand Up @@ -38,7 +42,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Added
- Basic skeleton of the package and tested it

[Unreleased]: https://github.com/adnaniazi/capfinder/compare/0.1.6...master
[Unreleased]: https://github.com/adnaniazi/capfinder/compare/0.1.7...master
[0.1.7]: https://github.com/adnaniazi/capfinder/compare/0.1.6...0.1.7
[0.1.6]: https://github.com/adnaniazi/capfinder/compare/0.1.5...0.1.6
[0.1.5]: https://github.com/adnaniazi/capfinder/compare/0.1.4...0.1.5
[0.1.4]: https://github.com/adnaniazi/capfinder/compare/0.1.3...0.1.4
Expand Down
29 changes: 28 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ python-kacl = "*"
pyupgrade = "*"
tryceratops = "*"

[tool.poetry.group.dev.dependencies]
ruff = "^0.5.1"

[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
Expand Down
9 changes: 5 additions & 4 deletions src/capfinder/cnn_lstm_model.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import math

from keras.layers import (
from ml_libs import (
LSTM,
Adam,
BatchNormalization,
Conv1D,
Dense,
Dropout,
HyperModel,
HyperParameters,
Input,
MaxPooling1D,
Model,
)
from keras.models import Model
from keras.optimizers import Adam
from keras_tuner import HyperModel, HyperParameters


class CapfinderHyperModel(HyperModel):
Expand Down
2 changes: 1 addition & 1 deletion src/capfinder/data_loader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Tuple

import tensorflow as tf
from ml_libs import tf


def parse_features(line: tf.Tensor, num_timesteps: int) -> tf.Tensor:
Expand Down
4 changes: 1 addition & 3 deletions src/capfinder/encoder_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from typing import List, Optional, Tuple

import keras
from keras import Model, layers
from keras_tuner import HyperModel, HyperParameters
from ml_libs import HyperModel, HyperParameters, Model, keras, layers


def transformer_encoder(
Expand Down
20 changes: 20 additions & 0 deletions src/capfinder/ml_libs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import os

os.environ["KERAS_BACKEND"] = "jax"

# from keras import Model

import logging
import warnings

# Disable TensorFlow logging
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # FATAL
logging.getLogger("tensorflow").setLevel(logging.FATAL)

# Filter out specific warnings
warnings.filterwarnings("ignore", category=UserWarning, message=".*cuFFT.*")
warnings.filterwarnings("ignore", category=UserWarning, message=".*cuDNN.*")
warnings.filterwarnings("ignore", category=UserWarning, message=".*cuBLAS.*")
warnings.filterwarnings("ignore", category=UserWarning, message=".*TensorRT.*")

# Now import TensorFlow
127 changes: 127 additions & 0 deletions src/capfinder/resnet_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from typing import Tuple

from ml_libs import HyperModel, HyperParameters, Model, keras, layers, tf


class ResNetBlockHyper(layers.Layer):
def __init__(self, filters: int, kernel_size: int, strides: int = 1):
super().__init__()
self.conv1 = layers.Conv1D(
filters, kernel_size, strides=strides, padding="same"
)
self.bn1 = layers.BatchNormalization()
self.conv2 = layers.Conv1D(filters, kernel_size, padding="same")
self.bn2 = layers.BatchNormalization()

self.shortcut = layers.Conv1D(filters, 1, strides=strides, padding="same")
self.bn_shortcut = layers.BatchNormalization()

def call(self, inputs: tf.Tensor) -> tf.Tensor:
x = self.conv1(inputs)
x = self.bn1(x)
x = keras.activations.relu(x)

x = self.conv2(x)
x = self.bn2(x)

shortcut = self.shortcut(inputs)
shortcut = self.bn_shortcut(shortcut)

x = layers.add([x, shortcut])
return keras.activations.relu(x)

def compute_output_shape(
self, input_shape: Tuple[int, int]
) -> Tuple[int, int, int]:
return input_shape[0], input_shape[1], self.conv1.filters


class ResNetTimeSeriesHyper(HyperModel):
"""
A HyperModel class for building a ResNet-style neural network for time series classification.
This class defines a tunable ResNet architecture that can be optimized using Keras Tuner.
It creates a model with an initial convolutional layer, followed by a variable number of
ResNet blocks, and ends with global average pooling and dense layers.
Attributes:
input_shape (Tuple[int, int]): The shape of the input data (timesteps, features).
n_classes (int): The number of classes for classification.
Methods:
build(hp): Builds and returns a compiled Keras model based on the provided hyperparameters.
"""

def __init__(self, input_shape: Tuple[int, int], n_classes: int):
self.input_shape = input_shape
self.n_classes = n_classes
self.encoder_model = None

def build(self, hp: HyperParameters) -> Model:
"""
Build and compile a ResNet model based on the provided hyperparameters.
This method constructs a ResNet architecture with tunable hyperparameters including
the number of filters, kernel sizes, number of ResNet blocks, dense layer units,
dropout rate, and learning rate.
Args:
hp (hp.HyperParameters): A HyperParameters object used to define the search space.
Returns:
Model: A compiled Keras model ready for training.
"""
inputs = keras.Input(shape=self.input_shape)

# Initial convolution
initial_filters = hp.Int(
"initial_filters", min_value=32, max_value=128, step=32
)
x = layers.Conv1D(
initial_filters,
kernel_size=hp.Choice("initial_kernel", values=[3, 5, 7]),
padding="same",
)(inputs)
x = layers.BatchNormalization()(x)
x = keras.activations.relu(x)
x = layers.MaxPooling1D(pool_size=3, strides=2, padding="same")(x)

# ResNet blocks
num_blocks_per_stage = hp.Int("num_blocks_per_stage", min_value=2, max_value=4)
num_stages = hp.Int("num_stages", min_value=2, max_value=4)

for stage in range(num_stages):
filters = hp.Int(
f"filters_stage_{stage}", min_value=64, max_value=256, step=64
)
for block in range(num_blocks_per_stage):
kernel_size = hp.Choice(
f"kernel_stage_{stage}_block_{block}", values=[3, 5, 7]
)
strides = 2 if block == 0 and stage > 0 else 1
x = ResNetBlockHyper(filters, kernel_size, strides)(x)

# Global pooling and output
x = layers.GlobalAveragePooling1D()(x)
x = layers.Dense(
hp.Int("dense_units", min_value=32, max_value=256, step=32),
activation="relu",
)(x)
x = layers.Dropout(hp.Float("dropout", min_value=0.0, max_value=0.5, step=0.1))(
x
)
outputs = layers.Dense(self.n_classes, activation="softmax")(x)

model = Model(inputs, outputs)

model.compile(
optimizer=keras.optimizers.Adam(
hp.Float(
"learning_rate", min_value=1e-4, max_value=1e-2, sampling="log"
)
),
loss="sparse_categorical_crossentropy",
metrics=["sparse_categorical_accuracy"],
)

return model
37 changes: 21 additions & 16 deletions src/capfinder/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,40 +6,45 @@
from typing import Dict, Literal, Optional, Tuple, Type, Union

import comet_ml
import jax
from comet_ml import Experiment # Import CometML before keras

os.environ["KERAS_BACKEND"] = (
"jax" # the placement is important, it cannot be after keras import
)

import keras
import numpy as np
import pandas as pd
import tensorflow as tf
from keras_tuner import BayesianOptimization, Hyperband, Objective, RandomSearch
from comet_ml import Experiment # Import CometML before keras
from loguru import logger
from ml_libs import (
BayesianOptimization,
Hyperband,
Objective,
RandomSearch,
jax,
keras,
tf,
)
from sklearn.metrics import confusion_matrix
from tqdm import tqdm

from capfinder.cnn_lstm_model import CapfinderHyperModel as CNNLSTMModel
from capfinder.data_loader import load_datasets
from capfinder.encoder_model import CapfinderHyperModel as EncoderModel
from capfinder.resnet_model import ResNetTimeSeriesHyper as ResnetModel
from capfinder.train_etl import train_etl
from capfinder.utils import map_cap_int_to_name

# Declare and initialize global stop_training flag
global stop_training
stop_training = False

ModelType = Literal["cnn_lstm", "encoder"]
ModelType = Literal["cnn_lstm", "encoder", "resnet"]


def get_model(model_type: ModelType) -> Type[CNNLSTMModel] | Type[EncoderModel]:
def get_model(
model_type: ModelType,
) -> Type[CNNLSTMModel] | Type[EncoderModel] | Type[ResnetModel]:
if model_type == "cnn_lstm":
return CNNLSTMModel
elif model_type == "encoder":
return EncoderModel
elif model_type == "resnet":
return ResnetModel


def handle_interrupt(
Expand Down Expand Up @@ -459,7 +464,7 @@ def run_training_pipeline(
project_name=tune_params["comet_project_name"] + "_" + model_type
)
tune_experiment_url = tune_experiment.url
if model_type not in ["cnn_lstm", "encoder"]:
if model_type not in ["cnn_lstm", "encoder", "resnet"]:
raise ValueError("Invalid model type. Expected 'cnn_lstm' or 'encoder'.")
model = get_model(model_type)

Expand Down Expand Up @@ -763,8 +768,8 @@ def run_training_pipeline(
"factor": 2,
"batch_size": 4,
"seed": 42,
"tuning_strategy": "hyperband", # "hyperband" or "random_search" or "bayesian_optimization"
"overwrite": True,
"tuning_strategy": "bayesian_optimization", # "hyperband" or "random_search" or "bayesian_optimization"
"overwrite": False,
}

train_params = {
Expand All @@ -777,7 +782,7 @@ def run_training_pipeline(
model_save_dir = (
"/export/valenfs/data/processed_data/MinION/9_madcap/5_trained_models_202405/"
)
model_type: ModelType = "encoder" # cnn_lstm
model_type: ModelType = "resnet" # cnn_lstm, resnet, encoder

# Run the training pipeline
run_training_pipeline(
Expand Down

0 comments on commit 6780429

Please sign in to comment.