# Copyright

<PRE>
This notebook was created as part of the "Deep learning / VITMMA19" class at
Budapest University of Technology and Economics, Hungary,
https://portal.vik.bme.hu/kepzes/targyak/VITMMA19.

Any re-use or publication of any part of the notebook is only allowed with the
written consent of the authors.

2024 (c) Mohammed Salah Al-Radhi (malradhi@tmit.bme.hu)
</PRE>

In [21]:
!pip install -q pytorch-lightning

## Data
See https://archive.ics.uci.edu/, a popular machine learning repository maintained by University of California, Irvine. We are working with their 2nd most popular dataset.

In [22]:
!wget https://raw.githubusercontent.com/stedy/Machine-Learning-with-R-datasets/master/insurance.csv

--2024-10-29 11:21:45--  https://raw.githubusercontent.com/stedy/Machine-Learning-with-R-datasets/master/insurance.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 54288 (53K) [text/plain]
Saving to: ‘insurance.csv.1’


2024-10-29 11:21:45 (6.59 MB/s) - ‘insurance.csv.1’ saved [54288/54288]



In [23]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import mean_squared_error, r2_score

In [24]:
df = pd.read_csv("insurance.csv")
df

Unnamed: 0,age,sex,bmi,children,smoker,region,charges
0,19,female,27.900,0,yes,southwest,16884.92400
1,18,male,33.770,1,no,southeast,1725.55230
2,28,male,33.000,3,no,southeast,4449.46200
3,33,male,22.705,0,no,northwest,21984.47061
4,32,male,28.880,0,no,northwest,3866.85520
...,...,...,...,...,...,...,...
1333,50,male,30.970,3,no,northwest,10600.54830
1334,18,female,31.920,0,no,northeast,2205.98080
1335,18,female,36.850,0,no,southeast,1629.83350
1336,21,female,25.800,0,no,southwest,2007.94500


In [25]:
df = pd.get_dummies(df, drop_first=True) # we have to talk about drop_first -> avoid the dummy variable trap
df

Unnamed: 0,age,bmi,children,charges,sex_male,smoker_yes,region_northwest,region_southeast,region_southwest
0,19,27.900,0,16884.92400,False,True,False,False,True
1,18,33.770,1,1725.55230,True,False,False,True,False
2,28,33.000,3,4449.46200,True,False,False,True,False
3,33,22.705,0,21984.47061,True,False,True,False,False
4,32,28.880,0,3866.85520,True,False,True,False,False
...,...,...,...,...,...,...,...,...,...
1333,50,30.970,3,10600.54830,True,False,True,False,False
1334,18,31.920,0,2205.98080,False,False,False,False,False
1335,18,36.850,0,1629.83350,False,False,False,True,False
1336,21,25.800,0,2007.94500,False,False,False,False,True


In [26]:
df.dtypes

Unnamed: 0,0
age,int64
bmi,float64
children,int64
charges,float64
sex_male,bool
smoker_yes,bool
region_northwest,bool
region_southeast,bool
region_southwest,bool


In [27]:
df.iloc[:, :4] = df.iloc[:, :4].astype('float32')

1       18.0
2       28.0
3       33.0
4       32.0
        ... 
1333    50.0
1334    18.0
1335    18.0
1336    21.0
1337    61.0
Name: age, Length: 1338, dtype: float32' has dtype incompatible with int64, please explicitly cast to a compatible dtype first.
  df.iloc[:, :4] = df.iloc[:, :4].astype('float32')
1       1.0
2       3.0
3       0.0
4       0.0
       ... 
1333    3.0
1334    0.0
1335    0.0
1336    0.0
1337    0.0
Name: children, Length: 1338, dtype: float32' has dtype incompatible with int64, please explicitly cast to a compatible dtype first.
  df.iloc[:, :4] = df.iloc[:, :4].astype('float32')


In [28]:
df.dtypes

Unnamed: 0,0
age,float32
bmi,float64
children,float32
charges,float64
sex_male,bool
smoker_yes,bool
region_northwest,bool
region_southeast,bool
region_southwest,bool


