In [36]:
# Download dataset from Kaggle using KaggleHub
import kagglehub

# Download latest version
path = kagglehub.dataset_download("pauloarayasantiago/pokmon-stats-across-generations-and-typings")

print("Path to dataset files:", path)

Path to dataset files: /Users/tongluyangyu/.cache/kagglehub/datasets/pauloarayasantiago/pokmon-stats-across-generations-and-typings/versions/1


In [37]:
# Read dataset
import pandas as pd
dataset = pd.read_csv(path + "/pokemon_dataset.csv")
print(dataset.head())

   pokemon_id        name primary_type secondary_type first_appreance  \
0           1   bulbasaur        grass         poison        red/blue   
1           2     ivysaur        grass         poison        red/blue   
2           3    venusaur        grass         poison        red/blue   
3           4  charmander         fire            NaN        red/blue   
4           5  charmeleon         fire            NaN        red/blue   

  generation category  total_base_stats  hp  attack  defense  special_attack  \
0      gen 1  regular               318  45      49       49              65   
1      gen 1  regular               405  60      62       63              80   
2      gen 1  regular               525  80      82       83             100   
3      gen 1  regular               309  39      52       43              60   
4      gen 1  regular               405  58      64       58              80   

   special_defense  speed  
0               65     45  
1               80     6

In [38]:
dataset.columns


Index(['pokemon_id', 'name', 'primary_type', 'secondary_type',
       'first_appreance', 'generation', 'category', 'total_base_stats', 'hp',
       'attack', 'defense', 'special_attack', 'special_defense', 'speed'],
      dtype='object')

In [39]:
# See the pokemon types
primary_types = dataset['primary_type'].unique()
primary_types

array(['grass', 'fire', 'water', 'bug', 'normal', 'poison', 'electric',
       'ground', 'fairy', 'fighting', 'psychic', 'rock', 'ghost', 'ice',
       'dragon', 'dark', 'steel', 'flying'], dtype=object)

In [40]:
# Get all rows where the pokemons without secondary type
nonSecondary = dataset[dataset['secondary_type'].isna()]
nonSecondary

Unnamed: 0,pokemon_id,name,primary_type,secondary_type,first_appreance,generation,category,total_base_stats,hp,attack,defense,special_attack,special_defense,speed
3,4,charmander,fire,,red/blue,gen 1,regular,309,39,52,43,60,50,65
4,5,charmeleon,fire,,red/blue,gen 1,regular,405,58,64,58,80,65,80
6,7,squirtle,water,,red/blue,gen 1,regular,314,44,48,65,50,64,43
7,8,wartortle,water,,red/blue,gen 1,regular,405,59,63,80,65,80,58
8,9,blastoise,water,,red/blue,gen 1,regular,530,79,83,100,85,105,78
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
976,977,dondozo,water,,scarlet/violet,gen 9,regular,530,150,100,115,65,65,35
981,982,dudunsparce,normal,,scarlet/violet,gen 9,regular,520,125,100,80,85,75,55
998,999,gimmighoul,ghost,,scarlet/violet,gen 9,regular,300,45,30,70,75,70,10
1016,1017,ogerpon,grass,,scarlet/violet,gen 9,legendary,550,80,120,84,60,96,110


In [41]:
# Get all rows where the pokemons with secondary type
Secondary = dataset[dataset['secondary_type'].notna()]
Secondary

Unnamed: 0,pokemon_id,name,primary_type,secondary_type,first_appreance,generation,category,total_base_stats,hp,attack,defense,special_attack,special_defense,speed
0,1,bulbasaur,grass,poison,red/blue,gen 1,regular,318,45,49,49,65,65,45
1,2,ivysaur,grass,poison,red/blue,gen 1,regular,405,60,62,63,80,80,60
2,3,venusaur,grass,poison,red/blue,gen 1,regular,525,80,82,83,100,100,80
5,6,charizard,fire,flying,red/blue,gen 1,regular,534,78,84,78,109,85,100
11,12,butterfree,bug,flying,red/blue,gen 1,regular,395,60,45,50,90,80,70
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1019,1020,gouging-fire,fire,dragon,scarlet/violet,gen 9,paradox,590,105,115,121,65,93,91
1020,1021,raging-bolt,electric,dragon,scarlet/violet,gen 9,paradox,590,125,73,91,137,89,75
1021,1022,iron-boulder,rock,psychic,scarlet/violet,gen 9,regular,590,90,120,80,68,108,124
1022,1023,iron-crown,steel,psychic,scarlet/violet,gen 9,paradox,590,90,72,100,122,108,98


In [42]:
# Which pokemon has the highest attack?
max_attack = dataset['attack'].max()
max_attack_pokemon = dataset[dataset['attack'] == max_attack]
max_attack_pokemon

Unnamed: 0,pokemon_id,name,primary_type,secondary_type,first_appreance,generation,category,total_base_stats,hp,attack,defense,special_attack,special_defense,speed
797,798,kartana,grass,steel,sun/moon,gen 7,ultra beast,570,59,181,131,59,31,109


### Task 1 - Try to classify the pokemon type based on the total base stats

In [43]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import classification_report, accuracy_score

# Preprocess the data
X = nonSecondary[['total_base_stats']]
y = nonSecondary['primary_type']

# Encode the labels
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(y)

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y_encoded, test_size=0.3, random_state=42)

# Standardize the features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# Train the MLP model
mlp = MLPClassifier(hidden_layer_sizes=(10, 10), max_iter=1000, random_state=42)
mlp.fit(X_train_scaled, y_train)

# Predict and evaluate the model
y_pred = mlp.predict(X_test_scaled)
print("Accuracy:", accuracy_score(y_test, y_pred))
print(classification_report(y_test, y_pred, target_names=label_encoder.classes_[:len(set(y_test))]))


Accuracy: 0.12666666666666668
              precision    recall  f1-score   support

         bug       0.14      0.17      0.15         6
        dark       0.00      0.00      0.00         4
      dragon       0.00      0.00      0.00        11
    electric       0.00      0.00      0.00         7
       fairy       0.00      0.00      0.00        13
    fighting       0.00      0.00      0.00        13
        fire       0.00      0.00      0.00         1
      flying       0.00      0.00      0.00         4
       ghost       0.00      0.00      0.00        12
       grass       0.00      0.00      0.00         2
      ground       0.00      0.00      0.00         5
         ice       0.11      0.21      0.14        29
      normal       0.00      0.00      0.00         7
      poison       0.00      0.00      0.00        11
     psychic       0.00      0.00      0.00         2
        rock       0.00      0.00      0.00         5
       steel       0.14      0.67      0.23        

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
