-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update base.py #55
Update base.py #55
Conversation
Привет! Тесты не проходят по той причине, что при создании 'pyitlib==0.2.2' авторы неправильно залили пакет, забыв поставить запятую. Попробуйте в requirements заменить версию 0.2.2 на 0.2.3 и проверить тесты еще раз. |
bamt/networks/base.py
Outdated
future = pool.submit(worker, node) | ||
self.distributions[node.name] = future.result() | ||
|
||
results = Parallel(n_jobs=-1)(delayed(worker)(node) for node in self.nodes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
В идеале, конечно, хотелось бы, чтобы n_jobs был не захардкоженный на -1, а передавался в метод при вызове, чтобы пользователь сам мог задать количество потоков, которое он может выделить на эту задачу, к примеру:
bn.fit_parameters(..., n_jobs = n)
Для этого можно просто отредактировать входы метода:
def fit_parameters(self, data: pd.DataFrame, dropna: bool = True, n_jobs: int = -1):
а в этом месте сделать так
results = Parallel(n_jobs=-1)(delayed(worker)(node) for node in self.nodes) | |
results = Parallel(n_jobs=n_jobs)(delayed(worker)(node) for node in self.nodes) |
bamt/networks/base.py
Outdated
@@ -7,6 +7,7 @@ | |||
import numpy as np | |||
import json | |||
import os | |||
from joblib import Parallel, delayed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Импорты желательно отредактировать в соответствии с pep8, быстрее всего использовать autopep8, чтобы импорты были упорядочены и смотрелись стройнее
bamt/networks/base.py
Outdated
@@ -7,6 +7,7 @@ | |||
import numpy as np | |||
import json | |||
import os | |||
from joblib import Parallel, delayed | |||
|
|||
from tqdm import tqdm | |||
from concurrent.futures import ThreadPoolExecutor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Этот импорт было бы хорошо выпилить, потому что он больше не используется, раз Вы используете joblib теперь
from concurrent.futures import ThreadPoolExecutor |
bamt/networks/base.py
Outdated
# pool = ThreadPoolExecutor(3) | ||
# for node in self.nodes: | ||
# future = pool.submit(worker, node) | ||
# self.distributions[node.name] = future.result() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Закоменченный код лучше убрать, когда все тесты пройдут нормально.
# pool = ThreadPoolExecutor(3) | |
# for node in self.nodes: | |
# future = pool.submit(worker, node) | |
# self.distributions[node.name] = future.result() |
1. версия pyitlib==0.2.2 -> pyitlib==0.2.3 в requirements.txt 2. импорт ThreadPoolExecutor заменен на joblib 3. удалено старое распараллеливание 4. в networks/base.py fit_parameters добавлен параметр n_jobs 5. код отредактирован с помощью autopep8
Изменено распараллеливание обучения параметров БС. Вместо ThreadPoolExecutor использован joblib Parallel.