In [109]:
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import math
import random
import sklearn

In [3]:
sns.set_theme()

## Load and preprocess data

In [16]:
train_data = pd.read_csv('train.csv')
train_data.head()

Unnamed: 0,Id,MSSubClass,MSZoning,LotFrontage,LotArea,Street,Alley,LotShape,LandContour,Utilities,...,PoolArea,PoolQC,Fence,MiscFeature,MiscVal,MoSold,YrSold,SaleType,SaleCondition,SalePrice
0,1,60,RL,65.0,8450,Pave,,Reg,Lvl,AllPub,...,0,,,,0,2,2008,WD,Normal,208500
1,2,20,RL,80.0,9600,Pave,,Reg,Lvl,AllPub,...,0,,,,0,5,2007,WD,Normal,181500
2,3,60,RL,68.0,11250,Pave,,IR1,Lvl,AllPub,...,0,,,,0,9,2008,WD,Normal,223500
3,4,70,RL,60.0,9550,Pave,,IR1,Lvl,AllPub,...,0,,,,0,2,2006,WD,Abnorml,140000
4,5,60,RL,84.0,14260,Pave,,IR1,Lvl,AllPub,...,0,,,,0,12,2008,WD,Normal,250000


In [20]:
print(train_data.columns)

Index(['Id', 'MSSubClass', 'MSZoning', 'LotFrontage', 'LotArea', 'Street',
       'Alley', 'LotShape', 'LandContour', 'Utilities', 'LotConfig',
       'LandSlope', 'Neighborhood', 'Condition1', 'Condition2', 'BldgType',
       'HouseStyle', 'OverallQual', 'OverallCond', 'YearBuilt', 'YearRemodAdd',
       'RoofStyle', 'RoofMatl', 'Exterior1st', 'Exterior2nd', 'MasVnrType',
       'MasVnrArea', 'ExterQual', 'ExterCond', 'Foundation', 'BsmtQual',
       'BsmtCond', 'BsmtExposure', 'BsmtFinType1', 'BsmtFinSF1',
       'BsmtFinType2', 'BsmtFinSF2', 'BsmtUnfSF', 'TotalBsmtSF', 'Heating',
       'HeatingQC', 'CentralAir', 'Electrical', '1stFlrSF', '2ndFlrSF',
       'LowQualFinSF', 'GrLivArea', 'BsmtFullBath', 'BsmtHalfBath', 'FullBath',
       'HalfBath', 'BedroomAbvGr', 'KitchenAbvGr', 'KitchenQual',
       'TotRmsAbvGrd', 'Functional', 'Fireplaces', 'FireplaceQu', 'GarageType',
       'GarageYrBlt', 'GarageFinish', 'GarageCars', 'GarageArea', 'GarageQual',
       'GarageCond', 'PavedDrive

In [17]:
train_data.shape

(1460, 81)

In [60]:
train_data = pd.read_csv('train.csv')

In [94]:
def preprocess(df : pd.DataFrame) -> pd.DataFrame:
    '''
    Process raw dataframe into one ready for a tensorflow tensor
    '''
    df.drop('Id', inplace=True, axis=1)
    static_cols = df.columns
    for c in static_cols:
        if pd.api.types.is_numeric_dtype(df[c]):
            fill_value = df[c].mean()
            df[c] = df[c].fillna(fill_value)
        elif pd.api.types.is_string_dtype(df[c]):
            one_hot = pd.get_dummies(df[c], prefix=c)
            df.drop(c, inplace=True, axis=1)
            df = df.join(one_hot)

        else:
            raise NotImplementedError('Failed to find contained value of this column')
    return df

In [95]:
train_data = pd.read_csv('train.csv')
train_data = preprocess(train_data)

In [96]:
train_data.head()

Unnamed: 0,MSSubClass,LotFrontage,LotArea,OverallQual,OverallCond,YearBuilt,YearRemodAdd,MasVnrArea,BsmtFinSF1,BsmtFinSF2,...,SaleType_ConLw,SaleType_New,SaleType_Oth,SaleType_WD,SaleCondition_Abnorml,SaleCondition_AdjLand,SaleCondition_Alloca,SaleCondition_Family,SaleCondition_Normal,SaleCondition_Partial
0,60,65.0,8450,7,5,2003,2003,196.0,706,0,...,0,0,0,1,0,0,0,0,1,0
1,20,80.0,9600,6,8,1976,1976,0.0,978,0,...,0,0,0,1,0,0,0,0,1,0
2,60,68.0,11250,7,5,2001,2002,162.0,486,0,...,0,0,0,1,0,0,0,0,1,0
3,70,60.0,9550,7,5,1915,1970,0.0,216,0,...,0,0,0,1,1,0,0,0,0,0
4,60,84.0,14260,8,5,2000,2000,350.0,655,0,...,0,0,0,1,0,0,0,0,1,0


In [97]:
bad_data = train_data[train_data.isna().any(axis=1)]

In [105]:
assert(bad_data.shape[0] == 0) # catch bad preprocess

In [106]:
train_data.shape

(1460, 289)

## Convert to Tensorflow Dataset (and split validation set)

In [110]:

y_df_train = train_data['SalePrice']
x_df_train = train_data.drop('SalePrice', inplace=False, axis=1)

In [112]:
from sklearn.model_selection import train_test_split

X_all_tensor = tf.convert_to_tensor(x_df_train)
Y_all_tensor = tf.convert_to_tensor(y_df_train)



In [113]:
X_train, X_test, y_train, y_test = train_test_split(X_all_tensor,
                                                    Y_all_tensor,
                                                    test_size=0.33,
                                                    random_state=42)

TypeError: Only integers, slices (`:`), ellipsis (`...`), tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid indices, got array([ 615,  613, 1303,  486,  561,  308,  461, 1142,  730, 1155, 1203,
        700,  849, 1260,  787,  352,  710,  124,  178,  287, 1407, 1208,
        294,  327, 1456,  841, 1121,  931,  236,   88,  886,  552,  630,
       1352,  665,  900,  290, 1382,  570,  348,  544, 1376,  660, 1111,
        458, 1005,  333,  721, 1289,  678, 1354,  328,  318, 1438,  908,
         12,  171,  260,  778,  818,  759,  138, 1209,  885,  741,  139,
       1276, 1081,  224, 1063,  409,  325,  299, 1242,  705, 1235, 1355,
        903,  829,  503, 1336,    3, 1313,  467,  321,    5, 1239, 1093,
        715,   39,  542,  242,  136,  714,  737, 1301, 1043, 1330, 1103,
        752, 1353,  376,  314,  442,  319,  767,  756,  424,  553, 1035,
        953, 1185,  596,  227, 1106,  771,  933, 1417,   66,  999,  579,
        942,  211, 1148,  745,  536, 1286, 1295,   85, 1314,  974,  976,
        631, 1315, 1383, 1401,  949,  273, 1034,  362,  622,    2,    6,
        311,  668, 1037, 1439,  215, 1119,   27, 1055,  824, 1415, 1198,
       1280,  624,  807,  256,   25,  336,  572,   47,  106,  349,   55,
       1443,  213,  120, 1206, 1388,   72,  743, 1188,   45, 1230,  523,
        826,  867, 1199,  757,  533,  603,  465, 1391, 1244, 1080,  904,
         60,   92,  666,  727,  194,  280,  359,  519, 1455,  923,  182,
        987,  110,   42, 1236,  204,  983,   52,  545,  945,  852, 1413,
       1375,  183,  137,  926,   80, 1364,  105,  541,  706,  480,  165,
        248,  307,  446,  712,  334,   97,  724,  905,  899,  102,  434,
       1007, 1305,  834, 1324,  302, 1229, 1191,  388,  235, 1212,  267,
        305,  117, 1179,  673,  249, 1065, 1052,   71,   94, 1440,   33,
       1167,  448,  445,  716,  844, 1112,  858,  997,   77,   84,  792,
         82,  457,  970,  676,  938, 1050,  436,  644,  780,  862, 1357,
        404, 1270,  941,  944,  104,  609, 1068,  993,  888,   62,  281,
        964,  238,  223,  145, 1362,  373,  118, 1096,  250, 1406,    9,
        449,  708,  255,  593,  389, 1202, 1168,  557, 1298, 1424,  460,
        723, 1346,  144,  501, 1070,  739, 1349,  731,  539,  177, 1408,
        697, 1402, 1022, 1116, 1429, 1453,    7,  820,  667, 1389,  357,
        783,  985,  329, 1444, 1117, 1087,  258, 1144,  372, 1075,  882,
        142,  525, 1107, 1335, 1091, 1387,  440,  531, 1169,  226, 1000,
        228, 1246, 1319,  212,   79,  148,  814, 1067,  869,  671, 1124,
        555,  864,  133, 1079,  823,  408,  982,  507, 1150,  504,  986,
        347,  386,    0,  360,  828, 1379,  475, 1247,   57,  637, 1273,
       1026,  909,  172,  450,  125,  934,  530,  857,  713,   90, 1412,
        181,  875, 1151,  786,  414,  251,   69,  803,  131,  300,  978,
        326, 1426, 1083,  499,  968,  832,  364,  495, 1359,  338,  421,
        164,   28,  516,  193, 1250, 1326, 1325,  169,  167,  652,  173,
        518, 1181, 1149,  912,  491,   73, 1219,  587,  750,  657,  842,
        234,  732,  214,  162,  132, 1098,  222, 1312,  185,   41,  692,
        108,   38,  568,  947, 1233, 1403,  483,  468,  890,   24,  962,
       1100,   68,  884, 1231,  366,  505,  850, 1281,  383,  734,  717,
       1300,  454,  264,   75,  914,  232,  444,  395,  611,  176,   18,
       1277, 1029,  341, 1097,  879,  595,   61, 1227,  272, 1176,  872,
       1090, 1370,  278, 1128,  694,  368, 1345,   36,  735, 1411,  662,
       1285, 1316,  977,  547, 1161,  688,  738,  191,  919,  456,  868,
        393,  675,  760, 1436,  728,  419,  375,  412,   74, 1366, 1141,
        500,  825,  195, 1459,  616,  114,  417,  601,  969,  921,  604,
       1448,   89,  848, 1114, 1395,  753,  498,   11,  396,  284,  690,
       1197, 1367,  689,  851,  159,  895, 1192,  257,  335, 1371, 1039,
        515, 1015,  768, 1292,  761, 1074,  521,   22,  356, 1451,  340,
        431,  473,  217,  911, 1172, 1011, 1420,   93, 1008,  696,  996,
        153, 1307, 1458,  789,  580,  703, 1225,  487,  559,  801,  512,
        633, 1302,  790,  116,  740,  830,  119,  656,  635,  369,  268,
        655,  935,  822,  866,   46, 1287,  991,  876,  773,    4, 1003,
       1268,  263, 1360, 1334,  443, 1002, 1077,  837, 1304,  304, 1211,
        313,  149,  574,   50, 1140,  726,  833, 1115,  470,  971,  653,
        399,  511,  320,   19,  684,   35,  902,  827, 1309,  407,  920,
        537,  758, 1062,  245, 1127,  784,  489,  154,  961,  853,  625,
        685,  569, 1399,   17,  127,  927,  980, 1431,  190, 1139,  992,
       1425, 1218,  606,  180,  301,  496, 1258, 1222, 1204,  517, 1069,
        476,  157,   16, 1072,  546,  658, 1193, 1019, 1223,  959, 1373,
        283,  797,  225,   26,  437, 1419,  229,   37,  749, 1045,  469,
       1245, 1434,  687, 1343,  639, 1261, 1158, 1217, 1328,  800,  160,
        956, 1186,  821, 1237,  877,  913,  152, 1338,  509, 1135, 1251,
        103,  586, 1380, 1143,   53,  151,  403, 1416,  207, 1381, 1122,
          8, 1014, 1308, 1060,  452,  253,  896, 1378,  880, 1377,  623,
        345, 1253,  262,  150,  472,  640, 1446, 1374,  550,  928,  488,
       1321, 1171,  146,  402,  954, 1404,  659,  463,  186, 1147, 1224,
        608,  143,  751,  981,  197,  883,  279,  293, 1329,  400,  122,
       1207,  202,  835,  246, 1386, 1372, 1153,  384, 1288,  854,  219,
        641,    1,  112,  698, 1393,  951, 1283, 1318,  441, 1136, 1256,
        663, 1433, 1257, 1296,  317,  648,  709, 1282,  972,  627,  632,
       1248, 1423, 1254, 1348,  795,  645, 1012,  556,  681,  577, 1109,
       1266, 1183,  524, 1059,  540, 1194,  748, 1351,  484,   95, 1020,
        563, 1264,  742,  863,  891, 1038,  206,  392,  794,  870,  397,
        766, 1341, 1241, 1028,  642,  612,  960,  725,  683,   98,  804,
        406,  502, 1071, 1056,  929,  779,  200,  134, 1051,   40, 1017,
        230,  378,  288,  418,  391,  592, 1162, 1086,  647, 1152,  520,
         64,   14, 1180, 1064,  492,  379,  187,  763,  216,  791, 1076,
        878,  337,  719,  295, 1016, 1275,  455, 1457,  815,  269,  995,
        201,  161,  729,  401,  702, 1129,  565, 1021, 1025, 1104,  205,
         34,  775,  508, 1441, 1390,   91, 1363,  897,  564, 1369,  776,
        241,   13,  315,  600,  387, 1297,  166,  840,   20,  646, 1154,
        831, 1267,  562, 1422,  686,  957,  189,  975,  699,  510, 1082,
        474,  856,  747,  252,   21, 1337,  459, 1184,  276,  955, 1215,
        385,  805, 1437,  343,  769, 1332,  130,  871, 1123, 1396,   87,
        330, 1238,  466,  121, 1044, 1095, 1130, 1294,  860, 1126])

## Create a model

In [None]:
def create_non_seq_model():