In [29]:
X = df.drop(columns=['charges']).values # features
y = df['charges'].values # label

In [30]:
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import pytorch_lightning as pl
import torchmetrics

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, random_state=42)


In [31]:
from sklearn.preprocessing import StandardScaler

def create_dataloader(X, y, batch_size, shuffle):
  scaler = StandardScaler()
  scaler.fit(X_train)
  X = scaler.transform(X).astype('float32')
  X = torch.from_numpy(X)
  y = y.astype('float32')
  y = torch.from_numpy(y)
  dataset = TensorDataset(X, y)
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
  return dataloader

batch_size = 64

train_loader = create_dataloader(X_train, y_train, batch_size, True)
val_loader = create_dataloader(X_val, y_val, batch_size, False)
test_loader = create_dataloader(X_test, y_test, batch_size, False)

In [32]:
from sklearn.metrics import r2_score

class MyFeedForwardNet(pl.LightningModule):
    def __init__(self, input_dim, output_dim, lr):
      super(MyFeedForwardNet, self).__init__()
      self.save_hyperparameters()
      self.lr = lr
      self.layers = nn.Sequential(
           nn.Linear(input_dim, 512),
           nn.ReLU(),
           nn.Dropout(0.2),
           nn.BatchNorm1d(512),  # Batch Normalization
           nn.Linear(512, 256),
           nn.ReLU(),
           nn.Dropout(0.2),
           nn.BatchNorm1d(256),  # Batch Normalization
           nn.Linear(256, 64),
           nn.ReLU(),
           nn.Dropout(0.2),
           nn.BatchNorm1d(64),  # Batch Normalization
           nn.Linear(64, output_dim)
      )
      self.loss_fn = nn.L1Loss()

    # forward propagation
    def forward(self, x):
      return self.layers(x)

    def on_epoch_start(self):
      print(f"Epoch {self.current_epoch} started. Training...")

    # one step of training
    def training_step(self, batch, batch_idx):
      inputs, targets = batch
      outputs = self(inputs).squeeze()
      mae = self.loss_fn(outputs, targets)
      self.log('train_mae', mae)
      return mae

    # one step of validation
    def validation_step(self, batch, batch_idx):
      inputs, targets = batch
      outputs = self(inputs).squeeze()
      mae = self.loss_fn(outputs, targets)
      self.log('val_mae', mae)
      print("val-mae at epoch {} : {}".format(self.current_epoch, mae.item()))
      return mae

    def configure_optimizers(self):
      return torch.optim.Adam(self.parameters(), lr=self.lr)

In [33]:
logger = pl.loggers.TensorBoardLogger("logs/", name="heart_disease_logs")

In [34]:
epochs = 120
lr = 0.01
output_dim = 1

# we instantiate our model
model = MyFeedForwardNet(X_train.shape[1], output_dim, lr)

# we use the ModelCheckpoint callback to save the best model
callback = pl.callbacks.ModelCheckpoint(
    monitor='val_mae',
    dirpath = '',
    filename = 'best_model',
    save_top_k=1,
    mode='min'
)

