In [None]:
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, BatchNormalization
from tensorflow.keras.optimizers import Adam, RMSprop
import numpy as np
import pandas as pd

# Load the dataset
train_data = pd.read_csv('train_data.csv')
test_data = pd.read_csv('test_data.csv')

# Data preparation (Assuming the last column is the target variable)
X = train_data.iloc[:, :-1]
y = train_data.iloc[:, -1]
X_test = test_data.iloc[:, :-1]
y_test = test_data.iloc[:, -1]


# Split train data for validation
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

# Standardize the data
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_val = scaler.transform(X_val)
X_test = scaler.transform(X_test)


In [31]:
# Define the model creation function
def create_model():
    model = Sequential([
        Dense(128, activation='relu', input_shape=(X_train.shape[1],)),
        BatchNormalization(),
        Dropout(0.3),
        Dense(64, activation='relu'),
        BatchNormalization(),
        Dropout(0.3),
        Dense(1, activation='sigmoid')
    ])
    model.compile(optimizer=Adam(), loss='binary_crossentropy', metrics=['accuracy'])
    return model

# KFold Cross Validation
kf = KFold(n_splits=5, shuffle=True, random_state=42)
fold_accuracies = []

for train_index, val_index in kf.split(np.arange(len(X_train))):
    # Split data
    X_train_fold, X_val_fold = X_train[train_index], X_train[val_index]
    y_train_fold, y_val_fold = y_train[train_index], y_train[val_index]

    # Create and train the model
    model = create_model()
    model.fit(X_train_fold, y_train_fold, epochs=30, batch_size=32, verbose=0)

    # Evaluate on the validation fold
    y_val_pred = (model.predict(X_val_fold) > 0.5).astype(int).flatten()
    fold_accuracy = accuracy_score(y_val_fold, y_val_pred)
    fold_accuracies.append(fold_accuracy)

# Average accuracy across folds
avg_accuracy = np.mean(fold_accuracies)
print(f"Average Cross-Validation Accuracy: {avg_accuracy}")


