# Распределенное обучение классических моделей

Поговорим про то, как решать задачу машинного обучения, когда самих наблюдений очень много и они все не помещаются на машину.

Центральная идея во всех алгоритмах - параллельно на нескольких машинах посчитать частичные элементы, которые требуются для принятия решения, передать их на центральную машину и сделать шаг алгоритма.

Для обучения линейных моделей на различных машинах будем считать градиент и на главной машине делать шаг градиентного спуска.

Для деревьев решений на различных машинах будем считать распределение по корзинкам (по бинам) и на главной машине будет определять порог для определенного признака.

### Распределенное обучение VW

Vowpal Wabbit также умеет работать распределенно, что делает его универсальным инструментом для обучения линейных моделей на больших данных. Для работы он использует дополнительный компонент - `spanning_tree` - это специальный процесс, который координирует работу различных воркеров между собой.

Про него можно также думать, как про корневую вершину в алгоритме "Tree Allreduce", который используется для эффективной утилизации сети при обучении.

Чтобы иметь возможность использовать `spanning_tree`, необходимо собрать VW руками.


Собирем VW. Делать это нужно с суперпользоателя, поэтому удобнее всего запускать из терминала.

```bash
apt update && \
apt install git psmisc -y && \
apt install libboost-dev libboost-program-options-dev libboost-system-dev libboost-thread-dev libboost-math-dev libboost-test-dev zlib1g-dev cmake g++ -y 


wget https://github.com/google/flatbuffers/archive/v1.12.0.tar.gz && \
tar -xzf v1.12.0.tar.gz && \
cd flatbuffers-1.12.0 && \
mkdir build_dir && \
cd build_dir && \
cmake -G "Unix Makefiles" -DFLATBUFFERS_BUILD_TESTS=Off -DFLATBUFFERS_INSTALL=On -DCMAKE_BUILD_TYPE=Release DFLATBUFFERS_BUILD_FLATHASH=Off .. && \
make install -j$(nproc) && \
cd ../..

git clone --recursive https://github.com/VowpalWabbit/vowpal_wabbit.git && \
cd vowpal_wabbit && \
sudo make && \
cd build && \
sudo make install -j$(nproc)
```

**Хозяйке на заметку** Чтобы получить рутовый доступ с кластера в Azure через Jupyter можно открыть терминал и по ssh подключиться к пользователю `azureuser`. Текущий пользователь `spark` к сожалению имеет очень мало прав.

```bash
ssh azureuser@localhost
sudo su
```

In [1]:
%%writefile install_vw.sh

sudo apt update -y
sudo apt install git psmisc -y 
sudo apt install libboost-dev libboost-program-options-dev libboost-system-dev libboost-thread-dev libboost-math-dev libboost-test-dev zlib1g-dev cmake g++ -y 

wget https://github.com/google/flatbuffers/archive/v1.12.0.tar.gz && \
    tar -xzf v1.12.0.tar.gz && \
    cd flatbuffers-1.12.0 && \
    mkdir build_dir && \
    cd build_dir && \
    cmake -G "Unix Makefiles" -DFLATBUFFERS_BUILD_TESTS=Off -DFLATBUFFERS_INSTALL=On -DCMAKE_BUILD_TYPE=Release DFLATBUFFERS_BUILD_FLATHASH=Off .. && \
    make install -j$(nproc) && \
    cd ../..
    
git clone --recursive https://github.com/VowpalWabbit/vowpal_wabbit.git && \
    cd vowpal_wabbit && \
    git checkout d1ead9a0a9afd56d2ee11a72e0c1aaa7702ee281 && \
    sudo make && \
    cd build && \
    sudo make install -j$(nproc)

Writing install_vw.sh


In [2]:
! bash install_vw.sh

Hit:1 http://mirror.yandex.ru/ubuntu focal InRelease
Get:2 http://mirror.yandex.ru/ubuntu focal-updates InRelease [114 kB]          [0m
Hit:3 http://mirror.yandex.ru/ubuntu focal-backports InRelease                 
Hit:4 http://mirror.yandex.ru/mirrors/postgresql focal-pgdg InRelease          
Hit:5 http://dataproc.storage.yandexcloud.net/ci/trunk/225-54615002560eee21 focal InRelease
Get:6 https://repos.influxdata.com/ubuntu focal InRelease [7,046 B]            [0m
Get:7 http://security.ubuntu.com/ubuntu focal-security InRelease [114 kB][33m  [0m[33m
Get:8 http://mirror.yandex.ru/ubuntu focal-updates/main i386 Packages [945 kB] [0m[33m
Get:9 http://mirror.yandex.ru/ubuntu focal-updates/main amd64 Packages [3,155 kB]0m[33m
Get:10 http://mirror.yandex.ru/ubuntu focal-updates/main Translation-en [504 kB][0m[33m
Get:11 http://mirror.yandex.ru/ubuntu focal-updates/restricted amd64 Packages [2,755 kB]
Get:12 http://mirror.yandex.ru/ubuntu focal-updates/restricted Translation-en [385


