## Dependencies

In [1]:
import os
import json
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np


import torch
from torch.utils.data import DataLoader,random_split
from torchvision import transforms


import importlib
import dataloader
import config
import model
import train


importlib.reload(dataloader)
importlib.reload(config)
importlib.reload(model)
importlib.reload(train)

from train import train,validate,inference
from dataloader import create_dataloaders
from config import Config
from model import ModelV1,initialize_weights


config = Config()

## Prepare Dataset

In [2]:
train_loader, val_loader, hb_mean, hb_std = create_dataloaders(
    data_dir=config.data_dir,
    batch_size=config.batch_size,
    test_split=config.test_split,
)

## Create Model

In [3]:
model_obj = ModelV1(freeze_backbone=True, unfreeze_from_layer='layer4')
# if config.load_model:
#     model_obj.load_state_dict(torch.load(config.model_path))
# else:
#     initialize_weights(model_obj)

## Train

In [5]:
train(model = model_obj,dataloader = train_loader,val_loader = val_loader, config = config)

Training on Device: mps

[Validation] Batch [23/23] - Batch Loss: 0.66137303  ------> Epoch [1/20] Completed. Avg Loss: 0.9874   ------> Validation Loss: 1.2052, R2: -0.1015, MAPE: 151.5483%
[Validation] Batch [23/23] - Batch Loss: 0.21993788  ------> Epoch [2/20] Completed. Avg Loss: 0.9548   ------> Validation Loss: 1.0337, R2: 0.0553, MAPE: 127.1283%
[Validation] Batch [23/23] - Batch Loss: 0.19671712  ------> Epoch [3/20] Completed. Avg Loss: 0.9286   ------> Validation Loss: 1.0008, R2: 0.0854, MAPE: 124.9063%
[Validation] Batch [23/23] - Batch Loss: 0.39188112  ------> Epoch [4/20] Completed. Avg Loss: 0.8783   ------> Validation Loss: 1.0708, R2: 0.0214, MAPE: 138.5790%
[Validation] Batch [23/23] - Batch Loss: 0.15390042  ------> Epoch [5/20] Completed. Avg Loss: 0.8717   ------> Validation Loss: 1.0284, R2: 0.0601, MAPE: 131.2871%
[Validation] Batch [23/23] - Batch Loss: 0.23995470  ------> Epoch [6/20] Completed. Avg Loss: 0.7981   ------> Validation Loss: 1.0035, R2: 0.0829, 

## Validate Model

In [7]:
_ = validate(model = model_obj, dataloader = train_loader, config = Config())

[Validation] Batch [208/208] - Batch Loss: 0.0693
[Validation] Completed. Avg Loss: 0.2804

[Validation Metrics]
MSE: 0.2804
RMSE: 0.5296
MAE: 0.3905
R2: 0.7151
MAPE (%): 127.6786


## inference

In [8]:
for i in range(1,10):
    paths = f"/Users/maheshsaravanan/Documents/HemoScan/Dataset/{i:04d}"
    inference(model = model_obj, path = paths, mean = hb_mean, std = hb_std)

Predicted HB: 12.12       | Actual HB: 13.20
Predicted HB: 12.09       | Actual HB: 13.00
Predicted HB: 12.60       | Actual HB: 13.20
Predicted HB: 11.67       | Actual HB: 11.60
Predicted HB: 11.42       | Actual HB: 11.30
Predicted HB: 10.51       | Actual HB: 10.70
Predicted HB: 11.70       | Actual HB: 12.20
Predicted HB: 11.67       | Actual HB: 12.30
Predicted HB: 11.59       | Actual HB: 12.80


In [4]:
def print_module_params(model):
    print(f"{'Module':<50} {'Trainable':<10} {'Params':>10}")
    print("-" * 80)
    total_params = 0
    trainable_params = 0
    
    for name, param in model.named_parameters():
        param_count = param.numel()
        total_params += param_count
        if param.requires_grad:
            trainable_params += param_count
        print(f"{name:<50} {str(param.requires_grad):<10} {param_count:>10,}")

    print("-" * 80)
    print(f"{'Total':<50} {'-':<10} {total_params:,}")
    print(f"{'Trainable':<50} {'-':<10} {trainable_params:,}")
print_module_params(model_obj)

Module                                             Trainable      Params
--------------------------------------------------------------------------------
resnet.conv1.weight                                False           9,408
resnet.bn1.weight                                  False              64
resnet.bn1.bias                                    False              64
resnet.layer1.0.conv1.weight                       False          36,864
resnet.layer1.0.bn1.weight                         False              64
resnet.layer1.0.bn1.bias                           False              64
resnet.layer1.0.conv2.weight                       False          36,864
resnet.layer1.0.bn2.weight                         False              64
resnet.layer1.0.bn2.bias                           False              64
resnet.layer1.1.conv1.weight                       False          36,864
resnet.layer1.1.bn1.weight                         False              64
resnet.layer1.1.bn1.bias                   