In [1]:
import pandas as pd
import tensorflow as tf
from tensorflow.keras.layers import Dense
from sklearn.model_selection import train_test_split
import wandb
from wandb.keras import WandbCallback
from google_Drive import download_from_url

In [None]:
wandb.init(project="AI-Snake")

### Get dataset from GoogleDrive

In [None]:
download_from_url(url="https://drive.google.com/file/d/1ErVU2ncSEvqiOtL1KHx-Y_M41UmJ9A27/view?usp=sharing",
                  output_path="data/features_data.csv")

In [3]:
data = pd.read_csv("data/features_data.csv")
data

Unnamed: 0,x_snake,y_snake,x_apple,y_apple,x_S-A_distance,y_S-A_distance,direction
0,401,299,443,259,-42,40,3
1,402,298,443,259,-41,39,3
2,403,297,443,259,-40,38,3
3,404,296,443,259,-39,37,3
4,405,295,443,259,-38,36,3
...,...,...,...,...,...,...,...
12700,762,358,788,358,-26,0,2
12701,763,358,788,358,-25,0,2
12702,764,358,788,358,-24,0,2
12703,765,358,788,358,-23,0,2


In [4]:
X = data.iloc[:, :-1].values # features
Y = data.iloc[:, -1].values # directions

In [5]:
Y = Y.reshape(-1, 1)

### Create train, test, validation data

In [6]:
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, shuffle=True, test_size=0.2, random_state=24) # Create Train,Test Data

In [7]:
X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, shuffle=True, test_size=0.2, random_state=24) # Create Validation Data

### Model

In [8]:
model = tf.keras.models.Sequential([
    Dense(64, input_dim=6, activation="relu"),
    Dense(128, activation="relu"),
    Dense(256, activation="relu"),
    Dense(512, activation="relu"),
    Dense(8, activation="softmax")
])

In [9]:
config = wandb.config
config.learning_rate = 0.001

In [10]:
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=config.learning_rate),
              loss=tf.keras.losses.sparse_categorical_crossentropy,
              metrics=["accuracy"])

In [None]:
model.fit(X_train, Y_train, validation_data=(X_val, Y_val), epochs=30, callbacks=[WandbCallback()])

In [15]:
model.evaluate(X_test, Y_test)



[0.11398351192474365, 0.9704840779304504]

### Save the model

In [16]:
model.save("Model/Snake_AI.h5")