In [1]:
import pandas as pd
import numpy as np
import xgboost as xgb
from sklearn.model_selection import train_test_split
from sklearn.impute import SimpleImputer


pd.set_option("display.max_rows", None, "display.max_columns", None)

In [2]:
train=pd.read_csv("original datasets/train.csv")
test=pd.read_csv("datasets/combined_interpol_knn3/test.csv")

In [3]:
galaxies_train=train.galaxy.unique().tolist()
galaxies_train.remove("NGC 5253")
galaxies_test=test.galaxy.unique().tolist()
common_galaxies=list(set(galaxies_train) & set(galaxies_test))

In [4]:
param = {'max_depth': 15, 
         'n_estimators': 1000,
         'gamma':0,
         'eta': 0.01, 
         'objective': 'reg:squarederror', 
         'min_child_weight':0.1,
         'colsample_bytree': 1,
         'colsample_bylevel':1,
         'importance_type': 'weight',
         'subsample':1,
         'lambda':1,
         'num_parallel_tree':1,
        }

In [5]:
def ml_loop(param, galaxy_data, test_data):    
    
    labels=galaxy_data['y']
    galaxy_data=galaxy_data.drop('y', axis=1)
    galaxy_data=galaxy_data.drop('galaxy', axis=1)
    test_data=test_data.drop('galaxy', axis=1)
    
    data = train_test_split(galaxy_data, labels, test_size=0.1, shuffle=True)
    X_train, X_valid, Y_train, Y_valid = data
      
    model = xgb.XGBRegressor(**param)
    model.fit(X_train, Y_train, eval_set=[(X_valid, Y_valid)], eval_metric='rmse', early_stopping_rounds=1, verbose=False)
    progress=model.evals_result()
    
    y_pred = model.predict(test_data, ntree_limit=model.best_ntree_limit)
    test_data['y']=y_pred
    
    return test_data, progress
    

In [6]:
def win_prohack(param, common_galaxies, train, test):
    predicted_galaxies=[]
    average_error=[]
    for galaxy in common_galaxies:
        train_subset=train.loc[train.galaxy==galaxy, :]
        test_subset=test.loc[test.galaxy==galaxy,:]
        
        test_subset_predicted, progress=ml_loop(param, train_subset, test_subset)
        
        eval_performance=progress['validation_0']['rmse'][-1]
        average_error.append(eval_performance)
        
        test_subset_predicted.insert(1, 'galaxy',galaxy)
        predicted_galaxies.append(test_subset_predicted)
        
    test_predicted=pd.concat(predicted_galaxies)
    average_error=sum(average_error)/len(average_error)
    
    return test_predicted.sort_index(), average_error 
        
    

In [7]:
test_pr, error=win_prohack(param, common_galaxies, train, test)

In [8]:
print(error)

0.0034921220930232578


In [9]:
test_pr.head()

