-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
17a0f7e
commit fb67279
Showing
14 changed files
with
1,242 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
""" | ||
This module defines a custom callback PrintAndSaveStats to monitor and save various statistics during the training of a | ||
MLP model. The callback logs information such as epoch timings, accuracy, loss, and metrics like precision and recall. | ||
It also computes aggregates like total training time and best accuracy achieved. Additionally, it writes these | ||
statistics to a file and logs them for TensorBoard visualization. The get_callbacks function generates a list of | ||
callbacks including Early Stopping, model checkpointing, the custom PrintAndSaveStats, and TensorBoard logging, | ||
tailored for a specific model with given parameters. | ||
""" | ||
|
||
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard | ||
import tensorflow as tf | ||
import datetime | ||
import time | ||
from parameters import get_tensorboard_path | ||
|
||
|
||
class PrintAndSaveStats(tf.keras.callbacks.Callback): | ||
|
||
def __init__(self, model_name): | ||
self.epoch_time_start = None | ||
self.model_name = model_name | ||
self.total_time = 0 | ||
self.last_epoch = 1 | ||
self.best_acc = 0 | ||
self.best_epoch = 1 | ||
self.first_acc = 0 | ||
self.last_acc = 0 | ||
self.last_loss = 0 | ||
self.last_f1_micro = 0 | ||
self.last_f1_macro = 0 | ||
self.last_precision = 0 | ||
self.last_recall = 0 | ||
|
||
def on_epoch_begin(self, batch, logs={}): | ||
self.epoch_time_start = time.time() | ||
|
||
def on_epoch_end(self, epoch, logs): | ||
epoch += 1 | ||
if epoch == 1: | ||
self.first_acc = logs["val_accuracy"] | ||
print('Epoch {} finished at {}'.format(epoch, datetime.datetime.now().time())) | ||
print(f"Printing log object:\n{logs}") | ||
elapsed_time = int((time.time() - self.epoch_time_start)) | ||
print(f"Elaspsed time: {elapsed_time}") | ||
if logs["loss"] != 0: | ||
print("val/train loss: {:.2f}".format(logs["val_loss"] / logs["loss"])) | ||
if logs["accuracy"] != 0: | ||
print("val/train acc: {:.2f}".format(logs["val_accuracy"] / logs["accuracy"])) | ||
file1 = open(get_history_path(self.model_name), "a") # append mode | ||
SEPARATOR = ";" | ||
file1.write(str(epoch) + SEPARATOR + str(datetime.datetime.now().time()) + SEPARATOR + | ||
str(elapsed_time) + SEPARATOR + str(logs["accuracy"]) + SEPARATOR + | ||
str(logs["val_accuracy"]) + SEPARATOR + str(logs["loss"]) + SEPARATOR + str(logs["val_loss"]) | ||
+ "\n") | ||
file1.close() | ||
self.compute_aggregates(elapsed_time, logs["val_accuracy"], epoch) | ||
|
||
self.last_acc = logs["val_accuracy"] | ||
self.last_loss = logs["val_loss"] | ||
# self.last_f1_micro = logs["val_f1_micro"] | ||
# self.last_f1_macro = logs["val_f1_macro"] | ||
self.last_precision = logs["val_precision"] | ||
self.last_recall = logs["val_recall"] | ||
with tf.summary.create_file_writer(get_tensorboard_path()).as_default(): | ||
tf.summary.scalar("val_accuracy", logs["val_accuracy"], step=epoch) | ||
tf.summary.scalar("val_loss", logs["val_loss"], step=epoch) | ||
tf.summary.scalar("train_accuracy", logs["accuracy"], step=epoch) | ||
tf.summary.scalar("train_loss", logs["loss"], step=epoch) | ||
tf.summary.scalar("time", elapsed_time, step=epoch) | ||
tf.summary.scalar("precision", logs["val_precision"], step=epoch) | ||
tf.summary.scalar("recall", logs["val_recall"], step=epoch) | ||
# tf.summary.scalar("f1_macro", logs["val_f1_macro"], step=epoch) | ||
# tf.summary.scalar("f1_micro", logs["val_f1_micro"], step=epoch) | ||
|
||
def compute_aggregates(self, elapsed_time: int, val_acc, epoch: int): | ||
self.total_time += elapsed_time | ||
self.last_epoch = epoch | ||
if val_acc > self.best_acc: | ||
self.best_acc = val_acc | ||
self.best_epoch = epoch | ||
|
||
def get_stats(self): | ||
return [int(self.total_time / self.last_epoch), self.first_acc, self.best_acc, self.best_epoch, self.last_epoch] | ||
|
||
|
||
def get_history_path(model_name: str): | ||
return model_name + "_history.csv" | ||
|
||
|
||
def get_best_model_path(model_name: str): | ||
return model_name + "_checkpoint.h5" | ||
|
||
|
||
def get_callbacks(model_name: str, early_patience: int) -> list: | ||
early_stopping = EarlyStopping(monitor="val_loss", mode="min", patience=early_patience, | ||
restore_best_weights=True, verbose=1) | ||
save_best_model = ModelCheckpoint(get_best_model_path(model_name), save_best_only=True, monitor="val_loss", verbose=1) | ||
save_model_stats = PrintAndSaveStats(model_name) | ||
tensorboard = TensorBoard(get_tensorboard_path()) | ||
return [save_best_model, save_model_stats, early_stopping, tensorboard] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
""" | ||
This module calculates class weights for the MLP training by first extracting class labels from the training dataset. | ||
Then, it computes class weights using scikit-learn compute_class_weight function to address | ||
class imbalance. Finally, it returns a dictionary mapping class indices to their respective weights.""" | ||
|
||
from lazy_load import load_ds_lazy | ||
from parameters import * | ||
from pickle_load import pickle_to_tensor | ||
import numpy as np | ||
from sklearn.utils.class_weight import compute_class_weight | ||
from typing import Dict, Any | ||
|
||
|
||
def get_class_weight(ds_train_y) -> Dict[int, Any]: | ||
class_labels = np.argmax(ds_train_y, axis=1) | ||
class_weights = compute_class_weight('balanced', classes=np.unique(class_labels), y=class_labels) | ||
cw_dict = {} | ||
for lang_index in range(0, class_weights.shape[0]): | ||
cw_dict[lang_index] = class_weights[lang_index] | ||
return cw_dict | ||
|
||
|
||
if __name__ == '__main__': | ||
for config in process_args(): | ||
ARG_MAP = config | ||
if ARG_MAP[LAZY_LOAD]: | ||
train_ds, val_ds = load_ds_lazy(ARG_MAP[BATCH_SIZE], ARG_MAP[N_LABELS], ARG_MAP[EPOCHS]) | ||
else: | ||
train_ds, val_ds = get_pickle_paths() | ||
train_y = pickle_to_tensor(train_ds + "_labels") | ||
print(f"{get_class_weight(train_y)} vector for {ARG_MAP[N_SAMPLES]} samples") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
""" | ||
This module transforms text files into integer vectors mapping each character to its UTF-8 value and stores them as CSV | ||
files. It reads the texts in the corpus and splits the different lines, concatenates them with their corresponding | ||
labels, and then batches them for efficient processing. It subsequently iterates through the resulting dataset, | ||
writing batches of lines into separate CSV files. Additionally, it also handles the dataset stratification.""" | ||
|
||
import tensorflow as tf | ||
from typing import List | ||
import os | ||
|
||
N_CORES = tf.data.AUTOTUNE | ||
|
||
|
||
def eff_write(ds, folder, file_counter=0, lines_per_file=1000): | ||
if not os.path.exists(folder): | ||
os.mkdir(folder) | ||
ds = ds.map(lambda line, label: tf.strings.reduce_join(tf.strings.as_string(tf.concat([line, [label]], axis=0)) | ||
, separator=",")) | ||
ds = ds.batch(lines_per_file) | ||
FILE_NAME = "f" | ||
for batch in ds: | ||
complete_name = FILE_NAME + str(file_counter) + ".csv" | ||
file_content = str(tf.strings.reduce_join(batch, separator="\n").numpy(), encoding="ascii") | ||
with open(folder + "/" + complete_name, 'w') as csvfile: | ||
csvfile.write(file_content) | ||
file_counter += 1 | ||
return file_counter | ||
|
||
|
||
def line_to_raw_int_ds(ds: tf.data.Dataset): | ||
ds = ds.map(lambda file, label: (file, tf.cast(label, tf.dtypes.int32)), num_parallel_calls=N_CORES) | ||
ds = ds.map(lambda file, label: (tf.strings.unicode_decode(file, "UTF-16LE"), label), num_parallel_calls=N_CORES) | ||
ds = ds.map(lambda file, label: (tf.strings.unicode_encode(file, "UTF-8"), label), num_parallel_calls=N_CORES) | ||
ds = ds.interleave(lambda file, label: tf.data.Dataset.from_tensor_slices( | ||
tf.map_fn(lambda line: (line, label), tf.strings.split(tf.strings.regex_replace(file, "\r", ""), "\n"), | ||
fn_output_signature=(tf.dtypes.string, tf.int32))).shuffle(1000) | ||
, num_parallel_calls=N_CORES, | ||
deterministic=True, block_length=1, cycle_length=20_000 | ||
) | ||
LENGTH_LIMIT: int = 10 | ||
ds = ds.filter(lambda line, _: tf.strings.length(line) >= LENGTH_LIMIT) | ||
return ds.map(lambda line, label: (tf.strings.unicode_decode(line, "UTF-8"), label), num_parallel_calls=N_CORES) | ||
|
||
|
||
def get_raw_lines_ds(source_folder: str) -> tf.data.Dataset: | ||
snippet_ds: tf.data.Dataset = tf.keras.utils.text_dataset_from_directory( | ||
source_folder, label_mode='int', batch_size=None, shuffle=True) | ||
return line_to_raw_int_ds(snippet_ds) | ||
|
||
|
||
def stratify_ds(ds: tf.data.Dataset, weights: List[float]): | ||
datasets = [ds.filter(lambda _, label: label == i) for i in range(len(weights))] | ||
return tf.data.Dataset.sample_from_datasets(datasets, weights, stop_on_empty_dataset=True) | ||
|
||
|
||
if __name__ == '__main__': | ||
SNIPPET_SOURCE_FOLDER = ".\\comments_V2_TXT_test" | ||
line_ds: tf.data.Dataset = get_raw_lines_ds(SNIPPET_SOURCE_FOLDER) | ||
DS_SIZE: int = 1_000_000 | ||
N_LANGS: int = 21 | ||
WEIGHTS = [DS_SIZE / N_LANGS] * N_LANGS | ||
line_ds = stratify_ds(line_ds, WEIGHTS).take(DS_SIZE) | ||
DEST_FOLDER: str = ".\\raw_int_lines_test_balanced_1M" | ||
eff_write(line_ds, DEST_FOLDER) |
Oops, something went wrong.