In [1]:
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset

import os
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Dataset, random_split
from PIL import Image
from tqdm import tqdm
import numpy as np
from sklearn.metrics import r2_score
import numpy as np

In [2]:
# Split train into train and validate again
new_train_csv = pd.read_csv('data/enhanced_train.csv')
new_test_csv = pd.read_csv('data/enhanced_test.csv')
output_mean = [1.03624107e+00, 1.48317376e+02, 1.97016450e+04, 3.48191181e+03, 1.51120666e+01, 3.99120598e+05]
output_std = [1.37329381e-01, 6.91740145e+00, 4.31037489e+00, 6.70979751e+01, 5.93192463e-01, 2.25494269e+03]
output_mean = np.array(output_mean)
output_std = np.array(output_std)

IDs_train = new_train_csv.iloc[:, 0].values
X_train = new_train_csv.iloc[:, 1:-6].values
Y_train = new_train_csv.iloc[:, -6:].values

IDs_test = new_test_csv.iloc[:, 0].values
X_test = new_test_csv.iloc[:, 1:].values

# Answer to life, universe, and everything for goodluck
X_train, X_val, Y_train, Y_val, IDs_train, IDs_val = train_test_split(
    X_train, Y_train, IDs_train, test_size=1/11, random_state=42
)

X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
Y_train_tensor = torch.tensor(Y_train, dtype=torch.float32)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
X_val_tensor = torch.tensor(X_val, dtype=torch.float32)
Y_val_tensor = torch.tensor(Y_val, dtype=torch.float32)

train_dataset = TensorDataset(X_train_tensor, Y_train_tensor)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [3]:
class RegressionModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(RegressionModel, self).__init__()
        self.layer1 = nn.Linear(input_dim, 64)
        self.layer2 = nn.Linear(64, 32)
        self.output = nn.Linear(32, output_dim)

    def forward(self, x):
        x = torch.relu(self.layer1(x))
        x = torch.relu(self.layer2(x))
        x = self.output(x)
        return x


In [4]:
input_dim = X_train_tensor.shape[1]
output_dim = Y_train_tensor.shape[1]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = RegressionModel(input_dim, output_dim).to(device)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.01)

In [5]:
def R2eval(model, X_val_tensor, Y_val_tensor):
    model.eval()
    with torch.no_grad():
        X_val_tensor = X_val_tensor.to(device)
        Y_val_tensor = Y_val_tensor.to(device)
        val_predictions = model(X_val_tensor).cpu().numpy()
        val_predictions = val_predictions * output_std + output_mean
        val_predictions_tensor = torch.tensor(val_predictions, dtype=torch.float32)
        val_predictions_tensor = val_predictions_tensor.to(device)
        val_r2 = r2_score(Y_val, val_predictions)
    return val_r2

In [6]:
epochs = 100
best_model = model
bestR2 = float('-inf')

for epoch in range(epochs):
    model.train()
    epoch_loss = 0
    for batch_X, batch_Y in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
        batch_Y = (batch_Y - output_mean) / output_std
        batch_X, batch_Y = batch_X.to(device).float(), batch_Y.to(device).float()
        optimizer.zero_grad()
        outputs = model(batch_X)
        loss = criterion(outputs, batch_Y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    currentR2 = R2eval(model, X_val_tensor, Y_val_tensor)
    print(f"Current Estimate R2 is {currentR2}")
    if (currentR2 > bestR2):
        best_model = model
        bestR2 = currentR2
        print("Saved model")
    print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss/len(train_loader)}")

Epoch 1/100: 100%|██████████| 1232/1232 [00:01<00:00, 712.53it/s]


Current Estimate R2 is 0.4430721541059537
Saved model
Epoch 1/100, Loss: 0.619124738838185


Epoch 2/100: 100%|██████████| 1232/1232 [00:01<00:00, 742.55it/s]


Current Estimate R2 is 0.4453653059801737
Saved model
Epoch 2/100, Loss: 0.5788445248474161


Epoch 3/100: 100%|██████████| 1232/1232 [00:01<00:00, 718.18it/s]


Current Estimate R2 is 0.45016754215271026
Saved model
Epoch 3/100, Loss: 0.5745156091432293


Epoch 4/100: 100%|██████████| 1232/1232 [00:01<00:00, 688.66it/s]


