## Dependencies

In [1]:
import os
import json
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np


import torch
from torch.utils.data import DataLoader,random_split
from torchvision import transforms


import importlib
import dataloader
import config
import model
import train


importlib.reload(dataloader)
importlib.reload(config)
importlib.reload(model)
importlib.reload(train)

from train import train,validate,inference
from dataloader import CustomDataloader,train_test_split
from config import Config
from model import ModelV1,initialize_weights


config = Config()

## Prepare Dataset

In [2]:

dataset = CustomDataloader(config.data_dir,config.transform)
train_dataset, test_dataset = train_test_split(dataset, config.test_split)

train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=config.shuffle)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=config.shuffle)


## Create Model

In [7]:
model_obj = ModelV1(freeze_backbone=True, unfreeze_from_layer='layer4')
if config.load_model:
    model_obj.load_state_dict(torch.load(config.model_path))
else:
    initialize_weights(model_obj)

## Train

In [8]:
train(model = model_obj,dataloader = train_loader,config = config)

Training on Device: mps

[Train] Epoch [1/20], Batch [208/208], Loss: 1.4463  ------> Epoch [1/20] Completed. Avg Loss: 0.9892
[Train] Epoch [2/20], Batch [208/208], Loss: 0.4800  ------> Epoch [2/20] Completed. Avg Loss: 0.9892
[Train] Epoch [3/20], Batch [106/208], Loss: 0.9859

KeyboardInterrupt: 

## Validate Model

In [5]:
_ = validate(model = model_obj, dataloader = test_loader, config = Config())

[Validation] Batch [23/23] - Batch Loss: 0.4656
[Validation] Completed. Avg Loss: 1.0976

[Validation Metrics]
MSE: 1.0976
RMSE: 1.0477
MAE: 0.8158
R2: -0.0016
MAPE (%): 99.9948


## inference

In [6]:
for i in range(1,10):
    paths = f"/Users/maheshsaravanan/Documents/HemoScan/Dataset/{i:04d}"
    inference(model = model_obj, path = paths, mean = dataset.hb_mean, std = dataset.hb_std)

Predicted HB: 12.05       | Actual HB: 13.20
Predicted HB: 12.05       | Actual HB: 13.00
Predicted HB: 12.05       | Actual HB: 13.20
Predicted HB: 12.05       | Actual HB: 11.60
Predicted HB: 12.05       | Actual HB: 11.30
Predicted HB: 12.05       | Actual HB: 10.70
Predicted HB: 12.05       | Actual HB: 12.20
Predicted HB: 12.05       | Actual HB: 12.30
Predicted HB: 12.05       | Actual HB: 12.80
