### Introduction to Naive Bayes

The Naive Bayes classification machine learning model uses Bayes Theorem to classify data points into one of several labels. When multiple features are used, the model calculates the probability of each label being the true label based on the conditional probability from the state of the feature. This is done for each of the features present, and the model assigns the label with the highest probability.

The Naive Bayes model operates under the assumption that the effects of the different features are independent of one another.

Given my prior hypothesis that certain team evaluation metrics during powerplays may be more predictive of regular season team success, I am looking to use Naive Bayes classification to determine which features are most predictive and then determine if these features during power play situations can be more accurately used to train a ML model to predict playoff berths then features from regular 5 on 5 game play.

### Prepare Data for Naive Bayes

I prepared the team powerplay data for Naive Bayes classification in an R file linked here. Team data from the 2018-2019, 2020-2021, 2021-2022, and 2022-2023 seasons were combined and the desired features were subsetted from the greater dataset. An additional variable was added for whether or not the team made the playoffs that season and all numerical metrics were normalized. Finally, all non-numerical variables except the playoff variable were removed.

In [9]:
import numpy as np
import pandas as pd

In [22]:
df=pd.read_csv("/Users/wan/Desktop/Assignment/standard/numerical_data.csv")
print(df.shape)

label = df['price'].copy()
df=df.drop(columns=['price'])
features=df.columns
feature_matrix = df[features].copy()

(1000, 4)


In [23]:
df.head()

Unnamed: 0,area,neibor_density,consumption
0,45,92.224562,135.603231
1,48,99.102488,144.654674
2,65,136.684707,196.140664
3,68,128.723201,208.26317
4,68,121.308618,205.22558


In [27]:
from sklearn.model_selection import train_test_split

X = feature_matrix
y= label

X_train, X_test, y_train, y_test = train_test_split(
    X, y,
    test_size=0.2,
    random_state=100
)


After loading the pre-prepared dataset, I used sklearn to split the dataset into a training set and a test set with an 80-20 split. It is important to keep a portion of the data aside to test the accuracy of the model on.

### Feature Selection

Feature selection is a process by which only the most predictive and least correlated features are used to train a model in order to maximize accuracy.

To find the subset of features that result in the highest accuracy score, I will find every possible combination of the seven features and calculate a metric to evaluate their correlation coefficients.

In [29]:
from itertools import chain, combinations

#list all possible feature subsets
feature_subsets = list(features)
feature_subset=chain.from_iterable(combinations(feature_subsets,r) for r in range(len(feature_subsets)+1))
feature_subset=list(feature_subset)



In [30]:
X_train_df=pd.DataFrame(X_train, columns=features)
X_test_df=pd.DataFrame(X_test, columns=features)

from scipy.stats import spearmanr
import itertools

#calculate spearman correlation coefficients for each subset
def mean_xx_corr(x_df):
    df_colnames=x_df.columns
    xx_corrs=[]

    df_colname_pairs=itertools.combinations(df_colnames, 2)
    for colname1, colname2 in df_colname_pairs:
        col1=x_df[colname1]
        col2=x_df[colname2]
        xx_pair_corr=spearmanr(col1, col2).stastic
        xx_corrs.append(xx_pair_corr)

    return np.mean(xx_corrs)


def compute_mean_xy_corr(x_df, y_vec):
    df_colnames=x_df.columns
    xy_corrs=[]
    for colname in df_colnames:
        x_col = x_df[colname]
        xy_pair_corr = spearmanr(x_col, y_vec)
        xy_corrs.append(xy_pair_corr)

    return np.mean(xy_corrs)


### Naive Bayes

The first NB model I used incorporated the data from all seven features. The GaussianNB model from the sklearn package was trained on the previously partitioned train data.

In [33]:
from sklearn.naive_bayes import GaussianNB

model=GaussianNB()
model.fit(X_train, y_train)

from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay


y_pred=model.predict(X_test)
accuracy=accuracy_score(y_pred, y_test)
f1 = f1_score(y_pred, y_test, average="weighted")

print(accuracy)
print(f1)

cm=confusion_matrix(y_test, y_pred)
disp=ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Non-playoff team","Playoff team"])
disp.plot()