Current Estimate R2 is 0.45636508138689075
Saved model
Epoch 4/100, Loss: 0.5726242355228245


Epoch 5/100: 100%|██████████| 1232/1232 [00:01<00:00, 732.07it/s]


Current Estimate R2 is 0.450359277275171
Epoch 5/100, Loss: 0.5714918358565925


Epoch 6/100: 100%|██████████| 1232/1232 [00:01<00:00, 759.88it/s]


Current Estimate R2 is 0.4542401802508797
Epoch 6/100, Loss: 0.5708833052595328


Epoch 7/100: 100%|██████████| 1232/1232 [00:01<00:00, 737.57it/s]


Current Estimate R2 is 0.4509965924724381
Epoch 7/100, Loss: 0.5705143152312799


Epoch 8/100: 100%|██████████| 1232/1232 [00:01<00:00, 746.52it/s]


Current Estimate R2 is 0.4583567700420858
Saved model
Epoch 8/100, Loss: 0.568626776627906


Epoch 9/100: 100%|██████████| 1232/1232 [00:01<00:00, 665.89it/s]


Current Estimate R2 is 0.45469688156267846
Epoch 9/100, Loss: 0.5683532797961266


Epoch 10/100: 100%|██████████| 1232/1232 [00:01<00:00, 672.84it/s]


Current Estimate R2 is 0.4595577612374431
Saved model
Epoch 10/100, Loss: 0.5685673487408982


Epoch 11/100: 100%|██████████| 1232/1232 [00:01<00:00, 667.15it/s]


Current Estimate R2 is 0.45640864193634806
Epoch 11/100, Loss: 0.5691204416480932


Epoch 12/100: 100%|██████████| 1232/1232 [00:01<00:00, 718.79it/s]


Current Estimate R2 is 0.4460566890610811
Epoch 12/100, Loss: 0.5682724435608109


Epoch 13/100: 100%|██████████| 1232/1232 [00:01<00:00, 758.12it/s]


Current Estimate R2 is 0.45897964925526064
Epoch 13/100, Loss: 0.5683096885536011


Epoch 14/100: 100%|██████████| 1232/1232 [00:01<00:00, 774.82it/s]


Current Estimate R2 is 0.46278935504104646
Saved model
Epoch 14/100, Loss: 0.5683834829984548


Epoch 15/100: 100%|██████████| 1232/1232 [00:01<00:00, 753.74it/s]


Current Estimate R2 is 0.4605845257863006
Epoch 15/100, Loss: 0.5677604883328661


Epoch 16/100: 100%|██████████| 1232/1232 [00:01<00:00, 741.71it/s]


Current Estimate R2 is 0.45419217447646637
Epoch 16/100, Loss: 0.5667778193321708


Epoch 17/100: 100%|██████████| 1232/1232 [00:01<00:00, 739.50it/s]


Current Estimate R2 is 0.45881128734606946
Epoch 17/100, Loss: 0.5678733940925691


Epoch 18/100: 100%|██████████| 1232/1232 [00:01<00:00, 742.30it/s]


Current Estimate R2 is 0.45678451875449505
Epoch 18/100, Loss: 0.5675280194536045


Epoch 19/100: 100%|██████████| 1232/1232 [00:01<00:00, 746.99it/s]


Current Estimate R2 is 0.4630556964742368
Saved model
Epoch 19/100, Loss: 0.5669070548051364


Epoch 20/100: 100%|██████████| 1232/1232 [00:01<00:00, 717.44it/s]


Current Estimate R2 is 0.4595431391398941
Epoch 20/100, Loss: 0.5676598927465739


Epoch 21/100: 100%|██████████| 1232/1232 [00:01<00:00, 717.40it/s]


Current Estimate R2 is 0.4593902466861744
Epoch 21/100, Loss: 0.5674637471652263


Epoch 22/100: 100%|██████████| 1232/1232 [00:01<00:00, 723.83it/s]


Current Estimate R2 is 0.4578946481652451
Epoch 22/100, Loss: 0.567376295491666


Epoch 23/100: 100%|██████████| 1232/1232 [00:01<00:00, 688.18it/s]


Current Estimate R2 is 0.4604260327015246
Epoch 23/100, Loss: 0.5675034197410206


