**1. Import Libraries and Define Functions**

In [None]:
# basic
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# PyTorch
import torch
from torch import nn
from torch.utils.data import Dataset
import timm

# Pyro: pip install pyro-ppl==1.8.4
import pyro
import pyro.distributions as dist
import pyro.distributions.constraints as constraints
from pyro.infer import TraceEnum_ELBO, config_enumerate

# others
import os
import pathlib
from PIL import Image
from typing import Tuple, Dict, List
import albumentations
from albumentations.pytorch.transforms import ToTensorV2

# data augmentation
def get_transforms(image_size):
    transforms_train = albumentations.Compose([
        albumentations.Transpose(p=0.5),
        albumentations.VerticalFlip(p=0.5),
        albumentations.HorizontalFlip(p=0.5),
        albumentations.RandomBrightness(limit=0.2, p=0.75),
        albumentations.RandomContrast(limit=0.2, p=0.75),
        albumentations.OneOf([
            albumentations.MotionBlur(blur_limit=5),
            albumentations.MedianBlur(blur_limit=5),
            albumentations.GaussianBlur(blur_limit=5),
            albumentations.GaussNoise(var_limit=(5.0, 30.0)),
        ], p=0.7),
        albumentations.OneOf([
            albumentations.OpticalDistortion(distort_limit=1.0),
            albumentations.GridDistortion(num_steps=5, distort_limit=1.),
            albumentations.ElasticTransform(alpha=3),
        ], p=0.7),
        albumentations.CLAHE(clip_limit=4.0, p=0.7),
        albumentations.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=10, p=0.5),
        albumentations.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, border_mode=0, p=0.85),
        albumentations.Resize(image_size, image_size),
        albumentations.Cutout(max_h_size=int(image_size*0.375), max_w_size=int(image_size*0.375), num_holes=1, p=0.7),
        albumentations.Normalize(),
        ToTensorV2()
    ])
    transforms_val = albumentations.Compose([
        albumentations.Resize(image_size, image_size),
        albumentations.Normalize(),
        ToTensorV2()
    ])
    return transforms_train, transforms_val

# customized 'torchvision.datasets.ImageFolder()'

# make function to find classes in target directory
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
    """Finds the class folder names in a target directory.

    Assumes target directory is in standard image classification format.

    Args:
        directory (str): target directory to load classnames from.

    Returns:
        Tuple[List[str], Dict[str, int]]: (list_of_class_names, dict(class_name: idx...))
    
    Example:
        find_classes("food_images/train")
        >>> (["class_1", "class_2"], {"class_1": 0, ...})
    """
    # 1. get the class names by scanning the target directory
    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
    
    # 2. raise an error if class names not found
    if not classes:
        raise FileNotFoundError(f"Couldn't find any classes in {directory}.")
        
    # 3. create a dictionary of index labels (computers prefer numerical rather than string labels)
    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx

# write a customized dataset class (inherits from torch.utils.data.Dataset)
# 1. subclass torch.utils.data.Dataset
class CustomizedImageFolder(Dataset):
    # 2. initialize with a targ_dir, metadata_file and a transform parameter
    def __init__(self, img_dir: str, metadata_file, transform) -> None:
        # 3. create class attributes
        # get all image paths
        self.paths = list(pathlib.Path(img_dir).glob("*/*.jpg")) # .png, .jpeg
        # get metadata file path
        self.metadata_file = metadata_file
        # setup transforms
        self.transform = transform
        # create classes and class_to_idx attributes
        self.classes, self.class_to_idx = find_classes(img_dir)
    # 4. make function to load images
    def load_image(self, index: int) -> Image.Image:
        "Opens an image via a path and returns it."
        image_path = self.paths[index]
        return Image.open(image_path) 
    # 5. overwrite the __len__() method (optional but recommended for subclasses of torch.utils.data.Dataset)
    def __len__(self) -> int:
        "Returns the total number of samples."
        return len(self.paths)
    # 6. overwrite the __getitem__() method (required for subclasses of torch.utils.data.Dataset)
    def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
        "Returns one sample of data, (img, metadata, label): (X, fv, y)."
        # load image and label
        img = self.load_image(index)
        class_name  = self.paths[index].parent.name # expects path in data_folder/class_name/image.jpg(.png, .jpeg)
        class_idx = self.class_to_idx[class_name]
        label = class_idx
        # load corresponding metadata
        metadata_csv = pd.read_csv(self.metadata_file, index_col='img_id') # or other header names
        img_id = self.paths[index].name.replace('.jpg', '') # .png, .jpeg
        img_metadata = metadata_csv.loc[img_id]
        metadata = torch.Tensor(img_metadata.values)
        # transform for applying data augmentation with albumentations library
        img = np.array(img)
        return self.transform(image=img)['image'], metadata, label # return (img, metadata, label): (X, fv, y)

