**1. Import Libraries and Define Functions**

In [None]:
# basic
import pandas as pd

# PyTorch
import torch
import torchmetrics
from torch.utils.data import Dataset, DataLoader

# Pyro: pip install pyro-ppl==1.8.4
import pyro
import pyro.distributions as dist
import pyro.distributions.constraints as constraints
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate

# others
import os
import albumentations
from albumentations.pytorch.transforms import ToTensorV2
import pathlib
from typing import Tuple, Dict, List

# data augmentation
def get_transforms(image_size):
    transforms_train = albumentations.Compose([
        albumentations.Transpose(p=0.5),
        albumentations.VerticalFlip(p=0.5),
        albumentations.HorizontalFlip(p=0.5),
        albumentations.RandomBrightness(limit=0.2, p=0.75),
        albumentations.RandomContrast(limit=0.2, p=0.75),
        albumentations.OneOf([
            albumentations.MotionBlur(blur_limit=5),
            albumentations.MedianBlur(blur_limit=5),
            albumentations.GaussianBlur(blur_limit=5),
            albumentations.GaussNoise(var_limit=(5.0, 30.0)),
        ], p=0.7),
        albumentations.OneOf([
            albumentations.OpticalDistortion(distort_limit=1.0),
            albumentations.GridDistortion(num_steps=5, distort_limit=1.),
            albumentations.ElasticTransform(alpha=3),
        ], p=0.7),
        albumentations.CLAHE(clip_limit=4.0, p=0.7),
        albumentations.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=10, p=0.5),
        albumentations.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, border_mode=0, p=0.85),
        albumentations.Resize(image_size, image_size),
        albumentations.Cutout(max_h_size=int(image_size*0.375), max_w_size=int(image_size*0.375), num_holes=1, p=0.7),
        albumentations.Normalize(),
        ToTensorV2()
    ])
    transforms_val = albumentations.Compose([
        albumentations.Resize(image_size, image_size),
        albumentations.Normalize(),
        ToTensorV2()
    ])
    return transforms_train, transforms_val

# 'torchvision.datasets.ImageFolder()' customized for applying data augmentation with albumentations library

# make function to find classes in target directory
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
    """Finds the class folder names in a target directory.

    Assumes target directory is in standard image classification format.

    Args:
        directory (str): target directory to load classnames from.

    Returns:
        Tuple[List[str], Dict[str, int]]: (list_of_class_names, dict(class_name: idx...))
    
    Example:
        find_classes("food_images/train")
        >>> (["class_1", "class_2"], {"class_1": 0, ...})
    """
    # 1. get the class names by scanning the target directory
    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
    
    # 2. raise an error if class names not found
    if not classes:
        raise FileNotFoundError(f"Couldn't find any classes in {directory}.")
        
    # 3. create a dictionary of index labels (computers prefer numerical rather than string labels)
    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx

# write a customized dataset class (inherits from torch.utils.data.Dataset)
# 1. subclass torch.utils.data.Dataset
class CustomizedTextFolder(Dataset):
    # 2. initialize with a targ_dir and a metadata_file
    def __init__(self, img_dir: str, metadata_file) -> None:
        # 3. create class attributes
        # get all image paths
        self.paths = list(pathlib.Path(img_dir).glob("*/*.jpg")) # .png, .jpeg
        # get metadata file path
        self.metadata_file = metadata_file
        # create classes and class_to_idx attributes
        self.classes, self.class_to_idx = find_classes(img_dir)
    # 3. overwrite the __len__() method (optional but recommended for subclasses of torch.utils.data.Dataset)
    def __len__(self) -> int:
        "Returns the total number of samples."
        return len(self.paths)
    # 4. overwrite the __getitem__() method (required for subclasses of torch.utils.data.Dataset)
    def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
        "Returns one sample of data, (metadata, label): (fv, y)."
        # load label
        class_name  = self.paths[index].parent.name # expects path in data_folder/class_name/image.jpg(.png, .jpeg)
        class_idx = self.class_to_idx[class_name]
        label = class_idx
        # load corresponding metadata
        metadata_csv = pd.read_csv(self.metadata_file, index_col='img_id') # or other header names
        img_id = self.paths[index].name.replace('.jpg', '') # .png, .jpeg
        img_metadata = metadata_csv.loc[img_id]
        metadata = torch.Tensor(img_metadata.values)
        return metadata, label # return (metadata, label): (fv, y)