Unnamed: 0,galactic year,galaxy,existence expectancy index,existence expectancy at birth,Gross income per capita,Income Index,Expected years of education (galactic years),Mean years of education (galactic years),Intergalactic Development Index (IDI),Education Index,"Intergalactic Development Index (IDI), Rank",Population using at least basic drinking-water services (%),Population using at least basic sanitation services (%),Gross capital formation (% of GGP),"Population, total (millions)","Population, urban (%)","Mortality rate, under-five (per 1,000 live births)","Mortality rate, infant (per 1,000 live births)",Old age dependency ratio (old age (65 and older) per 100 creatures (ages 15-64)),"Population, ages 15–64 (millions)","Population, ages 65 and older (millions)","Life expectancy at birth, male (galactic years)","Life expectancy at birth, female (galactic years)","Population, under age 5 (millions)",Young age (0-14) dependency ratio (per 100 creatures ages 15-64),"Adolescent birth rate (births per 1,000 female creatures ages 15-19)",Total unemployment rate (female to male ratio),Vulnerable employment (% of total employment),"Unemployment, total (% of labour force)",Employment in agriculture (% of total employment),Labour force participation rate (% ages 15 and older),"Labour force participation rate (% ages 15 and older), female",Employment in services (% of total employment),"Labour force participation rate (% ages 15 and older), male",Employment to population ratio (% ages 15 and older),Jungle area (% of total land area),"Share of employment in nonagriculture, female (% of total employment in nonagriculture)",Youth unemployment rate (female to male ratio),"Unemployment, youth (% ages 15–24)","Mortality rate, female grown up (per 1,000 people)","Mortality rate, male grown up (per 1,000 people)","Infants lacking immunization, red hot disease (% of one-galactic year-olds)","Infants lacking immunization, Combination Vaccine (% of one-galactic year-olds)",Gross galactic product (GGP) per capita,"Gross galactic product (GGP), total","Outer Galaxies direct investment, net inflows (% of GGP)",Exports and imports (% of GGP),Share of seats in senate (% held by female),Natural resource depletion,"Mean years of education, female (galactic years)","Mean years of education, male (galactic years)","Expected years of education, female (galactic years)","Expected years of education, male (galactic years)","Maternal mortality ratio (deaths per 100,000 live births)",Renewable energy consumption (% of total final energy consumption),"Estimated gross galactic income per capita, male","Estimated gross galactic income per capita, female",Rural population with access to electricity (%),Domestic credit provided by financial sector (% of GGP),"Population with at least some secondary education, female (% ages 25 and older)","Population with at least some secondary education, male (% ages 25 and older)",Gross fixed capital formation (% of GGP),"Remittances, inflows (% of GGP)",Population with at least some secondary education (% ages 25 and older),Intergalactic inbound tourists (thousands),"Gross enrolment ratio, primary (% of primary under-age population)","Respiratory disease incidence (per 100,000 people)",Interstellar phone subscriptions (per 100 people),"Interstellar Data Net users, total (% of population)",Current health expenditure (% of GGP),"Intergalactic Development Index (IDI), female","Intergalactic Development Index (IDI), male",Gender Development Index (GDI),"Intergalactic Development Index (IDI), female, Rank","Intergalactic Development Index (IDI), male, Rank",Adjusted net savings,"Creature Immunodeficiency Disease prevalence, adult (% ages 15-49), total",Private galaxy capital flows (% of GGP),Gender Inequality Index (GII),y
0,1007012,KK98 77,0.456086,51.562543,12236.576447,0.593325,10.414164,10.699072,0.547114,0.556267,232.621842,105.193088,64.241392,17.41835,614.545305,45.541882,141.62578,87.561677,16.489583,379.34258,36.905606,55.846395,56.973407,73.889103,102.006807,146.74639,2.613697,90.100628,10.450293,93.441177,90.384582,88.382248,38.396809,101.406522,92.22942,62.581224,50.156,2.091978,23.138172,695.536552,768.67401,44.400059,23.111652,28455.552974,6913.404068,16.428878,103.423517,30.664274,8.019367,9.03532,10.586921,12.566027,13.227783,839.549734,111.685818,26831.199177,12836.864061,43.030871,112.40313,60.440366,93.258754,16.136779,15.591102,82.029295,127635.926582,119.699911,668.366798,98.744231,24.506065,10.827875,0.611064,0.640265,1.022928,170.821497,191.378179,11.121316,22.120237,17.342915,0.837273,0.046099
1,1007012,Reticulum III,0.529835,57.228262,3431.883825,0.675407,7.239485,5.311122,0.497688,0.409969,247.580771,55.730638,46.21744,33.027886,1004.798842,38.905041,216.676542,123.809217,10.677828,447.713795,19.209037,57.405286,60.778432,46.282052,131.722182,246.975201,1.574817,118.449149,5.736326,119.20895,84.344276,74.196989,34.792701,91.134703,77.755019,23.357119,45.658505,1.479324,14.132189,518.175761,526.315531,80.804075,61.494987,18347.897946,6742.309382,8.252833,133.354236,14.176417,31.149227,5.113953,5.677881,8.691662,11.933653,1436.775868,118.32693,18961.52357,17398.153091,38.072606,93.942624,,,34.177478,,33.183049,62737.104647,93.968366,327.078299,79.5346,29.427321,7.702659,0.447227,0.600222,0.826083,206.544247,202.957033,18.696296,6.797387,,,0.039774
2,1008016,Reticulum III,0.560976,59.379539,27562.914252,0.594624,11.77489,5.937797,0.544744,0.486167,249.798771,58.7559,48.897493,31.613362,1114.097252,33.915995,210.162404,121.57387,11.78956,456.879801,11.797791,57.473996,62.150245,42.694004,130.833953,237.806999,1.539524,126.845313,6.043554,118.222701,82.713552,75.358106,33.150538,90.476994,76.075316,26.307907,46.195764,1.608119,15.738669,528.220861,521.586294,74.23386,60.16755,16955.650954,6637.038972,9.145436,138.536576,13.534015,27.801479,5.314399,5.596202,9.038558,12.121717,1380.259435,120.962272,15361.037297,19183.293723,44.031099,96.534667,,,36.576861,,33.480188,75731.87799,97.00362,326.111778,83.6709,31.009044,7.447764,0.466019,0.598493,0.841232,206.586735,209.415311,20.504065,7.019498,,,0.039774
3,1007012,Segue 1,0.56591,59.95239,20352.232905,0.8377,11.613621,10.067882,0.691641,0.523441,211.50506,70.176431,61.254104,34.399007,1338.624144,78.792405,205.881454,113.395653,12.862473,548.656356,47.398329,63.160476,62.275221,63.635181,119.40815,243.963638,2.544024,92.465174,22.253423,56.34265,86.938983,86.546345,62.695167,86.746511,81.508127,86.613197,51.978783,2.195129,55.26805,377.953258,449.075268,77.390541,43.614081,15915.122483,7596.49845,7.530585,177.928141,35.889526,40.853322,6.833242,8.634465,11.423343,11.835024,1040.278371,87.610253,25804.592233,21979.750358,26.735139,45.769801,60.001242,68.654345,35.97924,9.625473,71.222724,84572.5432,111.15618,605.630123,66.070021,37.137096,5.502806,0.620664,0.6653,0.983942,,,-6.801181,5.470829,26.937888,0.831416,0.042171
4,1013042,Virgo I,0.588274,55.42832,23959.704016,0.520579,10.392416,6.374637,0.530676,0.580418,234.721069,69.768692,51.31694,15.657091,1106.554194,64.382217,219.588961,129.974418,14.013421,474.877713,21.305665,59.144562,63.990109,80.409057,111.084652,162.451451,2.421466,126.050516,13.370766,89.542879,83.677526,87.84634,28.815475,88.782958,76.50274,56.338491,45.923296,2.043633,23.401352,546.437704,552.698573,92.815852,76.093251,28890.168075,5995.978901,8.248885,87.376942,23.595837,8.776004,5.628603,7.518383,8.975221,8.919711,1034.065346,96.028122,21309.50697,18103.974021,50.233592,86.296925,29.065459,58.562225,13.313253,4.107116,42.980727,99276.25673,104.28831,458.186555,79.177012,41.585873,7.357729,0.583373,0.600445,0.856158,206.674424,224.104054,13.486288,7.687626,,0.969146,0.031109


