In [1]:
import torch
import os
import numpy as np
import pandas as pd
from warnings import simplefilter
from torch.utils.data import DataLoader

simplefilter(action='ignore', category=UserWarning)
simplefilter(action='ignore', category=FutureWarning)

In [2]:
MAIN_DIR = os.path.abspath(os.path.join(os.getcwd(), ".."))
DATA_DIR = os.path.join(MAIN_DIR, "Solid_droplet", "Data")

In [3]:
from DataLoader import *
from Trainer import *
from Model import *

In [4]:
dtype = torch.float32
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:
# Split the data into train, validation and test set
dataspliter = DataSpliter(DATA_DIR, device, train_frac=0.6, val_frac=0.2, test_frac=0.2, use_seed=True, seed_val=42)

# Load Data from DataLoader
train_set = dataspliter.train
val_set = dataspliter.val
test_set = dataspliter.test

train_dataset = CustomImageDataset(DATA_DIR, train_set)
val_dataset = CustomImageDataset(DATA_DIR, val_set)
test_dataset = CustomImageDataset(DATA_DIR, test_set)

In [6]:
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=True)

In [7]:
# Get the model
model = CNNModel()
trainer = Trainer(model, train_dataloader, val_dataloader, test_dataloader)

The device that will be used in training is Quadro P1000


In [8]:
trainer.fit(epochs=10, batch_size= 8)
trainer.save_model("model.pt", "savefolderpytorch")

Epoch 1


Training: 24it [00:05,  4.33it/s, loss=672]                  
Validation: 8it [00:01,  7.90it/s, loss=932]               

Validation loss is 931.510986328125
Epoch 2



Training: 24it [00:03,  7.99it/s, loss=75.6]             
Validation: 8it [00:01,  7.19it/s, loss=67.6]               

Validation loss is 67.64094543457031
Epoch 3



Training: 24it [00:03,  7.84it/s, loss=180]               
Validation: 8it [00:01,  7.94it/s, loss=60.9]               

Validation loss is 60.8654899597168
Epoch 4



Training: 24it [00:03,  7.56it/s, loss=68.8]              
Validation: 8it [00:00,  8.03it/s, loss=53]               

Validation loss is 53.010345458984375
Epoch 5



Training: 24it [00:03,  7.63it/s, loss=48]                
Validation: 8it [00:01,  7.80it/s, loss=86.1]               

Validation loss is 86.098388671875
Epoch 6



Training: 24it [00:03,  7.78it/s, loss=19.9]            
Validation: 8it [00:01,  7.32it/s, loss=36.5]               

Validation loss is 36.51206970214844
Epoch 7



Training: 24it [00:03,  7.98it/s, loss=53.6]              
Validation: 8it [00:01,  7.58it/s, loss=40]               

Validation loss is 40.02277755737305
Epoch 8



Training: 24it [00:03,  7.69it/s, loss=32.9]              
Validation: 8it [00:01,  7.88it/s, loss=22.3]               

Validation loss is 22.254901885986328
Epoch 9



Training: 24it [00:03,  7.07it/s, loss=63.3]              
Validation: 8it [00:01,  7.15it/s, loss=27.4]               

Validation loss is 27.360485076904297
Epoch 10



Training: 24it [00:03,  7.18it/s, loss=24.7]              
Validation: 8it [00:01,  7.45it/s, loss=25.2]               

Validation loss is 25.226699829101562





In [9]:
_, output, labels = trainer.val_epoch(test_dataloader)

Validation: 9it [00:01,  6.89it/s, loss=18.5]               

Validation loss is 9.581824779510498





In [12]:
dfresults = pd.DataFrame(np.append(output, labels, axis=1),columns=["sigmapred", "volumepred", "radiuspred", "sigmatrue", "volumetrue", "radiustrue"])
dfresults["RMSE"] = np.sqrt((dfresults["sigmapred"]-dfresults["sigmatrue"])**2+(dfresults["volumepred"]-dfresults["volumetrue"])**2+(dfresults["radiuspred"]-dfresults["radiustrue"])**2)
dfresults

Unnamed: 0,sigmapred,volumepred,radiuspred,sigmatrue,volumetrue,radiustrue,RMSE
0,70.641296,25.195654,2.235899,71.0,32.0,1.0,6.924973
1,70.733902,25.114252,2.263872,71.75,32.0,1.0,7.074133
2,70.669304,25.162207,2.233022,71.25,32.0,1.0,6.9723
3,71.367516,25.170641,2.695787,79.75,32.0,1.0,10.944491
4,70.915543,25.118048,2.323453,73.25,32.0,1.0,7.386642
5,70.701576,25.144682,2.246611,71.5,32.0,1.0,7.013338
6,71.079597,25.119123,2.396783,74.0,32.0,1.0,7.604355
7,70.57148,25.215719,2.206904,70.25,32.0,1.0,6.898292
8,70.959908,25.12689,2.358584,73.5,32.0,1.0,7.452346