ValueError: Unknown label type: (array([ -9.50969297,  -8.10198704,  -4.8184365 ,  -3.37761756,
        -3.22186214,  -2.55708283,  -1.16680359,  -0.41367893,
         0.75350615,   1.49265416,   1.54409741,   1.63265099,
         1.64854021,   1.83870399,   2.04440559,   2.48755875,
         2.58275027,   2.74877013,   2.89824575,   2.97603528,
         4.02961622,   4.56652625,   4.72867834,   4.76556405,
         5.03893107,   5.16062023,   6.0966893 ,   6.51608934,
         6.55922882,   6.57091435,   6.65443063,   6.84904614,
         7.72742963,   8.57978778,   8.64060469,   8.81983511,
         8.86927367,   9.14495807,   9.17615828,   9.37381914,
         9.45019153,   9.82686801,  10.13015879,  10.24227471,
        10.31232449,  10.34564926,  10.35162161,  10.37523265,
        10.56627702,  10.7703802 ,  10.78191391,  10.78230879,
        10.87715414,  11.02530596,  11.11748785,  11.28849778,
        11.3364831 ,  11.3986961 ,  11.56013924,  11.75547052,
        11.76844302,  11.8915865 ,  12.08565633,  13.20169534,
        13.29549636,  13.52958469,  13.67326729,  14.39874403,
        14.63545473,  14.69019919,  14.73108273,  15.99291557,
        16.26938237,  16.44365722,  16.55422742,  16.57876299,
        17.17923317,  17.36573682,  17.61558031,  17.73716774,
        18.00201991,  18.11746461,  18.21775828,  18.52982676,
        19.26266324,  19.31234616,  19.52988644,  19.6131859 ,
        19.64394976,  19.82252931,  19.84855456,  19.87019645,
        20.02131155,  20.17535725,  20.17760148,  20.30848762,
        20.31469379,  20.56629962,  20.87972292,  21.09528042,
        21.17482959,  21.24282853,  21.3253735 ,  21.40941866,
        21.54264942,  21.73804573,  21.94607579,  22.7053443 ,
        22.70789917,  22.85059462,  23.00718356,  23.42969239,
        23.63654692,  23.69069536,  23.75555415,  24.09549734,
        24.74154508,  25.13854188,  25.85391963,  26.36937519,
        26.53748085,  26.71965774,  26.75500485,  26.78190905,
        26.98511907,  27.12100351,  27.26310553,  27.28164264,
        27.29721678,  27.76355151,  28.00302278,  28.03018357,
        28.3539964 ,  28.90628882,  29.22807527,  29.49224378,
        29.65843527,  30.4607251 ,  30.77660233,  30.9607624 ,
        31.4703157 ,  31.61892924,  31.70335048,  31.82625914,
        31.89983071,  32.08249026,  32.33882996,  32.47810296,
        32.63069833,  32.77235432,  32.82875476,  32.96013469,
        33.00994485,  33.54055369,  33.59546719,  34.38924239,
        34.96995072,  35.13134959,  35.24322856,  35.5160162 ,
        36.33272428,  36.90557255,  36.95632661,  37.5758672 ,
        37.74640792,  37.8101748 ,  37.95383964,  38.0687943 ,
        38.07760154,  38.17524166,  38.5512746 ,  38.67558993,
        38.75716694,  38.87756071,  38.8998385 ,  39.04158007,
        39.16906388,  39.19038788,  39.62173981,  39.73918848,
        39.86773695,  40.35226788,  40.48911419,  40.55682266,
        40.56520319,  40.688162  ,  40.87812759,  40.88265763,
        41.01227687,  41.09162011,  41.14189692,  41.25232018,
        41.58873001,  41.9094522 ,  42.22206161,  42.28841708,
        42.54744094,  42.84621744,  42.85015518,  42.87186605,
        43.15021061,  43.28006123,  43.56403799,  43.88954352,
        43.97142173,  44.03379435,  44.47184565,  44.72394475,
        45.42137733,  45.51005134,  45.59525077,  45.77646293,
        45.78363056,  45.82408148,  45.92565087,  46.51547999,
        46.7831174 ,  46.96720596,  47.08742361,  47.16697811,
        47.22857756,  47.52782153,  47.67025599,  47.71518068,
        48.40078708,  48.5017469 ,  49.162575  ,  49.22687   ,
        49.23426513,  49.24811089,  49.37025722,  49.50698587,
        49.52929192,  49.64553072,  50.1515118 ,  50.25294469,
        50.53267292,  50.64350717,  50.68160829,  50.7552271 ,
        51.10415966,  51.22712784,  51.24954909,  51.60274761,
        51.71173137,  51.92398186,  51.94946525,  52.1945092 ,
        52.207525  ,  52.37128448,  52.82332112,  53.09119797,
        53.31385833,  53.60753557,  53.89080319,  54.00820158,
        54.08825913,  54.49027147,  54.50820352,  54.61583528,
        54.8637895 ,  55.29121606,  55.48035689,  55.83696813,
        56.33499064,  56.43969903,  56.59730087,  56.66047555,
        56.75992038,  56.82366845,  57.64638504,  57.70451031,
        57.84813317,  58.55932244,  58.56609981,  58.63635276,
        58.75245084,  58.89757386,  59.11753814,  59.23405991,
        59.48176289,  59.65601056,  59.7152634 ,  59.98603391,
        59.99428631,  60.56508458,  60.69232383,  61.03233983,
        61.074268  ,  61.46293038,  61.73349808,  61.79903063,
        62.04520033,  62.12360034,  62.34160356,  62.34783752,
        62.78772428,  62.83066883,  63.08197184,  63.20817637,
        63.22197623,  63.37904977,  63.77910118,  63.99806177,
        64.04413223,  64.19827707,  64.26946705,  64.70157101,
        64.97917929,  65.56313278,  66.03719033,  66.60914382,
        67.08920564,  67.1531809 ,  67.1747286 ,  67.46055492,
        67.73560759,  68.31897987,  68.36657938,  68.41079559,
        68.43020472,  68.53580509,  68.73190727,  69.02314263,
        69.05796482,  69.23007195,  69.25857053,  69.53199666,
        69.75356539,  69.91257791,  70.08232023,  70.30452209,
        70.42520223,  70.42575717,  70.78802254,  71.1061648 ,
        71.23145127,  71.34855195,  71.78390365,  71.98876517,
        72.19592317,  72.75275721,  72.75864243,  72.7628457 ,
        72.77095907,  73.11719492,  73.17526126,  73.28792816,
        73.73280881,  73.73519506,  73.78479511,  73.80506072,
        74.03623918,  74.20369636,  74.71733259,  75.06608519,
        75.13888085,  75.14195261,  75.4513555 ,  75.47650687,
        75.84041307,  75.94443847,  76.31521499,  76.67086494,
        76.77398639,  76.83811829,  77.26616832,  77.57351216,
        78.10766881,  78.13535253,  78.47089105,  78.67899614,
        78.76884633,  78.85692482,  78.93467052,  79.1329126 ,
        79.34242522,  79.39588641,  79.40382773,  79.54429587,
        79.91207641,  80.12858096,  80.62635467,  80.65959904,
        80.73390676,  80.82306305,  81.25455538,  81.26241707,
        81.74514287,  81.77149082,  81.96999438,  82.05089561,
        82.28084101,  82.74088778,  82.89202097,  83.04448764,
        83.65918595,  83.82248886,  83.87043876,  83.87058583,
        85.02007187,  85.04667996,  85.07750592,  85.53580528,
        85.94047283,  86.05432429,  86.5389886 ,  86.87914034,
        87.27044792,  87.55283097,  87.567918  ,  87.79669353,
        87.83295838,  88.18420745,  88.44513802,  88.56660305,
        88.75550619,  88.75760328,  88.78292209,  88.91556997,
        89.10009398,  89.93378828,  90.76969635,  90.82100014,
        90.88738899,  91.36773209,  91.62723313,  91.7038731 ,
        91.86008706,  91.97398386,  92.21864831,  92.48707077,
        92.57962186,  92.75416543,  92.88779379,  92.90871366,
        93.29413351,  93.90852323,  94.19448572,  94.25119254,
        94.44744182,  94.51801142,  94.52552459,  94.67201992,
        94.71063312,  94.87216847,  94.92381634,  95.04946582,
        95.13808531,  95.32294624,  95.41889326,  95.64457755,
        96.31563228,  96.56003749,  96.82266805,  96.85429805,
        97.14211665,  97.23135876,  97.34110226,  97.80534048,
        97.85710354,  99.08359785,  99.5740492 ,  99.87189917,
       100.11721657, 100.26995013, 100.33436776, 100.41456887,
       100.52632121, 100.80028684, 100.97956092, 101.08019602,
       101.24575226, 101.58453324, 101.67211191, 101.81890647,
       102.40470876, 103.06300739, 103.30780199, 103.32281457,
       103.60408579, 104.25933733, 104.32103835, 104.3543222 ,
       104.56453247, 104.76392791, 105.0550663 , 105.07128925,
       105.1137574 , 105.38655871, 105.43402699, 105.58315821,
       106.34830458, 106.61270035, 106.63213279, 106.74804436,
       107.04869889, 107.2633886 , 107.38975689, 107.44641357,
       107.57195831, 108.04069664, 108.34483423, 108.54173691,
       108.62555115, 108.77107048, 108.78395197, 108.87155496,
       109.07927664, 109.24821899, 109.55728086, 110.03675729,
       110.04739378, 110.36693244, 110.48624241, 110.48987056,
       110.71756758, 110.77678629, 111.60729272, 111.6073638 ,
       111.96303855, 112.17062842, 112.30398816, 112.55850608,
       113.01323252, 113.37114443, 113.69639909, 113.7563117 ,
       114.25635113, 114.80659877, 115.03246255, 115.15546453,
       115.18218217, 116.36824388, 116.38261532, 116.69646425,
       117.07743114, 117.67312987, 117.72186081, 118.12831482,
       118.76491448, 118.76597733, 118.82878817, 118.90401279,
       118.9569862 , 118.95771877, 118.97695505, 119.34016392,
       119.65702512, 119.74678313, 119.82164749, 119.91796989,
       120.2548625 , 120.28237488, 120.31211218, 120.54920756,
       120.62017023, 121.4551595 , 121.61404679, 121.62504111,
       121.6975316 , 121.76111016, 121.82662376, 121.8354228 ,
       121.90620273, 121.99681507, 122.01896938, 122.83081776,
       123.00472468, 123.28886223, 123.33893136, 124.12534637,
       124.13192209, 124.21167276, 124.67996832, 124.76145522,
       124.77923407, 124.84918111, 125.00896985, 125.13311465,
       125.15944876, 125.30787994, 125.31752186, 125.37023566,
       125.40736684, 125.56383638, 125.64278045, 125.70803957,
       125.71882412, 125.79707694, 125.99942301, 126.24440577,
       126.42858431, 126.55744201, 126.78352436, 126.94708299,
       127.41010069, 127.43437657, 127.83329462, 128.04880892,
       128.30821281, 128.66471882, 128.6944923 , 128.81485861,
       129.54882125, 130.06473401, 130.2216023 , 130.33305364,
       130.35026993, 130.56994324, 130.74533605, 130.96403485,
       131.14436852, 131.47108631, 131.67674138, 131.7678787 ,
       131.85377202, 131.98451268, 132.10112047, 132.35878137,
       132.47013224, 132.62670141, 132.87699601, 133.3367186 ,
       133.72472678, 133.80063556, 133.86932364, 133.89668036,
       133.91871445, 134.17938182, 134.21632511, 134.29397562,
       134.35845842, 134.52295355, 134.70037908, 135.09324567,
       135.24860791, 135.41733267, 135.4487484 , 136.67339551,
       136.71406801, 136.72059631, 136.83844324, 137.17474641,
       137.20915497, 137.38552557, 137.75392844, 137.84247292,
       138.12338311, 138.25299241, 138.42516364, 138.71114391,
       138.81850057, 138.93586558, 139.05239223, 139.09510003,
       139.12874149, 139.21107742, 139.49721045, 139.67169184,
       139.67711437, 139.86279844, 140.0689424 , 140.79975939,
       140.87289086, 140.99344999, 141.45373828, 141.58270691,
       141.58389646, 141.70166076, 141.98655453, 142.08985086,
       142.15269429, 142.3964899 , 142.46868721, 142.78224733,
       142.91224696, 143.58668179, 143.73600092, 144.11000961,
       144.39330065, 144.71219506, 145.09917346, 145.25434182,
       145.28601746, 145.43493309, 145.90545042, 145.97846646,
       146.03192702, 146.0538303 , 146.40724155, 146.79529904,
       147.06508562, 147.34680809, 147.44818256, 148.33567767,
       148.78765756, 148.80994424, 148.88971579, 148.96065704,
       148.9649142 , 149.21432697, 149.25990326, 149.79737794,
       150.04982289, 150.14414363, 150.42055584, 150.44734579,
       150.51117281, 150.73934084, 151.0677702 , 151.26615209,
       151.43122197, 151.43914151, 151.73191419, 151.80181706,
       151.87324899, 152.13709494, 152.19851293, 152.47539749,
       152.49653474, 152.53265574, 152.70551873, 152.80820919,
       152.90851796, 153.15359479, 153.22810714, 153.34598828,
       153.69390513, 153.70044381, 153.76821693, 154.43394216,
       154.75560383, 154.83119532, 154.91607079, 154.95654686,
       155.40119696, 155.43161853, 155.89457754, 155.94914111,
       156.02270232, 156.34402592, 156.50919739, 157.18160874,
       157.65105718, 157.73357926, 157.87671628, 157.90911872,
       158.10278069, 158.20176283, 158.42079376, 158.4432359 ,
       158.46641177, 158.70372756, 158.87495548, 158.90336061,
       159.11495681, 159.1349432 , 159.17394966, 159.41302338,
       159.47134798, 159.58950595, 159.83943078, 160.07036291,
       160.1149116 , 160.12768123, 160.84718359, 160.98197721,
       161.22069799, 161.36190436, 161.39299323, 162.01265672,
       162.27558925, 162.33379818, 162.46697472, 162.77475503,
       162.78043612, 162.78639912, 162.82857098, 163.46868877,
       164.01625411, 164.28824797, 164.37260068, 164.82082754,
       165.20138365, 166.10937186, 166.21994584, 166.50155972,
       167.67744368, 168.19935033, 168.46916017, 170.59495877,
       171.46760545, 171.91823207, 171.94031015, 172.38669099,
       173.09596987, 173.86125808, 174.76804204, 178.17809327]),)