**2. Load and Pre-Process the Data**

In [None]:
num_classes = 4

fold = '_Fold1' # Stratified 5-Fold Cross Validation
img_train_dir = '/.../4Class'+str(fold)+'/train/'
img_val_dir = '/.../4Class'+str(fold)+'/val/'
img_test_dir = '/.../4Class'+str(fold)+'/test/'
metadata_file = '/.../metadata/metadata'+str(fold)+'.csv'

ds_train = CustomizedTextFolder(img_train_dir, metadata_file)
ds_val = CustomizedTextFolder(img_val_dir, metadata_file)
ds_test = CustomizedTextFolder(img_test_dir, metadata_file)

dl_train = DataLoader(ds_train, batch_size=64, shuffle=True)
dl_val = DataLoader(ds_val, batch_size=64, shuffle=False)
dl_test = DataLoader(ds_test, batch_size=64, shuffle=False)

# print('Examine Numerical Labels: ', ds_train.class_to_idx)
# for embeddings, labels in dl_train:
#     # shape of features: [batch_size; channels, height, width]
#     print('Examine Batched Data Shapes: ', embeddings.shape, labels.shape)
#     break

**3. Define the Model and Metrics**

In [None]:
# first layer 
### CNN latent variable: use the Softmax-ed output of CNN directly as the parameter of the distribution ###
# second layer
diagnosis_probs_param = torch.ones(32, 4)
diagnosis_probs_param = diagnosis_probs_param/diagnosis_probs_param.sum(dim=1, keepdim=True)
# third layer
itch_probs_param = torch.ones(4, 2)
itch_probs_param = itch_probs_param/itch_probs_param.sum(dim=1, keepdim=True)
grew_probs_param = torch.ones(4, 2)
grew_probs_param = grew_probs_param/grew_probs_param.sum(dim=1, keepdim=True)
hurt_probs_param = torch.ones(4, 2)
hurt_probs_param = hurt_probs_param/hurt_probs_param.sum(dim=1, keepdim=True)
changed_probs_param = torch.ones(4, 2)
changed_probs_param = changed_probs_param/changed_probs_param.sum(dim=1, keepdim=True)
bleed_probs_param = torch.ones(4, 2)
bleed_probs_param = bleed_probs_param/bleed_probs_param.sum(dim=1, keepdim=True)
elevation_probs_param = torch.ones(4, 2)
elevation_probs_param = elevation_probs_param/elevation_probs_param.sum(dim=1, keepdim=True)
site_probs_param = torch.ones(4, 14)
site_probs_param = site_probs_param/site_probs_param.sum(dim=1, keepdim=True)
diameter_probs_param = torch.ones(4, 8)
diameter_probs_param = diameter_probs_param/diameter_probs_param.sum(dim=1, keepdim=True)
age_probs_param = torch.ones(4, 10)
age_probs_param = age_probs_param/age_probs_param.sum(dim=1, keepdim=True)

