Skip to content

Commit

Permalink
Add save and load
Browse files Browse the repository at this point in the history
  • Loading branch information
wannaphong committed Apr 26, 2024
1 parent 13a6c7c commit 5d46987
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 5 deletions.
58 changes: 56 additions & 2 deletions notebooks/test_gzip_classify.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,60 @@
"source": [
"model.predict(\"ฉันดีใจ\", k=1)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "5a97f0d3",
"metadata": {},
"outputs": [],
"source": [
"model.save(\"d.model\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "6e183243",
"metadata": {},
"outputs": [],
"source": [
"model2 = pythainlp.classify.param_free.GzipModel(model_path=\"d.model\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "b30af6f0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Positive'"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model2.predict(x1=\"ฉันดีใจ\", k=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3e72a33b",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3.8.13 ('base')",
"language": "python",
"name": "python3"
},
Expand All @@ -78,7 +127,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.8.13"
},
"vscode": {
"interpreter": {
"hash": "a1d6ff38954a1cdba4cf61ffa51e42f4658fc35985cd256cd89123cae8466a39"
}
}
},
"nbformat": 4,
Expand Down
22 changes: 19 additions & 3 deletions pythainlp/classify/param_free.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import gzip
from typing import List, Tuple
import numpy as np
import json


class GzipModel:
Expand All @@ -16,9 +17,12 @@ class GzipModel:
:param list training_data: list [(text_sample,label)]
"""

def __init__(self, training_data: List[Tuple[str, str]]):
self.training_data = np.array(training_data)
self.Cx2_list = self.train()
def __init__(self, training_data: List[Tuple[str, str]]=None, model_path=None):
if model_path!=None:
self.load(model_path)
else:
self.training_data = np.array(training_data)
self.Cx2_list = self.train()

def train(self):
Cx2_list = []
Expand Down Expand Up @@ -72,3 +76,15 @@ def predict(self, x1: str, k: int = 1) -> str:
predict_class = top_k_class[counts.argmax()]

return predict_class

def save(self, path: str):
with open(path, "w") as f:
json.dump({
"training_data": self.training_data.tolist(), "Cx2_list":self.Cx2_list
}, f, ensure_ascii=False)

def load(self, path: str):
with open(path, "r") as f:
data = json.load(f)
self.Cx2_list = data["Cx2_list"]
self.training_data = np.array(data["training_data"])

0 comments on commit 5d46987

Please sign in to comment.