**1. Import Libraries and Define Functions**

In [None]:
# basic
import pandas as pd

# PyTorch
import torch
import torchmetrics
from torch.utils.data import Dataset, DataLoader

# Pyro: pip install pyro-ppl==1.8.4
import pyro
import pyro.distributions as dist
import pyro.distributions.constraints as constraints
from pyro.infer import TraceEnum_ELBO, config_enumerate

# others
import os
import albumentations
from albumentations.pytorch.transforms import ToTensorV2
import pathlib
from typing import Tuple, Dict, List

# data augmentation
def get_transforms(image_size):
    transforms_train = albumentations.Compose([
        albumentations.Transpose(p=0.5),
        albumentations.VerticalFlip(p=0.5),
        albumentations.HorizontalFlip(p=0.5),
        albumentations.RandomBrightness(limit=0.2, p=0.75),
        albumentations.RandomContrast(limit=0.2, p=0.75),
        albumentations.OneOf([
            albumentations.MotionBlur(blur_limit=5),
            albumentations.MedianBlur(blur_limit=5),
            albumentations.GaussianBlur(blur_limit=5),
            albumentations.GaussNoise(var_limit=(5.0, 30.0)),
        ], p=0.7),
        albumentations.OneOf([
            albumentations.OpticalDistortion(distort_limit=1.0),
            albumentations.GridDistortion(num_steps=5, distort_limit=1.),
            albumentations.ElasticTransform(alpha=3),
        ], p=0.7),
        albumentations.CLAHE(clip_limit=4.0, p=0.7),
        albumentations.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=10, p=0.5),
        albumentations.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, border_mode=0, p=0.85),
        albumentations.Resize(image_size, image_size),
        albumentations.Cutout(max_h_size=int(image_size*0.375), max_w_size=int(image_size*0.375), num_holes=1, p=0.7),
        albumentations.Normalize(),
        ToTensorV2()
    ])
    transforms_val = albumentations.Compose([
        albumentations.Resize(image_size, image_size),
        albumentations.Normalize(),
        ToTensorV2()
    ])
    return transforms_train, transforms_val

# 'torchvision.datasets.ImageFolder()' customized for applying data augmentation with albumentations library

# make function to find classes in target directory
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
    """Finds the class folder names in a target directory.

    Assumes target directory is in standard image classification format.

    Args:
        directory (str): target directory to load classnames from.

    Returns:
        Tuple[List[str], Dict[str, int]]: (list_of_class_names, dict(class_name: idx...))
    
    Example:
        find_classes("food_images/train")
        >>> (["class_1", "class_2"], {"class_1": 0, ...})
    """
    # 1. get the class names by scanning the target directory
    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
    
    # 2. raise an error if class names not found
    if not classes:
        raise FileNotFoundError(f"Couldn't find any classes in {directory}.")
        
    # 3. create a dictionary of index labels (computers prefer numerical rather than string labels)
    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx

# write a customized dataset class (inherits from torch.utils.data.Dataset)
# 1. subclass torch.utils.data.Dataset
class CustomizedTextFolder(Dataset):
    # 2. initialize with a targ_dir and a metadata_file
    def __init__(self, img_dir: str, metadata_file) -> None:
        # 3. create class attributes
        # get all image paths
        self.paths = list(pathlib.Path(img_dir).glob("*/*.jpg")) # .png, .jpeg
        # get metadata file path
        self.metadata_file = metadata_file
        # create classes and class_to_idx attributes
        self.classes, self.class_to_idx = find_classes(img_dir)
    # 3. overwrite the __len__() method (optional but recommended for subclasses of torch.utils.data.Dataset)
    def __len__(self) -> int:
        "Returns the total number of samples."
        return len(self.paths)
    # 4. overwrite the __getitem__() method (required for subclasses of torch.utils.data.Dataset)
    def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
        "Returns one sample of data, (metadata, label): (fv, y)."
        # load label
        class_name  = self.paths[index].parent.name # expects path in data_folder/class_name/image.jpg(.png, .jpeg)
        class_idx = self.class_to_idx[class_name]
        label = class_idx
        # load corresponding metadata
        metadata_csv = pd.read_csv(self.metadata_file, index_col='img_id') # or other header names
        img_id = self.paths[index].name.replace('.jpg', '') # .png, .jpeg
        img_metadata = metadata_csv.loc[img_id]
        metadata = torch.Tensor(img_metadata.values)
        return metadata, label # return (metadata, label): (fv, y)