@config_enumerate
def model(itch_obs, grew_obs, hurt_obs, changed_obs, bleed_obs, elevation_obs, site_obs, diameter_obs, age_obs, 
          C0_obs, C1_obs, C2_obs, C3_obs, C4_obs, 
          diagnosis_obs=None):
    # parameters
    # first layer 
    ### CNN latent variable: use the Softmax-ed output of CNN directly as the parameter of the distribution ###
    # second layer
    diagnosis_probs = pyro.param('diagnosis_probs', diagnosis_probs_param.cuda(), constraint=constraints.simplex)
    # third layer
    itch_probs = pyro.param('itch_probs', itch_probs_param.cuda(), constraint=constraints.simplex)
    grew_probs = pyro.param('grew_probs', grew_probs_param.cuda(), constraint=constraints.simplex)
    hurt_probs = pyro.param('hurt_probs', hurt_probs_param.cuda(), constraint=constraints.simplex)
    changed_probs = pyro.param('changed_probs', changed_probs_param.cuda(), constraint=constraints.simplex)
    bleed_probs = pyro.param('bleed_probs', bleed_probs_param.cuda(), constraint=constraints.simplex)
    elevation_probs = pyro.param('elevation_probs', elevation_probs_param.cuda(), constraint=constraints.simplex)
    site_probs = pyro.param('site_probs', site_probs_param.cuda(), constraint=constraints.simplex)
    diameter_probs = pyro.param('diameter_probs', diameter_probs_param.cuda(), constraint=constraints.simplex)
    age_probs = pyro.param('age_probs', age_probs_param.cuda(), constraint=constraints.simplex)

    # distributions
    with pyro.plate('data', len(itch_obs)):
        # first layer
        C0 = pyro.sample('Erythema', dist.Bernoulli(probs=C0_obs))
        C1 = pyro.sample('Brown', dist.Bernoulli(probs=C1_obs))
        C2 = pyro.sample('Crust', dist.Bernoulli(probs=C2_obs))
        C3 = pyro.sample('Telangiectasia', dist.Bernoulli(probs=C3_obs))
        C4 = pyro.sample('Black', dist.Bernoulli(probs=C4_obs))
        # second layer
        diagnosis = pyro.sample('diagnosis', dist.Categorical(probs=diagnosis_probs[(C4*16+C3*8+C2*4+C1*2+C0).long()]), obs=diagnosis_obs)
        # third layer
        itch = pyro.sample('itch', dist.Categorical(probs=itch_probs[(diagnosis).long()]), obs=itch_obs)
        grew = pyro.sample('grew', dist.Categorical(probs=grew_probs[(diagnosis).long()]), obs=grew_obs)
        hurt = pyro.sample('hurt', dist.Categorical(probs=hurt_probs[(diagnosis).long()]), obs=hurt_obs)
        changed = pyro.sample('changed', dist.Categorical(probs=changed_probs[(diagnosis).long()]), obs=changed_obs)
        bleed = pyro.sample('bleed', dist.Categorical(probs=bleed_probs[(diagnosis).long()]), obs=bleed_obs)
        elevation = pyro.sample('elevation', dist.Categorical(probs=elevation_probs[(diagnosis).long()]), obs=elevation_obs)
        site = pyro.sample('site', dist.Categorical(probs=site_probs[(diagnosis).long()]), obs=site_obs)
        diameter = pyro.sample('diameter', dist.Categorical(probs=diameter_probs[(diagnosis).long()]), obs=diameter_obs)
        age = pyro.sample('age', dist.Categorical(probs=age_probs[(diagnosis).long()]), obs=age_obs)
        return diagnosis

def guide(itch_obs, grew_obs, hurt_obs, changed_obs, bleed_obs, elevation_obs, site_obs, diameter_obs, age_obs, 
          C0_obs, C1_obs, C2_obs, C3_obs, C4_obs, 
          diagnosis_obs=None):
    pass