7[0;23r8[1ASelecting previously unselected package libboost1.71-dev:amd64.
(Reading database ... 132312 files and directories currently installed.)
Preparing to unpack .../00-libboost1.71-dev_1.71.0-6ubuntu6_amd64.deb ...
7[24;0f[42m[30mProgress: [  0%][49m[39m [..........................................................] 8Unpacking libboost1.71-dev:amd64 (1.71.0-6ubuntu6) ...
7[24;0f[42m[30mProgress: [  2%][49m[39m [#.........................................................] 8Selecting previously unselected package libboost-atomic1.71.0:amd64.
Preparing to unpack .../01-libboost-atomic1.71.0_1.71.0-6ubuntu6_amd64.deb ...
7[24;0f[42m[30mProgress: [  3%][49m[39m [#.........................................................] 8Unpacking libboost-atomic1.71.0:amd64 (1.71.0-6ubuntu6) ...
7[24;0f[42m[30mProgress: [  4%][49m[39m [##........................................................] 8Selecting previously unselected package libboost-atomic1.71-dev:amd64.
Pre

7[24;0f[42m[30mProgress: [ 88%][49m[39m [###################################################.......] 87[24;0f[42m[30mProgress: [ 89%][49m[39m [###################################################.......] 8Setting up libboost-date-time1.71-dev:amd64 (1.71.0-6ubuntu6) ...
7[24;0f[42m[30mProgress: [ 90%][49m[39m [####################################################......] 87[24;0f[42m[30mProgress: [ 91%][49m[39m [####################################################......] 8Setting up libboost-thread1.71-dev:amd64 (1.71.0-6ubuntu6) ...
7[24;0f[42m[30mProgress: [ 92%][49m[39m [#####################################################.....] 87[24;0f[42m[30mProgress: [ 93%][49m[39m [#####################################################.....] 8Setting up libboost-system-dev:amd64 (1.71.0.0ubuntu2) ...
7[24;0f[42m[30mProgress: [ 94%][49m[39m [######################################################....] 87[24;0f[42m[30mProgress: [ 95%][49m[39m [##

remote: Enumerating objects: 12542, done.        
remote: Counting objects: 100% (169/169), done.        
remote: Compressing objects: 100% (169/169), done.        
remote: Total 12542 (delta 87), reused 0 (delta 0), pack-reused 12373        
Receiving objects: 100% (12542/12542), 34.38 MiB | 21.03 MiB/s, done.
Resolving deltas: 100% (8520/8520), done.
Cloning into '/home/ubuntu/lsml-2024/flatbuffers-1.12.0/build_dir/vowpal_wabbit/ext_libs/boost_math'...
remote: Enumerating objects: 109334, done.        
remote: Counting objects: 100% (5114/5114), done.        
remote: Compressing objects: 100% (1601/1601), done.        
remote: Total 109334 (delta 3473), reused 4767 (delta 3220), pack-reused 104220        
Receiving objects: 100% (109334/109334), 175.31 MiB | 25.41 MiB/s, done.
Resolving deltas: 100% (83408/83408), done.
Cloning into '/home/ubuntu/lsml-2024/flatbuffers-1.12.0/build_dir/vowpal_wabbit/ext_libs/eigen'...
remote: Enumerating objects: 124163, done.        
remote: Counting

-- Found Git: /usr/bin/git (found version "2.25.1") 
-- Git Version: d1ead9a0a
-- Number of processors: 4
-- Looking for pthread.h
-- Looking for pthread.h - found
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD - Failed
-- Check if compiler accepts -pthread
-- Check if compiler accepts -pthread - yes
-- Found Threads: TRUE  
-- Found Boost: /usr/lib/x86_64-linux-gnu/cmake/Boost-1.71.0/BoostConfig.cmake (found version "1.71.0") found components: program_options system 
-- Found ZLIB: /usr/lib/x86_64-linux-gnu/libz.so (found version "1.2.11") 
-- Submodule update
Submodule path 'rapidjson': checked out 'f54b0e47a08782a6131cc3d60f94d038fa6e0a51'
Submodule path 'rapidjson/thirdparty/gtest': checked out '0a439623f75c029912728d80cb7f1b8b48739ca4'
-- help2man not found, please install it to generate manpages
-- Found Boost: /usr/lib/x86_64-linux-gnu/cmake/Boost-1.71.0/BoostConfig.cmake (found version "1.71.0") found components: unit_test_framework 
-- Co

[ 52%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/lda_core.cc.o[0m
[ 53%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/learner.cc.o[0m
[ 53%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/log_multi.cc.o[0m
[ 55%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/loss_functions.cc.o[0m
[ 55%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/lrq.cc.o[0m
[ 56%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/lrqfa.cc.o[0m
[ 57%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/marginal.cc.o[0m
[ 57%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/memory_tree.cc.o[0m
[ 59%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/mf.cc.o[0m
[ 59%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/multiclass.cc.o[0m
[ 60%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/multilabel_oaa.cc.o[0m
[ 60%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/multilabel.cc.o[0m
[ 62%] [32mBuilding CX

[ 76%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/print.cc.o[0m
[ 76%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/prob_dist_cont.cc.o[0m
[ 78%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/rand48.cc.o[0m
[ 78%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/recall_tree.cc.o[0m
[ 79%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/sample_pdf.cc.o[0m
[ 79%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/scorer.cc.o[0m
[ 81%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/search_dep_parser.cc.o[0m
[ 82%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/search_entityrelationtask.cc.o[0m
[ 82%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/search_graph.cc.o[0m
[ 84%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/search_hooktask.cc.o[0m
[ 84%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/search_meta.cc.o[0m
[ 85%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/se

[ 88%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/simple_label.cc.o[0m
[ 89%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/slates_label.cc.o[0m
[ 89%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/slates.cc.o[0m
[ 91%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/stagewise_poly.cc.o[0m
[ 91%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/svrg.cc.o[0m
[ 92%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/tag_utils.cc.o[0m
[ 94%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/topk.cc.o[0m
[ 94%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/unique_sort.cc.o[0m
[ 95%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/version.cc.o[0m
[ 95%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/vw_exception.cc.o[0m
[ 97%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/vw_validate.cc.o[0m
[ 97%] [32mBuilding CXX object vowpalwabbit/CMakeFiles/vw.dir/warm_cb.cc.o[0m
[ 98%] [32m[1

[ 97%] [32mBuilding CXX object test/unit_test/CMakeFiles/vw-unit-test.out.dir/slates_test.cc.o[0m
[ 97%] [32mBuilding CXX object test/unit_test/CMakeFiles/vw-unit-test.out.dir/stable_unique_tests.cc.o[0m
[ 98%] [32mBuilding CXX object test/unit_test/CMakeFiles/vw-unit-test.out.dir/tag_utils_test.cc.o[0m
[ 98%] [32mBuilding CXX object test/unit_test/CMakeFiles/vw-unit-test.out.dir/test_common.cc.o[0m
[ 99%] [32mBuilding CXX object test/unit_test/CMakeFiles/vw-unit-test.out.dir/tokenize_tests.cc.o[0m
[ 99%] [32mBuilding CXX object test/unit_test/CMakeFiles/vw-unit-test.out.dir/weights_test.cc.o[0m
[100%] [32mBuilding CXX object test/unit_test/CMakeFiles/vw-unit-test.out.dir/vwdll_test.cc.o[0m
[100%] [32m[1mLinking CXX executable vw-unit-test.out[0m
[100%] Built target vw-unit-test.out
[36mInstall the project...[0m
-- Install configuration: "Release"
-- Installing: /usr/local/include/rapidjson
-- Installing: /usr/local/include/rapidjson/stream.h
-- Installing: /usr/loca

In [22]:
! which vw

/usr/local/bin/vw


In [23]:
! which spanning_tree

/usr/local/bin/spanning_tree


In [3]:
! wget https://archive.ics.uci.edu/ml/machine-learning-databases/00462/drugsCom_raw.zip
! unzip drugsCom_raw.zip

--2024-03-13 20:21:31--  https://archive.ics.uci.edu/ml/machine-learning-databases/00462/drugsCom_raw.zip
Resolving archive.ics.uci.edu (archive.ics.uci.edu)... 128.195.10.252
Connecting to archive.ics.uci.edu (archive.ics.uci.edu)|128.195.10.252|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified
Saving to: ‘drugsCom_raw.zip.1’

drugsCom_raw.zip.1      [               <=>  ]  41.00M  10.8MB/s    in 3.8s    

2024-03-13 20:21:36 (10.8 MB/s) - ‘drugsCom_raw.zip.1’ saved [42989872]

Archive:  drugsCom_raw.zip
replace drugsComTest_raw.tsv? [y]es, [n]o, [A]ll, [N]one, [r]ename: ^C


In [4]:
! hdfs dfs -ls /user

Found 6 items
drwxr-xr-x   - ubuntu hadoop          0 2024-02-25 12:26 /user/airbnb
drwxr-xr-x   - hive   hadoop          0 2024-02-08 06:57 /user/hive
drwxr-xr-x   - ubuntu hadoop          0 2024-02-08 07:22 /user/pokemons
drwxr-xr-x   - ubuntu hadoop          0 2024-02-22 08:25 /user/spark-example
drwxr-xr-x   - ubuntu hadoop          0 2024-02-22 07:31 /user/tweets
drwxr-xr-x   - ubuntu hadoop          0 2024-02-22 07:21 /user/ubuntu


In [24]:
! hdfs dfs -rm -r /user/drugs/data || true
! hdfs dfs -mkdir -p /user/drugs/data

Deleted /user/drugs/data


In [25]:
! hdfs dfs -ls /user/drugs

Found 4 items
drwxr-xr-x   - ubuntu hadoop          0 2024-03-13 22:58 /user/drugs/data
drwxr-xr-x   - ubuntu hadoop          0 2024-03-13 20:25 /user/drugs/part1.vw
drwxr-xr-x   - ubuntu hadoop          0 2024-03-13 20:25 /user/drugs/part2.vw
drwxr-xr-x   - ubuntu hadoop          0 2024-03-13 20:25 /user/drugs/test.vw


Выгрузим датасет с препаратами.

In [26]:
%%bash

cat drugsComTrain_raw.tsv <(tail -n +2 drugsComTest_raw.tsv) | hdfs dfs -put - /user/drugs/data/drugs.tsv

In [27]:
! hdfs dfs -ls  -h /user/drugs/data

Found 1 items
-rw-r--r--   1 ubuntu hadoop    107.2 M 2024-03-13 22:59 /user/drugs/data/drugs.tsv


In [1]:
import findspark
findspark.init()

In [2]:
import pyspark
sc = pyspark.SparkContext(appName="lsml-app-1")

SLF4J: Class path contains multiple SLF4J bindings.
SLF4J: Found binding in [jar:file:/usr/lib/spark/jars/slf4j-log4j12-1.7.30.jar!/org/slf4j/impl/StaticLoggerBinder.class]
SLF4J: Found binding in [jar:file:/usr/lib/hadoop/lib/slf4j-log4j12-1.7.25.jar!/org/slf4j/impl/StaticLoggerBinder.class]
SLF4J: See http://www.slf4j.org/codes.html#multiple_bindings for an explanation.
SLF4J: Actual binding is of type [org.slf4j.impl.Log4jLoggerFactory]
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
2024-03-13 22:59:54,925 WARN util.Utils: spark.executor.instances less than spark.dynamicAllocation.minExecutors is invalid, ignoring its setting, please update your configs.
2024-03-13 22:59:59,813 WARN util.Utils: spark.executor.instances less than spark.dynamicAllocation.minExecutors is invalid, ignoring its setting, please update your configs.
2024-03-13 22:59:59,823 WARN cluster.YarnSchedulerBackend$YarnSchedulerEndpo

In [3]:
from pyspark.sql import SparkSession, Row
se = SparkSession(sc)

In [4]:
from pyspark.sql import functions as F
from datetime import datetime
import re

In [5]:
data = se.read.option("delimiter", "\t").csv('/user/drugs/data/*', header=True, inferSchema=True)

                                                                                

In [6]:
data.limit(10).toPandas()

Unnamed: 0,_c0,drugName,condition,review,rating,date,usefulCount
0,206461,Valsartan,Left Ventricular Dysfunction,"""""""It has no side effect, I take it in combina...",9.0,"May 20, 2012",27.0
1,95260,Guanfacine,ADHD,"""""""My son is halfway through his fourth week o...",,,
2,We have tried many different medications and s...,8.0,"April 27, 2010",192,,,
3,92703,Lybrel,Birth Control,"""""""I used to take another oral contraceptive, ...",,,
4,The positive side is that I didn&#039;t have a...,5.0,"December 14, 2009",17,,,
5,138000,Ortho Evra,Birth Control,"""""""This is my first time using any form of bir...",8.0,"November 3, 2015",10.0
6,35696,Buprenorphine / naloxone,Opiate Dependence,"""""""Suboxone has completely turned my life arou...",9.0,"November 27, 2016",37.0
7,155963,Cialis,Benign Prostatic Hyperplasia,"""""""2nd day on 5mg started to work with rock ha...",2.0,"November 28, 2015",43.0
8,165907,Levonorgestrel,Emergency Contraception,"""""""He pulled out, but he cummed a bit in me. I...",1.0,"March 7, 2017",5.0
9,102654,Aripiprazole,Bipolar Disorde,"""""""Abilify changed my life. There is hope. I w...",10.0,"March 14, 2015",32.0


Мы будем запускать 2 воркера. Поэтмоу разделим весь датасет на 3 части - 2 равные для воркером и 1 маленькую часть для теста.

In [7]:
part1, part2, test = (
    data
    .na.drop('any')
    .randomSplit([0.45, 0.45, 0.1], 422)
)

Соберем датасет на спарке

In [9]:
def convert_to_vw(data):
    target = data['usefulCount']
    
    drug_name = data['drugName'].lower().replace(' ', '_')
    condition = data['condition'].lower().replace(' ', '_')
    
    raw_text = data['review'].lower()
    word_pattern = re.compile(r"[a-zA-Z0-9_]+")
    words = [match.group(0) for match in re.finditer(word_pattern, raw_text)]
    review = ' '.join(words)
    
    rating = data['rating']
    
    weekday = datetime.strptime(data['date'], '%B %d, %Y').weekday()
    
    template = "{target} |d {drug_name} |c {condition} |r {review} |w {weekday} |s rating:{rating}"
    return template.format(
        target=target,
        drug_name=drug_name,
        condition=condition,
        review=review,
        weekday=weekday,
        rating=rating
    )

In [10]:
! hdfs dfs -rm -r /user/drugs/*.vw

Deleted /user/drugs/part1.vw
Deleted /user/drugs/part2.vw
Deleted /user/drugs/test.vw


In [14]:
part1.rdd.map(convert_to_vw).saveAsTextFile('/user/drugs/part1.vw')
part2.rdd.map(convert_to_vw).saveAsTextFile('/user/drugs/part2.vw')
test.rdd.map(convert_to_vw).saveAsTextFile('/user/drugs/test.vw')

In [15]:

! hdfs dfs -cat /user/drugs/part1.vw/* > train.part1.vw
! hdfs dfs -cat /user/drugs/part2.vw/* > train.part2.vw
! hdfs dfs -cat /user/drugs/test.vw/* > test.vw

Посмотрим, какие результаты мы получим, если просто запустим VW на всем файле.

In [16]:
! cat train.*.vw > train.full.vw

In [17]:
import numpy as np
from sklearn.metrics import r2_score


def calc_r2(predictions_filename, answers_filename):
    def read_target_from_vw(vw_record):
        return float(vw_record.split(' ')[0])
    
    with open(predictions_filename, 'r') as f:
        y_pred = np.array([float(value) for value in f.readlines()])
        
    with open(answers_filename, 'r') as f:
        y_expected = np.array([read_target_from_vw(value) for value in f.readlines()])
        
    return r2_score(y_expected, y_pred)

In [18]:
! vw --help | head

Num weight bits = 18
learning rate = 0.5
initial_t = 0
power_t = 0.5
using no cache
Reading datafile = 
num sources = 1
driver:
  --onethread           Disable parse thread
VW options:
  --ring_size arg (=256, ) size of example ring
  --strict_parse           throw on malformed examples
Update options:
  -l [ --learning_rate ] arg Set learning rate
  --power_t arg              t power value
  --decay_learning_rate arg  Set Decay factor for learning_rate between passes
  --initial_t arg            initial t value


Обучаем VW на одном файле целиком

In [19]:
%%time

! vw --final_regressor drugs.model.bin train.full.vw \
    --onethread \
    --learning_rate 20.0 \
    --bit_precision 23 \
    --passes 40 \
    --ngram r2 \
    --interactions dc \
    --cache -k

Generating 2-grams for r namespaces.
creating features for following interactions: dc 
final_regressor = drugs.model.bin
Num weight bits = 23
learning rate = 20
initial_t = 0
power_t = 0.5
decay_learning_rate = 1
creating cache_file = train.full.vw.cache
Reading datafile = train.full.vw
num sources = 1
Enabled reductions: gd, scorer
average  since         example        example  current  current  current
loss     last          counter         weight    label  predict features
16.000000 16.000000            1            1.0   4.0000   0.0000      155
8.284978 0.569956            2            2.0   1.0000   1.7550      207
5.384596 2.484213            4            4.0   0.0000   1.6859      249
3.444507 1.504419            8            8.0   1.0000   0.1637       45
110.780388 218.116268           16           16.0   3.0000   7.7153      139
75.235660 39.690933           32           32.0   2.0000   4.0961      151
51.536644 27.837627           64           64.0   2.0000   1.9416       7

In [20]:
! vw --testonly --initial_regressor drugs.model.bin --predictions drugs.preductions.txt test.vw

Generating 2-grams for r namespaces.
creating features for following interactions: dc 
only testing
predictions = drugs.preductions.txt
Num weight bits = 23
learning rate = 0.5
initial_t = 0
power_t = 0.5
using no cache
Reading datafile = test.vw
num sources = 1
Enabled reductions: gd, scorer
average  since         example        example  current  current  current
loss     last          counter         weight    label  predict features
0.447677 0.447677            1            1.0   2.0000   2.6691      317
2.223839 4.000000            2            2.0   2.0000   0.0000      237
4.300955 6.378071            4            4.0  10.0000  11.9381      163
25.360267 46.419579            8            8.0   1.0000  14.3987       59
17.799924 10.239580           16           16.0   2.0000   9.0071      285
24.017554 30.235185           32           32.0   7.0000   3.3096      125
158.437527 292.857500           64           64.0  10.0000  44.5076      175
336.991100 515.544673          128     

In [21]:
calc_r2('drugs.preductions.txt', 'test.vw')

0.6526252786437012

Обучили модель на **0.65** за **30** секунд.

Посмотрим, что будет если мы обучим модель только на части данных

In [22]:
%%time

! vw --final_regressor drugs.model.bin train.part1.vw \
    --onethread \
    --learning_rate 20.0 \
    --bit_precision 23 \
    --passes 40 \
    --ngram r2 \
    --interactions dc \
    --cache -k

Generating 2-grams for r namespaces.
creating features for following interactions: dc 
final_regressor = drugs.model.bin
Num weight bits = 23
learning rate = 20
initial_t = 0
power_t = 0.5
decay_learning_rate = 1
creating cache_file = train.part1.vw.cache
Reading datafile = train.part1.vw
num sources = 1
Enabled reductions: gd, scorer
average  since         example        example  current  current  current
loss     last          counter         weight    label  predict features
16.000000 16.000000            1            1.0   4.0000   0.0000      155
8.284978 0.569956            2            2.0   1.0000   1.7550      207
5.384596 2.484213            4            4.0   0.0000   1.6859      249
3.444507 1.504419            8            8.0   1.0000   0.1637       45
110.780388 218.116268           16           16.0   3.0000   7.7153      139
75.235660 39.690933           32           32.0   2.0000   4.0961      151
51.536644 27.837627           64           64.0   2.0000   1.9416      

In [23]:
! vw --testonly --initial_regressor drugs.model.bin --predictions drugs.preductions.txt test.vw

Generating 2-grams for r namespaces.
creating features for following interactions: dc 
only testing
predictions = drugs.preductions.txt
Num weight bits = 23
learning rate = 0.5
initial_t = 0
power_t = 0.5
using no cache
Reading datafile = test.vw
num sources = 1
Enabled reductions: gd, scorer
average  since         example        example  current  current  current
loss     last          counter         weight    label  predict features
1054.792969 1054.792969            1            1.0   2.0000  34.4776      317
529.396484 4.000000            2            2.0   2.0000   0.0000      237
351.250557 173.104630            4            4.0  10.0000  28.3633      163
214.487823 77.725090            8            8.0   1.0000  16.5387       59
140.830627 67.173430           16           16.0   2.0000  21.5064      285
101.366076 61.901526           32           32.0   7.0000   3.7109      125
202.674485 303.982894           64           64.0  10.0000  32.8632      175
677.614774 1152.555063  

In [24]:

calc_r2('drugs.preductions.txt', 'test.vw')

0.38178390321607647

Гораздо быстрее обучились, но потеряли в качестве. 

Модель на **0.38** за **6** секунд

**Мораль** - семплирование не самых удачный подход, чтобы получать качество, нужно засовывать в модель вообще все данные.

Запустим в фоновом режиме `spanning_tree` и проверим что он правда работает.

Далее воркеры будут подключаться к нему по tcp.

In [26]:
%%bash --bg --out OUT --err ERR
spanning_tree --nondaemon

In [27]:
! ps aux | grep spanning_tree

ubuntu     15643  0.0  0.0   6068  1496 ?        S    23:05   0:00 spanning_tree --nondaemon
ubuntu     15644  0.0  0.0   9492  3216 pts/0    Ss+  23:05   0:00 /bin/bash -c  ps aux | grep spanning_tree
ubuntu     15646  0.0  0.0   9032   656 pts/0    S+   23:05   0:00 grep spanning_tree


Пора запускать рабочих. Для этого используется уже известная команда vw, в которую просто добавляются специальные параметры

* `--span_server` - указываем адрес, где находится менеджер (spanning_tree). В нашем случае это localhost. В реальной жизни там мог бы быть IP адрес другой машины
* `--unique_id` - так как один spanning_tree может обрабатывать сразу много различных процессов обучения, то необходимо их как-то разграничить. Для этого используется unique_id - это число, которое должно быть одинаковым для всех ваших рабочих, чтобы их не перепутали с другими. Например ваш коллега также обучает VW но для другой задачи - он может подключить свои VW к этому же spanning_tree указав для них unique_id = 0. В таком случае вам, чтобы подключиться, нужно запускать свои рабочие например с unique_id = 5, чтобы они не смешались с рабочими вашего коллеги.
* `--total` - число рабочих, которое вы планируете подключить в текущей сессии обучения
* `--node` - идентификатор текущего рабочего. Нумерация начинается с нуля, поэтому если вы хотите запустить 3 рабочих, то им нужно выдать значения для --node 0, 1 и 2.
* `-d` - данные для обработки для текущего рабочего
Все остальные параметры обучения должны быть одинаковыми для всех рабочих.

Чтобы сохранить коэффициенты полученной модели, необходимо для какого-то одного рабочего указать через `-f` или `--final_regressor` файл, куда записать результат. Точно также, как мы это делали в предыдущей лабораторной.

Запустим двух рабочих. Первого запустим также в фоне, а вот второй запустим прямо в ноутбуке и будем следить за процессом обучения.

In [28]:
%%bash --bg --out OUT --err ERR

vw -d train.part1.vw \
    --span_server localhost \
    --total 2 \
    --node 0 \
    --unique_id 1 \
    --learning_rate 20.0 \
    --bit_precision 23 \
    --passes 40 \
    --ngram r2 \
    --interactions dc \
    --cache -k

In [29]:
%%time

! vw -d train.part2.vw \
    --span_server localhost \
    --total 2 \
    --node 1 \
    --unique_id 1 \
    --learning_rate 20.0 \
    --bit_precision 23 \
    --passes 40 \
    --ngram r2 \
    --interactions dc \
    --cache -k \
    -f drugs.model.bin

Generating 2-grams for r namespaces.
creating features for following interactions: dc 
final_regressor = drugs.model.bin
Num weight bits = 23
learning rate = 20
initial_t = 0
power_t = 0.5
decay_learning_rate = 1
creating cache_file = train.part2.vw.cache
Reading datafile = train.part2.vw
num sources = 1
Enabled reductions: gd, scorer
average  since         example        example  current  current  current
loss     last          counter         weight    label  predict features
4.000000 4.000000            1            1.0   2.0000   0.0000      207
9.166723 14.333446            2            2.0   4.0000   0.2140       61
42.704962 76.243202            4            4.0   0.0000   6.1127      235
53.751544 64.798126            8            8.0  18.0000   2.1770      203
53.708574 53.665604           16           16.0  24.0000   6.9774      155
40.367511 27.026447           32           32.0   3.0000   6.5818      273
45.832557 51.297603           64           64.0   2.0000   8.2354     

In [30]:
! vw --testonly --initial_regressor drugs.model.bin --predictions drugs.preductions.txt test.vw

Generating 2-grams for r namespaces.
creating features for following interactions: dc 
only testing
predictions = drugs.preductions.txt
Num weight bits = 23
learning rate = 0.5
initial_t = 0
power_t = 0.5
using no cache
Reading datafile = test.vw
num sources = 1
Enabled reductions: gd, scorer
average  since         example        example  current  current  current
loss     last          counter         weight    label  predict features
2.849555 2.849555            1            1.0   2.0000   3.6881      317
3.424778 4.000000            2            2.0   2.0000   0.0000      237
6.356657 9.288537            4            4.0  10.0000  13.0947      163
34.073244 61.789832            8            8.0   1.0000  16.2863       59
32.563332 31.053420           16           16.0   2.0000  16.5415      285
31.656058 30.748784           32           32.0   7.0000   3.2724      125
82.980499 134.304941           64           64.0  10.0000  43.8365      175
333.861668 584.742837          128      

In [31]:
calc_r2('drugs.preductions.txt', 'test.vw')

0.6496397151076883

Качество получилось даже немного больше, чем при одиночном запуске.

Сильного ускорения по времени мы не увидели, потому что мы все это запускаем на одной машине. Однако если запускать эти воркеры на разных машинах и на больших объемах данных, то можно увидеть сильное ускорение процесса обучения.

И основное достижение этого алгоритма - теперь мы можем размещать данные по нескольким машинам, что позволяет нам теоретически обработать датасет произвольного размера.

### Пробуем запуститься на нескольких машинах руками

In [98]:
%%writefile run_install_tmux.sh
sudo apt install tmux <<< y

Overwriting run_install_tmux.sh


In [32]:
%%writefile run_node_prepare.sh

# sudo apt update -y
# sudo apt install git psmisc -y 
# sudo apt install libboost-dev libboost-program-options-dev libboost-system-dev libboost-thread-dev libboost-math-dev libboost-test-dev zlib1g-dev cmake g++ -y 

# wget https://github.com/google/flatbuffers/archive/v1.12.0.tar.gz && \
#     tar -xzf v1.12.0.tar.gz && \
#     cd flatbuffers-1.12.0 && \
#     mkdir build_dir && \
#     cd build_dir && \
#     cmake -G "Unix Makefiles" -DFLATBUFFERS_BUILD_TESTS=Off -DFLATBUFFERS_INSTALL=On -DCMAKE_BUILD_TYPE=Release DFLATBUFFERS_BUILD_FLATHASH=Off .. && \
#     make install -j$(nproc) && \
#     cd ../..
    
# git clone --recursive https://github.com/VowpalWabbit/vowpal_wabbit.git && \
#     cd vowpal_wabbit && \
#     git checkout d1ead9a0a9afd56d2ee11a72e0c1aaa7702ee281 && \
#     sudo make && \
#     cd build && \
#     sudo make install -j$(nproc)

hdfs dfs -cat /user/drugs/part${NODE_NUMBER}.vw/* > train.vw

Overwriting run_node_prepare.sh


In [33]:
%%writefile run_node_train.sh

vw -d train.vw \
    --span_server ${MASTER_NODE} \
    --total 2 \
    --node ${NODE_NUMBER} \
    --unique_id 1 \
    --learning_rate 20.0 \
    --bit_precision 23 \
    --passes 40 \
    --ngram r2 \
    --interactions dc \
    --cache -k

Overwriting run_node_train.sh


In [34]:
!chmod 777 run_install_tmux.sh
!chmod 777 run_node_prepare.sh
!chmod 777 run_node_train.sh

In [37]:
!pip install plumbum

Collecting plumbum
  Downloading plumbum-1.8.2-py3-none-any.whl (127 kB)
[K     |████████████████████████████████| 127 kB 1.3 MB/s eta 0:00:01
[?25hInstalling collected packages: plumbum
Successfully installed plumbum-1.8.2


In [92]:
%%writefile distribute.py

import sys
import plumbum

masternode = 'rc1a-dataproc-m-zkmve6o6f689a4so.mdb.yandexcloud.net'
datanodes = ['rc1a-dataproc-d-t299z6uw2n0sko0q.mdb.yandexcloud.net']

file_to_distribute = sys.argv[1]
plumbum.local['scp'][file_to_distribute][f'{datanodes[0]}:~/.'] & plumbum.FG
plumbum.local['ssh'][datanodes[0]][f'MASTER_NODE={masternode} NODE_NUMBER={0} ./{file_to_distribute}'] & plumbum.FG

Overwriting distribute.py


In [35]:
%%writefile distribute_tmux.py

import sys
import plumbum

masternode = 'rc1a-dataproc-m-zkmve6o6f689a4so.mdb.yandexcloud.net'
datanodes = ['rc1a-dataproc-d-t299z6uw2n0sko0q.mdb.yandexcloud.net']

file_to_distribute = sys.argv[1]
plumbum.local['scp'][file_to_distribute][f'{datanodes[0]}:~/.'] & plumbum.FG
cmd = plumbum.local['ssh']['-A'][datanodes[0]]\
[f'tmux new-session -d -s remote_run "MASTER_NODE={masternode} NODE_NUMBER={1} ./{file_to_distribute}"']\
& plumbum.FG

Overwriting distribute_tmux.py


In [36]:
! vw -d train.part2.vw \
    --span_server rc1a-dataproc-m-zkmve6o6f689a4so.mdb.yandexcloud.net \
    --total 2 \
    --node 0 \
    --unique_id 1 \
    --learning_rate 20.0 \
    --bit_precision 23 \
    --passes 40 \
    --ngram r2 \
    --interactions dc \
    --cache -k \
    -f drugs.model.bin

Generating 2-grams for r namespaces.
creating features for following interactions: dc 
final_regressor = drugs.model.bin
Num weight bits = 23
learning rate = 20
initial_t = 0
power_t = 0.5
decay_learning_rate = 1
creating cache_file = train.part2.vw.cache
Reading datafile = train.part2.vw
num sources = 1
Enabled reductions: gd, scorer
average  since         example        example  current  current  current
loss     last          counter         weight    label  predict features
4.000000 4.000000            1            1.0   2.0000   0.0000      207
9.166723 14.333446            2            2.0   4.0000   0.2140       61
42.704962 76.243202            4            4.0   0.0000   6.1127      235
53.751544 64.798126            8            8.0  18.0000   2.1770      203
53.708574 53.665604           16           16.0  24.0000   6.9774      155
40.367511 27.026447           32           32.0   3.0000   6.5818      273
45.832557 51.297603           64           64.0   2.0000   8.2354     

### VW на Hadoop

VW достаточно несложно запустить в виде обычной MapReduce задачи. Для этого даже есть готовый скрипт, который написан авторами инструмента. 

Почитать про то, как запускать этот инструмент на Hadoop можно вот здесь - https://github.com/VowpalWabbit/vowpal_wabbit/tree/master/cluster .

Мы же с вами более внимательно рассмотрим более удобный интерфейс для распределенного обучения VW на кластере.

### SynapseML

Существует целый набор библиотек для Spark от Microsoft, который позволяет удобно и быстро запускать распределенные алгоритмы на кластере Spark. Про все возможности можно почитать на официальном GitHub - https://github.com/microsoft/SynapseML

Мы с вами воспользуемся двумя инструментами оттуда - VW и LightGBM (градиентный бустинг).


Чтобы поставить SynapseML в окружение с lyvi , достаточно просто переконфигурировать сессию спарка.

In [1]:
import findspark
findspark.init()

In [2]:
import pyspark
se = pyspark.sql.SparkSession.builder.appName("MyApp2") \
            .config("spark.jars.packages", "com.microsoft.azure:synapseml_2.12:0.9.5") \
            .config("spark.dynamicAllocation.enabled", False) \
            .config("spark.locality.wait", 0) \
            .getOrCreate()


SLF4J: Class path contains multiple SLF4J bindings.
SLF4J: Found binding in [jar:file:/usr/lib/spark/jars/slf4j-log4j12-1.7.30.jar!/org/slf4j/impl/StaticLoggerBinder.class]
SLF4J: Found binding in [jar:file:/usr/lib/hadoop/lib/slf4j-log4j12-1.7.25.jar!/org/slf4j/impl/StaticLoggerBinder.class]
SLF4J: See http://www.slf4j.org/codes.html#multiple_bindings for an explanation.
SLF4J: Actual binding is of type [org.slf4j.impl.Log4jLoggerFactory]
Ivy Default Cache set to: /home/ubuntu/.ivy2/cache
The jars for the packages stored in: /home/ubuntu/.ivy2/jars
:: loading settings :: url = jar:file:/usr/lib/spark/jars/ivy-2.4.0.jar!/org/apache/ivy/core/settings/ivysettings.xml
com.microsoft.azure#synapseml_2.12 added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-78573325-8ced-4f3a-92c7-3410d7d583db;1.0
	confs: [default]
	found com.microsoft.azure#synapseml_2.12;0.9.5 in central
	found com.microsoft.azure#synapseml-core_2.12;0.9.5 in central
	found org.scalactic#

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
2024-03-13 23:25:37,230 WARN yarn.Client: Same path resource file:///home/ubuntu/.ivy2/jars/com.microsoft.azure_synapseml_2.12-0.9.5.jar added multiple times to distributed cache.
2024-03-13 23:25:37,230 WARN yarn.Client: Same path resource file:///home/ubuntu/.ivy2/jars/com.microsoft.azure_synapseml-core_2.12-0.9.5.jar added multiple times to distributed cache.
2024-03-13 23:25:37,230 WARN yarn.Client: Same path resource file:///home/ubuntu/.ivy2/jars/com.microsoft.azure_synapseml-deep-learning_2.12-0.9.5.jar added multiple times to distributed cache.
2024-03-13 23:25:37,230 WARN yarn.Client: Same path resource file:///home/ubuntu/.ivy2/jars/com.microsoft.azure_synapseml-cognitive_2.12-0.9.5.jar added multiple times to distributed cache.
2024-03-13 23:25:37,230 WARN yarn.Client: Same path resource file:///home/ubuntu/.ivy2/jars/com.microsoft.azure_synapseml

In [3]:
! cd /home/ubuntu/.ivy2/jars && \
    cp io.netty_netty-transport-native-epoll-4.1.68.Final-linux-x86_64.jar io.netty_netty-transport-native-epoll-4.1.68.Final.jar && \
    cp io.netty_netty-transport-native-kqueue-4.1.68.Final-osx-x86_64.jar io.netty_netty-transport-native-kqueue-4.1.68.Final.jar && \
    cp io.netty_netty-resolver-dns-native-macos-4.1.68.Final-osx-x86_64.jar io.netty_netty-resolver-dns-native-macos-4.1.68.Final.jar

In [3]:
from pyspark.sql.functions import when, col
from pyspark.ml import Pipeline
from synapse.ml.vw import VowpalWabbitFeaturizer, VowpalWabbitRegressor

In [4]:
data = se.read.option("delimiter", "\t").csv('/user/drugs/data/*', header=True, inferSchema=True)

                                                                                

In [5]:
data.limit(10).toPandas()

Unnamed: 0,_c0,drugName,condition,review,rating,date,usefulCount
0,206461,Valsartan,Left Ventricular Dysfunction,"""""""It has no side effect, I take it in combina...",9.0,"May 20, 2012",27.0
1,95260,Guanfacine,ADHD,"""""""My son is halfway through his fourth week o...",,,
2,We have tried many different medications and s...,8.0,"April 27, 2010",192,,,
3,92703,Lybrel,Birth Control,"""""""I used to take another oral contraceptive, ...",,,
4,The positive side is that I didn&#039;t have a...,5.0,"December 14, 2009",17,,,
5,138000,Ortho Evra,Birth Control,"""""""This is my first time using any form of bir...",8.0,"November 3, 2015",10.0
6,35696,Buprenorphine / naloxone,Opiate Dependence,"""""""Suboxone has completely turned my life arou...",9.0,"November 27, 2016",37.0
7,155963,Cialis,Benign Prostatic Hyperplasia,"""""""2nd day on 5mg started to work with rock ha...",2.0,"November 28, 2015",43.0
8,165907,Levonorgestrel,Emergency Contraception,"""""""He pulled out, but he cummed a bit in me. I...",1.0,"March 7, 2017",5.0
9,102654,Aripiprazole,Bipolar Disorde,"""""""Abilify changed my life. There is hope. I w...",10.0,"March 14, 2015",32.0


In [6]:
data.columns

['_c0', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount']

In [7]:
columns = [
    '_c0',
    'd',
    'c',
    'r',
    'rating',
    'data',
    'target',
]
df = data.toDF(*columns)
df.printSchema()

root
 |-- _c0: string (nullable = true)
 |-- d: string (nullable = true)
 |-- c: string (nullable = true)
 |-- r: string (nullable = true)
 |-- rating: string (nullable = true)
 |-- data: string (nullable = true)
 |-- target: integer (nullable = true)



In [8]:
train, test = (
    df
    .na.drop('any')
    .randomSplit([0.9, 0.1], 422)
)

In [9]:
train.limit(20).toPandas()

                                                                                

Unnamed: 0,_c0,d,c,r,rating,data,target
0,10,Medroxyprogesterone,Abnormal Uterine Bleeding,"""""""I&#039;m 17 years old and I got shot in Aug...",7.0,"October 20, 2015",2
1,1000,Everolimus,Breast Cance,"""""""Although the medication did effectively tre...",2.0,"March 15, 2016",4
2,10000,Lo Loestrin Fe,Birth Control,"""""""I was on this birth control for 8 months. T...",7.0,"April 10, 2013",4
3,100004,Desogestrel / ethinyl estradiol,Birth Control,"""""""I have been taking Azurette for 3 years now...",8.0,"July 11, 2017",1
4,100007,Desogestrel / ethinyl estradiol,Birth Control,"""""""At the beginning, Kariva seemed to be worki...",5.0,"June 29, 2017",0
5,100008,Desogestrel / ethinyl estradiol,Birth Control,"""""""This is yet another update. Just finished m...",1.0,"June 28, 2017",0
6,100009,Desogestrel / ethinyl estradiol,Birth Control,"""""""I was on reclipsen for less than 2 weeks an...",1.0,"June 28, 2017",2
7,100011,Desogestrel / ethinyl estradiol,Birth Control,"""""""I have been on Apri for 4 months, I&#039;m ...",8.0,"June 19, 2017",3
8,100012,Desogestrel / ethinyl estradiol,Birth Control,"""""""I&#039;ve been taking Velivet for about a y...",9.0,"June 17, 2017",2
9,100013,Desogestrel / ethinyl estradiol,Birth Control,"""""""Gives me heartburn and indigestion. Also ma...",2.0,"June 8, 2017",1


Создадим объект для создания признаков в формате VW. Он принимает dataframe и возвращает dataframe но уже с новой колонкой, в которой записаны эти признаки

In [10]:
vw_featurizer = VowpalWabbitFeaturizer(
    inputCols=["rating"], 
    stringSplitInputCols=["d", "c", "r"],
    outputCol="features",
    numBits=24
)

In [11]:
x = vw_featurizer.transform(train).rdd.first()
x['features']

                                                                                

SparseVector(16777216, {380313: 1.0, 615218: 2.0, 1162472: 5.0, 1279788: 1.0, 1349603: 1.0, 1944935: 1.0, 2310638: 1.0, 2555950: 1.0, 2562337: 1.0, 2601746: 1.0, 3107187: 1.0, 3346639: 11.0, 3416701: 1.0, 3446374: 1.0, 3615423: 1.0, 4099010: 1.0, 4121377: 1.0, 4415074: 1.0, 5317406: 1.0, 5728618: 1.0, 5881332: 1.0, 6051161: 1.0, 6696771: 2.0, 6866455: 1.0, 7608613: 1.0, 7636861: 2.0, 8202109: 1.0, 8315460: 2.0, 8336163: 3.0, 8717045: 1.0, 8943791: 1.0, 9261333: 1.0, 9523050: 3.0, 9787552: 1.0, 9845063: 2.0, 9878045: 1.0, 9999245: 1.0, 10090473: 1.0, 10189708: 2.0, 10204651: 1.0, 10410939: 1.0, 11169916: 1.0, 11269461: 1.0, 11318998: 1.0, 11658255: 1.0, 12043885: 1.0, 12082180: 1.0, 12343490: 1.0, 12501165: 3.0, 12730453: 2.0, 12741825: 3.0, 12892703: 1.0, 12951244: 1.0, 13357553: 1.0, 13515522: 1.0, 13601167: 1.0, 13735132: 7.0, 13956072: 1.0, 14355936: 1.0, 14380379: 1.0, 14465432: 1.0, 14768116: 1.0, 14866460: 1.0, 15200139: 1.0, 15384876: 1.0, 15571315: 1.0, 15639492: 1.0, 16385598:

Создадим объект для обучения классификатора. Схема работы точно такая же - принимает на вход dataframe и потом может модифицировать другой dataframe, делая предсказание.

In [13]:
args = "--learning_rate 20.0 --bit_precision 24 --ngram r2 --interactions dc"
vw_model = VowpalWabbitRegressor(
    featuresCol="features",
    labelCol="target",
    args=args,
    numPasses=40
)

Соберем их в единый пайплайн

In [15]:
vw_pipeline = Pipeline(stages=[vw_featurizer, vw_model])

In [16]:
vw_trained = vw_pipeline.fit(train)

inbound connection from 10.128.0.12(rc1a-dataproc-d-t299z6uw2n0sko0q.mdb.yandexcloud.net:12190) serv=53228
10.128.0.12(rc1a-dataproc-d-t299z6uw2n0sko0q.mdb.yandexcloud.net:12190): nonce=2051155050
10.128.0.12(rc1a-dataproc-d-t299z6uw2n0sko0q.mdb.yandexcloud.net:12190): total=2
10.128.0.12(rc1a-dataproc-d-t299z6uw2n0sko0q.mdb.yandexcloud.net:12190): node id=1


nonce 2051155050 still waiting for 1 nodes out of 2 for example node 0


inbound connection from 10.128.0.12(rc1a-dataproc-d-t299z6uw2n0sko0q.mdb.yandexcloud.net:12190) serv=53236
10.128.0.12(rc1a-dataproc-d-t299z6uw2n0sko0q.mdb.yandexcloud.net:12190): nonce=2051155050
10.128.0.12(rc1a-dataproc-d-t299z6uw2n0sko0q.mdb.yandexcloud.net:12190): total=2
10.128.0.12(rc1a-dataproc-d-t299z6uw2n0sko0q.mdb.yandexcloud.net:12190): node id=0
                                                                                

In [17]:
prediction = vw_trained.transform(test)

Generating 2-grams for r namespaces.
creating features for following interactions: dc 
only testing
Num weight bits = 24
learning rate = 0.5
initial_t = 0
power_t = 0.5
using no cache
Reading datafile = 
num sources = 1


In [18]:
prediction.limit(10).toPandas()

                                                                                

Unnamed: 0,_c0,d,c,r,rating,data,target,features,rawPrediction,prediction
0,100,Medroxyprogesterone,Birth Control,"""""""Depo was not for me, but that does not mean...",5.0,"August 17, 2015",2,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",12.627239,12.627239
1,100002,Desogestrel / ethinyl estradiol,Birth Control,"""""""I have been taking this birth control for f...",4.0,"July 12, 2017",2,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.0,0.0
2,100017,Desogestrel / ethinyl estradiol,Birth Control,"""""""Love it! Been continuously dosing without b...",9.0,"May 30, 2017",3,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.0,0.0
3,10002,Lo Loestrin Fe,Birth Control,"""""""Well, I&#039;ve been on this right now for ...",6.0,"April 4, 2013",10,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",20.083971,20.083971
4,100021,Desogestrel / ethinyl estradiol,Birth Control,"""""""So I&#039;ve been on this birth control a l...",5.0,"May 18, 2017",3,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",2.882582,2.882582
5,100107,Desogestrel / ethinyl estradiol,Birth Control,"""""""completely lost every bit of sex drive i ha...",5.0,"November 23, 2016",1,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.0,0.0
6,100123,Desogestrel / ethinyl estradiol,Birth Control,"""""""I haven&#039;t taken any BC in 5 years sinc...",1.0,"November 2, 2016",2,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",10.435595,10.435595
7,100126,Desogestrel / ethinyl estradiol,Birth Control,"""""""I have used Cerelle for 9 days I feel very...",10.0,"October 28, 2016",1,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",20.759893,20.759893
8,100165,Desogestrel / ethinyl estradiol,Birth Control,"""""""I&#039;ve had a pretty good experience with...",8.0,"March 15, 2016",1,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.0,0.0
9,100167,Desogestrel / ethinyl estradiol,Birth Control,"""""""I had previously been on Microgestin which ...",10.0,"March 11, 2016",10,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",15.034412,15.034412


In [19]:
from synapse.ml.train import ComputeModelStatistics
metrics = ComputeModelStatistics(
    evaluationMetric='regression',
    labelCol='target',
    scoresCol='prediction'
).transform(prediction)

                                                                                

In [20]:
metrics.toPandas()

Unnamed: 0,mean_squared_error,root_mean_squared_error,R^2,mean_absolute_error
0,745.794131,27.309232,0.451299,16.695151


### SparkML

Нужно отметить, что в стандартной библиотеке Spark присутствует модуль для машинного обучения.

**ОДНАКО** нужно сказать, что работает он крайне плохо. Лучшее, что вы можете с ним сделать - это попробовать один раз его запустить и понять, что больше никогда не будете его использовать.

Это правда важно, потому что это не звучит слишком убедительно, что стандартная библиотека для ML насколько уж плохо работет и наверное все таки есть случаи, когда она работает хорошо, правда ведь? Ответ - вполне возможно. Чтобы вам самим понять, есть ли такие случаи, попробуйте самостоятельно что-то обучить на SparkML и прочувствуйте границы применимости :)

In [1]:
import findspark
findspark.init()

In [2]:
import pyspark
sc = pyspark.SparkContext(appName="lsml-app-1")

SLF4J: Class path contains multiple SLF4J bindings.
SLF4J: Found binding in [jar:file:/usr/lib/spark/jars/slf4j-log4j12-1.7.30.jar!/org/slf4j/impl/StaticLoggerBinder.class]
SLF4J: Found binding in [jar:file:/usr/lib/hadoop/lib/slf4j-log4j12-1.7.25.jar!/org/slf4j/impl/StaticLoggerBinder.class]
SLF4J: See http://www.slf4j.org/codes.html#multiple_bindings for an explanation.
SLF4J: Actual binding is of type [org.slf4j.impl.Log4jLoggerFactory]
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
2024-03-13 23:34:22,015 WARN util.Utils: spark.executor.instances less than spark.dynamicAllocation.minExecutors is invalid, ignoring its setting, please update your configs.
2024-03-13 23:34:27,096 WARN util.Utils: spark.executor.instances less than spark.dynamicAllocation.minExecutors is invalid, ignoring its setting, please update your configs.
2024-03-13 23:34:27,107 WARN cluster.YarnSchedulerBackend$YarnSchedulerEndpo

In [3]:
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.regression import LinearRegression
from pyspark.ml.feature import HashingTF, IDF, Tokenizer
from pyspark.ml import Pipeline
from pyspark.ml.feature import OneHotEncoder, StringIndexer, VectorAssembler

In [4]:
from pyspark.sql import SparkSession, Row

se = SparkSession(sc)

In [5]:
data = se.read.option("delimiter", "\t").csv('/user/drugs/data/*', header=True, inferSchema=True)

                                                                                

In [6]:
data = (
    data
    .na.drop('any')
    .withColumn('ratingNum', data.rating.cast('integer'))
)


train, test = data.randomSplit([0.9, 0.1], 422)
train, test = train.cache(), test.cache()

In [7]:
tokenizer = Tokenizer(inputCol="review", outputCol="words")
wordsData = tokenizer.transform(train)

hashingTF = HashingTF(inputCol="words", outputCol="rawFeatures", numFeatures=2**23)
featurizedData = hashingTF.transform(wordsData)
idf = IDF(inputCol="rawFeatures", outputCol="features")
idfModel = idf.fit(featurizedData)

rescaledData = idfModel.transform(featurizedData)

                                                                                

In [8]:
rescaledData.limit(5).toPandas()

2024-03-13 23:37:29,262 WARN scheduler.DAGScheduler: Broadcasting large task binary with size 128.1 MiB
                                                                                

Unnamed: 0,_c0,drugName,condition,review,rating,date,usefulCount,ratingNum,words,rawFeatures,features
0,10,Medroxyprogesterone,Abnormal Uterine Bleeding,"""""""I&#039;m 17 years old and I got shot in Aug...",7.0,"October 20, 2015",2,7,"[""""""i&#039;m, 17, years, old, and, i, got, sho...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
1,1000,Everolimus,Breast Cance,"""""""Although the medication did effectively tre...",2.0,"March 15, 2016",4,2,"[""""""although, the, medication, did, effectivel...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
2,10000,Lo Loestrin Fe,Birth Control,"""""""I was on this birth control for 8 months. T...",7.0,"April 10, 2013",4,7,"[""""""i, was, on, this, birth, control, for, 8, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
3,100004,Desogestrel / ethinyl estradiol,Birth Control,"""""""I have been taking Azurette for 3 years now...",8.0,"July 11, 2017",1,8,"[""""""i, have, been, taking, azurette, for, 3, y...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
4,100007,Desogestrel / ethinyl estradiol,Birth Control,"""""""At the beginning, Kariva seemed to be worki...",5.0,"June 29, 2017",0,5,"[""""""at, the, beginning,, kariva, seemed, to, b...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."


In [9]:
stringIndexer = StringIndexer(inputCol='drugName', outputCol = "drugIndex").setHandleInvalid("skip")
encoder = OneHotEncoder(inputCol="drugIndex", outputCol="drugVec")

pipeline = Pipeline(stages=[stringIndexer, encoder])
ohe = pipeline.fit(rescaledData).transform(rescaledData)

                                                                                

In [10]:
x = ohe.limit(1).rdd.first()
x

2024-03-13 23:38:59,391 WARN scheduler.DAGScheduler: Broadcasting large task binary with size 128.4 MiB
2024-03-13 23:39:00,594 WARN scheduler.DAGScheduler: Broadcasting large task binary with size 128.3 MiB
                                                                                

Row(_c0='10', drugName='Medroxyprogesterone', condition='Abnormal Uterine Bleeding', review='"""I&#039;m 17 years old and I got shot in August 2015, personally. I don&#039;t mind it. I mean, I bleed little bits and random times, but I&#039;d rather have the blood that&#039;s supposed to come out, come out and not worry about where it&#039;s going or staying in my body. I have my other injection in November on the 2nd, and I&#039;m still wondering if I could take it again. The only downside to the injection is that I gained access weight and I&#039;m kind of moody."""', rating='7.0', date='October 20, 2015', usefulCount=2, ratingNum=7, words=['"""i&#039;m', '17', 'years', 'old', 'and', 'i', 'got', 'shot', 'in', 'august', '2015,', 'personally.', 'i', 'don&#039;t', 'mind', 'it.', 'i', 'mean,', 'i', 'bleed', 'little', 'bits', 'and', 'random', 'times,', 'but', 'i&#039;d', 'rather', 'have', 'the', 'blood', 'that&#039;s', 'supposed', 'to', 'come', 'out,', 'come', 'out', 'and', 'not', 'worry',

In [11]:
x['drugVec']

SparseVector(3492, {13: 1.0})

Подготавливаем признаки

In [17]:
wordsData = tokenizer.transform(train)

tokenizer = Tokenizer(inputCol="review", outputCol="words")
hashingTF = HashingTF(inputCol="words", outputCol="rawFeatures", numFeatures=2**10)
idf = IDF(inputCol="rawFeatures", outputCol="revviewFeatures")

stringIndexerCondition = StringIndexer(inputCol='condition', outputCol = "conditionIndex").setHandleInvalid("skip")
encoderCondition = OneHotEncoder(inputCol="conditionIndex", outputCol="conditionVec")

stringIndexerDrug = StringIndexer(inputCol='drugName', outputCol = "drugIndex").setHandleInvalid("skip")
encoderDrug = OneHotEncoder(inputCol="drugIndex", outputCol="drugVec")

assembler = VectorAssembler(inputCols=["drugVec", "conditionVec", "revviewFeatures", 'ratingNum'], outputCol="features")

preproc = Pipeline(stages=[
    tokenizer,
    hashingTF,
    idf,
    stringIndexerCondition,
    encoderCondition,
    stringIndexerDrug,
    encoderDrug,
    assembler
])

In [14]:
train_proc = preproc.fit(train).transform(train).cache()

                                                                                

In [16]:
train_proc.limit(10).toPandas()

Unnamed: 0,_c0,drugName,condition,review,rating,date,usefulCount,ratingNum,words,rawFeatures,revviewFeatures,conditionIndex,conditionVec,drugIndex,drugVec,features
0,10,Medroxyprogesterone,Abnormal Uterine Bleeding,"""""""I&#039;m 17 years old and I got shot in Aug...",7.0,"October 20, 2015",2,7,"[""""""i&#039;m, 17, years, old, and, i, got, sho...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",14.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",13.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
1,1000,Everolimus,Breast Cance,"""""""Although the medication did effectively tre...",2.0,"March 15, 2016",4,2,"[""""""although, the, medication, did, effectivel...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",83.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",1322.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
2,10000,Lo Loestrin Fe,Birth Control,"""""""I was on this birth control for 8 months. T...",7.0,"April 10, 2013",4,7,"[""""""i, was, on, this, birth, control, for, 8, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.0,"(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",35.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
3,100004,Desogestrel / ethinyl estradiol,Birth Control,"""""""I have been taking Azurette for 3 years now...",8.0,"July 11, 2017",1,8,"[""""""i, have, been, taking, azurette, for, 3, y...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.0,"(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",57.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
4,100007,Desogestrel / ethinyl estradiol,Birth Control,"""""""At the beginning, Kariva seemed to be worki...",5.0,"June 29, 2017",0,5,"[""""""at, the, beginning,, kariva, seemed, to, b...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.0,"(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",57.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
5,100008,Desogestrel / ethinyl estradiol,Birth Control,"""""""This is yet another update. Just finished m...",1.0,"June 28, 2017",0,1,"[""""""this, is, yet, another, update., just, fin...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.0,"(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",57.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
6,100009,Desogestrel / ethinyl estradiol,Birth Control,"""""""I was on reclipsen for less than 2 weeks an...",1.0,"June 28, 2017",2,1,"[""""""i, was, on, reclipsen, for, less, than, 2,...","(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(3.3244338839699816, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.0,"(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",57.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
7,100011,Desogestrel / ethinyl estradiol,Birth Control,"""""""I have been on Apri for 4 months, I&#039;m ...",8.0,"June 19, 2017",3,8,"[""""""i, have, been, on, apri, for, 4, months,, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.0,"(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",57.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
8,100012,Desogestrel / ethinyl estradiol,Birth Control,"""""""I&#039;ve been taking Velivet for about a y...",9.0,"June 17, 2017",2,9,"[""""""i&#039;ve, been, taking, velivet, for, abo...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.0,"(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",57.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
9,100013,Desogestrel / ethinyl estradiol,Birth Control,"""""""Gives me heartburn and indigestion. Also ma...",2.0,"June 8, 2017",1,2,"[""""""gives, me, heartburn, and, indigestion., a...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.0,"(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",57.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."


In [None]:
tokenizer = Tokenizer(inputCol="review", outputCol="words")
hashingTF = HashingTF(inputCol="words", outputCol="rawFeatures", numFeatures=2**23)
idf = IDF(inputCol="rawFeatures", outputCol="revviewFeatures")

preproc = Pipeline(stages=[
    tokenizer,
    hashingTF,
    idf
])

Запустить сбор признаков, который написан, у вас скорее всего не получится. Поэтому попробуем урезать количество вычислений - может быть получится.

In [14]:
stringIndexerCondition = StringIndexer(inputCol='condition', outputCol = "conditionIndex").setHandleInvalid("skip")
encoderCondition = OneHotEncoder(inputCol="conditionIndex", outputCol="conditionVec")

stringIndexerDrug = StringIndexer(inputCol='drugName', outputCol = "drugIndex").setHandleInvalid("skip")
encoderDrug = OneHotEncoder(inputCol="drugIndex", outputCol="drugVec")

assembler = VectorAssembler(inputCols=["drugVec", "conditionVec", 'ratingNum'], outputCol="features")

preproc = Pipeline(stages=[
    stringIndexerCondition,
    encoderCondition,
    stringIndexerDrug,
    encoderDrug,
    assembler
])

In [15]:
preproc = preproc.fit(data)

                                                                                

In [16]:
train_proc = preproc.transform(train).cache()
test_proc = preproc.transform(test).cache()

In [17]:
train_proc.limit(10).toPandas()

                                                                                

Unnamed: 0,_c0,drugName,condition,review,rating,date,usefulCount,ratingNum,conditionIndex,conditionVec,drugIndex,drugVec,features
0,10,Medroxyprogesterone,Abnormal Uterine Bleeding,"""""""I&#039;m 17 years old and I got shot in Aug...",7.0,"October 20, 2015",2,7,14.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",14.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
1,1000,Everolimus,Breast Cance,"""""""Although the medication did effectively tre...",2.0,"March 15, 2016",4,2,83.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",1379.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
2,10000,Lo Loestrin Fe,Birth Control,"""""""I was on this birth control for 8 months. T...",7.0,"April 10, 2013",4,7,0.0,"(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",35.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
3,100004,Desogestrel / ethinyl estradiol,Birth Control,"""""""I have been taking Azurette for 3 years now...",8.0,"July 11, 2017",1,8,0.0,"(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",60.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
4,100007,Desogestrel / ethinyl estradiol,Birth Control,"""""""At the beginning, Kariva seemed to be worki...",5.0,"June 29, 2017",0,5,0.0,"(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",60.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
5,100008,Desogestrel / ethinyl estradiol,Birth Control,"""""""This is yet another update. Just finished m...",1.0,"June 28, 2017",0,1,0.0,"(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",60.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
6,100009,Desogestrel / ethinyl estradiol,Birth Control,"""""""I was on reclipsen for less than 2 weeks an...",1.0,"June 28, 2017",2,1,0.0,"(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",60.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
7,100011,Desogestrel / ethinyl estradiol,Birth Control,"""""""I have been on Apri for 4 months, I&#039;m ...",8.0,"June 19, 2017",3,8,0.0,"(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",60.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
8,100012,Desogestrel / ethinyl estradiol,Birth Control,"""""""I&#039;ve been taking Velivet for about a y...",9.0,"June 17, 2017",2,9,0.0,"(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",60.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
9,100013,Desogestrel / ethinyl estradiol,Birth Control,"""""""Gives me heartburn and indigestion. Also ma...",2.0,"June 8, 2017",1,2,0.0,"(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",60.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."


In [18]:
train_proc.rdd.first()

Row(_c0='10', drugName='Medroxyprogesterone', condition='Abnormal Uterine Bleeding', review='"""I&#039;m 17 years old and I got shot in August 2015, personally. I don&#039;t mind it. I mean, I bleed little bits and random times, but I&#039;d rather have the blood that&#039;s supposed to come out, come out and not worry about where it&#039;s going or staying in my body. I have my other injection in November on the 2nd, and I&#039;m still wondering if I could take it again. The only downside to the injection is that I gained access weight and I&#039;m kind of moody."""', rating='7.0', date='October 20, 2015', usefulCount=2, ratingNum=7, conditionIndex=14.0, conditionVec=SparseVector(896, {14: 1.0}), drugIndex=14.0, drugVec=SparseVector(3572, {14: 1.0}), features=SparseVector(4469, {14: 1.0, 3586: 1.0, 4468: 7.0}))

Если все таки удалось собрать датасет, то запускаем линейную регрессию

In [19]:
lr = LinearRegression(featuresCol='features', labelCol='usefulCount', maxIter=10, regParam=0.3, elasticNetParam=0.8)

In [20]:
lrModel = lr.fit(train_proc)

2024-03-13 23:41:56,448 WARN netlib.BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS
2024-03-13 23:41:56,450 WARN netlib.BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeRefBLAS


In [21]:
lrModel.coefficients

SparseVector(4469, {0: -1.5396, 7: 0.6016, 8: -2.2478, 11: 18.5525, 12: -2.6516, 13: 6.3993, 15: -4.2273, 17: 2.2721, 18: 2.7929, 19: -4.6926, 20: -1.8025, 22: -1.4182, 24: 1.5974, 25: 5.6574, 27: -3.6162, 28: 14.7229, 30: -2.1567, 31: -6.25, 32: 4.1213, 33: -0.733, 34: -1.3478, 36: 16.8292, 37: -2.0554, 42: -1.3523, 43: 8.9463, 44: 16.5407, 45: 7.6053, 46: -4.8076, 47: 19.9402, 49: 1.5438, 51: 21.3003, 54: -3.1937, 57: -0.3751, 59: 3.3786, 61: -7.7252, 63: 17.8873, 64: 3.3672, 68: -7.1302, 69: 4.5331, 73: 1.5355, 74: 22.0911, 75: 10.1474, 77: 12.5425, 82: 0.837, 83: 10.9366, 85: 0.745, 88: -1.6474, 89: -1.2763, 91: 3.3217, 94: -4.1866, 97: 8.6183, 101: 5.3111, 104: 7.3622, 106: 15.4524, 107: -1.2783, 109: 12.6343, 110: 21.9048, 111: 8.4787, 113: -5.7946, 116: 1.2026, 118: 15.4499, 119: 3.2374, 120: 6.8017, 122: 7.03, 123: 10.7318, 124: 26.3091, 126: -0.7743, 132: 1.4863, 139: -0.3653, 140: 3.7497, 142: -0.4101, 143: 34.2557, 145: 14.5784, 146: 14.928, 149: 5.9699, 152: 5.5148, 153: 6.

In [24]:
from pyspark.ml.evaluation import RegressionEvaluator

predictions = lrModel.transform(test_proc)

In [25]:
predictions.limit(10).toPandas()

Unnamed: 0,_c0,drugName,condition,review,rating,date,usefulCount,ratingNum,conditionIndex,conditionVec,drugIndex,drugVec,features,prediction
0,100,Medroxyprogesterone,Birth Control,"""""""Depo was not for me, but that does not mean...",5.0,"August 17, 2015",2,5,0.0,"(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",14.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",6.099772
1,100002,Desogestrel / ethinyl estradiol,Birth Control,"""""""I have been taking this birth control for f...",4.0,"July 12, 2017",2,4,0.0,"(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",60.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",3.934818
2,100017,Desogestrel / ethinyl estradiol,Birth Control,"""""""Love it! Been continuously dosing without b...",9.0,"May 30, 2017",3,9,0.0,"(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",60.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",14.759589
3,10002,Lo Loestrin Fe,Birth Control,"""""""Well, I&#039;ve been on this right now for ...",6.0,"April 4, 2013",10,6,0.0,"(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",35.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",8.264726
4,100021,Desogestrel / ethinyl estradiol,Birth Control,"""""""So I&#039;ve been on this birth control a l...",5.0,"May 18, 2017",3,5,0.0,"(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",60.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",6.099772
5,100107,Desogestrel / ethinyl estradiol,Birth Control,"""""""completely lost every bit of sex drive i ha...",5.0,"November 23, 2016",1,5,0.0,"(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",60.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",6.099772
6,100123,Desogestrel / ethinyl estradiol,Birth Control,"""""""I haven&#039;t taken any BC in 5 years sinc...",1.0,"November 2, 2016",2,1,0.0,"(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",60.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",-2.560044
7,100126,Desogestrel / ethinyl estradiol,Birth Control,"""""""I have used Cerelle for 9 days I feel very...",10.0,"October 28, 2016",1,10,0.0,"(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",60.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",16.924543
8,100165,Desogestrel / ethinyl estradiol,Birth Control,"""""""I&#039;ve had a pretty good experience with...",8.0,"March 15, 2016",1,8,0.0,"(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",60.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",12.594635
9,100167,Desogestrel / ethinyl estradiol,Birth Control,"""""""I had previously been on Microgestin which ...",10.0,"March 11, 2016",10,10,0.0,"(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",60.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",16.924543


In [26]:
lr_evaluator = RegressionEvaluator(predictionCol="prediction", labelCol="usefulCount", metricName="r2")
lr_evaluator.evaluate(predictions)

                                                                                

0.27444947176597745

С недавнего времени на Spark появился CatBoost. Давайте попробуем поиграться с этим инструментом.

In [1]:
import findspark
findspark.init()

In [2]:
import pyspark
se = pyspark.sql.SparkSession.builder.appName("MyApp2") \
            .config("spark.jars.packages", "ai.catboost:catboost-spark_3.0_2.12:1.1.1") \
            .config("spark.dynamicAllocation.enabled", False) \
            .config("spark.locality.wait", 0) \
            .getOrCreate()


SLF4J: Class path contains multiple SLF4J bindings.
SLF4J: Found binding in [jar:file:/usr/lib/spark/jars/slf4j-log4j12-1.7.30.jar!/org/slf4j/impl/StaticLoggerBinder.class]
SLF4J: Found binding in [jar:file:/usr/lib/hadoop/lib/slf4j-log4j12-1.7.25.jar!/org/slf4j/impl/StaticLoggerBinder.class]
SLF4J: See http://www.slf4j.org/codes.html#multiple_bindings for an explanation.
SLF4J: Actual binding is of type [org.slf4j.impl.Log4jLoggerFactory]
Ivy Default Cache set to: /home/ubuntu/.ivy2/cache
The jars for the packages stored in: /home/ubuntu/.ivy2/jars
:: loading settings :: url = jar:file:/usr/lib/spark/jars/ivy-2.4.0.jar!/org/apache/ivy/core/settings/ivysettings.xml
ai.catboost#catboost-spark_3.0_2.12 added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-11a8d042-9f19-480d-bbb1-35fcae4e1778;1.0
	confs: [default]
	found ai.catboost#catboost-spark_3.0_2.12;1.1.1 in central
	found org.scala-lang.modules#scala-collection-compat_2.12;2.6.0 in central
	found 

In [3]:
import catboost_spark

In [4]:
data = se.read.option("delimiter", "\t").csv('/user/drugs/data/*', header=True, inferSchema=True)

data = (
    data
    .na.drop('any')
    .withColumn('ratingNum', data.rating.cast('integer'))
)


train, test = data.randomSplit([0.9, 0.1], 422)
train, test = train.cache(), test.cache()

                                                                                

In [5]:
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.regression import LinearRegression
from pyspark.ml.feature import HashingTF, IDF, Tokenizer
from pyspark.ml import Pipeline
from pyspark.ml.feature import OneHotEncoder, StringIndexer, VectorAssembler

In [6]:
stringIndexerCondition = StringIndexer(inputCol='condition', outputCol = "conditionIndex").setHandleInvalid("skip")
encoderCondition = OneHotEncoder(inputCol="conditionIndex", outputCol="conditionVec")

stringIndexerDrug = StringIndexer(inputCol='drugName', outputCol = "drugIndex").setHandleInvalid("skip")
encoderDrug = OneHotEncoder(inputCol="drugIndex", outputCol="drugVec")

assembler = VectorAssembler(inputCols=["drugVec", "conditionVec", 'ratingNum'], outputCol="features")


preproc = Pipeline(stages=[
    stringIndexerCondition,
    encoderCondition,
    stringIndexerDrug,
    encoderDrug,
    assembler
])

In [7]:
preproc = preproc.fit(data)

                                                                                

In [8]:
train_proc = preproc.transform(train).cache()
test_proc = preproc.transform(test).cache()

In [9]:
train_proc.limit(10).toPandas()

                                                                                

Unnamed: 0,_c0,drugName,condition,review,rating,date,usefulCount,ratingNum,conditionIndex,conditionVec,drugIndex,drugVec,features
0,10,Medroxyprogesterone,Abnormal Uterine Bleeding,"""""""I&#039;m 17 years old and I got shot in Aug...",7.0,"October 20, 2015",2,7,14.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",14.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
1,1000,Everolimus,Breast Cance,"""""""Although the medication did effectively tre...",2.0,"March 15, 2016",4,2,83.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",1379.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
2,10000,Lo Loestrin Fe,Birth Control,"""""""I was on this birth control for 8 months. T...",7.0,"April 10, 2013",4,7,0.0,"(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",35.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
3,100004,Desogestrel / ethinyl estradiol,Birth Control,"""""""I have been taking Azurette for 3 years now...",8.0,"July 11, 2017",1,8,0.0,"(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",60.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
4,100007,Desogestrel / ethinyl estradiol,Birth Control,"""""""At the beginning, Kariva seemed to be worki...",5.0,"June 29, 2017",0,5,0.0,"(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",60.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
5,100008,Desogestrel / ethinyl estradiol,Birth Control,"""""""This is yet another update. Just finished m...",1.0,"June 28, 2017",0,1,0.0,"(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",60.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
6,100009,Desogestrel / ethinyl estradiol,Birth Control,"""""""I was on reclipsen for less than 2 weeks an...",1.0,"June 28, 2017",2,1,0.0,"(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",60.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
7,100011,Desogestrel / ethinyl estradiol,Birth Control,"""""""I have been on Apri for 4 months, I&#039;m ...",8.0,"June 19, 2017",3,8,0.0,"(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",60.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
8,100012,Desogestrel / ethinyl estradiol,Birth Control,"""""""I&#039;ve been taking Velivet for about a y...",9.0,"June 17, 2017",2,9,0.0,"(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",60.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
9,100013,Desogestrel / ethinyl estradiol,Birth Control,"""""""Gives me heartburn and indigestion. Also ma...",2.0,"June 8, 2017",1,2,0.0,"(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",60.0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."


In [10]:
from pyspark.sql.types import *
from pyspark.ml.linalg import Vectors, VectorUDT
from pyspark.sql import Row,SparkSession

In [11]:
srcDataSchema = [
    StructField("features", VectorUDT()),
    StructField("label", DoubleType())
]

In [12]:
train_proc.rdd.map(lambda x: Row(x.features, float(x.usefulCount))).take(1)

                                                                                

[<Row(SparseVector(4469, {14: 1.0, 3586: 1.0, 4468: 7.0}), 2.0)>]

In [13]:
trainData = train_proc.rdd.map(lambda x: Row(x.features, float(x.usefulCount)))

In [14]:
trainDf = se.createDataFrame(trainData, StructType(srcDataSchema))

In [15]:
trainDf.limit(10).toPandas()

                                                                                

Unnamed: 0,features,label
0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",2.0
1,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",4.0
2,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",4.0
3,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",1.0
4,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.0
5,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.0
6,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",2.0
7,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",3.0
8,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",2.0
9,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",1.0


In [16]:
evalData = test_proc.rdd.map(lambda x: Row(x.features, float(x.usefulCount)))

In [17]:
evalDf = se.createDataFrame(evalData, StructType(srcDataSchema))

In [18]:
evalDf.limit(10).toPandas()

                                                                                

Unnamed: 0,features,label
0,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",2.0
1,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",2.0
2,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",3.0
3,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",10.0
4,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",3.0
5,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",1.0
6,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",2.0
7,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",1.0
8,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",1.0
9,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",10.0


In [19]:
trainPool = catboost_spark.Pool(trainDf)
evalPool = catboost_spark.Pool(evalDf)

In [20]:
regressor = catboost_spark.CatBoostRegressor()

In [21]:
model = regressor.fit(trainPool, evalDatasets=[evalPool])

2024-03-13 22:38:05,891 WARN scheduler.DAGScheduler: Broadcasting large task binary with size 1130.2 KiB
2024-03-13 22:38:05,982 WARN scheduler.DAGScheduler: Broadcasting large task binary with size 1130.2 KiB
2024-03-13 22:38:47,566 WARN storage.BlockManagerMasterEndpoint: No more replicas available for rdd_29_1 !
2024-03-13 22:38:47,566 WARN storage.BlockManagerMasterEndpoint: No more replicas available for rdd_29_0 !
2024-03-13 22:38:47,566 WARN storage.BlockManagerMasterEndpoint: No more replicas available for rdd_33_0 !
2024-03-13 22:38:47,566 WARN storage.BlockManagerMasterEndpoint: No more replicas available for rdd_52_0 !
2024-03-13 22:38:47,566 WARN storage.BlockManagerMasterEndpoint: No more replicas available for rdd_33_1 !
2024-03-13 22:38:47,567 WARN storage.BlockManagerMasterEndpoint: No more replicas available for rdd_56_1 !
2024-03-13 22:38:47,567 WARN storage.BlockManagerMasterEndpoint: No more replicas available for rdd_56_0 !
2024-03-13 22:38:47,567 WARN storage.Bloc

KeyboardInterrupt: 

In [None]:
predictions = model.transform(evalPool.data)
predictions.show()

In [None]:
from pyspark.ml.evaluation import RegressionEvaluator

In [None]:
lr_evaluator = RegressionEvaluator(predictionCol="prediction", labelCol="label", metricName="r2")
lr_evaluator.evaluate(predictions)

In [None]:
lr_evaluator = RegressionEvaluator(predictionCol="prediction", labelCol="label", metricName="rmse")
lr_evaluator.evaluate(predictions)