**2. Build the Multimodal Bayesian Network**

In [None]:
img_size = 300
num_classes = 6
_, transform_normal = get_transforms(img_size)

# 1. DEFINE THE STRUCTURE AND LOAD THE PARAMETERS FOR CONVOLUTIONAL NEURAL NETWORK
enb3 = timm.create_model('efficientnet_b3', features_only=False, pretrained=False, num_classes=num_classes)
enb3.load_state_dict(torch.load('CNN_params.pt', map_location=torch.device('cpu')))
enb3.eval()

# 2. DEFINE THE STRUCTURE AND LOAD THE PARAMETERS FOR BAYESIAN NETWORK
pyro.clear_param_store()
best_params = torch.load('BN_params.pt', map_location=torch.device('cpu'))
# first layer
### CNN latent variable: use the Softmax-ed output of CNN directly as the parameter of the distribution ###
# second layer
diagnosis_probs_param = best_params[6]
# third layer
itch_probs_param = best_params[0]
grew_probs_param = best_params[1]
hurt_probs_param = best_params[2]
changed_probs_param = best_params[3]
bleed_probs_param = best_params[4]
elevation_probs_param = best_params[5]
site_probs_param = best_params[7]
diameter_probs_param = best_params[8]
age_probs_param = best_params[9]

@config_enumerate
def model(img_obs, itch_obs, grew_obs, hurt_obs, changed_obs, bleed_obs, elevation_obs, site_obs, diameter_obs, age_obs, diagnosis_obs=None):
    # parameters
    # first layer 
    ### CNN latent variable: use the Softmax-ed output of CNN directly as the parameter of the distribution ###
    # second layer
    diagnosis_probs = pyro.param('diagnosis_probs', diagnosis_probs_param, constraint=constraints.simplex)
    # third layer
    itch_probs = pyro.param('itch_probs', itch_probs_param, constraint=constraints.simplex)
    grew_probs = pyro.param('grew_probs', grew_probs_param, constraint=constraints.simplex)
    hurt_probs = pyro.param('hurt_probs', hurt_probs_param, constraint=constraints.simplex)
    changed_probs = pyro.param('changed_probs', changed_probs_param, constraint=constraints.simplex)
    bleed_probs = pyro.param('bleed_probs', bleed_probs_param, constraint=constraints.simplex)
    elevation_probs = pyro.param('elevation_probs', elevation_probs_param, constraint=constraints.simplex)
    site_probs = pyro.param('site_probs', site_probs_param, constraint=constraints.simplex)
    diameter_probs = pyro.param('diameter_probs', diameter_probs_param, constraint=constraints.simplex)
    age_probs = pyro.param('age_probs', age_probs_param, constraint=constraints.simplex)

    # distributions
    # first layer
    CNN = pyro.sample('img_latent_variable', dist.Categorical(probs=(nn.Softmax()(enb3(img_obs)))))
    # second layer
    diagnosis = pyro.sample('diagnosis', dist.Categorical(probs=diagnosis_probs[(CNN).long()]), obs=diagnosis_obs)
    # third layer
    itch = pyro.sample('itch', dist.Categorical(probs=itch_probs[(CNN*6+diagnosis).long()]), obs=itch_obs)
    grew = pyro.sample('grew', dist.Categorical(probs=grew_probs[(CNN*6+diagnosis).long()]), obs=grew_obs)
    hurt = pyro.sample('hurt', dist.Categorical(probs=hurt_probs[(CNN*6+diagnosis).long()]), obs=hurt_obs)
    changed = pyro.sample('changed', dist.Categorical(probs=changed_probs[(CNN*6+diagnosis).long()]), obs=changed_obs)
    bleed = pyro.sample('bleed', dist.Categorical(probs=bleed_probs[(CNN*6+diagnosis).long()]), obs=bleed_obs)
    elevation = pyro.sample('elevation', dist.Categorical(probs=elevation_probs[(CNN*6+diagnosis).long()]), obs=elevation_obs)
    site = pyro.sample('site', dist.Categorical(probs=site_probs[(CNN*6+diagnosis).long()]), obs=site_obs)
    diameter = pyro.sample('diameter', dist.Categorical(probs=diameter_probs[(CNN*6+diagnosis).long()]), obs=diameter_obs)
    age = pyro.sample('age', dist.Categorical(probs=age_probs[(CNN*6+diagnosis).long()]), obs=age_obs)
    return diagnosis