**2. Load and Pre-Process the Data**

In [None]:
num_classes = 6

fold = '_Fold1' # Stratified 5-Fold Cross Validation
img_train_dir = '/.../PAD-UFES-20_300x300_SoG_Split4-1-1'+str(fold)+'/train/'
img_val_dir = '/.../PAD-UFES-20_300x300_SoG_Split4-1-1'+str(fold)+'/val/'
img_test_dir = '/.../PAD-UFES-20_300x300_SoG_Split4-1-1'+str(fold)+'/test/'
metadata_file = '/.../metadata/metadata'+str(fold)+'.csv'

ds_train = CustomizedTextFolder(img_train_dir, metadata_file)
ds_val = CustomizedTextFolder(img_val_dir, metadata_file)
ds_test = CustomizedTextFolder(img_test_dir, metadata_file)

dl_train = DataLoader(ds_train, batch_size=64, shuffle=True)
dl_val = DataLoader(ds_val, batch_size=64, shuffle=False)
dl_test = DataLoader(ds_test, batch_size=64, shuffle=False)

# print('Examine Numerical Labels: ', ds_train.class_to_idx)
# for embeddings, labels in dl_train:
#     # shape of features: [batch_size; channels, height, width]
#     print('Examine Batched Data Shapes: ', embeddings.shape, labels.shape)
#     break

**3. Load and Fuse the Parameters**

In [None]:
# MUST CLEAR PARAM STORE BEFORE LOADING THE BEST PARAMS !!!
pyro.clear_param_store()

# load 2-class BN parameters
best_params_2Class = torch.load('/.../checkpoints/2ClassBN'+str(fold)+'.pt')
diagnosis_probs_param_2Class = best_params_2Class[6]
itch_probs_param_2Class = best_params_2Class[0]
grew_probs_param_2Class = best_params_2Class[1]
hurt_probs_param_2Class = best_params_2Class[2]
changed_probs_param_2Class = best_params_2Class[3]
bleed_probs_param_2Class = best_params_2Class[4]
elevation_probs_param_2Class = best_params_2Class[5]
site_probs_param_2Class = best_params_2Class[7]
diameter_probs_param_2Class = best_params_2Class[8]
age_probs_param_2Class = best_params_2Class[9]
# print('<<<<< 2-Class Concept BN >>>>>')
# print('Diagnosis Probs: \n', diagnosis_probs_param_2Class)
# print('\nItch Probs: \n', itch_probs_param_2Class)
# print('\nGrew Probs: \n', grew_probs_param_2Class)
# print('\nHurt Probs: \n', hurt_probs_param_2Class)
# print('\nChanged Probs: \n', changed_probs_param_2Class)
# print('\nBleed Probs: \n', bleed_probs_param_2Class)
# print('\nElevation Probs: \n', elevation_probs_param_2Class)
# print('\nSite Probs: \n', site_probs_param_2Class)
# print('\nDiameter Probs: \n', diameter_probs_param_2Class)
# print('\nAge Probs: \n', age_probs_param_2Class)