Epoch 24/100: 100%|██████████| 1232/1232 [00:01<00:00, 688.60it/s]


Current Estimate R2 is 0.45891366183141896
Epoch 24/100, Loss: 0.5674735096513064


Epoch 25/100: 100%|██████████| 1232/1232 [00:01<00:00, 685.38it/s]


Current Estimate R2 is 0.4501414496792984
Epoch 25/100, Loss: 0.5671542447737672


Epoch 26/100: 100%|██████████| 1232/1232 [00:01<00:00, 697.17it/s]


Current Estimate R2 is 0.4512805377554594
Epoch 26/100, Loss: 0.56699269456039


Epoch 27/100: 100%|██████████| 1232/1232 [00:01<00:00, 772.09it/s]


Current Estimate R2 is 0.45888138684503516
Epoch 27/100, Loss: 0.5671600945680947


Epoch 28/100: 100%|██████████| 1232/1232 [00:01<00:00, 771.16it/s]


Current Estimate R2 is 0.4587633585030635
Epoch 28/100, Loss: 0.5679633228180857


Epoch 29/100: 100%|██████████| 1232/1232 [00:01<00:00, 767.43it/s]


Current Estimate R2 is 0.4539220529567422
Epoch 29/100, Loss: 0.566972979138811


Epoch 30/100: 100%|██████████| 1232/1232 [00:01<00:00, 764.87it/s]


Current Estimate R2 is 0.460344444732413
Epoch 30/100, Loss: 0.5672716779706927


Epoch 31/100: 100%|██████████| 1232/1232 [00:01<00:00, 717.21it/s]


Current Estimate R2 is 0.45656895470392694
Epoch 31/100, Loss: 0.5672698228003142


Epoch 32/100: 100%|██████████| 1232/1232 [00:01<00:00, 721.48it/s]


Current Estimate R2 is 0.4604142383784182
Epoch 32/100, Loss: 0.5673439565852478


Epoch 33/100: 100%|██████████| 1232/1232 [00:01<00:00, 719.78it/s]


Current Estimate R2 is 0.4616436320469539
Epoch 33/100, Loss: 0.5674577076955662


Epoch 34/100: 100%|██████████| 1232/1232 [00:01<00:00, 778.48it/s]


Current Estimate R2 is 0.4594929049381646
Epoch 34/100, Loss: 0.5672567317673525


Epoch 35/100: 100%|██████████| 1232/1232 [00:01<00:00, 776.89it/s]


Current Estimate R2 is 0.45795660290717716
Epoch 35/100, Loss: 0.5668010718685079


Epoch 36/100: 100%|██████████| 1232/1232 [00:01<00:00, 773.20it/s]


Current Estimate R2 is 0.4559219671798697
Epoch 36/100, Loss: 0.5667121493100346


Epoch 37/100: 100%|██████████| 1232/1232 [00:01<00:00, 728.52it/s]


Current Estimate R2 is 0.46077810250006057
Epoch 37/100, Loss: 0.5673409917379741


Epoch 38/100: 100%|██████████| 1232/1232 [00:01<00:00, 669.32it/s]


Current Estimate R2 is 0.4587487089126297
Epoch 38/100, Loss: 0.5669780691916292


Epoch 39/100: 100%|██████████| 1232/1232 [00:01<00:00, 678.69it/s]


Current Estimate R2 is 0.45945508045578604
Epoch 39/100, Loss: 0.5660841277522313


Epoch 40/100: 100%|██████████| 1232/1232 [00:01<00:00, 693.02it/s]


Current Estimate R2 is 0.4603649291044014
Epoch 40/100, Loss: 0.5665869181129065


Epoch 41/100: 100%|██████████| 1232/1232 [00:01<00:00, 699.08it/s]


Current Estimate R2 is 0.4588104209479938
Epoch 41/100, Loss: 0.5663575561257539


Epoch 42/100: 100%|██████████| 1232/1232 [00:01<00:00, 716.00it/s]


Current Estimate R2 is 0.46116924011071764
Epoch 42/100, Loss: 0.566055090856049


Epoch 43/100: 100%|██████████| 1232/1232 [00:01<00:00, 672.37it/s]


Current Estimate R2 is 0.46261111881316097
Epoch 43/100, Loss: 0.5671633778793084