# itch_obs_test = torch.ones(10).cuda()
# grew_obs_test = torch.ones(10).cuda()
# hurt_obs_test = torch.ones(10).cuda()
# changed_obs_test = torch.ones(10).cuda()
# bleed_obs_test = torch.ones(10).cuda()
# elevation_obs_test = torch.ones(10).cuda()
# diagnosis_obs_test = torch.ones(10).cuda()
# site_obs_test = torch.ones(10).cuda()
# diameter_obs_test = torch.ones(10).cuda()
# age_obs_test = torch.ones(10).cuda()
# C0_obs_test = torch.ones(10).cuda()
# C1_obs_test = torch.ones(10).cuda()
# C2_obs_test = torch.ones(10).cuda()
# C3_obs_test = torch.ones(10).cuda()
# C4_obs_test = torch.ones(10).cuda()
# pyro.render_model(model=model, model_args=(itch_obs_test, grew_obs_test, hurt_obs_test, changed_obs_test, bleed_obs_test, elevation_obs_test, site_obs_test, diameter_obs_test, age_obs_test, 
#                                            C0_obs_test, C1_obs_test, C2_obs_test, C3_obs_test, C4_obs_test,  
#                                            diagnosis_obs_test), render_distributions=True, render_params=False)

**4. Train and Save the Model**

In [None]:
num_iterations = 200
optimizer = pyro.optim.Adam({'lr':2e-3})
svi = SVI(model=model, guide=guide, optim=optimizer, loss=TraceEnum_ELBO())
BACC = torchmetrics.Accuracy(multiclass=True, num_classes=4, average='macro').cuda()

def predict(itch_obs, grew_obs, hurt_obs, changed_obs, bleed_obs, elevation_obs, site_obs, diameter_obs, age_obs, 
            C0_obs, C1_obs, C2_obs, C3_obs, C4_obs):
    conditional_marginals = TraceEnum_ELBO().compute_marginals(model, guide, itch_obs=itch_obs, grew_obs=grew_obs, hurt_obs=hurt_obs, changed_obs=changed_obs, bleed_obs=bleed_obs, elevation_obs=elevation_obs, site_obs=site_obs, diameter_obs=diameter_obs, age_obs=age_obs, 
                                                               C0_obs=C0_obs, C1_obs=C1_obs, C2_obs=C2_obs, C3_obs=C3_obs, C4_obs=C4_obs)
    diagnosis_prob_0 = conditional_marginals['diagnosis'].log_prob(torch.tensor(0).cuda()).exp().reshape(1, len(itch_obs))
    diagnosis_prob_1 = conditional_marginals['diagnosis'].log_prob(torch.tensor(1).cuda()).exp().reshape(1, len(itch_obs))
    diagnosis_prob_2 = conditional_marginals['diagnosis'].log_prob(torch.tensor(2).cuda()).exp().reshape(1, len(itch_obs))
    diagnosis_prob_3 = conditional_marginals['diagnosis'].log_prob(torch.tensor(3).cuda()).exp().reshape(1, len(itch_obs))
    diagnosis_probs = torch.cat((diagnosis_prob_0, diagnosis_prob_1, diagnosis_prob_2, diagnosis_prob_3), dim=0).T
    diagnosis_preds = torch.argmax(diagnosis_probs, dim=1)
    return diagnosis_preds, diagnosis_probs

