In [None]:
import logging
import time
import json
import matplotlib
import timm
import torch
from fastai.vision.all import (
    F1Score,
    accuracy,
    Precision,
    Recall,
    DataLoaders,
    Learner,
    SaveModelCallback,
    valley,
    slide,
    ClassificationInterpretation)
from data import get_dls_from_images, get_dls_from_dataset
import io
import os

from datasets import Dataset, concatenate_datasets, load_dataset
from concurrent.futures import ThreadPoolExecutor
from fastai.vision.all import (
    Path,
    DataBlock,
    ImageBlock,
    CategoryBlock,
    Resize,
    aug_transforms,
    Normalize,
    get_image_files,
    parent_label,
    imagenet_stats,
)
from PIL import Image
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from fastcore.foundation import L, range_of

In [4]:
import json
with open("config.json", "r") as config_file:
    config = json.load(config_file)
model_name = config["model"]["timm_model_name"]
model_name

'efficientnet_b0'

In [26]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import pandas as pd
def get_dls_from_images():
    #path = "C:/BitBucketRepo/datascienceprojects/ChartExt/Images/" #Path(config["data"]["image_dir"])
    #img_size = config["data"]["transformations"]["img_size"]

    data = pd.read_excel("C:/BitBucketRepo/datascienceprojects/ChartExt/res.xlsx")
    data['label'] = data['label'].apply(str)
    labels = data['label']
    X_train, X_temp = train_test_split(data, test_size=0.2, stratify=labels, random_state = 42)
    label_test_val = X_temp['label']
    X_test, X_val = train_test_split(X_temp, test_size=0.5, stratify=label_test_val, random_state = 42)

    image_size = 128
    image_channel = 3
    bat_size = 32

    train_datagen = ImageDataGenerator(rescale=1./255,
                                    rotation_range = 15,
                                    horizontal_flip = True,
                                    zoom_range = 0.2,
                                    shear_range = 0.1,
                                    fill_mode = 'reflect',
                                    width_shift_range = 0.1,
                                    height_shift_range = 0.1)

    test_datagen = ImageDataGenerator(rescale=1./255)

    train_generator = train_datagen.flow_from_dataframe(X_train,
                                                    directory = 'C:/BitBucketRepo/datascienceprojects/ChartExt/Images/',
                                                    x_col= 'filename',
                                                    y_col= 'label',
                                                    batch_size = bat_size,
                                                    target_size = (image_size,image_size)
                                                   )
    val_generator = test_datagen.flow_from_dataframe(X_val, 
                                                    directory = 'C:/BitBucketRepo/datascienceprojects/ChartExt/Images/',
                                                    x_col= 'filename',
                                                    y_col= 'label',
                                                    batch_size = bat_size,
                                                    target_size = (image_size,image_size),
                                                    shuffle=False
                                                    )

    test_generator = test_datagen.flow_from_dataframe(X_test, 
                                                    directory = 'C:/BitBucketRepo/datascienceprojects/ChartExt/Images/',
                                                    x_col= 'filename',
                                                    y_col= 'label',
                                                    batch_size = bat_size,
                                                    target_size = (image_size,image_size),
                                                    shuffle=False
                                                    )
    
    return train_generator, val_generator, test_generator
#train_dl, val_dl, test_dl = get_dls_from_images(config=config)

In [17]:
image_size = 128
image_channel = 3
bat_size = 32

In [42]:
from tensorflow.keras.callbacks import ReduceLROnPlateau,EarlyStopping
from keras.layers import Dense
from keras.optimizers import Adam
from sklearn.model_selection import train_test_split
from efficientnet.keras import EfficientNetB3
from keras.models import Sequential
class ChartRecognizer:
    def __init__(self):
        # Read model args from config.json
        with open("config.json", "r") as config_file:
            self.config = json.load(config_file)

        model_name = self.config["model"]["timm_model_name"]

        # Change format of timm_models
        '''timm_models = [
            model.split(".")[0] for model in timm.list_models(pretrained=True)
        ]
        if model_name not in timm_models:
            raise ValueError(
                f"Model {model_name} not found in timm.list_models(pretrained=True)"
            )

        self.model = timm.create_model(model_name, pretrained=True, num_classes=2)'''
        efficient_net = EfficientNetB3(
        weights='imagenet',
        input_shape=(128,128,3),
        include_top=False,
        pooling='max')

        self.model = Sequential()
        self.model.add(efficient_net)
        self.model.add(Dense(units = 120, activation='relu'))
        self.model.add(Dense(units = 120, activation = 'relu'))
        self.model.add(Dense(units = 1, activation='sigmoid'))

        # For converting config.json to function
        self.metrics_dict = {
            "f1_score": F1Score(),
            "precision": Precision(),
            "recall": Recall(),
            "accuracy": accuracy,
        }

    def train(self):
        train_generator, val_generator, test_generator = get_dls_from_images()

        

        # Find an appropriate learning rate
        # learn.lr_find()
        #suggested_lr = learn.lr_find(suggest_funcs=(valley, slide))[0]
        
        learning_rate_reduction = ReduceLROnPlateau(monitor = 'val_accuracy',
                                            patience=2,
                                            factor=0.5,
                                            min_lr = 0.00001,
                                            verbose = 1)
        early_stoping = EarlyStopping(monitor='val_loss',patience= 3,restore_best_weights=True,verbose=0)

        self.model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy'])

        self.model.fit(train_generator,
                    validation_data = val_generator, 
                    callbacks=[early_stoping,learning_rate_reduction],
                    epochs = 15,
                    # steps_per_epoch = len(train_generator),
                    # validation_steps = len(val_generaotor),
                   )
    def predictt(self):
        train_generator, val_generator, test_generator = get_dls_from_images()
        self.model.predict(test_generator)

In [43]:
if __name__ == "__main__":
    model = ChartRecognizer()
    model.train()

Found 116 validated image filenames belonging to 2 classes.
Found 15 validated image filenames belonging to 2 classes.
Found 15 validated image filenames belonging to 2 classes.
Epoch 1/15
Epoch 2/15
Epoch 3/15
Epoch 3: ReduceLROnPlateau reducing learning rate to 0.0005000000237487257.
Epoch 4/15
Epoch 5/15
Epoch 5: ReduceLROnPlateau reducing learning rate to 0.0002500000118743628.
Epoch 6/15
Epoch 7/15
Epoch 7: ReduceLROnPlateau reducing learning rate to 0.0001250000059371814.
Epoch 8/15
Epoch 9/15
Epoch 9: ReduceLROnPlateau reducing learning rate to 6.25000029685907e-05.
Epoch 10/15
Epoch 11/15
Epoch 11: ReduceLROnPlateau reducing learning rate to 3.125000148429535e-05.
Epoch 12/15
Epoch 13/15
Epoch 13: ReduceLROnPlateau reducing learning rate to 1.5625000742147677e-05.
Epoch 14/15
Epoch 15/15
Epoch 15: ReduceLROnPlateau reducing learning rate to 1e-05.


In [45]:
from timm.data import resolve_data_config, create_transform
train_generator, val_generator, test_generator = get_dls_from_images()
model.predictt()

Found 116 validated image filenames belonging to 2 classes.
Found 15 validated image filenames belonging to 2 classes.
Found 15 validated image filenames belonging to 2 classes.
Found 116 validated image filenames belonging to 2 classes.
Found 15 validated image filenames belonging to 2 classes.
Found 15 validated image filenames belonging to 2 classes.