# load 4-class BN parameters
best_params_4Class = torch.load('/.../checkpoints/4ClassBN'+str(fold)+'.pt')
diagnosis_probs_param_4Class = best_params_4Class[6]
itch_probs_param_4Class = best_params_4Class[0]
grew_probs_param_4Class = best_params_4Class[1]
hurt_probs_param_4Class = best_params_4Class[2]
changed_probs_param_4Class = best_params_4Class[3]
bleed_probs_param_4Class = best_params_4Class[4]
elevation_probs_param_4Class = best_params_4Class[5]
site_probs_param_4Class = best_params_4Class[7]
diameter_probs_param_4Class = best_params_4Class[8]
age_probs_param_4Class = best_params_4Class[9]
# print('\n<<<<< 4-Class Concept BN >>>>>')
# print('Diagnosis Probs: \n', diagnosis_probs_param_4Class)
# print('\nItch Probs: \n', itch_probs_param_4Class)
# print('\nGrew Probs: \n', grew_probs_param_4Class)
# print('\nHurt Probs: \n', hurt_probs_param_4Class)
# print('\nChanged Probs: \n', changed_probs_param_4Class)
# print('\nBleed Probs: \n', bleed_probs_param_4Class)
# print('\nElevation Probs: \n', elevation_probs_param_4Class)
# print('\nSite Probs: \n', site_probs_param_4Class)
# print('\nDiameter Probs: \n', diameter_probs_param_4Class)
# print('\nAge Probs: \n', age_probs_param_4Class)

# fuse into a 6-class BN without additional training
diagnosis_probs_param = torch.cat((diagnosis_probs_param_4Class*2/3, diagnosis_probs_param_2Class/3), dim=1)
itch_probs_param = torch.cat((itch_probs_param_4Class, itch_probs_param_2Class), dim=0)
grew_probs_param = torch.cat((grew_probs_param_4Class, grew_probs_param_2Class), dim=0)
hurt_probs_param = torch.cat((hurt_probs_param_4Class, hurt_probs_param_2Class), dim=0)
changed_probs_param = torch.cat((changed_probs_param_4Class, changed_probs_param_2Class), dim=0)
bleed_probs_param = torch.cat((bleed_probs_param_4Class, bleed_probs_param_2Class), dim=0)
elevation_probs_param = torch.cat((elevation_probs_param_4Class, elevation_probs_param_2Class), dim=0)
site_probs_param = torch.cat((site_probs_param_4Class, site_probs_param_2Class), dim=0)
diameter_probs_param = torch.cat((diameter_probs_param_4Class, diameter_probs_param_2Class), dim=0)
age_probs_param = torch.cat((age_probs_param_4Class, age_probs_param_2Class), dim=0)
# print('<<<<< Fused Concept BN >>>>>')
# print('Diagnosis Probs: \n', diagnosis_probs_param)
# print('\nItch Probs: \n', itch_probs_param)
# print('\nGrew Probs: \n', grew_probs_param)
# print('\nHurt Probs: \n', hurt_probs_param)
# print('\nChanged Probs: \n', changed_probs_param)
# print('\nBleed Probs: \n', bleed_probs_param)
# print('\nElevation Probs: \n', elevation_probs_param)
# print('\nSite Probs: \n', site_probs_param)
# print('\nDiameter Probs: \n', diameter_probs_param)
# print('\nAge Probs: \n', age_probs_param)

**4. Build and Evaluate the Model**

