In [1]:
import re
import sys

sys.path.append("..")
from rocket_rag.utils import *
from rocket_rag.node import Node
from rocket_rag.node_indexing import NodeIndexer

from typing import List, Any, Dict, Tuple

from pyts.transformation import ROCKET
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from sklearn.linear_model import RidgeClassifier
from sklearn.metrics import accuracy_score

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class BaseVectorStore():
    """Simple custom Vector Store.

    Stores documents in a simple in-memory dict.

    """

    stores_text: bool = True

    def get(self, text_id: str) -> List[float]:
        """Get vector for a text ID."""
        pass

    def add(self, nodes: List[Node]) -> List[str]:
        """Add nodes to index"""
        pass

    def delete(self, node_id: str, **delete_kwargs: Any) -> None:
        """Delete nodes using node_id."""
        pass

    def query(self, query: str, **kwargs: Any):
        """Get nodes for response"""
        pass 

In [3]:
class VectorStore(BaseVectorStore):
    """An updated version of above SimpleVectorStore"""

    stores_text: bool = True

    def __init__(self) -> None:
        """Init params"""
        self.node_dict: Dict[str, Node] = {}
        self.nodes = self.node_dict.values()
        super().__init__()

    def get(self, text_id: str) -> List[float]:
        """Get vector for a text ID."""
        return self.node_dict[text_id]
    
    def add(self, nodes: List[Node]) -> List[str]:
        """Add nodes to index"""
        for node in nodes:
            self.node_dict[node.node_id] = node
    
    def delete(self, node_id: str, **delete_kwargs: Any) -> None:
        """Delete nodes using node_id"""
        del self.node_dict[node_id]

In [4]:
loguru.logger.debug(f'Testing on vector store module...')
load = '-40kg'
node_indexer = NodeIndexer()
nodes = node_indexer.load_node_indexing(f'../store/nodes_{load}.pkl')