After training on the training data, I had the model predict whether the records in the test data were playoff teams based on their features. The accuracy and f1 scores were both around 80% which is pretty high and the confusion matrix shows that the model was equally proficient at predicting teams to make the playoffs as they are at predicting teams to miss the playoffs.

Further, I would like to test other subsets of features to find a more accurate model. And then carry out the same process to evaluate models based on regular 5 on 5 play and compare the accuracy of these two models.

For now, it is clear to see that player evaluation metrics during power plays can be used to predict regular season success.

In [8]:
df_train = pd.read_csv('Train.csv', delimiter=',', quotechar='"')
df_test = pd.read_csv('Test.csv', delimiter=',', quotechar='"')
df_sample = pd.read_csv('Sample_submission.csv', delimiter=',', quotechar='"')

In [9]:
df_train.head(12)

Unnamed: 0,review,label
0,mature intelligent and highly charged melodram...,pos
1,http://video.google.com/videoplay?docid=211772...,pos
2,Title: Opera (1987) Director: Dario Argento Ca...,pos
3,I think a lot of people just wrote this off as...,pos
4,This is a story of two dogs and a cat looking ...,pos
5,Steve Carell comes into his own in his first s...,pos
6,I'm only going to write more because it's requ...,neg
7,"OK, it was a ""risky"" move to rent this flick, ...",neg
8,"Cannibalism, a pair of cinematic references to...",pos
9,This is one of the great modern kung fu films....,pos