Epoch 44/100: 100%|██████████| 1232/1232 [00:01<00:00, 712.36it/s]


Current Estimate R2 is 0.45979182674574437
Epoch 44/100, Loss: 0.566393848823069


Epoch 45/100: 100%|██████████| 1232/1232 [00:01<00:00, 746.87it/s]


Current Estimate R2 is 0.45709085454791065
Epoch 45/100, Loss: 0.566182469054089


Epoch 46/100: 100%|██████████| 1232/1232 [00:01<00:00, 730.77it/s]


Current Estimate R2 is 0.4577097668528926
Epoch 46/100, Loss: 0.5662911887002455


Epoch 47/100: 100%|██████████| 1232/1232 [00:01<00:00, 731.43it/s]


Current Estimate R2 is 0.45694501789511793
Epoch 47/100, Loss: 0.5662868802568742


Epoch 48/100: 100%|██████████| 1232/1232 [00:01<00:00, 755.89it/s]


Current Estimate R2 is 0.4590533138014538
Epoch 48/100, Loss: 0.5665656610742792


Epoch 49/100: 100%|██████████| 1232/1232 [00:01<00:00, 749.68it/s]


Current Estimate R2 is 0.462419417998928
Epoch 49/100, Loss: 0.5665301458740776


Epoch 50/100: 100%|██████████| 1232/1232 [00:01<00:00, 759.85it/s]


Current Estimate R2 is 0.4597566030674831
Epoch 50/100, Loss: 0.5660328908593624


Epoch 51/100: 100%|██████████| 1232/1232 [00:01<00:00, 751.94it/s]


Current Estimate R2 is 0.457927203837959
Epoch 51/100, Loss: 0.5664637175376539


Epoch 52/100: 100%|██████████| 1232/1232 [00:01<00:00, 628.39it/s]


Current Estimate R2 is 0.45708260098863834
Epoch 52/100, Loss: 0.5658260076851039


Epoch 53/100: 100%|██████████| 1232/1232 [00:01<00:00, 623.99it/s]


Current Estimate R2 is 0.4544680686609531
Epoch 53/100, Loss: 0.5664853852290612


Epoch 54/100: 100%|██████████| 1232/1232 [00:01<00:00, 616.60it/s]


Current Estimate R2 is 0.4578281667713559
Epoch 54/100, Loss: 0.5663298418043883


Epoch 55/100: 100%|██████████| 1232/1232 [00:01<00:00, 698.50it/s]


Current Estimate R2 is 0.4585091331892493
Epoch 55/100, Loss: 0.5662547370491476


Epoch 56/100: 100%|██████████| 1232/1232 [00:01<00:00, 762.27it/s]


Current Estimate R2 is 0.4595823777110659
Epoch 56/100, Loss: 0.5661218412123717


Epoch 57/100: 100%|██████████| 1232/1232 [00:01<00:00, 747.94it/s]


Current Estimate R2 is 0.4629706549070273
Epoch 57/100, Loss: 0.566694685584539


Epoch 58/100: 100%|██████████| 1232/1232 [00:01<00:00, 752.28it/s]


Current Estimate R2 is 0.45792008444972443
Epoch 58/100, Loss: 0.5661455641512747


Epoch 59/100: 100%|██████████| 1232/1232 [00:01<00:00, 777.41it/s]


Current Estimate R2 is 0.4541681501337764
Epoch 59/100, Loss: 0.5663839315681102


Epoch 60/100: 100%|██████████| 1232/1232 [00:01<00:00, 789.55it/s]


Current Estimate R2 is 0.4590407803795436
Epoch 60/100, Loss: 0.5670513217518856


Epoch 61/100: 100%|██████████| 1232/1232 [00:01<00:00, 762.40it/s]


Current Estimate R2 is 0.4549737044412235
Epoch 61/100, Loss: 0.5658507364281973


Epoch 62/100: 100%|██████████| 1232/1232 [00:01<00:00, 718.41it/s]


Current Estimate R2 is 0.45901120057869454
Epoch 62/100, Loss: 0.5663973435928876


Epoch 63/100: 100%|██████████| 1232/1232 [00:01<00:00, 718.14it/s]