[32m2024-05-25 13:31:52.785[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [34m[1mTesting on vector store module...[0m
[32m2024-05-25 13:31:52.786[0m | [34m[1mDEBUG   [0m | [36mrocket_rag.node_indexing[0m:[36mload_node_indexing[0m:[36m98[0m - [34m[1mLoading all nodes...[0m
[32m2024-05-25 13:31:53.198[0m | [1mINFO    [0m | [36mrocket_rag.node_indexing[0m:[36mload_node_indexing[0m:[36m102[0m - [1mAll nodes are loaded.[0m


In [5]:
loguru.logger.debug(f'Initializing vector store...')
vector_store = VectorStore()
vector_store.add(nodes)
loguru.logger.info(f'Loaded nodes into the vector store.')

[32m2024-05-25 13:31:54.817[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [34m[1mInitializing vector store...[0m
[32m2024-05-25 13:31:54.817[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1mLoaded nodes into the vector store.[0m


In [6]:
rocket_features = np.array([node.get_rocket_feature() for node in vector_store.nodes])
doc_ids = np.array([node.id_ for node in vector_store.nodes])

In [7]:
if_files_dict = parse_files(main_directory=INFERENCE_DIR)
if_ts_files = if_files_dict[load]

In [43]:
# np.random.seed(42)
rand_idx = np.random.randint(0, len(if_ts_files))
if_ts_filename = if_ts_files[rand_idx]
print(f'Randomly selected file: {if_ts_filename}')

if_rocket_feature = fit_transform(ts_filename=[if_ts_filename],
                                  field='current',
                                  smooth=True,
                                  smooth_ws=15,
                                  tolist=False,
                                  verbo=True)

[32m2024-05-25 14:24:23.633[0m | [34m[1mDEBUG   [0m | [36mrocket_rag.utils[0m:[36mfit_transform[0m:[36m149[0m - [34m[1mExtract the time series data points[0m
[32m2024-05-25 14:24:23.657[0m | [1mINFO    [0m | [36mrocket_rag.utils[0m:[36mfit[0m:[36m125[0m - [1mtime series extracted SUCCESSFULLY.[0m
[32m2024-05-25 14:24:23.658[0m | [34m[1mDEBUG   [0m | [36mrocket_rag.utils[0m:[36mfit_transform[0m:[36m161[0m - [34m[1mRocket transforming...[0m


Randomly selected file: ../data/inference/-40kg\spalling5\spalling5_-40_3_2.csv


### Use sklearn to do the retrieval

In [23]:
def parse_result(res: Union[List[Any], np.ndarray]):
    tar_str = res.tolist()[0] if isinstance(res, np.ndarray) else res
    return re.match(r'(.*?)_', tar_str).group(1)

In [95]:
knn = KNeighborsClassifier(n_neighbors=1, weights='distance', metric='euclidean')
knn.fit(rocket_features, doc_ids)
knn_pred = knn.predict(if_rocket_feature)
dist, ids = knn.kneighbors(if_rocket_feature, n_neighbors=5, return_distance=True)
ids_to_doc = {i: doc_ids[i] for i in range(len(doc_ids))}
print(dist.squeeze().tolist())
# print(ids_to_doc[ids.squeeze().tolist()])
print([ids_to_doc[i] for i in ids.tolist()[0]])
# print(parse_result(knn_pred))

[2.913731003851641, 3.2295925336847713, 3.265172098148915, 3.329911209755798, 3.5833409196557535]
['spalling5_-40_3_4', 'spalling6_-40_4_3', 'spalling4_-40_5_4', 'spalling4_-40_2_1', 'spalling4_-40_6_2']


In [92]:
ids.tolist()[0]

[374, 418, 344, 328, 347]

In [20]:
ridge = RidgeClassifier()
ridge.fit(rocket_features, doc_ids)
ridge_pred = ridge.predict(if_rocket_feature)
print(parse_result(ridge_pred))

backlash2


### Use handcraft coding for KNN similarity computing

In [21]:
rocket_features.shape

(520, 20000)

In [22]:
euclidean = [np.linalg.norm(rf - if_rocket_feature.squeeze()) for rf in rocket_features]
cosine = [np.dot(rf, if_rocket_feature.squeeze()) / (np.linalg.norm(rf) * np.linalg.norm(if_rocket_feature.squeeze())) for rf in rocket_features]

In [23]:
sorted([(euclidean[i], doc_ids[i]) for i in range(len(doc_ids))], key=lambda x: x[0])[:3]

[(7.6432402002574165, 'backlash2_-40_8_4'),
 (9.13494425676515, 'backlash2_-40_7_2'),
 (9.529327858850607, 'backlash2_-40_9_5')]

In [24]:
sorted([(cosine[i], doc_ids[i]) for i in range(len(doc_ids))], key=lambda x: x[0], reverse=True)[:3]

[(0.9988390402632548, 'backlash2_-40_8_4'),
 (0.9986860760662098, 'backlash2_-40_9_5'),
 (0.9985502045686941, 'backlash2_-40_7_2')]

### Compute the accuracy of the handcraft KNN result
#### Top-1 accuracy
20kg: 63.0769%

40kg: 75.1938%

-40kg: 46.1538%

#### Top-5 accuracy
20kg: 83.0769%

40kg: 91.4729%

-40kg: 81.5385%

In [27]:
if_rocket_features = fit_transform(ts_filename=if_ts_files,
                                   field='current',
                                   smooth=True,
                                   smooth_ws=15,
                                   tolist=False,
                                   verbo=False)

In [28]:
euclideans = [[(np.linalg.norm(rocket_features[i] - if_rf), doc_ids[i]) 
                for i in range(len(rocket_features))] 
                for if_rf in if_rocket_features]

In [459]:
# euclideans[:5]

In [29]:
retrieve_res = np.array([(parse_result(sorted(e)[0][1])) for e in euclideans])
retrieve_res[:5]

array(['spalling4', 'normal', 'spalling4', 'spalling1', 'spalling4'],
      dtype='<U16')

In [30]:
k = 5
top_k_retri_res = []
for e in euclideans:
    e_top_3 = sorted(e, key=lambda x: x[0])[:k]
    top_k_retri_res.append([parse_result(i[1]) for i in e_top_3])
top_3_retri_res = np.array(top_k_retri_res)
top_3_retri_res[:5]

array([['spalling4', 'normal', 'spalling4', 'spalling4', 'normal'],
       ['normal', 'spalling1', 'normal', 'spalling5', 'spalling4'],
       ['spalling4', 'spalling4', 'spalling4', 'spalling6', 'spalling4'],
       ['spalling1', 'spalling4', 'spalling4', 'normal', 'normal'],
       ['spalling4', 'normal', 'spalling5', 'normal', 'spalling6']],
      dtype='<U16')

In [31]:
labels = []
for f in if_ts_files:
    raw_label = re.search(r'(.*).csv', os.path.basename(f)).group(1)
    label = re.match(r'^(.*?)_', raw_label).group(1)
    labels.append(label)
labels = np.array(labels)
labels[:5]

array(['normal', 'normal', 'normal', 'normal', 'normal'], dtype='<U16')

In [32]:
print(f'Accuracy: {accuracy_score(labels, retrieve_res)*100: .4f}%')

Accuracy:  46.1538%


In [33]:
is_in = [labels[i] in top_k_retri_res[i] for i in range(len(labels))]
print(f'Top {k} accuracy: {(sum(bool(x) for x in is_in) / len(labels))*100: .4f}%')

Top 5 accuracy:  81.5385%
