In [54]:
from typing import Optional, List
from sklearn.feature_extraction.text import TfidfVectorizer
import pickle
import os
import codecs
from sklearn.linear_model import LogisticRegression
import numpy as np

In [78]:
class ZamokClassifier:
    def __init__(self, load_path: Optional[str] = None) -> None:
        if load_path:
            with open(load_path+'.vectorizer.pk', 'rb') as fin:
                self.vectorizer = pickle.load(fin)
            with open(load_path+'.model.pk', 'rb') as fin:
                self.log_model = pickle.load(fin)
        else:
            self.vectorizer = TfidfVectorizer(max_features=100)
            self.log_model = LogisticRegression(random_state=0)
            
    def train(self, texts1: List[str], texts2: List[str]) -> None:
        # castles_tokens = [token for text in texts1 for token in text.split()]
        # locks_tokens = [token for text in texts2 for token in text.split()]
        self.vectorizer.fit(texts1+texts2)
        x_castles = self.vectorizer.transform(texts1).toarray()
        x_locks = self.vectorizer.transform(texts2).toarray()
        X = np.concatenate([x_castles, x_locks], axis=0)
        y = np.concatenate([[1]*len(x_castles), [0]*len(x_locks)])
        self.log_model.fit(X, y)
    
    def save(self, path: str) -> None:
        with open(path+'.vectorizer.pk', 'wb') as fin:
            pickle.dump(self.vectorizer, fin)
        with open(path+'.model.pk', 'wb') as fin:
            pickle.dump(self.log_model, fin)
        
    
    def predict(self, text: str) -> str:
        emb = self.vectorizer.transform([text])
        class_ = self.log_model.predict(emb)
        if class_:
            ans = 'YES'
        else:
            ans = 'NO'
        return ans
        

In [79]:
model = ZamokClassifier()

In [80]:
castles = []
for filename in os.listdir("castles/"):
    if filename[::-1][:4] == 'txt.':
        fileObj = codecs.open("castles/"+filename, 'r', "utf_8_sig")
        txt = fileObj.readlines()
        castles.extend(txt)
        fileObj.close()
        
locks = []
for filename in os.listdir("locks/"):
    if filename[::-1][:4] == 'txt.':
        fileObj = codecs.open("locks/"+filename, 'r', "utf_8_sig")
        txt = fileObj.readlines()
        locks.extend(txt)
        fileObj.close()

In [81]:
len(castles)

329

In [82]:
len(locks)

284

In [83]:
model.train(castles, locks)

In [84]:
np.concatenate([x1,x2], axis=0).shape

(613, 100)

In [85]:
np.concatenate([[0]*len(x1), [1]*len(x2)]).shape

(613,)

In [86]:
model.save('test_model')

In [87]:
model1 = ZamokClassifier('test_model')

In [88]:
model1.predict(locks[101])

'NO'