Current Estimate R2 is 0.4559106210263173
Epoch 63/100, Loss: 0.5659335024122681


Epoch 64/100: 100%|██████████| 1232/1232 [00:01<00:00, 733.43it/s]


Current Estimate R2 is 0.4549253641268873
Epoch 64/100, Loss: 0.5661821836339576


Epoch 65/100: 100%|██████████| 1232/1232 [00:01<00:00, 731.98it/s]


Current Estimate R2 is 0.4623944169681922
Epoch 65/100, Loss: 0.5662906506376995


Epoch 66/100: 100%|██████████| 1232/1232 [00:01<00:00, 721.47it/s]


Current Estimate R2 is 0.46231121376146217
Epoch 66/100, Loss: 0.565827016029265


Epoch 67/100: 100%|██████████| 1232/1232 [00:01<00:00, 705.02it/s]


Current Estimate R2 is 0.459239088014219
Epoch 67/100, Loss: 0.5662978110810766


Epoch 68/100: 100%|██████████| 1232/1232 [00:01<00:00, 718.97it/s]


Current Estimate R2 is 0.46178737926055674
Epoch 68/100, Loss: 0.5662386417630818


Epoch 69/100: 100%|██████████| 1232/1232 [00:01<00:00, 757.33it/s]


Current Estimate R2 is 0.45686345770062003
Epoch 69/100, Loss: 0.5661640584033418


Epoch 70/100: 100%|██████████| 1232/1232 [00:01<00:00, 810.09it/s]


Current Estimate R2 is 0.46045682348701233
Epoch 70/100, Loss: 0.5659127393858386


Epoch 71/100: 100%|██████████| 1232/1232 [00:01<00:00, 810.56it/s]


Current Estimate R2 is 0.4630651799557981
Saved model
Epoch 71/100, Loss: 0.5664965681393038


Epoch 72/100: 100%|██████████| 1232/1232 [00:01<00:00, 789.08it/s]


Current Estimate R2 is 0.45882980024765924
Epoch 72/100, Loss: 0.5662625008927924


Epoch 73/100: 100%|██████████| 1232/1232 [00:01<00:00, 705.53it/s]


Current Estimate R2 is 0.46020075671991184
Epoch 73/100, Loss: 0.5659719387722479


Epoch 74/100: 100%|██████████| 1232/1232 [00:01<00:00, 709.73it/s]


Current Estimate R2 is 0.4561704086129789
Epoch 74/100, Loss: 0.5656039392599812


Epoch 75/100: 100%|██████████| 1232/1232 [00:01<00:00, 721.10it/s]


Current Estimate R2 is 0.4610394758043881
Epoch 75/100, Loss: 0.5657486283256636


Epoch 76/100: 100%|██████████| 1232/1232 [00:01<00:00, 727.19it/s]


Current Estimate R2 is 0.4581775199007811
Epoch 76/100, Loss: 0.5658764708787203


Epoch 77/100: 100%|██████████| 1232/1232 [00:01<00:00, 708.14it/s]


Current Estimate R2 is 0.4548841386765316
Epoch 77/100, Loss: 0.5661538945777076


Epoch 78/100: 100%|██████████| 1232/1232 [00:01<00:00, 721.89it/s]


Current Estimate R2 is 0.45932702818518045
Epoch 78/100, Loss: 0.5662348033110429


Epoch 79/100: 100%|██████████| 1232/1232 [00:01<00:00, 735.28it/s]


Current Estimate R2 is 0.4580470333680892
Epoch 79/100, Loss: 0.5665599627347736


Epoch 80/100: 100%|██████████| 1232/1232 [00:01<00:00, 737.19it/s]


Current Estimate R2 is 0.46063927186918274
Epoch 80/100, Loss: 0.5652209769740895


Epoch 81/100: 100%|██████████| 1232/1232 [00:01<00:00, 711.48it/s]


Current Estimate R2 is 0.4589594696454764
Epoch 81/100, Loss: 0.5660606032938926


Epoch 82/100: 100%|██████████| 1232/1232 [00:01<00:00, 693.54it/s]


Current Estimate R2 is 0.46060944459357817
Epoch 82/100, Loss: 0.565895538777113


Epoch 83/100: 100%|██████████| 1232/1232 [00:01<00:00, 695.78it/s]