print('START TRAINING')
best_bacc = 0
for i in range(1, num_iterations+1):
    # CALCULATE TRAIN LOSS AND ACCURACY
    loss = 0
    preds_all = torch.empty(0).cuda()
    labels_all = torch.empty(0).cuda()
    for batch_id, data in enumerate(dl_train):
        # get data from one batch
        embeddings, labels = data
        embeddings = embeddings.cuda()
        labels = labels.cuda()
        # assign data for BN
        itch_obs = embeddings[:, 0]
        grew_obs = embeddings[:, 1]
        hurt_obs = embeddings[:, 2]
        changed_obs = embeddings[:, 3]
        bleed_obs = embeddings[:, 4]
        elevation_obs = embeddings[:, 5]
        site_obs = embeddings[:, 6]
        diameter_obs = embeddings[:, 7]
        age_obs = embeddings[:, 8]
        C0_obs = embeddings[:, 9]
        C1_obs = embeddings[:, 10]
        C2_obs = embeddings[:, 11]
        C3_obs = embeddings[:, 12]
        C4_obs = embeddings[:, 13]
        diagnosis_obs = labels
        # calculate the loss and take a gradient descent step
        loss += svi.step(itch_obs, grew_obs, hurt_obs, changed_obs, bleed_obs, elevation_obs, site_obs, diameter_obs, age_obs, 
                         C0_obs, C1_obs, C2_obs, C3_obs, C4_obs, 
                         diagnosis_obs)
        # calculate train accuracy
        preds, _ = predict(itch_obs, grew_obs, hurt_obs, changed_obs, bleed_obs, elevation_obs, site_obs, diameter_obs, age_obs, 
                           C0_obs, C1_obs, C2_obs, C3_obs, C4_obs)
        preds_all = torch.cat((preds_all, preds))
        labels_all = torch.cat((labels_all, labels))
    bacc = BACC(preds_all.long(), labels_all.long())
    normalizer = len(dl_train)
    epoch_loss = loss/normalizer
    print('\nEpoch ', i, '\nTrain Loss ', epoch_loss)
    print('Train BACC: '+str((100*bacc).item())+' %')

    # CALCULATE VALIDATION ACCURACY
    preds_all = torch.empty(0).cuda()
    labels_all = torch.empty(0).cuda()
    for batch_id, data in enumerate(dl_val):
        # get data from one batch
        embeddings, labels = data
        embeddings = embeddings.cuda()
        labels = labels.cuda()
        # assign data for BN
        itch_obs = embeddings[:, 0]
        grew_obs = embeddings[:, 1]
        hurt_obs = embeddings[:, 2]
        changed_obs = embeddings[:, 3]
        bleed_obs = embeddings[:, 4]
        elevation_obs = embeddings[:, 5]
        site_obs = embeddings[:, 6]
        diameter_obs = embeddings[:, 7]
        age_obs = embeddings[:, 8]
        C0_obs = embeddings[:, 9]
        C1_obs = embeddings[:, 10]
        C2_obs = embeddings[:, 11]
        C3_obs = embeddings[:, 12]
        C4_obs = embeddings[:, 13]
        # calculate accuracy
        preds, _ = predict(itch_obs, grew_obs, hurt_obs, changed_obs, bleed_obs, elevation_obs, site_obs, diameter_obs, age_obs, 
                           C0_obs, C1_obs, C2_obs, C3_obs, C4_obs)
        preds_all = torch.cat((preds_all, preds))
        labels_all = torch.cat((labels_all, labels))
    bacc = BACC(preds_all.long(), labels_all.long())
    print('Validation BACC: '+str((100*bacc).item())+' %')

    # SAVE THE PARAMETERS OF THE BEST MODEL
    if  bacc >= best_bacc:
        best_bacc = bacc
        best_params = []
        best_params.append(pyro.param('itch_probs'))
        best_params.append(pyro.param('grew_probs'))
        best_params.append(pyro.param('hurt_probs'))
        best_params.append(pyro.param('changed_probs'))
        best_params.append(pyro.param('bleed_probs'))
        best_params.append(pyro.param('elevation_probs'))
        best_params.append(pyro.param('diagnosis_probs'))
        best_params.append(pyro.param('site_probs'))
        best_params.append(pyro.param('diameter_probs'))
        best_params.append(pyro.param('age_probs'))
        torch.save(best_params, '4ClassBN'+str(fold)+'.pt')
        print('<<<<< Reached Best BACC: '+str((100*bacc).item())+' % >>>>>')

print('\nEND TRAINING')

**5. Load and Evaluate the Model**

In [None]:
# MUST CLEAR PARAM STORE BEFORE LOADING THE BEST PARAMS !!!
pyro.clear_param_store()
best_params = torch.load('/.../checkpoints/4ClassBN'+str(fold)+'.pt')
# first layer
### CNN latent variable: use the Softmax-ed output of CNN directly as the parameter of the distribution ###
# second layer
diagnosis_probs_param = best_params[6]
# third layer
itch_probs_param = best_params[0]
grew_probs_param = best_params[1]
hurt_probs_param = best_params[2]
changed_probs_param = best_params[3]
bleed_probs_param = best_params[4]
elevation_probs_param = best_params[5]
site_probs_param = best_params[7]
diameter_probs_param = best_params[8]
age_probs_param = best_params[9]