In [None]:
@config_enumerate
def model(itch_obs, grew_obs, hurt_obs, changed_obs, bleed_obs, elevation_obs, site_obs, diameter_obs, age_obs, 
          C0_obs, C1_obs, C2_obs, C3_obs, C4_obs, 
          diagnosis_obs=None):
    # parameters
    # first layer 
    ### CNN latent variable: use the Softmax-ed output of CNN directly as the parameter of the distribution ###
    # second layer
    diagnosis_probs = pyro.param('diagnosis_probs', diagnosis_probs_param.cuda(), constraint=constraints.simplex)
    # third layer
    itch_probs = pyro.param('itch_probs', itch_probs_param.cuda(), constraint=constraints.simplex)
    grew_probs = pyro.param('grew_probs', grew_probs_param.cuda(), constraint=constraints.simplex)
    hurt_probs = pyro.param('hurt_probs', hurt_probs_param.cuda(), constraint=constraints.simplex)
    changed_probs = pyro.param('changed_probs', changed_probs_param.cuda(), constraint=constraints.simplex)
    bleed_probs = pyro.param('bleed_probs', bleed_probs_param.cuda(), constraint=constraints.simplex)
    elevation_probs = pyro.param('elevation_probs', elevation_probs_param.cuda(), constraint=constraints.simplex)
    site_probs = pyro.param('site_probs', site_probs_param.cuda(), constraint=constraints.simplex)
    diameter_probs = pyro.param('diameter_probs', diameter_probs_param.cuda(), constraint=constraints.simplex)
    age_probs = pyro.param('age_probs', age_probs_param.cuda(), constraint=constraints.simplex)

    # distributions
    # first layer
    C0 = pyro.sample('Erythema', dist.Bernoulli(probs=C0_obs))
    C1 = pyro.sample('Brown', dist.Bernoulli(probs=C1_obs))
    C2 = pyro.sample('Crust', dist.Bernoulli(probs=C2_obs))
    C3 = pyro.sample('Telangiectasia', dist.Bernoulli(probs=C3_obs))
    C4 = pyro.sample('Nodule', dist.Bernoulli(probs=C4_obs))
    # second layer
    diagnosis = pyro.sample('diagnosis', dist.Categorical(probs=diagnosis_probs[(C4*16+C3*8+C2*4+C1*2+C0).long()]), obs=diagnosis_obs)
    # third layer
    itch = pyro.sample('itch', dist.Categorical(probs=itch_probs[(diagnosis).long()]), obs=itch_obs)
    grew = pyro.sample('grew', dist.Categorical(probs=grew_probs[(diagnosis).long()]), obs=grew_obs)
    hurt = pyro.sample('hurt', dist.Categorical(probs=hurt_probs[(diagnosis).long()]), obs=hurt_obs)
    changed = pyro.sample('changed', dist.Categorical(probs=changed_probs[(diagnosis).long()]), obs=changed_obs)
    bleed = pyro.sample('bleed', dist.Categorical(probs=bleed_probs[(diagnosis).long()]), obs=bleed_obs)
    elevation = pyro.sample('elevation', dist.Categorical(probs=elevation_probs[(diagnosis).long()]), obs=elevation_obs)
    site = pyro.sample('site', dist.Categorical(probs=site_probs[(diagnosis).long()]), obs=site_obs)
    diameter = pyro.sample('diameter', dist.Categorical(probs=diameter_probs[(diagnosis).long()]), obs=diameter_obs)
    age = pyro.sample('age', dist.Categorical(probs=age_probs[(diagnosis).long()]), obs=age_obs)
    return diagnosis

def guide(itch_obs, grew_obs, hurt_obs, changed_obs, bleed_obs, elevation_obs, site_obs, diameter_obs, age_obs, 
          C0_obs, C1_obs, C2_obs, C3_obs, C4_obs, 
          diagnosis_obs=None):
    pass

def predict(itch_obs, grew_obs, hurt_obs, changed_obs, bleed_obs, elevation_obs, site_obs, diameter_obs, age_obs, 
            C0_obs, C1_obs, C2_obs, C3_obs, C4_obs):
    conditional_marginals = TraceEnum_ELBO().compute_marginals(model, guide, itch_obs=itch_obs, grew_obs=grew_obs, hurt_obs=hurt_obs, changed_obs=changed_obs, bleed_obs=bleed_obs, elevation_obs=elevation_obs, site_obs=site_obs, diameter_obs=diameter_obs, age_obs=age_obs, 
                                                               C0_obs=C0_obs, C1_obs=C1_obs, C2_obs=C2_obs, C3_obs=C3_obs, C4_obs=C4_obs)
    diagnosis_prob_0 = conditional_marginals['diagnosis'].log_prob(torch.tensor(0).cuda()).exp().reshape(1, 1)
    diagnosis_prob_1 = conditional_marginals['diagnosis'].log_prob(torch.tensor(1).cuda()).exp().reshape(1, 1)
    diagnosis_prob_2 = conditional_marginals['diagnosis'].log_prob(torch.tensor(2).cuda()).exp().reshape(1, 1)
    diagnosis_prob_3 = conditional_marginals['diagnosis'].log_prob(torch.tensor(3).cuda()).exp().reshape(1, 1)
    diagnosis_prob_4 = conditional_marginals['diagnosis'].log_prob(torch.tensor(4).cuda()).exp().reshape(1, 1)
    diagnosis_prob_5 = conditional_marginals['diagnosis'].log_prob(torch.tensor(5).cuda()).exp().reshape(1, 1)
    diagnosis_probs = torch.cat((diagnosis_prob_0, diagnosis_prob_1, diagnosis_prob_2, diagnosis_prob_3, diagnosis_prob_4, diagnosis_prob_5), dim=0).T
    diagnosis_preds = torch.argmax(diagnosis_probs, dim=1)
    return diagnosis_preds, diagnosis_probs