In [10]:
df_sample.tail(3)

Unnamed: 0,Id,label
9997,9997,pos
9998,9998,pos
9999,9999,pos


In [11]:
df_test.head(3)

Unnamed: 0,review
0,Remember those old kung fu movies we used to w...
1,This movie is another one on my List of Movies...
2,How in the world does a thing like this get in...


In [12]:
train_reviews = df_train.review
test_reviews = df_test.review
labels = df_train.label

In [13]:
train_reviews[0]

"mature intelligent and highly charged melodrama unbelivebly filmed in China in 1948. wei wei's stunning performance as the catylast in a love triangle is simply stunning if you have the oppurunity to see this magnificent film take it"

In [14]:
!pip install nltk



In [15]:
import nltk
from nltk.tokenize import RegexpTokenizer
from nltk.stem import PorterStemmer
from nltk.corpus import stopwords

# Ensure NLTK resources are available
nltk.download('stopwords')

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


True

In [16]:
def clean_view(text):
    ps = PorterStemmer()
    tokenizer = RegexpTokenizer('[a-zA-Z]+')
    stopword = set(stopwords.words('english'))
    text = text.lower()
    tokens = tokenizer.tokenize(text)
    new_token = [ps.stem(token) for token in tokens if token not in stopword] # stemming and stopword removing
    return ' '.join(new_token)