@config_enumerate
def model(itch_obs, grew_obs, hurt_obs, changed_obs, bleed_obs, elevation_obs, site_obs, diameter_obs, age_obs, 
          C0_obs, C1_obs, C2_obs, C3_obs, C4_obs, 
          diagnosis_obs=None):
    # parameters
    # first layer 
    ### CNN latent variable: use the Softmax-ed output of CNN directly as the parameter of the distribution ###
    # second layer
    diagnosis_probs = pyro.param('diagnosis_probs', diagnosis_probs_param.cuda(), constraint=constraints.simplex)
    # third layer
    itch_probs = pyro.param('itch_probs', itch_probs_param.cuda(), constraint=constraints.simplex)
    grew_probs = pyro.param('grew_probs', grew_probs_param.cuda(), constraint=constraints.simplex)
    hurt_probs = pyro.param('hurt_probs', hurt_probs_param.cuda(), constraint=constraints.simplex)
    changed_probs = pyro.param('changed_probs', changed_probs_param.cuda(), constraint=constraints.simplex)
    bleed_probs = pyro.param('bleed_probs', bleed_probs_param.cuda(), constraint=constraints.simplex)
    elevation_probs = pyro.param('elevation_probs', elevation_probs_param.cuda(), constraint=constraints.simplex)
    site_probs = pyro.param('site_probs', site_probs_param.cuda(), constraint=constraints.simplex)
    diameter_probs = pyro.param('diameter_probs', diameter_probs_param.cuda(), constraint=constraints.simplex)
    age_probs = pyro.param('age_probs', age_probs_param.cuda(), constraint=constraints.simplex)

    # distributions
    # first layer
    C0 = pyro.sample('Erythema', dist.Bernoulli(probs=C0_obs))
    C1 = pyro.sample('Brown', dist.Bernoulli(probs=C1_obs))
    C2 = pyro.sample('Crust', dist.Bernoulli(probs=C2_obs))
    C3 = pyro.sample('Telangiectasia', dist.Bernoulli(probs=C3_obs))
    C4 = pyro.sample('Nodule', dist.Bernoulli(probs=C4_obs))
    # second layer
    diagnosis = pyro.sample('diagnosis', dist.Categorical(probs=diagnosis_probs[(C4*16+C3*8+C2*4+C1*2+C0).long()]), obs=diagnosis_obs)
    # third layer
    itch = pyro.sample('itch', dist.Categorical(probs=itch_probs[(diagnosis).long()]), obs=itch_obs)
    grew = pyro.sample('grew', dist.Categorical(probs=grew_probs[(diagnosis).long()]), obs=grew_obs)
    hurt = pyro.sample('hurt', dist.Categorical(probs=hurt_probs[(diagnosis).long()]), obs=hurt_obs)
    changed = pyro.sample('changed', dist.Categorical(probs=changed_probs[(diagnosis).long()]), obs=changed_obs)
    bleed = pyro.sample('bleed', dist.Categorical(probs=bleed_probs[(diagnosis).long()]), obs=bleed_obs)
    elevation = pyro.sample('elevation', dist.Categorical(probs=elevation_probs[(diagnosis).long()]), obs=elevation_obs)
    site = pyro.sample('site', dist.Categorical(probs=site_probs[(diagnosis).long()]), obs=site_obs)
    diameter = pyro.sample('diameter', dist.Categorical(probs=diameter_probs[(diagnosis).long()]), obs=diameter_obs)
    age = pyro.sample('age', dist.Categorical(probs=age_probs[(diagnosis).long()]), obs=age_obs)
    return diagnosis