In [10]:
test.head()

Unnamed: 0,galactic year,galaxy,existence expectancy index,existence expectancy at birth,Gross income per capita,Income Index,Expected years of education (galactic years),Mean years of education (galactic years),Intergalactic Development Index (IDI),Education Index,"Intergalactic Development Index (IDI), Rank",Population using at least basic drinking-water services (%),Population using at least basic sanitation services (%),Gross capital formation (% of GGP),"Population, total (millions)","Population, urban (%)","Mortality rate, under-five (per 1,000 live births)","Mortality rate, infant (per 1,000 live births)",Old age dependency ratio (old age (65 and older) per 100 creatures (ages 15-64)),"Population, ages 15–64 (millions)","Population, ages 65 and older (millions)","Life expectancy at birth, male (galactic years)","Life expectancy at birth, female (galactic years)","Population, under age 5 (millions)",Young age (0-14) dependency ratio (per 100 creatures ages 15-64),"Adolescent birth rate (births per 1,000 female creatures ages 15-19)",Total unemployment rate (female to male ratio),Vulnerable employment (% of total employment),"Unemployment, total (% of labour force)",Employment in agriculture (% of total employment),Labour force participation rate (% ages 15 and older),"Labour force participation rate (% ages 15 and older), female",Employment in services (% of total employment),"Labour force participation rate (% ages 15 and older), male",Employment to population ratio (% ages 15 and older),Jungle area (% of total land area),"Share of employment in nonagriculture, female (% of total employment in nonagriculture)",Youth unemployment rate (female to male ratio),"Unemployment, youth (% ages 15–24)","Mortality rate, female grown up (per 1,000 people)","Mortality rate, male grown up (per 1,000 people)","Infants lacking immunization, red hot disease (% of one-galactic year-olds)","Infants lacking immunization, Combination Vaccine (% of one-galactic year-olds)",Gross galactic product (GGP) per capita,"Gross galactic product (GGP), total","Outer Galaxies direct investment, net inflows (% of GGP)",Exports and imports (% of GGP),Share of seats in senate (% held by female),Natural resource depletion,"Mean years of education, female (galactic years)","Mean years of education, male (galactic years)","Expected years of education, female (galactic years)","Expected years of education, male (galactic years)","Maternal mortality ratio (deaths per 100,000 live births)",Renewable energy consumption (% of total final energy consumption),"Estimated gross galactic income per capita, male","Estimated gross galactic income per capita, female",Rural population with access to electricity (%),Domestic credit provided by financial sector (% of GGP),"Population with at least some secondary education, female (% ages 25 and older)","Population with at least some secondary education, male (% ages 25 and older)",Gross fixed capital formation (% of GGP),"Remittances, inflows (% of GGP)",Population with at least some secondary education (% ages 25 and older),Intergalactic inbound tourists (thousands),"Gross enrolment ratio, primary (% of primary under-age population)","Respiratory disease incidence (per 100,000 people)",Interstellar phone subscriptions (per 100 people),"Interstellar Data Net users, total (% of population)",Current health expenditure (% of GGP),"Intergalactic Development Index (IDI), female","Intergalactic Development Index (IDI), male",Gender Development Index (GDI),"Intergalactic Development Index (IDI), female, Rank","Intergalactic Development Index (IDI), male, Rank",Adjusted net savings,"Creature Immunodeficiency Disease prevalence, adult (% ages 15-49), total",Private galaxy capital flows (% of GGP),Gender Inequality Index (GII)
0,1007012,KK98 77,0.456086,51.562543,12236.576447,0.593325,10.414164,10.699072,0.547114,0.556267,232.621842,105.193088,64.241392,17.41835,614.545305,45.541882,141.62578,87.561677,16.489583,379.34258,36.905606,55.846395,56.973407,73.889103,102.006807,146.74639,2.613697,90.100628,10.450293,93.441177,90.384582,88.382248,38.396809,101.406522,92.22942,62.581224,50.156,2.091978,23.138172,695.536552,768.67401,44.400059,23.111652,28455.552974,6913.404068,16.428878,103.423517,30.664274,8.019367,9.03532,10.586921,12.566027,13.227783,839.549734,111.685818,26831.199177,12836.864061,43.030871,112.40313,60.440366,93.258754,16.136779,15.591102,82.029295,127635.926582,119.699911,668.366798,98.744231,24.506065,10.827875,0.611064,0.640265,1.022928,170.821497,191.378179,11.121316,22.120237,17.342915,0.837273
1,1007012,Reticulum III,0.529835,57.228262,3431.883825,0.675407,7.239485,5.311122,0.497688,0.409969,247.580771,55.730638,46.21744,33.027886,1004.798842,38.905041,216.676542,123.809217,10.677828,447.713795,19.209037,57.405286,60.778432,46.282052,131.722182,246.975201,1.574817,118.449149,5.736326,119.20895,84.344276,74.196989,34.792701,91.134703,77.755019,23.357119,45.658505,1.479324,14.132189,518.175761,526.315531,80.804075,61.494987,18347.897946,6742.309382,8.252833,133.354236,14.176417,31.149227,5.113953,5.677881,8.691662,11.933653,1436.775868,118.32693,18961.52357,17398.153091,38.072606,93.942624,37.798116,37.985973,34.177478,9.971183,33.183049,62737.104647,93.968366,327.078299,79.5346,29.427321,7.702659,0.447227,0.600222,0.826083,206.544247,202.957033,18.696296,6.797387,21.766588,0.942323
2,1008016,Reticulum III,0.560976,59.379539,27562.914252,0.594624,11.77489,5.937797,0.544744,0.486167,249.798771,58.7559,48.897493,31.613362,1114.097252,33.915995,210.162404,121.57387,11.78956,456.879801,11.797791,57.473996,62.150245,42.694004,130.833953,237.806999,1.539524,126.845313,6.043554,118.222701,82.713552,75.358106,33.150538,90.476994,76.075316,26.307907,46.195764,1.608119,15.738669,528.220861,521.586294,74.23386,60.16755,16955.650954,6637.038972,9.145436,138.536576,13.534015,27.801479,5.314399,5.596202,9.038558,12.121717,1380.259435,120.962272,15361.037297,19183.293723,44.031099,96.534667,38.387369,45.725975,36.576861,10.290635,33.480188,75731.87799,97.00362,326.111778,83.6709,31.009044,7.447764,0.466019,0.598493,0.841232,206.586735,209.415311,20.504065,7.019498,20.612698,0.936722
3,1007012,Segue 1,0.56591,59.95239,20352.232905,0.8377,11.613621,10.067882,0.691641,0.523441,211.50506,70.176431,61.254104,34.399007,1338.624144,78.792405,205.881454,113.395653,12.862473,548.656356,47.398329,63.160476,62.275221,63.635181,119.40815,243.963638,2.544024,92.465174,22.253423,56.34265,86.938983,86.546345,62.695167,86.746511,81.508127,86.613197,51.978783,2.195129,55.26805,377.953258,449.075268,77.390541,43.614081,15915.122483,7596.49845,7.530585,177.928141,35.889526,40.853322,6.833242,8.634465,11.423343,11.835024,1040.278371,87.610253,25804.592233,21979.750358,26.735139,45.769801,60.001242,68.654345,35.97924,9.625473,71.222724,84572.5432,111.15618,605.630123,66.070021,37.137096,5.502806,0.620664,0.6653,0.983942,138.542646,145.050309,-6.801181,5.470829,26.937888,0.831416
4,1013042,Virgo I,0.588274,55.42832,23959.704016,0.520579,10.392416,6.374637,0.530676,0.580418,234.721069,69.768692,51.31694,15.657091,1106.554194,64.382217,219.588961,129.974418,14.013421,474.877713,21.305665,59.144562,63.990109,80.409057,111.084652,162.451451,2.421466,126.050516,13.370766,89.542879,83.677526,87.84634,28.815475,88.782958,76.50274,56.338491,45.923296,2.043633,23.401352,546.437704,552.698573,92.815852,76.093251,28890.168075,5995.978901,8.248885,87.376942,23.595837,8.776004,5.628603,7.518383,8.975221,8.919711,1034.065346,96.028122,21309.50697,18103.974021,50.233592,86.296925,29.065459,58.562225,13.313253,4.107116,42.980727,99276.25673,104.28831,458.186555,79.177012,41.585873,7.357729,0.583373,0.600445,0.856158,206.674424,224.104054,13.486288,7.687626,30.936726,0.969146


In [12]:
test_pr.to_csv("test_predicted.csv", index=False)