In [17]:
clean_train = [clean_view(each) for each in train_reviews]

In [18]:
clean_test = [clean_view(each) for each in test_reviews]

In [19]:
from sklearn.feature_extraction.text import TfidfVectorizer
tf = TfidfVectorizer(ngram_range=(2, 2))
tf.fit(clean_train)

In [20]:
x_train = tf.transform(clean_train)
x_test = tf.transform(clean_test)

In [21]:
from sklearn.preprocessing import LabelEncoder
le = LabelEncoder()
y = le.fit_transform(labels)

# Naive Bayes for textual data

In [22]:
from sklearn.naive_bayes import MultinomialNB

model = MultinomialNB()
model.fit(x_train, y)

In [24]:
model.score(x_train, y)

0.996125

In [25]:
pred = model.predict(x_test)

### The evaluation metrics:

- **Accuracy**: Measures the correct prediction ratio. Effective for balanced classes, less so for imbalanced ones.
- **Precision**: Proportion of true positives in positive predictions. Key when false positives have high costs.
- **Recall (Sensitivity)**: Proportion of actual positives correctly identified. Critical when false negatives carry significant risks.
- **F1 Score**: Harmonic mean of precision and recall. Ideal for balancing these metrics, particularly in imbalanced datasets.

In [None]:
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import f1_score