KeyError: '[5, 6, 9, 27, 32, 40, 54, 97, 113, 114, 127, 135, 138, 139, 147, 158, 161, 178, 191, 197, 226, 227, 229, 237, 243, 253, 255, 258, 259, 260, 270, 280, 286, 289, 293, 299, 301, 302, 313, 322, 324, 331, 338, 341, 358, 370, 377, 380, 382, 383, 385, 395, 406, 407, 430, 436, 444, 449, 452, 456, 464, 467, 477, 478, 484, 489, 490, 492, 493, 494, 502, 505, 514, 520, 523, 526, 548, 549, 552, 553, 564, 566, 577, 595, 596, 597, 606, 608, 621, 629, 636, 644, 653, 656, 658, 660, 662, 671, 684, 685, 688, 701, 702, 716, 721, 738, 745, 749, 754, 757, 760, 771, 774, 777, 778, 784, 798, 803, 822, 825, 826, 831, 839, 845, 854, 869, 875, 888, 905, 909, 922, 939, 951, 952, 956, 961, 966, 971, 974, 978, 979, 985, 987, 988, 1003, 1004, 1013, 1014, 1031, 1036, 1051, 1054, 1063, 1064, 1067, 1069, 1071, 1091, 1106, 1107, 1109, 1118, 1154, 1156, 1159, 1161, 1162, 1165, 1173, 1186, 1200, 1205, 1231, 1237, 1243, 1246, 1251, 1256, 1260, 1268, 1278, 1279, 1292, 1300, 1319, 1324, 1329, 1333, 1335, 1337, 1338, 1355, 1364, 1371, 1378, 1380, 1387, 1389, 1396, 1406, 1411, 1415, 1422, 1423, 1442, 1443, 1453, 1459, 1464, 1469, 1470, 1474, 1484, 1486, 1498, 1504, 1516, 1524, 1532, 1541, 1545, 1552, 1563, 1564, 1565, 1570, 1573, 1574, 1582, 1584, 1596, 1598, 1605, 1613, 1620, 1633, 1634, 1652, 1657, 1658, 1661, 1673, 1675, 1680, 1683, 1688, 1689, 1694, 1696, 1704, 1708, 1730, 1734, 1735, 1736, 1738, 1742, 1748, 1751, 1752, 1761, 1764, 1771, 1777, 1778, 1780, 1784, 1786, 1795, 1796, 1797, 1801, 1810, 1812, 1818, 1822, 1824, 1826, 1827, 1828, 1833, 1838, 1851, 1855, 1856, 1866, 1871, 1872, 1874, 1875, 1879, 1885, 1902, 1906, 1908, 1920, 1921, 1925, 1932, 1935, 1936, 1939, 1944, 1951, 1958, 1969, 1976, 1981, 1982, 1999, 2004, 2007, 2010, 2011, 2013, 2016, 2021, 2031, 2033, 2035, 2041, 2051, 2054, 2062, 2073, 2075, 2078, 2084, 2103, 2104, 2134, 2135, 2137, 2140, 2141, 2144, 2147, 2156, 2162, 2170, 2171, 2174, 2179, 2180, 2181, 2200, 2241, 2247, 2254, 2257, 2264, 2278, 2281, 2282, 2296, 2297, 2307, 2310, 2311, 2323, 2327, 2328, 2330, 2331, 2335, 2344, 2359, 2370, 2371, 2375, 2377, 2379, 2392, 2395, 2403, 2421, 2422, 2425, 2426, 2428, 2431, 2444, 2445, 2466, 2474, 2476, 2480, 2490, 2493, 2497, 2500, 2503, 2506, 2520, 2534, 2535, 2537, 2541, 2545, 2550, 2551, 2552, 2553, 2566, 2576, 2583, 2588, 2595, 2601, 2607, 2610, 2613, 2615, 2617, 2624, 2630, 2635, 2636, 2638, 2645, 2662, 2674, 2682, 2705, 2709, 2711, 2712, 2732, 2737, 2738, 2744, 2747, 2749, 2757, 2759, 2766, 2769, 2770, 2787, 2788, 2799, 2810, 2819, 2832, 2838, 2845, 2849, 2866, 2872, 2895, 2896, 2898, 2914, 2917, 2919, 2920, 2923, 2938, 2942, 2946, 2949, 2957, 2961, 2967, 2974, 2975, 2978, 2996, 3014, 3041, 3042, 3043, 3051, 3054, 3067, 3086, 3093, 3104, 3108, 3111, 3114, 3117, 3120, 3125, 3132, 3133, 3138, 3143, 3157, 3164, 3169, 3173, 3175, 3177, 3195, 3198, 3203, 3205, 3208, 3210, 3218, 3222, 3225, 3226, 3227, 3242, 3243, 3249, 3255, 3278, 3281, 3302, 3308, 3310, 3311, 3313, 3317, 3320, 3324, 3325, 3327, 3332, 3342, 3348, 3351, 3356, 3364, 3366, 3375, 3387, 3393, 3397, 3400, 3404, 3413, 3415, 3416, 3437, 3446, 3448, 3454, 3459, 3472, 3474, 3483, 3503, 3505, 3507, 3526, 3529, 3533, 3542, 3550, 3560, 3564, 3570, 3571, 3574, 3576, 3585, 3586, 3591, 3601, 3602, 3606, 3608, 3617, 3618, 3621, 3623, 3625, 3627, 3629, 3636, 3638, 3645, 3648, 3650, 3663, 3664, 3667, 3669, 3679, 3681, 3683, 3684, 3692, 3697, 3713, 3715, 3725, 3727, 3738, 3744, 3756, 3761, 3765, 3773, 3776, 3781, 3791, 3796, 3801, 3802, 3803, 3804, 3806, 3825, 3826, 3828, 3829, 3833, 3848, 3849, 3850, 3854, 3865, 3870, 3872, 3893, 3894, 3905, 3912, 3914, 3919, 3923, 3928, 3934, 3940, 3943, 3950, 3966, 3973, 3981, 3987, 3990, 3998, 4006, 4009, 4011, 4014, 4018, 4022, 4026, 4027, 4029, 4040, 4043, 4045, 4049, 4050, 4051, 4053, 4059, 4068, 4083, 4085, 4086, 4106, 4114, 4131, 4134, 4142, 4153, 4156, 4159, 4169, 4173, 4174, 4177, 4181, 4188, 4192, 4193, 4199, 4202, 4203, 4206, 4209, 4215, 4225, 4226, 4227, 4232, 4235, 4237, 4240, 4245, 4260, 4265, 4278, 4279, 4289, 4305, 4313, 4316, 4321, 4335, 4337, 4345, 4358, 4365, 4382, 4386, 4388, 4393, 4406, 4409, 4411, 4415, 4422, 4423, 4431, 4450, 4457, 4467, 4472, 4476, 4483, 4496, 4499, 4531, 4535, 4540, 4541, 4549, 4565, 4570, 4574, 4577, 4589, 4590, 4592, 4599, 4607, 4652, 4655, 4660, 4661, 4663, 4666, 4676, 4678, 4685, 4695, 4698, 4716, 4718, 4727, 4728, 4729, 4743, 4750, 4758, 4761, 4784, 4785, 4787, 4796, 4802, 4805, 4809, 4814, 4818, 4830, 4831, 4833, 4841, 4849, 4859, 4860, 4869, 4882, 4890, 4892, 4909, 4910, 4927, 4928, 4929, 4941, 4947, 4952, 4958, 4963, 4986, 4994, 5005, 5031, 5033, 5034, 5045, 5051, 5061, 5069, 5071, 5072, 5077, 5081, 5085, 5088, 5096, 5105, 5115, 5122, 5128, 5134, 5145, 5172, 5175, 5179, 5182, 5188, 5191, 5192, 5198, 5201, 5202, 5205, 5210, 5211, 5227, 5233, 5240, 5241, 5277, 5289, 5312, 5317, 5318, 5323, 5324, 5325, 5326, 5328, 5330, 5339, 5341, 5345, 5348, 5350, 5351, 5361, 5365, 5369, 5372, 5373, 5379, 5391, 5392, 5396, 5397, 5402, 5404, 5417, 5418, 5426, 5427, 5428, 5433, 5435, 5436, 5440, 5442, 5453, 5455, 5457, 5459, 5462, 5467, 5477, 5478, 5504, 5506, 5507, 5514, 5515, 5517, 5520, 5523, 5526, 5529, 5530, 5535, 5539, 5546, 5548, 5565, 5582, 5586, 5588, 5591, 5593, 5601, 5617, 5624, 5625, 5631, 5637, 5638, 5640, 5647, 5654, 5657, 5659, 5668, 5675, 5688, 5702, 5709, 5710, 5716, 5717, 5718, 5721, 5725, 5746, 5749, 5752, 5766, 5769, 5772, 5781, 5790, 5794, 5795, 5800, 5805, 5806, 5809, 5815, 5818, 5822, 5824, 5825, 5827, 5836, 5842, 5844, 5851, 5860, 5867, 5871, 5872, 5879, 5883, 5888, 5889, 5892, 5893, 5897, 5899, 5901, 5902, 5909, 5914, 5915, 5917, 5918, 5927, 5943, 5944, 5970, 5972, 5977, 5978, 5986, 5989, 6004, 6005, 6011, 6021, 6026, 6027, 6039, 6047, 6052, 6059, 6066, 6068, 6069, 6070, 6071, 6078, 6081, 6090, 6092, 6103, 6113, 6116, 6120, 6128, 6130, 6138, 6152, 6157, 6182, 6191, 6192, 6195, 6197, 6201, 6220, 6224, 6229, 6233, 6237, 6242, 6245, 6252, 6258, 6259, 6265, 6266, 6288, 6294, 6295, 6299, 6318, 6327, 6328, 6334, 6343, 6349, 6352, 6369, 6375, 6378, 6379, 6384, 6389, 6393, 6394, 6400, 6409, 6411, 6412, 6415, 6417, 6422, 6424, 6426, 6435, 6446, 6450, 6451, 6456, 6458, 6464, 6474, 6482, 6487, 6492, 6493, 6507, 6509, 6510, 6529, 6530, 6533, 6539, 6541, 6543, 6547, 6549, 6554, 6558, 6559, 6566, 6569, 6574, 6580, 6581, 6589, 6593, 6601, 6609, 6619, 6620, 6623, 6625, 6635, 6640, 6643, 6661, 6662, 6667, 6676, 6680, 6700, 6711, 6713, 6714, 6717, 6723, 6728, 6734, 6738, 6741, 6750, 6775, 6779, 6780, 6784, 6785, 6798, 6801, 6818, 6823, 6825, 6826, 6834, 6842, 6846, 6851, 6852, 6855, 6859, 6861, 6865, 6866, 6871, 6872, 6884, 6886, 6890, 6892, 6900, 6913, 6918, 6924, 6936, 6949, 6958, 6968, 6974, 6979, 6980, 6995, 7003, 7005, 7008, 7010, 7012, 7014, 7023, 7025, 7029, 7035, 7036, 7043, 7045, 7046, 7052, 7056, 7074, 7092, 7094, 7100, 7126, 7131, 7138, 7140, 7143, 7146, 7147, 7149, 7150, 7153, 7169, 7172, 7174, 7175, 7179, 7182, 7186, 7197, 7198, 7204, 7232, 7236, 7246, 7251, 7260, 7264, 7265, 7266, 7268, 7275, 7279, 7286, 7294, 7299, 7323, 7331, 7333, 7344, 7357, 7371, 7384, 7405, 7418, 7420, 7427, 7441, 7448, 7449, 7453, 7463, 7465, 7468, 7470, 7472, 7485, 7487, 7491, 7494, 7495, 7496, 7503, 7505, 7508, 7513, 7514, 7519, 7521, 7525, 7554, 7561, 7571, 7574, 7575, 7577, 7578, 7590, 7609, 7611, 7612, 7618, 7625, 7634, 7645, 7652, 7657, 7659, 7665, 7669, 7671, 7672, 7676, 7678, 7696, 7700, 7717, 7728, 7731, 7744, 7745, 7747, 7779, 7797, 7806, 7814, 7822, 7824, 7828, 7853, 7855, 7865, 7875, 7881, 7882, 7898, 7903, 7911, 7921, 7930, 7933, 7941, 7943, 7947, 7952, 7955, 7966, 7968, 7969, 7973, 7982, 7991, 7992, 7997, 7999, 8000, 8002, 8006, 8012, 8013, 8017, 8021, 8027, 8035, 8043, 8046, 8056, 8073, 8080, 8086, 8093, 8099, 8112, 8120, 8124, 8125, 8134, 8136, 8143, 8153, 8161, 8164, 8168, 8169, 8171, 8184, 8197, 8199, 8200, 8206, 8213, 8214, 8215, 8220, 8221, 8223, 8230, 8238, 8240, 8250, 8254, 8273, 8277, 8285, 8295, 8301, 8302, 8310, 8317, 8318, 8329, 8357, 8361, 8363, 8373, 8379, 8394, 8400, 8402, 8407, 8409, 8424, 8439, 8441, 8443] not in index'