In [14]:
import torch
import os
import numpy as np
import pandas as pd
from warnings import simplefilter
from torch.utils.data import DataLoader

simplefilter(action='ignore', category=UserWarning)
simplefilter(action='ignore', category=FutureWarning)

In [15]:
MAIN_DIR = os.path.abspath(os.path.join(os.getcwd(), ".."))
DATA_DIR = os.path.join(MAIN_DIR, "Solid_droplet", "Data")

In [16]:
from DataLoader import *
from Trainer import *
from Model import *

In [17]:
dtype = torch.float32
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [18]:
# Split the data into train, validation and test set
dataspliter = DataSpliter(DATA_DIR, device, train_frac=0.6, val_frac=0.2, test_frac=0.2, use_seed=True, seed_val=42)

# Load Data from DataLoader
train_set = dataspliter.train
val_set = dataspliter.val
test_set = dataspliter.test

train_dataset = CustomImageDataset(DATA_DIR, train_set)
val_dataset = CustomImageDataset(DATA_DIR, val_set)
test_dataset = CustomImageDataset(DATA_DIR, test_set)

In [19]:
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=True)

In [20]:
# Get the model
model = CNNModel()
trainer = Trainer(model, train_dataloader, val_dataloader, test_dataloader)

The device that will be used in training is Quadro P1000


In [21]:
trainer.fit(epochs=10, batch_size= 8)
trainer.save_model("model.pt", "savefolderpytorch")

Epoch 1


Training: 576it [00:59,  9.67it/s, loss=45]                            
Validation: 192it [00:17, 10.76it/s, loss=9.17]                       

Validation loss is 0.46771186755763156
Epoch 2



Training: 576it [00:59,  9.67it/s, loss=19]                         
Validation: 192it [00:17, 10.79it/s, loss=9.1]                        

Validation loss is 0.3757900082402759
Epoch 3



Training: 576it [01:00,  9.51it/s, loss=31.4]                       
Validation: 192it [00:17, 10.80it/s, loss=21.9]                       

Validation loss is 0.5889465808868408
Epoch 4



Training: 576it [00:59,  9.62it/s, loss=6.43]                       
Validation: 192it [00:17, 10.80it/s, loss=3.74]                       

Validation loss is 0.17274966256486046
Epoch 5



Training: 576it [01:00,  9.58it/s, loss=3.81]                       
Validation: 192it [00:17, 10.87it/s, loss=5.29]                       

Validation loss is 0.19382429785198638
Epoch 6



Training: 576it [00:59,  9.67it/s, loss=2.46]                       
Validation: 192it [00:17, 10.86it/s, loss=0.928]                       

Validation loss is 0.054397722292277545
Epoch 7



Training: 576it [00:59,  9.68it/s, loss=5.74]                       
Validation: 192it [00:17, 10.89it/s, loss=7.2]                        

Validation loss is 0.2611794231666459
Epoch 8



Training: 576it [00:59,  9.70it/s, loss=12.1]                        
Validation: 192it [00:17, 10.86it/s, loss=9.65]                       

Validation loss is 0.7283290442493228
Epoch 9



Training: 576it [00:59,  9.72it/s, loss=0.955]                      
Validation: 192it [00:17, 10.85it/s, loss=7.01]                       

Validation loss is 0.321987868183189
Epoch 10



Training: 576it [00:59,  9.67it/s, loss=8.49]                        
Validation: 192it [00:17, 10.87it/s, loss=6.75]                       

Validation loss is 0.25695282303624684





In [34]:
_, output, labels = trainer.val_epoch(test_dataloader)

Validation: 192it [00:17, 10.71it/s, loss=5.97]                       

Validation loss is 0.23512407640616098





In [35]:
dfresults = pd.DataFrame(np.append(output, labels, axis=1),columns=["sigmapred", "volumepred", "sigmatrue", "volumetrue"])
dfresults["RMSE"] = np.sqrt((dfresults["sigmapred"]-dfresults["sigmatrue"])**2+(dfresults["volumepred"]-dfresults["volumetrue"])**2)
dfresults

Unnamed: 0,sigmapred,volumepred,sigmatrue,volumetrue,RMSE
0,73.542747,34.973106,72.5,32.0,3.150664
1,69.208298,33.956120,69.0,32.0,1.967179
2,75.364967,34.689098,74.0,32.0,3.015690
3,71.678352,34.499477,72.5,32.0,2.631063
4,74.059593,35.681126,72.5,32.0,3.997877
...,...,...,...,...,...
187,69.740074,34.929283,68.5,32.0,3.180956
188,75.782646,35.826004,74.0,32.0,4.220916
189,72.793373,35.659393,71.5,32.0,3.881233
190,74.548027,33.973171,74.5,32.0,1.973756