# we use the Trainer class to train our model
trainer = pl.Trainer(
    logger=logger,
    max_epochs=epochs,
    log_every_n_steps=1,
    callbacks=[callback]
)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [35]:
trainer.fit(model, train_loader, val_loader) # train the model

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory  exists and is not empty.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type       | Params | Mode 
-----------------------------------------------
0 | layers  | Sequential | 154 K  | train
1 | loss_fn | L1Loss     | 0      | train
-----------------------------------------------
154 K     Trainable params
0         Non-trainable params
154 K     Total params
0.616     Total estimated model params size (MB)
15        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 0 : 14276.291015625
val-mae at epoch 0 : 13872.3955078125


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 0 : 14273.3359375
val-mae at epoch 0 : 13869.3056640625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 1 : 14263.947265625
val-mae at epoch 1 : 13859.5830078125


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 2 : 14251.126953125
val-mae at epoch 2 : 13846.36328125


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 3 : 14228.0263671875
val-mae at epoch 3 : 13822.4326171875


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 4 : 14201.97265625
val-mae at epoch 4 : 13796.2060546875


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 5 : 14165.154296875
val-mae at epoch 5 : 13760.4560546875


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 6 : 14119.88671875
val-mae at epoch 6 : 13707.6513671875


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 7 : 14051.99609375
val-mae at epoch 7 : 13646.357421875


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 8 : 13981.73046875
val-mae at epoch 8 : 13576.50390625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 9 : 13908.9033203125
val-mae at epoch 9 : 13504.7763671875


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 10 : 13829.7919921875
val-mae at epoch 10 : 13431.1337890625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 11 : 13734.96484375
val-mae at epoch 11 : 13332.125


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 12 : 13635.86328125
val-mae at epoch 12 : 13232.6826171875


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 13 : 13511.0966796875
val-mae at epoch 13 : 13101.345703125


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 14 : 13386.611328125
val-mae at epoch 14 : 12986.4677734375


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 15 : 13254.736328125
val-mae at epoch 15 : 12850.6064453125


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 16 : 13138.21875
val-mae at epoch 16 : 12727.4443359375


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 17 : 12947.93359375
val-mae at epoch 17 : 12560.556640625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 18 : 12756.6796875
val-mae at epoch 18 : 12362.5146484375


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 19 : 12607.318359375
val-mae at epoch 19 : 12204.5927734375


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 20 : 12446.0244140625
val-mae at epoch 20 : 12047.078125


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 21 : 12218.818359375
val-mae at epoch 21 : 11824.853515625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 22 : 12037.7890625
val-mae at epoch 22 : 11662.7373046875


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 23 : 11912.24609375
val-mae at epoch 23 : 11539.744140625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 24 : 11647.5400390625
val-mae at epoch 24 : 11306.5693359375


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 25 : 11519.517578125
val-mae at epoch 25 : 11218.30859375


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 26 : 11197.443359375
val-mae at epoch 26 : 10898.3681640625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 27 : 11102.9619140625
val-mae at epoch 27 : 10830.970703125


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 28 : 10805.98828125
val-mae at epoch 28 : 10515.513671875


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 29 : 10608.611328125
val-mae at epoch 29 : 10337.15234375


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 30 : 10371.0224609375
val-mae at epoch 30 : 10119.47265625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 31 : 10197.9892578125
val-mae at epoch 31 : 9980.3466796875


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 32 : 9977.693359375
val-mae at epoch 32 : 9771.0322265625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 33 : 9821.470703125
val-mae at epoch 33 : 9645.9716796875


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 34 : 9339.3857421875
val-mae at epoch 34 : 9220.625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 35 : 9433.18359375
val-mae at epoch 35 : 9338.87890625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 36 : 8938.8623046875
val-mae at epoch 36 : 8835.12890625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 37 : 8589.6611328125
val-mae at epoch 37 : 8544.6455078125


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 38 : 8297.029296875
val-mae at epoch 38 : 8283.30859375


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 39 : 8167.23388671875
val-mae at epoch 39 : 8172.970703125


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 40 : 8106.0517578125
val-mae at epoch 40 : 8163.51513671875


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 41 : 7788.3623046875
val-mae at epoch 41 : 7869.87353515625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 42 : 7621.580078125
val-mae at epoch 42 : 7782.76318359375


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 43 : 7269.5166015625
val-mae at epoch 43 : 7420.1875


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 44 : 7371.39013671875
val-mae at epoch 44 : 7520.630859375


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 45 : 7151.2822265625
val-mae at epoch 45 : 7290.04931640625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 46 : 6945.6005859375
val-mae at epoch 46 : 7108.5146484375


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 47 : 6604.40283203125
val-mae at epoch 47 : 6835.994140625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 48 : 6760.02490234375
val-mae at epoch 48 : 6936.19775390625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 49 : 6321.70166015625
val-mae at epoch 49 : 6558.49560546875


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 50 : 6112.0888671875
val-mae at epoch 50 : 6366.16845703125


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 51 : 6508.81005859375
val-mae at epoch 51 : 6692.70654296875


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 52 : 6139.58203125
val-mae at epoch 52 : 6404.11328125


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 53 : 5948.59033203125
val-mae at epoch 53 : 6174.88134765625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 54 : 5861.7734375
val-mae at epoch 54 : 6079.61328125


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 55 : 6016.48388671875
val-mae at epoch 55 : 6228.9150390625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 56 : 5623.50048828125
val-mae at epoch 56 : 5857.16162109375


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 57 : 5532.978515625
val-mae at epoch 57 : 5777.6318359375


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 58 : 5622.8544921875
val-mae at epoch 58 : 5876.42626953125


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 59 : 5360.076171875
val-mae at epoch 59 : 5635.470703125


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 60 : 5452.119140625
val-mae at epoch 60 : 5803.1982421875


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 61 : 5375.671875
val-mae at epoch 61 : 5664.88330078125


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 62 : 5188.0830078125
val-mae at epoch 62 : 5504.82421875


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 63 : 5210.5068359375
val-mae at epoch 63 : 5589.49560546875


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 64 : 5268.998046875
val-mae at epoch 64 : 5626.064453125


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 65 : 4928.5
val-mae at epoch 65 : 5328.24853515625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 66 : 4860.556640625
val-mae at epoch 66 : 5345.1953125


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 67 : 4984.4287109375
val-mae at epoch 67 : 5342.11767578125


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 68 : 4803.20703125
val-mae at epoch 68 : 5205.53466796875


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 69 : 4626.96923828125
val-mae at epoch 69 : 4975.4453125


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 70 : 4678.931640625
val-mae at epoch 70 : 5099.6171875


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 71 : 4526.83984375
val-mae at epoch 71 : 5031.17724609375


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 72 : 4531.60009765625
val-mae at epoch 72 : 5041.44189453125


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 73 : 4540.59130859375
val-mae at epoch 73 : 5044.7373046875


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 74 : 4826.99853515625
val-mae at epoch 74 : 5170.93115234375


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 75 : 4023.09423828125
val-mae at epoch 75 : 4757.14208984375


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 76 : 4049.67578125
val-mae at epoch 76 : 4611.837890625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 77 : 4504.65625
val-mae at epoch 77 : 4938.26904296875


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 78 : 4564.85546875
val-mae at epoch 78 : 5074.73681640625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 79 : 4158.58056640625
val-mae at epoch 79 : 4803.7724609375


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 80 : 4099.46142578125
val-mae at epoch 80 : 4678.1455078125


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 81 : 4424.2294921875
val-mae at epoch 81 : 4962.1689453125


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 82 : 4059.94775390625
val-mae at epoch 82 : 4707.75


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 83 : 3659.3212890625
val-mae at epoch 83 : 4530.27294921875


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 84 : 3439.43798828125
val-mae at epoch 84 : 4397.00830078125


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 85 : 3451.06494140625
val-mae at epoch 85 : 4316.8515625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 86 : 3835.5185546875
val-mae at epoch 86 : 4547.466796875


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 87 : 3223.80029296875
val-mae at epoch 87 : 4219.8408203125


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 88 : 3404.86083984375
val-mae at epoch 88 : 4343.2421875


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 89 : 3799.08056640625
val-mae at epoch 89 : 4550.05712890625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 90 : 3989.3330078125
val-mae at epoch 90 : 4666.765625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 91 : 3172.36083984375
val-mae at epoch 91 : 4169.32275390625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 92 : 3900.006103515625
val-mae at epoch 92 : 4587.154296875


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 93 : 3010.64013671875
val-mae at epoch 93 : 3949.6904296875


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 94 : 2912.546875
val-mae at epoch 94 : 4052.052978515625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 95 : 3149.322021484375
val-mae at epoch 95 : 4076.97607421875


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 96 : 3002.21728515625
val-mae at epoch 96 : 4064.472412109375


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 97 : 3038.87109375
val-mae at epoch 97 : 4005.896728515625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 98 : 2989.25
val-mae at epoch 98 : 3989.024658203125


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 99 : 2788.03759765625
val-mae at epoch 99 : 3901.63525390625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 100 : 2836.55126953125
val-mae at epoch 100 : 3780.0166015625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 101 : 2806.97607421875
val-mae at epoch 101 : 3883.24462890625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 102 : 2988.3662109375
val-mae at epoch 102 : 3952.8291015625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 103 : 2452.6875
val-mae at epoch 103 : 3523.775390625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 104 : 2781.4521484375
val-mae at epoch 104 : 3736.293212890625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 105 : 2310.04541015625
val-mae at epoch 105 : 3531.811279296875


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 106 : 2555.9296875
val-mae at epoch 106 : 4298.62060546875


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 107 : 2236.7412109375
val-mae at epoch 107 : 3404.061767578125


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 108 : 2777.12939453125
val-mae at epoch 108 : 3776.787353515625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 109 : 2304.181396484375
val-mae at epoch 109 : 3449.281982421875


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 110 : 1870.279541015625
val-mae at epoch 110 : 3031.98681640625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 111 : 2190.947021484375
val-mae at epoch 111 : 3258.12060546875


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 112 : 2570.313720703125
val-mae at epoch 112 : 3568.199462890625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 113 : 2032.293701171875
val-mae at epoch 113 : 3412.054443359375


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 114 : 1821.071533203125
val-mae at epoch 114 : 3095.844482421875


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 115 : 2262.660400390625
val-mae at epoch 115 : 3423.47900390625


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 116 : 2261.531005859375
val-mae at epoch 116 : 3408.921142578125


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 117 : 1798.49072265625
val-mae at epoch 117 : 3092.338623046875