def guide(img_obs, itch_obs, grew_obs, hurt_obs, changed_obs, bleed_obs, elevation_obs, site_obs, diameter_obs, age_obs, diagnosis_obs=None):
    pass

# 3. DEFINE THE INFERENCE PROCESS
def predict(img_obs, itch_obs=None, grew_obs=None, hurt_obs=None, changed_obs=None, bleed_obs=None, elevation_obs=None, site_obs=None, diameter_obs=None, age_obs=None):
    conditional_marginals = TraceEnum_ELBO().compute_marginals(model, guide, img_obs=img_obs, itch_obs=itch_obs, grew_obs=grew_obs, hurt_obs=hurt_obs, changed_obs=changed_obs, bleed_obs=bleed_obs, elevation_obs=elevation_obs, site_obs=site_obs, diameter_obs=diameter_obs, age_obs=age_obs)
    diagnosis_prob_0 = conditional_marginals['diagnosis'].log_prob(torch.tensor(0)).exp().reshape(1, 1)
    diagnosis_prob_1 = conditional_marginals['diagnosis'].log_prob(torch.tensor(1)).exp().reshape(1, 1)
    diagnosis_prob_2 = conditional_marginals['diagnosis'].log_prob(torch.tensor(2)).exp().reshape(1, 1)
    diagnosis_prob_3 = conditional_marginals['diagnosis'].log_prob(torch.tensor(3)).exp().reshape(1, 1)
    diagnosis_prob_4 = conditional_marginals['diagnosis'].log_prob(torch.tensor(4)).exp().reshape(1, 1)
    diagnosis_prob_5 = conditional_marginals['diagnosis'].log_prob(torch.tensor(5)).exp().reshape(1, 1)
    diagnosis_probs = torch.cat((diagnosis_prob_0, diagnosis_prob_1, diagnosis_prob_2, diagnosis_prob_3, diagnosis_prob_4, diagnosis_prob_5), dim=0).T
    diagnosis_preds = torch.argmax(diagnosis_probs, dim=1)
    return diagnosis_preds, diagnosis_probs

# 4. EXAMINE THE MODEL ARCHITECTURE
img_obs_test = torch.ones(1, 3, img_size, img_size)
itch_obs_test = torch.ones(1)
grew_obs_test = torch.ones(1)
hurt_obs_test = torch.ones(1)
changed_obs_test = torch.ones(1)
bleed_obs_test = torch.ones(1)
elevation_obs_test = torch.ones(1)
diagnosis_obs_test = torch.ones(1)
site_obs_test = torch.ones(1)
diameter_obs_test = torch.ones(1)
age_obs_test = torch.ones(1)
pyro.render_model(model=model, model_args=(img_obs_test, itch_obs_test, grew_obs_test, hurt_obs_test, changed_obs_test, bleed_obs_test, elevation_obs_test, site_obs_test, diameter_obs_test, age_obs_test, 
                                           diagnosis_obs_test), render_distributions=True, render_params=False)