# calculate ACC, BACC, and AUROC
ACC = torchmetrics.Accuracy(multiclass=True, num_classes=6, average='micro').cuda()
BACC = torchmetrics.Accuracy(multiclass=True, num_classes=6, average='macro').cuda()
AUROC = torchmetrics.AUROC(num_classes=6, average='macro').cuda()
preds_all = torch.empty(0).cuda()
probs_all = torch.empty(0).cuda()
labels_all = torch.empty(0).cuda()
for i in range(len(ds_test)):
    # get one datapoint
    embeddings, labels = ds_test[i]
    embeddings = embeddings.cuda()
    labels = torch.tensor(labels, dtype=torch.int8).reshape(1).cuda()
    # assign data for BN
    itch_obs = embeddings[0]
    grew_obs = embeddings[1]
    hurt_obs = embeddings[2]
    changed_obs = embeddings[3]
    bleed_obs = embeddings[4]
    elevation_obs = embeddings[5]
    site_obs = embeddings[6]
    diameter_obs = embeddings[7]
    age_obs = embeddings[8]
    C0_obs = embeddings[9]
    C1_obs = embeddings[10]
    C2_obs = embeddings[11]
    C3_obs = embeddings[12]
    C4_obs = embeddings[13]
    # calculate accuracy
    preds, probs = predict(itch_obs, grew_obs, hurt_obs, changed_obs, bleed_obs, elevation_obs, site_obs, diameter_obs, age_obs, 
                           C0_obs, C1_obs, C2_obs, C3_obs, C4_obs)
    preds_all = torch.cat((preds_all, preds))
    probs_all = torch.cat((probs_all, probs))
    labels_all = torch.cat((labels_all, labels))
acc = ACC(preds_all.long(), labels_all.long())
bacc = BACC(preds_all.long(), labels_all.long())
auroc = AUROC(probs_all, labels_all.long())
print('Test ACC: '+str((100*acc).item())+' %')
print('Test BACC: '+str((100*bacc).item())+' %')
print('Test AUROC: '+str((auroc).item()))

**5. Save Classification Results**

In [None]:
bn_test = []
id_test = []
for i in range(len(ds_test)):
    # get one datapoint
    embeddings, _ = ds_test[i]
    embeddings = embeddings.cuda()
    # assign data for BN
    itch_obs = embeddings[0]
    grew_obs = embeddings[1]
    hurt_obs = embeddings[2]
    changed_obs = embeddings[3]
    bleed_obs = embeddings[4]
    elevation_obs = embeddings[5]
    site_obs = embeddings[6]
    diameter_obs = embeddings[7]
    age_obs = embeddings[8]
    C0_obs = embeddings[9]
    C1_obs = embeddings[10]
    C2_obs = embeddings[11]
    C3_obs = embeddings[12]
    C4_obs = embeddings[13]
    _, probs = predict(itch_obs, grew_obs, hurt_obs, changed_obs, bleed_obs, elevation_obs, site_obs, diameter_obs, age_obs, 
                       C0_obs, C1_obs, C2_obs, C3_obs, C4_obs)
    probs = probs.cpu().detach().numpy()
    bn_test.append(probs)
    id_test.append(os.path.basename(ds_test.paths[i]))
bn_test = [row[0] for row in bn_test]
bn_test = pd.DataFrame(data=bn_test).to_csv('bn_test.csv')
id_test = pd.DataFrame(data=id_test).to_csv('id_test.csv')