Validation: |          | 0/? [00:00<?, ?it/s]

val-mae at epoch 118 : 2653.831787109375
val-mae at epoch 118 : 3606.47900390625


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=120` reached.


val-mae at epoch 119 : 1967.339111328125
val-mae at epoch 119 : 3187.923583984375


In [39]:
!pip install scikit_learn

import pandas as pd
from sklearn.metrics import r2_score
from sklearn.metrics import mean_absolute_error

best_model = MyFeedForwardNet.load_from_checkpoint(callback.best_model_path)

if torch.cuda.is_available():
    best_model.cuda()

predicts_list = []
targets_list = []
for batch in test_loader:
    inputs, targets = batch

    if torch.cuda.is_available():
        inputs = inputs.cuda()
        targets = targets.cuda()

    predicts_batch = best_model(inputs).squeeze().tolist()
    targets_batch = targets.tolist()
    predicts_list.extend(predicts_batch)
    targets_list.extend(targets_batch)

mae = mean_absolute_error(targets_list, predicts_list)
r2 = r2_score(targets_list, predicts_list)

results = pd.DataFrame({
    'MAE': [mae],
    'R2 Score': [r2]
})
results.to_csv('mlp_results.csv', index=False)

print(f'MAE: {mae}')
print(f'R2 Score: {r2}')

MAE: 3092.535039190036
R2 Score: 0.7986198624436849


In [40]:
from sklearn.ensemble import RandomForestRegressor
from sklearn.preprocessing import StandardScaler

y_pred_ensemble = []
scaler = StandardScaler()
scaler.fit(X_train)
rf_model = RandomForestRegressor(n_estimators=100, random_state=42)
rf_model.fit(scaler.transform(X_train), y_train)
y_pred_RandomForestRegressor = rf_model.predict(scaler.transform(X_test))
y_pred_ensemble = [0.5 * nn_pred + 0.5 * rf_pred for nn_pred, rf_pred in zip(predicts_list, y_pred_RandomForestRegressor)]

mae = mean_absolute_error(targets_list, y_pred_ensemble)
r2 = r2_score(targets_list, y_pred_ensemble)

results = pd.DataFrame({
    'MAE': [mae],
    'R2 Score': [r2]
})
results.to_csv('ensemble_results.csv', index=False)

print(f'MAE: {mae}')
print(f'R2 Score: {r2}')

MAE: 2595.0477385130926
R2 Score: 0.8580683019264124