**3. Test Your Own Samples**

In [None]:
img_sample_dir = '/img_sample/'
metadata_sample_file = '/metadata_sample/metadata_sample.csv'
ds_sample = CustomizedImageFolder(img_sample_dir, metadata_sample_file, transform=transform_normal)

for i in range(len(ds_sample)):
    # get one datapoint
    features, embeddings, _ = ds_sample[i]
    # assign data for MBN
    img_obs = features.reshape(1, 3, img_size, img_size)
    itch_obs = embeddings[0]
    grew_obs = embeddings[1]
    hurt_obs = embeddings[2]
    changed_obs = embeddings[3]
    bleed_obs = embeddings[4]
    elevation_obs = embeddings[5]
    site_obs = embeddings[6]
    diameter_obs = embeddings[7]
    age_obs = embeddings[8]
    # calculate the predictions
    preds, probs = predict(img_obs, itch_obs, grew_obs, hurt_obs, changed_obs, bleed_obs, elevation_obs, site_obs, diameter_obs, age_obs)
    # get top-3 predictions and their probabilities
    top_probs_ImageOnly, top_preds_ImageOnly = torch.topk((nn.Softmax()(enb3(img_obs))), 3)
    top_probs_ImageMetadata, top_preds_ImageMetadata = torch.topk(probs, 3)
    # display the diagnostic results
    print('Predicted Diagnosis of "'+str(os.path.basename(ds_sample.paths[i]))+'":')
    # CNN output
    print('\n<<<<< Image Only >>>>>')
    for rank_ImageOnly in range(top_probs_ImageOnly.size(1)):
        disease_ImageOnly = {0:'ACK - Actinic Keratosis', 1:'BCC - Basal Cell Carcinoma', 2:'MEL - Malignant Melanoma', 
                             3:'NEV - Benign Melanocytic Nevus', 4:'SCC - Squamous Cell Carcinoma', 5:'SEK - Seborrheic Keratosis'}.get(top_preds_ImageOnly[0, rank_ImageOnly].item(), 'UNK - Unknown Disease')
        probability_ImageOnly = round(100*top_probs_ImageOnly[0, rank_ImageOnly].item(), 1)
        print(f'Top {rank_ImageOnly+1}: {disease_ImageOnly} (Probability: {probability_ImageOnly}%, Baseline: {round(100/6, 1)}%)')
    # BN output
    print('\n<<<<< Image + Metadata >>>>>')
    for rank_ImageMetadata in range(top_probs_ImageMetadata.size(1)):
        disease_ImageMetadata = {0:'ACK - Actinic Keratosis', 1:'BCC - Basal Cell Carcinoma', 2:'MEL - Malignant Melanoma', 
                   3:'NEV - Benign Melanocytic Nevus', 4:'SCC - Squamous Cell Carcinoma', 5:'SEK - Seborrheic Keratosis'}.get(top_preds_ImageMetadata[0, rank_ImageMetadata].item(), 'UNK - Unknown Disease')
        probability_ImageMetadata = round(100*top_probs_ImageMetadata[0, rank_ImageMetadata].item(), 1)
        print(f'Top {rank_ImageMetadata+1}: {disease_ImageMetadata} (Probability: {probability_ImageMetadata}%, Baseline: {round(100/6, 1)}%)')
    # display the thumbnail image
    img_thumbnail = Image.open(ds_sample.paths[i])
    plt.figure(figsize=(2.5, 2.5))
    plt.imshow(img_thumbnail)
    plt.axis('off')
    plt.show()