Current Estimate R2 is 0.4555124748989698
Epoch 83/100, Loss: 0.5657529679743888


Epoch 84/100: 100%|██████████| 1232/1232 [00:01<00:00, 723.68it/s]


Current Estimate R2 is 0.45885180369125567
Epoch 84/100, Loss: 0.5656173984435472


Epoch 85/100: 100%|██████████| 1232/1232 [00:01<00:00, 700.02it/s]


Current Estimate R2 is 0.46200375374362085
Epoch 85/100, Loss: 0.5657275777503654


Epoch 86/100: 100%|██████████| 1232/1232 [00:01<00:00, 699.10it/s]


Current Estimate R2 is 0.457868765722175
Epoch 86/100, Loss: 0.5659859093849535


Epoch 87/100: 100%|██████████| 1232/1232 [00:01<00:00, 694.31it/s]


Current Estimate R2 is 0.4620715706722312
Epoch 87/100, Loss: 0.5662061610779205


Epoch 88/100: 100%|██████████| 1232/1232 [00:01<00:00, 706.99it/s]


Current Estimate R2 is 0.4604754946589122
Epoch 88/100, Loss: 0.5658179781363382


Epoch 89/100: 100%|██████████| 1232/1232 [00:01<00:00, 693.41it/s]


Current Estimate R2 is 0.4570079070782274
Epoch 89/100, Loss: 0.5661337050887478


Epoch 90/100: 100%|██████████| 1232/1232 [00:01<00:00, 706.75it/s]


Current Estimate R2 is 0.45824739843025997
Epoch 90/100, Loss: 0.5663872589296722


Epoch 91/100: 100%|██████████| 1232/1232 [00:01<00:00, 689.52it/s]


Current Estimate R2 is 0.4621460646106131
Epoch 91/100, Loss: 0.5655152359372609


Epoch 92/100: 100%|██████████| 1232/1232 [00:01<00:00, 701.56it/s]


Current Estimate R2 is 0.46033951017853586
Epoch 92/100, Loss: 0.5657050329466145


Epoch 93/100: 100%|██████████| 1232/1232 [00:01<00:00, 733.35it/s]


Current Estimate R2 is 0.46371670597199793
Saved model
Epoch 93/100, Loss: 0.5657123055357438


Epoch 94/100: 100%|██████████| 1232/1232 [00:01<00:00, 698.73it/s]


Current Estimate R2 is 0.4581485636522949
Epoch 94/100, Loss: 0.5649944529808187


Epoch 95/100: 100%|██████████| 1232/1232 [00:01<00:00, 688.42it/s]


Current Estimate R2 is 0.4582085259847015
Epoch 95/100, Loss: 0.5665059067405663


Epoch 96/100: 100%|██████████| 1232/1232 [00:01<00:00, 645.08it/s]


Current Estimate R2 is 0.45672677476031326
Epoch 96/100, Loss: 0.5656529606694912


Epoch 97/100: 100%|██████████| 1232/1232 [00:01<00:00, 691.65it/s]


Current Estimate R2 is 0.4581048989444823
Epoch 97/100, Loss: 0.5662194475192915


Epoch 98/100: 100%|██████████| 1232/1232 [00:01<00:00, 767.48it/s]


Current Estimate R2 is 0.45788339328730765
Epoch 98/100, Loss: 0.5658043890272255


Epoch 99/100: 100%|██████████| 1232/1232 [00:01<00:00, 776.51it/s]


Current Estimate R2 is 0.4613438973172656
Epoch 99/100, Loss: 0.566162646179656


Epoch 100/100: 100%|██████████| 1232/1232 [00:01<00:00, 768.08it/s]

Current Estimate R2 is 0.45942402012931344
Epoch 100/100, Loss: 0.5658085811883211





In [8]:
X_test_tensor = X_test_tensor.to(device)

with torch.no_grad():
    test_predictions = best_model(X_test_tensor).cpu().numpy()
    test_predictions = test_predictions * output_std + output_mean

predictions_df = pd.DataFrame(test_predictions, columns=['X4', 'X11', 'X18', 'X26', 'X50', 'X3112'])
predictions_df.insert(0, 'id', IDs_test)

# Save the predictions to a CSV file
predictions_df.to_csv('predictions.csv', index=False)