accuracy=accuracy_score(y_pred, y_test)

f1 = f1_score(y_pred, y_test, average="weighted")

precision = precision_score(y_test, y_pred, average="weighted")

recall = recall_score(y_test, y_pred, average="weighted")

print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)
print("F1 Score:", f1)

In [None]:
from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay

cm=confusion_matrix(y_test, y_pred)
disp=ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Non-playoff team","Playoff team"])
disp.plot()

### Overfitting and under-fitting.

Overfitting and underfitting are the two biggest causes for poor performance of machine learning algorithms. 

Overfitting refers to a model that models fit training data too well. Overfitting happens when a model learns the detail and noise in the training data to the extent that it negatively impacts the performance of the model on new data. This means that the noise or random fluctuations in the training data is picked up and learned as concepts by the model. The problem is that these concepts do not apply to new data and negatively impact the models ability to generalize.

Underfitting refers to a model that can neither fit the training data nor generalize to new data. An underfit machine learning model is not a suitable model and will be obvious as it will have poor performance on the training data.

Naive Bayes classifiers are less prone to overfitting, especially when dealing with high-dimensional data. Improving the set of features, such as using different text representation techniques for text classification (like TF-IDF), can help the model capture more information, reducing underfitting.


### Model’s performance (need revision).

- **Accuracy (0.504)**: Marginally better than random guessing. Effective if baseline accuracy is low.
- **Precision (0.5039)**: Correct half the time on positive predictions. Indicates average reliability.
- **Recall (0.504)**: Identifies 50.4% of actual positives. Moderate performance in detecting true cases.
- **F1 Score (0.5042)**: Harmonic mean of precision and recall, indicating average balance.

In [6]:
import pandas as pd
import numpy as np

np.random.seed(0)
n_samples_per_class = 200

class1 = {
    'no_feedback': np.zeros(n_samples_per_class),
    'var1': np.random.normal(0, 1, n_samples_per_class),
    'var2': np.random.normal(0, 1, n_samples_per_class),
    'var3': np.random.normal(0, 1, n_samples_per_class)
}

class2 = {
    'with_feedback': np.ones(n_samples_per_class),
    'var1': np.random.normal(3, 1, n_samples_per_class),
    'var2': np.random.normal(3, 1, n_samples_per_class),
    'var3': np.random.normal(3, 1, n_samples_per_class)
}

data_class1 = pd.DataFrame(class1)
data_class2 = pd.DataFrame(class2)

# 合并两个类别的数据
data = pd.concat([data_class1, data_class2])

# 打印前几行以查看结果
print(data.head())

# 保存数据为CSV文件
data.to_csv('cluster_data.csv', index=False)


# 打印前几行以查看结果
# 保存DataFrame为CSV文件
df.to_csv('/Users/wan/Desktop/Assignment/standard/team_pp_data_clean.csv', index=False)




In [8]:
import pandas as pd
df=pd.read_csv("/Users/wan/Desktop/Assignment/standard/team_pp_data_clean.csv")
print(df.shape)

(1000, 4)