def guide(itch_obs, grew_obs, hurt_obs, changed_obs, bleed_obs, elevation_obs, site_obs, diameter_obs, age_obs, 
          C0_obs, C1_obs, C2_obs, C3_obs, C4_obs, 
          diagnosis_obs=None):
    pass

def predict(itch_obs, grew_obs, hurt_obs, changed_obs, bleed_obs, elevation_obs, site_obs, diameter_obs, age_obs, 
            C0_obs, C1_obs, C2_obs, C3_obs, C4_obs):
    conditional_marginals = TraceEnum_ELBO().compute_marginals(model, guide, itch_obs=itch_obs, grew_obs=grew_obs, hurt_obs=hurt_obs, changed_obs=changed_obs, bleed_obs=bleed_obs, elevation_obs=elevation_obs, site_obs=site_obs, diameter_obs=diameter_obs, age_obs=age_obs, 
                                                               C0_obs=C0_obs, C1_obs=C1_obs, C2_obs=C2_obs, C3_obs=C3_obs, C4_obs=C4_obs)
    diagnosis_prob_0 = conditional_marginals['diagnosis'].log_prob(torch.tensor(0).cuda()).exp().reshape(1, 1)
    diagnosis_prob_1 = conditional_marginals['diagnosis'].log_prob(torch.tensor(1).cuda()).exp().reshape(1, 1)
    diagnosis_prob_2 = conditional_marginals['diagnosis'].log_prob(torch.tensor(2).cuda()).exp().reshape(1, 1)
    diagnosis_prob_3 = conditional_marginals['diagnosis'].log_prob(torch.tensor(3).cuda()).exp().reshape(1, 1)
    diagnosis_probs = torch.cat((diagnosis_prob_0, diagnosis_prob_1, diagnosis_prob_2, diagnosis_prob_3), dim=0).T
    diagnosis_preds = torch.argmax(diagnosis_probs, dim=1)
    return diagnosis_preds, diagnosis_probs

# calculate ACC, BACC, and AUROC
ACC = torchmetrics.Accuracy(multiclass=True, num_classes=4, average='micro').cuda()
BACC = torchmetrics.Accuracy(multiclass=True, num_classes=4, average='macro').cuda()
AUROC = torchmetrics.AUROC(num_classes=4, average='macro').cuda()
preds_all = torch.empty(0).cuda()
probs_all = torch.empty(0).cuda()
labels_all = torch.empty(0).cuda()
for i in range(len(ds_test)):
    # get one datapoint
    embeddings, labels = ds_test[i]
    embeddings = embeddings.cuda()
    labels = torch.tensor(labels, dtype=torch.int8).reshape(1).cuda()
    # assign data for BN
    itch_obs = embeddings[0]
    grew_obs = embeddings[1]
    hurt_obs = embeddings[2]
    changed_obs = embeddings[3]
    bleed_obs = embeddings[4]
    elevation_obs = embeddings[5]
    site_obs = embeddings[6]
    diameter_obs = embeddings[7]
    age_obs = embeddings[8]
    C0_obs = embeddings[9]
    C1_obs = embeddings[10]
    C2_obs = embeddings[11]
    C3_obs = embeddings[12]
    C4_obs = embeddings[13]
    # calculate accuracy
    preds, probs = predict(itch_obs, grew_obs, hurt_obs, changed_obs, bleed_obs, elevation_obs, site_obs, diameter_obs, age_obs, 
                           C0_obs, C1_obs, C2_obs, C3_obs, C4_obs)
    preds_all = torch.cat((preds_all, preds))
    probs_all = torch.cat((probs_all, probs))
    labels_all = torch.cat((labels_all, labels))
acc = ACC(preds_all.long(), labels_all.long())
bacc = BACC(preds_all.long(), labels_all.long())
auroc = AUROC(probs_all, labels_all.long())
print('Test ACC: '+str((100*acc).item())+' %')
print('Test BACC: '+str((100*bacc).item())+' %')
print('Test AUROC: '+str((auroc).item()))