In [8]:
import os
print("Current working directory:", os.getcwd())
IMAGES_DIR = os.environ['SM_CHANNEL_TRAINING']

Current working directory: /home/ec2-user/SageMaker/histology-image-analysis


In [None]:
# From Perplexity.ai
# When creating your PyTorch estimator,
# specify the S3 location of your data as an input channel:
estimator = PyTorch(
    entry_point='train.py',
    role=role,
    instance_count=1,
    instance_type='m5.xlarge',
    # framework_version='1.8.1',
    # py_version='py3',
    input_mode='File',
    inputs={'training': 's3://mhist-streamlit-app/images/original/'}
)

## Data

In [9]:
# Load CSV files
import pandas as pd

# image codes are 3 letters long
# 'name' : MHIST_<code>.png
# 'label' = HP or SSA # binary, categorical label
# 'experts' = (int) 0 through 7
# 'code' = 3-letter image code

# Training set samples: 2175
# Test set samples: 977
train_df = pd.read_csv('training/trainset_info.csv')
test_df = pd.read_csv('training/testset_info.csv')
print('train_df.shape', train_df.shape)
print('test_df.shape', test_df.shape)
train_df.head(), test_df.head()

train_df.shape (2175, 3)
test_df.shape (977, 3)


(            name  experts  label
 0  MHIST_aaa.png        6      1
 1  MHIST_aab.png        0      0
 2  MHIST_aac.png        5      1
 3  MHIST_aae.png        1      0
 4  MHIST_aaf.png        5      1,
             name  experts  label
 0  MHIST_aag.png        2      0
 1  MHIST_aah.png        2      0
 2  MHIST_aaq.png        5      1
 3  MHIST_aar.png        0      0
 4  MHIST_aay.png        1      0)

In [10]:
import torch
from torchvision import transforms

# Don't resize nor crop. These are medical images, so we don't want to lose
# image integrity. Also, most models, like ViT expect images to be 224x224 pixels.

# ToTensor: Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a
# torch.FloatTensor with shape (C x H x W) in the range [0.0, 1.0]

# For Normalize: (calculated from the training data per channel)
train_mean = [0.738, 0.649, 0.775]
train_std =  [0.197, 0.244, 0.17]

# Flatten data for FC
DEFAULT_FC_TRANSFORMS = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(train_mean, train_std),
    transforms.Lambda(lambda x: torch.flatten(x))
])

# Don't need to flatten our 2-D, 3-channel image data for ViT
DEFAULT_VIT_TRANSFORMS = transforms.Compose([
    # transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize(train_mean, train_std)
])

# We don't need to use a different transform for test sets here because
# we're only pre-processing images, not adding synthetic data


In [2]:
from PIL import Image
from torch.utils.data import Dataset

# Custom Dataset class:
class MHIST_dataset(Dataset):
    def __init__(self, df, images_dir=IMAGES_DIR, transform=None):
        self.df = df
        self.images_dir = images_dir
        if transform == None:
            print("Error: missing transform for MHIST_dataset")
            raise ValueError("Error: missing transform for MHIST_dataset")
        self.transform = transform


    def __len__(self):
        return len(self.df)


    # getitem() returns {image tensor, label int64, image filename}
    def __getitem__(self, idx):
        # df['label'] = 0 or 1 (int64)
        # df['name'] ex: MHIST_abc.png
        row = self.df.iloc[idx]
        full_path = os.path.join(self.images_dir, row['name'])
        image_PIL = Image.open(full_path).convert('RGB')
        if image_PIL is None:
            raise FileNotFoundError(full_path)
        if self.transform:
            image = self.transform(image_PIL) # includes ToTensor

        return {
            'image': image,
            'label': row['label'],
            'filename': row['name'],
        }

## Models

### Simple FC model

In [None]:
from torch.nn import Module, Linear
from torch.nn.functional import relu

# Number of features: 224*224*3= 150528

class SimpleFC(Module):
    def __init__(self,D_in,H1,H2,H3,D_out):
        super().__init__()
        self.layer1 = Linear(D_in,H1)
        self.layer2 = Linear(H1,H2)
        self.layer3 = Linear(H2,H3)
        self.outlayer = Linear(H3,D_out)
    def forward(self,x):
        x = relu(self.layer1(x))
        x = relu(self.layer2(x))
        x = relu(self.layer3(x))
        return self.outlayer(x)

fc = SimpleFC(150528, 2352, 2352, 294, 1)
fc

### **ViT**
I used TIMM to download the ViT model and change the output to binary. I set `pretrained = True` and froze all except the head to fine-tune the model with pretrained weights. This version is ViT base and it breaks (224 x 224) images into 16 patches.

In [None]:
!pip install -q timm # CPU-only version is: timm[torch-cpu]
import timm
VIT_MODEL_TYPE = 'vit_base_patch16_224.augreg2_in21k_ft_in1k'

## Training

In [None]:
# For MLflow tracking server

!pip install -q mlflow
!pip install -q boto3
!pip install -q awscli
!aws configure
!aws --version # aws-cli/1.33.15 Python/3.10.12 Linux/6.1.85+ botocore/1.34.133
import mlflow
'mlflow', mlflow.__version__ # 2.14.1

In [None]:
import os
import math
import time
import logging

# For data
import pandas as pd
from torch.utils.data import DataLoader

# # For training
# TRAIN_DF = pd.read_csv('artifacts/old_train_df.csv')
# TEST_DF = pd.read_csv('artifacts/old_test_df.csv')

# For inference
TRAIN_DF = None
TEST_DF = pd.read_csv('artifacts/old_test_df.csv')

# For training
import torch
print('torch.version', torch.__version__)
print('torch.version.cuda', torch.version.cuda)
print('torch.backends.cudnn.version', torch.backends.cudnn.version())

import copy
from tqdm import tqdm
from torch.nn import Conv2d, BCEWithLogitsLoss, init # for init.xavier_uniform_
from torch.optim import lr_scheduler, Adam
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, balanced_accuracy_score, roc_auc_score, average_precision_score, confusion_matrix, classification_report

# For tracking
# import mlflow # imported above to get version number
MLFLOW_SERVER = "http://13.52.243.246:5000"
MLFLOW_MODEL_PATH = 'onnx_artifacts' #1.4G
MLFLOW_DEFAULT_EXPERIMENT = 'MHIST FC (binary classification)'

# checkpoints are saved in PyTorch format (by using torch.save model.state_dict)
# uses current device (CPU or GPU)
# FC model should have a single output of logits (positive_prob = sigmoid(logit))
DEFAULT_MODEL_PATH = 'MHIST_model.pt' # Relative path for saving and loading checkpoints
DEFAULT_LR = 1e-6 #1e-4 #2.5e4
DROP_LAST_BATCH = False # The dataset size might not divisible by the batch size

# To save confusion matrix as a heatmap
import seaborn as sns
import matplotlib.pyplot as plt

class TrainingSession:
  def __init__(self,
               train_transform=None,
               val_transform=None,
               path_for_resuming = None, # resume from local path if not None (not used for saving best model)
               resume_from_object = False, # resume from model object
               model = None, # if not resume_from_object, FC model is initialized to random
               model_type = 'FC', # 'FC' or 'VIT'
               batch_size = None,
               eval_on='loss',
               enable_tracking=False,
               logger=None,
               ):
    self.model_type = model_type
    self.logger = logging.getLogger(DEFAULT_LOGGER_NAME) if logger is None else logger
    self.enable_tracking = enable_tracking

    # Training Setup
    self.random = 42
    self.label_names = ['HP', 'SSA'] # HP = 0, SSA = 1
    self.learning_rate = None # will be set before training
    self.optimizer = None     # will be set before training
    self.scheduler = None     # will be set before training
    self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # <class 'torch.device'>
    self.batch_size = self.__get_batch_size(batch_size)
    self.train_transform, self.val_transform = self.__get_transforms(train_transform, val_transform) # use default for None
    self.loader_dict = self.__get_loaders(TRAIN_DF, TEST_DF) # might return None or dict['train'] = None
    self.model, self.best_model_wts = self.__model_init(path_for_resuming, resume_from_object, model) # resume from path if it is not None

    # For saving the best model
    self.eval_on = eval_on      # eval model on this metric (can't change this later)
    self.best_metric = None     # will be set before training
    self.best_model_dest = None # will be set before training
    # self.save_total_limit=5,  # 4 most recent models, plus best, when load_best_model_at_end=True
    self.precision_goal = 0.8 # 0.8152
    self.recall_goal = 0.8 # 0.8039
    self.f1_goal = 0.8 # 0.8178
    self.accuracy_goal = 0.8 # 0.8192


  def __get_batch_size(self, batch_size):
      if batch_size is None:
        print('Using default batch size')
        return 32 if torch.cuda.is_available() else 4 # BATCH_SIZE = 256 # for A100
      return batch_size


  def __get_transforms(self, train_transform, val_transform): # check for None values
    print('Setting up data:')
    if self.model_type == 'FC':
        print('Using FC transforms')
        if train_transform is None:
          train_transform = DEFAULT_FC_TRANSFORMS
        if val_transform is None:
          val_transform = DEFAULT_FC_TRANSFORMS

    elif self.model_type == 'VIT':
        print('Using ViT transforms')
        if train_transform is None:
          train_transform = DEFAULT_VIT_TRANSFORMS
        if val_transform is None:
          val_transform = DEFAULT_VIT_TRANSFORMS
    return train_transform, val_transform


  def __get_loaders(self, train_df=None, test_df=None):
    samples_weights = np.load('artifacts/samples_weights.npy')
    sampler = WeightedRandomSampler(samples_weights, len(samples_weights))

    if test_df is None:
      print("Error: can't evaluate or run inference without test_df.")
      return None

    elif train_df is None: # don't need to shuffle
      val_loader = DataLoader(MHIST_dataset(test_df, transform=self.val_transform),
                                batch_size=self.batch_size, shuffle=False, drop_last=DROP_LAST_BATCH)
      loader_dict = {'train':None, 'val':val_loader}
      test_samples = len(test_df) # loader has length = num_batches
      print(f"No training set. Model is ready for inference on {test_samples} samples.")
      print('batch_size', self.batch_size, 'val_num_batches', math.floor(test_samples/self.batch_size))
      return loader_dict # Training loader is empty!

    else: # we have train_df and test_df
      print('train_df len', len(train_df), 'test_df len', len(test_df))
      train_loader = DataLoader(MHIST_dataset(train_df, transform=self.train_transform),
                                batch_size=self.batch_size, sampler=sampler, drop_last=DROP_LAST_BATCH)
      val_loader = DataLoader(MHIST_dataset(test_df, transform=self.val_transform),
                                batch_size=self.batch_size, shuffle=True, drop_last=DROP_LAST_BATCH)
      loader_dict = {'train':train_loader, 'val':val_loader}

      train_samples = len(train_df)
      test_samples = len(test_df)
      print(f"Training with {len(loader_dict['train'])} batches, validating with {len(loader_dict['val'])} batches with batch_size {self.batch_size}")
      print('train_num_batches', math.ceil(train_samples/self.batch_size), 'and val_num_batches', math.ceil(test_samples/self.batch_size))
      return loader_dict

      # print("\nExample image:")
      # import random
      # rand_img_path = IMAGES_DIR+random.choice(annotations['Image Name'])
      # print('Random image path:', rand_img_path)
      # visualize_scan(rand_img_path)


  @staticmethod
  def initialize_weights(module):
      if type(module) == Linear or type(module) == Conv2d:
          init.xavier_uniform_(module.weight)


  # Log gradient summary before training
  def __grad_summary(self):
      if mhist_logger.getEffectiveLevel() <= logging.WARNING:
          layers = list(self.model.children())
          total_layers_training = 0
          # print('Model layer gradients:\n', layers)

          # Get the parameters of the first three layers
          for i, layer in enumerate(layers, 1):
              # count total tensors and tensors that require grad
              print(f"Layer {i}:")
              total_tensors = 0
              num_training = 0
              for param in layer.parameters():
                  total_tensors += 1 # two tensors per linear layer: weights and biases
                  if param.requires_grad:
                      num_training += 1
                  # print(param.dtype) # Debug
                  # print(param.shape) # Debug

              if num_training==0 and total_tensors>0: # fully frozen
                  print(f"  frozen")
              elif total_tensors>0:
                  total_layers_training += 1
                  print(f"  tensors that require grad: {num_training}/{total_tensors}") # Info
              else:
                  print(f"  no tensors")
          print('---------------------------------')
          print('Training', total_layers_training, 'layers of', len(layers), 'total')


  def __freeze_all_but_head(self):
      # Freeze all params
      for param in self.model.parameters():
          param.requires_grad = False

      # Unfreeze last layer
      if hasattr(self.model, 'head'):
          for param in self.model.head.parameters():
              param.requires_grad = True
      else:
          print("Error: Can't unfreeze classification head")

      # # Check grads
      # for name, param in model.named_parameters():
      #     print(f"{name}: requires_grad={param.requires_grad}")


      # For large models
      # train_dense_w_and_b = sum([1 if param.requires_grad else 0 for param in self.model.classifier.dense.parameters()])
      # train_out_w_and_b = sum([1 if param.requires_grad else 0 for param in self.model.classifier.out_proj.parameters()])

      # if train_dense_w_and_b == 0 and train_out_w_and_b == 0:
      #     logger.debug('\nClassification head is frozen.')
      # elif train_dense_w_and_b == 2 and train_out_w_and_b == 2:
      #     logger.debug('\nFine-tuning the classification head.')

      # else: # Some params in classification head are being trained and some aren't
      #     if train_dense_w_and_b == 0:
      #         dense = 'Final dense layer is frozen'
      #     elif train_dense_w_and_b == 1:
      #         dense = 'Warning: fine-tuning some, but not all params in dense layer'
      #     else: # train_dense_w_and_b == 2:
      #         dense = 'Fine-tuning final dense layer'

      #     if train_out_w_and_b == 0:
      #         out = ', but output projection layer is frozen.'
      #     elif train_out_w_and_b == 1:
      #         out = 'and Warning: fine-tuning some, but not all params in output projection layer.'
      #     else: # train_out_w_and_b == 2:
      #         out = ', but fine-tuning output projection layer.'
      #     logger.debug(dense + out)


  def __alloc_new_model(self, model_obj): # this won't init weights
      if model_obj is None and self.model_type == 'FC':
        print('Allocating a new SimpleFC')
        model_obj = SimpleFC(150528, 2352, 2352, 294, 1)
      elif model_obj is None and self.model_type == 'VIT':
        model_obj = timm.create_model(
            model_name=VIT_MODEL_TYPE,
            pretrained=True,
            num_classes=1, # change number of outputs in classification head
        )
        print(f"Allocating a new model from timm: {model_obj.default_cfg['architecture']} with pretrained weight tag: {model_obj.default_cfg['tag']}")
      return model_obj

  def __model_init(self, path_for_resuming, resume_from_object, model_obj): # resume from path if it is not None
    # If we're resuming from object and path, resuming from object is given priority
    # This way, we won't overwrite memory, if we accidentally pass in a path as well
    # If we meant to resume from a path, we can always retry because it's on disk
    if resume_from_object: # else use model passed into the training method without altering the weights
      if model_obj is None:
          print("Warning: Can't resume because model = None")
          return None, None
      print('Info: Resuming from model object')
      # model_obj returned below

    elif path_for_resuming is not None: # Load weights from disk
      model_obj = self.__alloc_new_model(model_obj) # based on self.model_type
      print('Info: Resuming from saved state dict', path_for_resuming)
      model_obj.load_state_dict(torch.load(path_for_resuming)) # model weights might not work with default model, above

    # Create default model (not resuming from object nor path)
    elif model_obj is None:
      model_obj = self.__alloc_new_model(model_obj) # based on self.model_type
      if self.model_type == 'FC': # don't init ViT pretrained
          print('Info: Initializing model weights with Xavier Uniform')
          model_obj.apply(self.initialize_weights) # Xavier Uniform

    # Load into memory, in case this session makes the weights worse
    # If not resuming from object, use the model that was passed in (which could be different from the currently loaded model)
    print('Loading a copy of initial model weights into memory')
    best_model_wts = copy.deepcopy(model_obj.state_dict())

    return model_obj, best_model_wts # model to train and copy of initial model weights (random or from checkpoint)



                    ### TRAIN/EVAL LOOP ###

  # Wrapper for setting params before running train/eval loop
  def train_and_evaluate(self, model = None, # (optional) pass a model to freeze/unfreeze, check, or change anything between runs
                         epochs = 10,
                         freeze_all_but_head = True,
                         learning_rate = None,
                         best_metric = None,
                         best_model_dest = DEFAULT_MODEL_PATH,
                         mlflow_experiment = MLFLOW_DEFAULT_EXPERIMENT,
                         mlflow_run = None):

    if model is not None:
        self.model = model
    if freeze_all_but_head:
        self.__freeze_all_but_head()
    self.__grad_summary()

    if best_metric is not None:
        self.best_metric = best_metric
    elif self.best_metric is None and self.eval_on != 'loss':
        self.best_metric = 0.
    # else: if eval_on == 'loss' then best_metric might be None
    self.best_model_dest = best_model_dest
    print('Evaluating the model on', self.eval_on, 'best_metric =', self.best_metric)

    # Get a list of trainable params to pass into optimizer (for efficiency)
    trainable_tensors = [p for p in self.model.parameters() if p.requires_grad]
    # Init self with default LR and create a new optimizer
    if learning_rate is None and self.learning_rate is None:
        self.learning_rate = DEFAULT_LR
        self.optimizer = Adam(params=trainable_tensors, lr=DEFAULT_LR)

    # Update with new LR (passed into method)
    elif learning_rate is not None:
        if self.learning_rate is None: # init with LR param and new optimizer
            self.optimizer = Adam(params=trainable_tensors, lr=learning_rate)
        else: # update existing LR with new one in existing optimizer
            self.optimizer.param_groups[0]['lr'] = learning_rate
        self.learning_rate = learning_rate
    # else: param is None, so don't update existing LR and optimizer (which are not None)
    print('Initial learning rate', self.learning_rate)

    # Train the model
    if self.enable_tracking:
        # __track_training will call __train_eval_loop()
        self.__track_training(epochs, mlflow_experiment, mlflow_run)
    else: # Don't use MLflow tracking server
        print('Info: Starting training without tracking to MLflow server.')
        self.__train_eval_loop(epochs)

    return self.model # in case we want to freeze/unfreeze, check, or change anything


  def __get_model_hyperparams(self):
    # ViT input: torch.FloatTensor with shape (C x H x W) in the range [0.0, 1.0]
    input_type = {
        'FC':"flattened 3 x 224 x 224 (RGB mode) in range [0.0, 1.0]",
        'VIT':"3 x 224 x 224 (RGB mode) in range [0.0, 1.0]",
    }
    in_feature_shape = 150528 if self.model_type == 'FC' else [3, 224, 224] # (C x H x W)
    return dict(
      model_class=str(self.model.__class__.__name__),
      model_info=str(self.model),
      input_image=input_type[self.model_type],
      num_in_features=in_feature_shape,
      in_feature_dtype = 'torch.float32',
      num_outputs=1,
      label_names=self.label_names,
      test_size='is set with dataframes',
      batch_size=self.batch_size,
      random=self.random,
      device=str(self.device),
      eval_on = self.eval_on,
      best_metric = self.best_metric,
      best_model_dest = self.best_model_dest,
      # logger=get_log_filename(self.logger),
    )


  def __track_training(self, epochs, experiment, run_id):
    mlflow.set_tracking_uri(MLFLOW_SERVER)
    mlflow.set_experiment(experiment)

    # "with" manages context by calling mlflow.end_run() appropriately.
    # Beforehand, it ends any active run. Upon error or exit, it ends this run.
    # run_id = None will start a new run
    with mlflow.start_run(run_id) as run:
        print('\nLogging metrics with server:', mlflow.get_tracking_uri())
        print('MLflow: run_name =', run.info.run_name, 'run_id =', run.info.run_id) # run_id is a UUID
        self.run_id = run.info.run_id # save in case we want to log the model later with self.log_model
        # print('Artifacts stored at:', mlflow.get_artifact_uri())
        print('MLflow Experiment name:', experiment, '\n')
        mlflow.log_params(self.__get_model_hyperparams())

        self.__train_eval_loop(epochs)


  def __train_eval_loop(self, epochs):
    if self.scheduler is None:
        self.scheduler = lr_scheduler.StepLR(self.optimizer, step_size=5, gamma=0.1, last_epoch=-1) # multiply LR by 10% every 5 epochs

    self.model.to(self.device)
    print('Device:', next(self.model.parameters()).device, 'torch.cuda.device_count:', torch.cuda.device_count())
    criterion = BCEWithLogitsLoss()
    steps = 0
    previous_lr = None # To track changes in LR
    start = time.time()

    for epoch in range(epochs):
        # Each epoch has a two phases: training and validation
        for phase in ['train', 'val']:
            if phase == 'train':
                self.model.train()
            else:
                self.model.eval()

            # for each epoch phase:
            epoch_loss    = 0.
            epoch_correct = 0
            epoch_samples = 0
            y = []      # correct labels
            y_pred = [] # pred labels

            for batch_idx, batch in enumerate(tqdm(self.loader_dict[phase])):#, total=len(self.loader_dict[phase])):
                # batch is a dict of:
                # {'image' : tensor of batch_size flattened or [batch,3,224,224] images,
                #  'label' : tensor batch_size integer labels}
                images = batch['image'].to(self.device) # logits dtype torch.float32
                labels = batch['label'].type(torch.FloatTensor).to(self.device) # labels dtype torch.float32

                # zero the parameter gradients
                self.optimizer.zero_grad()

                # forward + backward + optimize
                # track history only when phase == 'train'
                with torch.set_grad_enabled(phase == 'train'):
                    batch_size = len(labels)
                    logits = self.model(images).squeeze() # output is torch.Size([32, 1])
                    loss = criterion(logits, labels) # BCEWithLogitsLoss
                    preds = (logits > 0).float() # np.where(logits.squeeze() > 0, 1, 0)

                    # DEBUG: shapes and dtypes
                    # print('\nlabels:', labels.shape, 'dtype', labels.dtype) # torch.float32
                    # print('logits:', logits.shape, 'dtype', logits.dtype) # torch.float32
                    # print('preds:', preds.shape, 'dtype', preds.dtype) # torch.float32
                    num_correct = (preds == labels).sum().item() # this will broadcast if shapes don't match

                    if phase == 'train':
                        loss.backward()
                        self.optimizer.step()
                        steps += 1

                # Update totals and lists
                epoch_loss    += loss.item() * batch_size # TODO: check this
                epoch_correct += num_correct
                epoch_samples += batch_size  # last dataloader batch might be smaller
                y.extend(labels.cpu().numpy())
                y_pred.extend(preds.cpu().numpy())

                # # DEBUG: First batch (in 'train' phase)
                # if epoch == 0 and batch_idx == 0 and phase == 'train':
                #   labels_np= labels.cpu().numpy()
                #   preds_np= preds.cpu().numpy()
                #   logits_np = logits.squeeze().cpu().detach().numpy()
                #   f1 = f1_score(y, y_pred, average='weighted') # accepts numpy or list
                #   pr_auc = average_precision_score(y, y_pred, average='weighted')
                #   recall = recall_score(y, y_pred, pos_label=1) # recall for positive class
                #   print('Label, Pred, Logit, Correct')
                #   print(np.column_stack((labels_np, preds_np, logits_np, preds_np == labels_np)))
                #   print(f'[Epoch {epoch}, Batch {batch_idx}] Batch loss: {loss.item()/batch_size :.3f}, Accuracy: {float(num_correct)/batch_size :.3f} correct, F1 Score: {f1 :.3f}, Recall for 1: {recall:.3f}')
                #   print(classification_report(y, y_pred))

            # After each epoch phase, update stats:
            epoch_metrics = self.__compute_metrics(phase, epoch_loss, epoch_correct, epoch_samples, y, y_pred)
            if phase == 'train':
              eval_on = self.eval_on
              loss_name = 'loss'
              f1_name = 'f1'
              accuracy_name = 'accuracy'
            else:
              eval_on = 'val_'+self.eval_on
              loss_name = 'val_loss'
              f1_name = 'val_f1'
              accuracy_name = 'val_accuracy'
            if self.enable_tracking:
                mlflow.log_metrics(metrics=epoch_metrics, step=steps)
            print(f"Epoch {epoch+1} ({steps} steps) {phase} Loss: {epoch_metrics[loss_name]:.3f}, {eval_on}: {epoch_metrics[eval_on]}, F1: {epoch_metrics[f1_name]} Accuracy: {epoch_metrics[accuracy_name]:.3f}")

            if phase == 'train':
                # Step LR scheduler
                self.scheduler.step()

                # Check if the learning rate has changed
                current_lr = self.scheduler.get_last_lr()[0]
                if previous_lr is None:
                    previous_lr = current_lr
                elif current_lr != previous_lr:
                    print(f"Epoch {epoch+1} {phase}: Learning rate changed to {current_lr:.6f}")
                    previous_lr = current_lr

        # End of an epoch (train and val)
        if epoch==0:
            epoch_time = time.time() - start
            print(f'Epoch 0 completed in {epoch_time // 60:.0f}m {epoch_time % 60:.0f}s')

    time_elapsed = time.time() - start
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best {self.eval_on} score: {self.best_metric}')
    initial_lr = self.optimizer.param_groups[0]['initial_lr']
    final_lr = self.scheduler.get_last_lr()[0]
    print(f'Learning rate: start={initial_lr} final={final_lr}')
    cm = confusion_matrix(y, y_pred)
    print(cm)
    print(classification_report(y, y_pred))

    if self.enable_tracking:
        mlflow.log_metric("time_elapsed", time_elapsed, step=steps)
        self.__log_dict(cm, step=steps)
        self.__save_heatmap(cm) # save image to local and mlflow artifacts
        self.__log_hyperparams(self.optimizer, self.scheduler, criterion)

    # Load best model weights, in case this run was worse
    self.model.load_state_dict(self.best_model_wts)


  def __log_hyperparams(self, optimizer, scheduler, criterion):
        # Adam optimizer
        mlflow.log_params(dict(
            optimizer = optimizer.__class__.__name__,
            initial_lr = self.optimizer.param_groups[0]['initial_lr'],
            final_lr = self.scheduler.get_last_lr()[0],
            weight_decay = optimizer.param_groups[0]['weight_decay']
            ))

        # StepLR Scheduler
        mlflow.log_params(dict( # StepLR
            scheduler = scheduler.__class__.__name__,
            step_size = scheduler.step_size,
            gamma = scheduler.gamma,
            last_epoch = scheduler.last_epoch,
            verbose = scheduler.verbose,
        ))

        # BCEWithLogitsLoss
        mlflow.log_param('loss_function', type(criterion).__name__)


  def __log_dict(self, cm, step):
      # Rows represent the actual classes
      # Columns represent the predicted classes
      mlflow.log_metric("true_negative", cm[0][0], step=step)
      mlflow.log_metric("false_positive", cm[0][1], step=step)
      mlflow.log_metric("false_negative", cm[1][0], step=step)
      mlflow.log_metric("true_positive", cm[1][1], step=step)


  def __save_heatmap(self, cm_numpy):
      plt.figure(figsize=(4, 3))
      sns.heatmap(cm_numpy, annot=True, fmt='d', cmap='Blues')
      plt.title('Confusion Matrix')
      plt.ylabel('True label')
      plt.xlabel('Predicted label')

      # Save and upload image
      model_name = os.path.splitext(os.path.basename(self.best_model_dest))[0] # get filename, remove extension
      image_path = f'artifacts/{model_name}_confusion_matrix.png'
      plt.savefig(image_path) # it's possible that there won't be a best model checkpoint to match this image
      mlflow.log_artifact(image_path)


  def __compute_metrics(self, phase, epoch_loss, epoch_correct, epoch_samples, y, y_pred):
      # print(f'Computing metrics for {phase}')
      # # debug:
      # compare_labels = np.column_stack((logits, preds, labels, preds == labels))
      tn, fp, fn, tp = confusion_matrix(y, y_pred).ravel()
      metrics = dict(
          loss = epoch_loss/epoch_samples, # average loss per sample (not per batch)
          weighted_precision = precision_score(y, y_pred, average='weighted'),
          weighted_recall = recall_score(y, y_pred, average='weighted'),
          f1 = f1_score(y, y_pred, average='weighted'),
          accuracy = accuracy_score(y, y_pred),
          false_negative_rate = fn / (fn + tp),
          true_positive_rate = tp / (tp + fn) # recall for positive class
          # recall = recall_score(y, y_pred, pos_label=1), # recall for positive class
      )

      if phase == 'val': # check model for improvement (epoch end)
          improved = False
          # loss improves when minimized, others when maximized
          if self.eval_on == 'loss' and self.best_metric is None:
              self.best_metric = metrics[self.eval_on]
          elif self.eval_on == 'loss' and metrics[self.eval_on] < self.best_metric:
              print(f"Improved validation loss: {metrics['loss']}")
              improved = True
          elif self.eval_on != 'loss' and metrics[self.eval_on] > self.best_metric:
              print(f'Improved validation {self.eval_on}: {metrics[self.eval_on]}')
              improved = True

          if improved and self.__met_goal(metrics):
              print(f'Saving model weights to {self.best_model_dest}')
              torch.save(self.model.state_dict(), self.best_model_dest) #to disk
              best_model_wts = copy.deepcopy(self.model.state_dict()) #to memory
              self.best_metric = metrics[self.eval_on]

          return {'val_'+key: value for key, value in metrics.items()} # add prefix to val metric names
      return metrics


  def __met_goal(self, metrics_dict):
      met_new_goal = False
      if metrics_dict['weighted_precision'] > self.precision_goal:
          print(f"Improved validation weighted_precision: {metrics_dict['weighted_precision']}")
          self.precision_goal = metrics_dict['weighted_precision']
          met_new_goal = True
      if metrics_dict['weighted_recall'] > self.recall_goal:
          print(f"Improved validation weighted_recall: {metrics_dict['weighted_recall']}")
          self.recall_goal = metrics_dict['weighted_recall']
          met_new_goal = True
      if metrics_dict['f1'] > self.f1_goal:
          print(f"Improved validation F1-score: {metrics_dict['f1']}")
          self.f1_goal = metrics_dict['f1']
          met_new_goal = True
      if metrics_dict['accuracy'] > self.accuracy_goal:
          print(f"Improved validation accuracy: {metrics_dict['accuracy']}")
          self.accuracy_goal = metrics_dict['accuracy']
          met_new_goal = True
      return met_new_goal


  # Log from path or log self.model
  # Pass in run_id or use self.run_id

  # Implement logging to mlflow based on best_metric here
  # Compate best_metric from runs in mlflow experiment,
  # in case there may be a better model already stored on remote.
  # If this is a better model than the best model in the experiment, log the artifact.
  # user needs to manage the number of models in s3.
  # MLflow can store metrics regardless of whether the artifacts are present.
  # Serialize as ONNX
  # This won't save the model in the MLflow registry.
  def log_model(self, path=None, run_id=None):
    # self.model
    pass


  # Download onnx model from MLflow server to local PyTorch checkpoint
  # for resuming training or running inference
  @staticmethod
  def download_model(self, run_id, local_dest=DEFAULT_MODEL_PATH):
    # if self.debug in ['debug', 'info']:
    #     print(f'Downloading Huggingface model and tokenizer: {pretrained_name}')

    # # Huggingface Model:
    # # If config.num_labels == 1 a regression loss is computed (Mean-Square loss),
    # # If config.num_labels > 1 a classification loss is computed (Cross-Entropy).
    # # FB RoBERTa already uses 2 labels, so we don't need to change the model layers (num_labels=2) or ignore the number of weights
    # # ignore_mismatched_sizes=True # ignore downloaded model weights for the last layer (out_proj) to output new num_labels
    # model = AutoModelForSequenceClassification.from_pretrained(pretrained_name, num_labels=2, ignore_mismatched_sizes=True)
    # model.config.id2label = {i:label for i, label in enumerate(self.labels)}
    # model.config.label2id = {label:i for i, label in enumerate(self.labels)}
    # tokenizer = AutoTokenizer.from_pretrained(pretrained_name, config=model.config)

    # if self.debug in ['debug']:
    #     print('Model config problem type:', model.config.problem_type)
    #     print(model)

    # # # Freeze pretrained model weights/biases
    # # for param in model.parameters():
    # #     param.requires_grad = False
    # #     print('\nAll layers are frozen')
    # # for param in model.classifier.parameters(): # train classification head
    # #     param.requires_grad = True
    # #     print('except for the classification head.')

    # return model
    pass

  @staticmethod
  def __sigmoid(np_outs):
      np_outs = np.clip(np_outs, -50, 50) # prevent np.exp overflow for large values (in case of an issue with preprocessing)
      return 1 / (1 + np.exp(-np_outs))


  # Batch inference: initalize a session with a model, then run this method (not train)
  # input df path is hard-coded, above, in global var: TEST_DF
  # The following uses numpy, not Pytorch, but it would be faster and more concise to use Pytorch
  def batch_predict(self):
      self.model.eval()
      self.model.to(self.device)
      dataset = self.loader_dict['val'].dataset
      epoch_np = None
      with torch.no_grad():
          first_batch = True
          for batch in tqdm(self.loader_dict['val']):
              images = batch['image'].to(self.device) # torch.Size([960, 3, 224, 224]) torch.float32
              # print('inputs.dtype:', images.dtype, 'shape:', images.shape)
              logits_np = self.model(images).detach().cpu().squeeze().numpy() # shape=(960,) dtype=float32
              # print('logits_np.dtype:', logits_np.dtype, 'shape:', logits_np.shape)

              positive_probs_np = self.__sigmoid(logits_np) # shape=(960,) dtype=float32
              # print('positive_probs_np.dtype:', positive_probs_np.dtype, 'shape:', positive_probs_np.shape)
              preds_np = (positive_probs_np > 0.5).astype(int) # shape=(960,) dtype=int64
              # print('preds_np.dtype:', preds_np.dtype, 'shape:', preds_np.shape)

              labels_np = batch['label'].cpu().squeeze().numpy() # shape=(960,) dtype=int64
              # print('labels_np.dtype:', labels_np.dtype, 'labels_np.shape:', labels_np.shape)
              correct_np = np.equal(labels_np, preds_np).astype(int) # shape=(960,) dtype=int64
              # print('correct_np.dtype:', correct_np.dtype, 'shape:', correct_np.shape)

              # # Store results (keyed on the idx of the image in the dataset)
              # idx_np = batch['idx'].cpu().squeeze().numpy()
              # print('idx_np.dtype:', idx_np.dtype, 'idx_np.shape:', idx_np.shape)
              # for i in range(len(idx_np)):
              #     dataset.store_result(
              #         idx_np[i].item(),
              #         logits_np[i].item(),
              #         preds_np[i].item(),
              #         positive_probs_np[i].item(),
              #         correct_np[i].item()
              #     )

              filename_np = np.array(batch['filename']) # np objects, shape (960,) filenames are returned as <class 'list'>
              if first_batch:
                  first_batch = False
                  epoch_np = np.column_stack((filename_np, labels_np, logits_np, preds_np, positive_probs_np, correct_np))
                  # print(pd.DataFrame(epoch_np, columns=['filename', 'label', 'logit', 'prediction', 'positive_prob', 'correct']))
              else:
                  batch_np = np.column_stack((filename_np, labels_np, logits_np, preds_np, positive_probs_np, correct_np))
                  epoch_np = np.vstack((epoch_np, batch_np))
      results_df = pd.DataFrame(epoch_np, columns=['filename', 'label', 'logit', 'prediction', 'positive_prob', 'correct'])

      # Convert datatypes to string, float, int
      results_df['filename'] = results_df['filename'].astype('string')
      results_df['logit'] = pd.to_numeric(results_df['logit'], errors='coerce')
      results_df['positive_prob'] = pd.to_numeric(results_df['positive_prob'], errors='coerce')
      results_df['label'] = pd.to_numeric(results_df['label'], errors='coerce').astype('Int64')
      results_df['prediction'] = pd.to_numeric(results_df['prediction'], errors='coerce').astype('Int64')
      results_df['correct'] = pd.to_numeric(results_df['correct'], errors='coerce').astype('Int64')

      # calculate metrics
      y_true = results_df['label'].to_numpy(dtype=int)
      y_prob = results_df['positive_prob'].to_numpy(dtype=int)
      y_pred = results_df['prediction'].to_numpy(dtype=int)
      tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
      metrics = {
          'weighted_precision': precision_score(y_true, y_pred, average='weighted'),
          'weighted_recall': recall_score(y_true, y_pred, average='weighted'),
          'weighted_f1': f1_score(y_true, y_pred, average='weighted'),
          'accuracy': accuracy_score(y_true, y_pred),
          'balanced_accuracy': balanced_accuracy_score(y_true, y_pred),
          'roc_auc': roc_auc_score(y_true, y_prob),
          'pr_auc': average_precision_score(y_true, y_prob),
          'false_negative_rate': fn / (fn + tp),
          'true_positive_rate': tp / (tp + fn) # recall for positive class
      }

      return results_df, metrics


### Train FC model

In [None]:
# previous best: eval_on f1 , best_metric=0.7669682977041087
# Improved validation recall: 0.5482954545454546
fc_session = FinetuneFC(
              #  train_transform=None,
              #  val_transform=None,
              #  path_for_resuming = 'artifacts/MHIST_FC_run2', # resume from local path if not None (not used for saving best model)
               # resume_from_object = False, # resume from model object
               # model = best_model, # if not resume_from_object, model is initialized to random
               batch_size = 64 if torch.cuda.is_available() else 4,
               eval_on='recall', # for positive (minority) class
               enable_tracking = True,
              #  logger=None,
               )

best_model = fc_session.train_and_evaluate(
    # model = None, # (optional) pass a model to freeze/unfreeze, check, or change anything between runs
    epochs = 10,
    learning_rate = 1e-6,
    best_metric = 0.,
    best_model_dest = "artifacts/MHIST_SmallFC_v1.pt",  # DEFAULT_MODEL_PATH = 'MHIST_model.pt'
    mlflow_experiment = "MHIST FC (binary classification)",
    #  mlflow_run = None # create a new run
    )

### Train ViT

In [None]:
# Eval on: improved true_positive_rate, ideally > 0.95 (fnr = 0.05)
# Previous best: true_positive_rate = 0.9
# Improved validation recall: 0.5482954545454546
session = TrainingSession(
    # train_transform=None,
    # val_transform=None,
    # path_for_resuming = 'artifacts/MHIST_FCN_run2', # resume from local path if not None (not used for saving best model)
    # resume_from_object = False, # don't resume from model object
    # model = best_model, # if not resuming, new FC model is initialized to random
    model_type='VIT',
    batch_size = 512,
    eval_on='true_positive_rate', # for positive (minority) class
    enable_tracking = True,
    # logger=None,
    )

best_ViT = session.train_and_evaluate(
    # model = None, # (optional) pass a model to freeze/unfreeze, check, or change anything between runs
    epochs = 8,
    # freeze_all_but_head = True,
    learning_rate = 1e-2,
    best_metric = 0.,
    best_model_dest = "artifacts/MHIST_ViT_v14.pt",  # DEFAULT_MODEL_PATH = 'MHIST_model.pt'
    mlflow_experiment = "MHIST ViT (binary classification)",
    #  mlflow_run = None # create a new run
    )

### Run batch inference on best model
After training, I ran batch inference on the test set with the best ViT model to inspect the results more carefully, checking for inaccurate classification and comparing the expert labeling with the results.

In [None]:
session = TrainingSession(
    # train_transform=None,
    # val_transform=None,
    path_for_resuming = 'artifacts/MHIST_ViT_v13.pt', # resume from local path if not None (not used for saving best model)
    # resume_from_object = False, # resume from model object
    # model = None, # if not resume_from_object, model is initialized to random
    model_type = 'VIT',
    batch_size = 960,
    # eval_on='loss',
    # enable_tracking=False,
    # logger=None,
    )
results_df, metrics = session.batch_predict()

# # Save
# results_df.to_csv('artifacts/MHIST_ViT_v13_results.csv', index=False)
# pd.Series(metrics).to_json('artifacts/MHIST_ViT_v13_metrics.json')

# For samples that have a negative prediction (pred class HP),
# correct the probability to be = 1 - positive_prob
corrected_results_df = results_df.copy()
corrected_results_df['prob'] = corrected_results_df['positive_prob']
neg_pred_samples = corrected_results_df['prediction'] == 0
corrected_results_df.loc[neg_pred_samples, 'prob'] = 1 - corrected_results_df.loc[neg_pred_samples, 'prob']

# Info about false negatives
incorrect_df = corrected_results_df[(corrected_results_df['correct'] == 0) & (corrected_results_df['prob']>0.5) & (corrected_results_df['label']==1)]
incorrect_annotations = incorrect_df.merge(annotations, left_on='filename', right_on='Image Name', how='left')
cols = ['filename', 'label', 'prediction', 'positive_prob', 'prob', 'Number of Annotators who Selected SSA (Out of 7)']
expert_comparison_df = incorrect_annotations.loc[incorrect_annotations['Number of Annotators who Selected SSA (Out of 7)']>5, cols]

# Display model results and false negative info
metrics, corrected_results_df.info(), corrected_results_df.head(), incorrect_annotations[cols], expert_comparison_df

In [None]:
# Get metrics on different thresholds for a given dataset
# The dataframe is hard-coded in RESULTS_DF, below
# Code is from Perplexity.ai

import numpy as np
import pandas as pd
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, roc_auc_score, average_precision_score, balanced_accuracy_score, confusion_matrix
from tqdm import tqdm
import matplotlib.pyplot as plt

RESULTS_DF = corrected_results_df

def evaluate_threshold(y, y_prob, threshold):
    y_true = y.astype(int)
    y_pred = (y_prob >= threshold).astype(int)
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    fnr = fn / (fn + tp)
    tpr = tp / (tp + fn)
    return {
        'threshold': threshold,
        'weighted_precision': precision_score(y_true, y_pred, average='weighted'),
        'weighted_recall': recall_score(y_true, y_pred, average='weighted'),
        'weighted_f1': f1_score(y_true, y_pred, average='weighted'),
        'accuracy': accuracy_score(y_true, y_pred),
        'balanced_accuracy': balanced_accuracy_score(y_true, y_pred),
        'roc_auc': roc_auc_score(y_true, y_prob),
        'pr_auc': average_precision_score(y_true, y_prob),
        'false_negative_rate': fnr,
        'true_positive_rate': tpr
    }

def find_best_thresholds(df):
    y_true = df['label'].values
    y_prob = df['positive_prob'].values

    thresholds = np.arange(0.01, 1.00, 0.01)  # Expanded range
    metrics = []

    for threshold in tqdm(thresholds, desc="Evaluating thresholds"):
        metrics.append(evaluate_threshold(y_true, y_prob, threshold))

    metrics_df = pd.DataFrame(metrics)

    best_thresholds = {
        'weighted_precision': metrics_df.loc[metrics_df['weighted_precision'].idxmax()],
        'weighted_recall': metrics_df.loc[metrics_df['weighted_recall'].idxmax()],
        'weighted_f1': metrics_df.loc[metrics_df['weighted_f1'].idxmax()],
        'accuracy': metrics_df.loc[metrics_df['accuracy'].idxmax()],
        'balanced_accuracy': metrics_df.loc[metrics_df['balanced_accuracy'].idxmax()],
        'roc_auc': metrics_df.loc[metrics_df['roc_auc'].idxmax()],
        'pr_auc': metrics_df.loc[metrics_df['pr_auc'].idxmax()],
        'false_negative_rate': metrics_df.loc[metrics_df['false_negative_rate'].idxmin()]
    }

    return best_thresholds, metrics_df

# Run the analysis on the cleaned data
best_thresholds, metrics_df = find_best_thresholds(RESULTS_DF)

# Print the results
for metric, row in best_thresholds.items():
    print(f"\nBest threshold for {metric}:")
    print(f"Threshold: {row['threshold']:.2f}")
    print(f"Weighted Precision: {row['weighted_precision']:.4f}")
    print(f"Weighted Recall: {row['weighted_recall']:.4f}")
    print(f"Weighted F1: {row['weighted_f1']:.4f}")
    print(f"Accuracy: {row['accuracy']:.4f}")
    print(f"Balanced Accuracy: {row['balanced_accuracy']:.4f}")
    print(f"ROC AUC: {row['roc_auc']:.4f}")
    print(f"PR AUC: {row['pr_auc']:.4f}")
    print(f"False Negative Rate: {row['false_negative_rate']:.4f}")
    print(f"True Positive Rate: {row['true_positive_rate']:.4f}")

# Plot the metrics across thresholds
plt.figure(figsize=(12, 8))
for column in metrics_df.columns:
    if column != 'threshold':
        plt.plot(metrics_df['threshold'], metrics_df[column], label=column)
plt.xlabel('Threshold')
plt.ylabel('Metric Value')
plt.title('Metrics vs Threshold')
plt.legend()
plt.grid(True)
plt.show()

# Plot focusing on False Negative Rate and True Positive Rate
plt.figure(figsize=(12, 8))
plt.plot(metrics_df['threshold'], metrics_df['false_negative_rate'], label='False Negative Rate')
plt.plot(metrics_df['threshold'], metrics_df['true_positive_rate'], label='True Positive Rate')
plt.xlabel('Threshold')
plt.ylabel('Rate')
plt.title('False Negative Rate and True Positive Rate vs Threshold')
plt.legend()
plt.grid(True)
plt.show()

# Find the threshold that minimizes False Negative Rate
min_fnr_threshold = metrics_df.loc[metrics_df['false_negative_rate'].idxmin(), 'threshold']
print(f"\nThreshold that minimizes False Negative Rate: {min_fnr_threshold:.2f}")
print("Metrics at this threshold:")
print(metrics_df.loc[metrics_df['threshold'] == min_fnr_threshold].iloc[0].to_string())


In [None]:
# Save metrics for a new threshold for a given dataset
# Set RESULTS_DF and THRESHOLD below
# Code is from Perplexity.ai

import pandas as pd
import numpy as np
from sklearn.metrics import precision_score, recall_score, confusion_matrix, accuracy_score, balanced_accuracy_score, roc_auc_score, average_precision_score

RESULTS_DF = corrected_results_df
THRESHOLD = 0.3

# Create the new DataFrame with adjusted predictions
adjusted_threshold_df = results_df.copy()
adjusted_threshold_df['prediction'] = (adjusted_threshold_df['positive_prob'] >= THRESHOLD).astype(int)
adjusted_threshold_df['correct'] = (adjusted_threshold_df['prediction'] == adjusted_threshold_df['label']).astype(int)
adjusted_threshold_df['prob'] = np.where(adjusted_threshold_df['prediction'] == 0,
                                         1 - adjusted_threshold_df['positive_prob'],
                                         adjusted_threshold_df['positive_prob'])

# Calculate metrics
y_true = adjusted_threshold_df['label']
y_pred = adjusted_threshold_df['prediction']
y_prob = adjusted_threshold_df['positive_prob']

tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()

metrics = {
    'Weighted Precision': precision_score(y_true, y_pred, average='weighted'),
    'weighted_recall': recall_score(y_true, y_pred, average='weighted'),
    'False Negative Rate': fn / (fn + tp),
    'True Positive Rate': tp / (tp + fn),
    'Weighted F1-Score': f1_score(y_true, y_pred, average='weighted'),
    'Accuracy': accuracy_score(y_true, y_pred),
    'Balanced Accuracy': balanced_accuracy_score(y_true, y_pred),
    'ROC-AUC': roc_auc_score(y_true, y_prob),
    'PR-AUC': average_precision_score(y_true, y_prob)
}

# # Save the adjusted DataFrame to CSV
# adjusted_threshold_df.to_csv('artifacts/MHIST_ViT_v13_adjusted_threshold_results.csv', index=False)

# # Save metrics to a separate file
# pd.Series(metrics).to_json('artifacts/MHIST_ViT_v13_adjusted_threshold_metrics.json')

# Show metrics and new results
for metric, value in metrics.items():
    print(f"{metric}: {value:.4f}")
adjusted_threshold_df

## Export to ONNX
PyTorch does a lot of the work of converting models to ONNX. The new Dynamo export method integrates with ONNX to make conversion simple as well as efficient and accurate. I only selectively log models to MLflow to avoid uneccessary fees from AWS for data transfer and storage, which can add up quickly with large models.

In [None]:
# Test ViT inference with the best model (PyTorch):
!pip install -q timm # CPU-only version: pip install timm[torch-cpu]
import timm
vit = timm.create_model(
    model_name='vit_base_patch16_224',
    pretrained=False,
    num_classes=1, # change number of outputs in classification head
    # img_size=224
)
print('vit.head:\n', vit.head) # Linear(in_features=768, out_features=1, bias=True)

import os
import pandas as pd
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
PYTORCH_MODEL_PATH = 'artifacts/MHIST_ViT_v13.pt'
IMAGES_DIR = 'images/'
PATH = 'MHIST_aah.png'
test_df = pd.read_csv('artifacts/testset_info.csv')

label = 'SSA' if test_df.loc[test_df['name'] == PATH, 'label'].item() == 1 else 'HP'

image_path = os.path.join(IMAGES_DIR, PATH)
image_PIL = Image.open(image_path).convert('RGB') # PIL Image size (224, 224)
print('\nimage_PIL dimensions', image_PIL.size)

# Mean and std values were calculated from the training data, to normalize the colors (per channel):
# Model expects the shape to be [BATCH, 3, 224, 224]
TRAIN_MEAN = [0.738, 0.649, 0.775]
TRAIN_STD =  [0.197, 0.244, 0.17]
MHIST_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(TRAIN_MEAN, TRAIN_STD),
])

preprocessed_image = MHIST_transforms(image_PIL)
print('preprocessed_image shape', preprocessed_image.shape)

# Load model weights from PyTorch model checkpoint
vit.load_state_dict(torch.load(PYTORCH_MODEL_PATH, map_location=torch.device('cpu'))) # load model from disk, specify the location for mapping (loading) the model's params
vit.eval()

with torch.no_grad():
    logit = vit(preprocessed_image.unsqueeze(0)) # torch.Size([1, 3, 224, 224]) with dtype = torch.float32
    print('4D image.shape', preprocessed_image.unsqueeze(0).shape)
    print('\nlogit =', logit.item())
    pred = logit.item() > 0 # Python bool
    prob = torch.sigmoid(logit).item()
    print('pred =', 'SSA' if pred else 'HP')
    print('probability', prob if pred else 1-prob)
print('label =', label)

In [None]:
# Export PyTorch model to ONNX
!pip install --upgrade onnx onnxscript
import torch.onnx

# Generate an example input (image)
example_input = torch.randn(224, 224, 3) # Tensor size (224, 224, 3)
example_preprocessed = MHIST_transforms(example_input.numpy()) # includes ToTensor
print('example_preprocessed shape', example_preprocessed.shape) # torch.Size([3, 224, 224])

# Export to ONNX (with Dynamo)
onnx_program = torch.onnx.dynamo_export(vit, example_preprocessed.unsqueeze(0))
onnx_program.save("artifacts/MHIST_ViT_v13_dynamo_model.onnx")
onnx_program.model_proto.graph.input[0]


In [None]:
# Test inference with local (not EFS) ONNX Dynamo model and S3
# !pip install torch onnx
!pip install -q onnxruntime
import time
import json
from io import BytesIO
from PIL import Image
import numpy as np

# PYTORCH_MODEL_PATH = 'artifacts/MHIST_ViT_v13.pt'
# IMAGES_DIR = 'images/'
# PATH = 'MHIST_aah.png'

import boto3
S3_BUCKET = "mhist-streamlit-app"
S3_ORIGINALS_DIR = "images/original/"
image_filename = 'MHIST_aah.png'

# For inference
from onnxruntime import InferenceSession
EFS_ACCESS_POINT = os.getcwd()
MODEL_PATH = "artifacts/MHIST_ViT_v13_dynamo_model.onnx"

def standardize_image(np_image):
    # Convert lists to numpy
    np_mean = np.array(TRAIN_MEAN, dtype=np.float32).reshape(3, 1, 1) # np_mean.shape (3, 1, 1)
    np_std = np.array(TRAIN_STD, dtype=np.float32).reshape(3, 1, 1) # np_std.shape (3, 1, 1)

    # Normalize: operations are performed element-wise using NumPy broadcasting
    np_image = (np_image - np_mean) / np_std
    return np_image


# Images are normalized to range [0., 1.] and standardized by channel
def preprocess(image_filename):
    # Download image (png file) as bytes from S3
    image_s3key = os.path.join(S3_ORIGINALS_DIR, image_filename)
    s3 = boto3.client('s3')
    file_obj = s3.get_object(Bucket=S3_BUCKET, Key=image_s3key)
    image_bytes = BytesIO(file_obj['Body'].read())

    # Convert bytes (buffer) to 3-channels, then to ndarray
    # We could do this without PIL (using only NumPy)
    pil_image = Image.open(image_bytes).convert('RGB') # pil_image.size (224, 224) with 3 channels
    np_image = np.array(pil_image, dtype=np.float32) # np_image shape (224, 224, 3) dtype float32
    transposed_np = np.transpose(np_image, (2, 0, 1)) # shape (3, 224, 224) max pixel value = 255.
    normalized_np = transposed_np / 255.0 # normalize range to [0., 1.]
    standardized_np = standardize_image(normalized_np) # normalize color-channels
    return np.expand_dims(standardized_np, axis=0)


def sigmoid(np_outs):
    np_outs = np.clip(np_outs, -50, 50) # prevent np.exp overflow for large values
    return 1 / (1 + np.exp(-np_outs))


def predict(image_filename): # image_url <class '_io.BytesIO'>
    print('EFS_ACCESS_POINT contents:', os.listdir(EFS_ACCESS_POINT))

    # Run inference with optimized ONNX model
    # It only uses 3.3 GB CPU memory, and 1.4 GB space (for artifacts)
    onnx_path = os.path.join(EFS_ACCESS_POINT, MODEL_PATH)
    ort_session = InferenceSession(os.path.abspath(onnx_path))#, providers=['CPUExecutionProvider'])

    start_inference = time.monotonic()
    preprocessed_image = preprocess(image_filename)
    preprocess_time = time.monotonic()
    input_name = ort_session.get_inputs()[0].name
    ort_outs = ort_session.run(None, {input_name: preprocessed_image}) # output: [array([-1.2028292], dtype=float32)]
    inference_time = time.monotonic()

    logit = ort_outs[0].item() # <class 'numpy.ndarray'> shape (1,) dtype=float32
    positive_prob = sigmoid(logit).item()
    pred = positive_prob > 0.3
    inference_info = {#json.dumps({
        'logit': logit,
        'predicted_class': 'SSA' if pred else 'HP',
        'probability': positive_prob if pred else 1-positive_prob,
        'preprocess_time': preprocess_time-start_inference,
        'inference_time': inference_time-preprocess_time,
        }
    return inference_info


post_start = time.monotonic()
r_dict = predict(image_filename)
onnx_runtime = time.monotonic() - post_start

print('Completed inference on image_filename', image_filename, 'logit', r_dict['logit'])
print('inference_info:\n', r_dict)

correct = r_dict['predicted_class'] == 'HP'
class_type = 'positive' if r_dict['predicted_class'] == 'SSA' else 'negative'
print(f"Prediction: {r_dict['predicted_class']}, which is a {str(correct).lower()} {class_type}")
print(f"Model's predicted probability: {r_dict['probability']*100:.2f}%")
print(f"Preprocessed image in {r_dict['preprocess_time']:.2f} seconds")
print(f"Classified image in {r_dict['inference_time']:.2f} seconds")
print(f"Total: {int(round(onnx_runtime))} seconds")


In [None]:
# Log best ViT model (as a generic artifact) to MLflow server
# import mlflow
print('Using:', 'boto3', boto3.__version__, 'mlflow', mlflow.__version__)#, 'onnx', onnx.__version__)

MLFLOW_SERVER="http://13.52.243.246:5000"
MLFLOW_EXPERIMENT = 'MHIST ViT (binary classification)'
DYNAMO_MODEL_PATH = "artifacts/MHIST_ViT_v13_dynamo_model.onnx"
MLFLOW_MODEL_PATH = 'onnx_artifacts'

mlflow.set_tracking_uri(MLFLOW_SERVER)
mlflow.set_experiment(MLFLOW_EXPERIMENT)
run_id = '84074e5ab58749f1b609ef5ef90c499f'
# run_name = masked-sheep-165

# with mlflow.start_run(run_id) as run: #any active run will be ended
#     print('\nLogging metrics and best model with MLflow: run_name =', run.info.run_name, 'run_id =', run.info.run_id) # run_id is a UUID
#     print('MLflow server:', mlflow.get_tracking_uri())
#     mlflow.log_artifact(DYNAMO_MODEL_PATH, artifact_path=MLFLOW_MODEL_PATH)
#     # mlflow.log_artifact(local_path, artifact_path=None)
#     # mlflow.log_artifacts(local_dir, artifact_path=None)
#     print('Artifacts stored at:', mlflow.get_artifact_uri())


In [None]:
# Log best ViT model (as a generic artifact) to MLflow server
# import mlflow
print('Using:', 'boto3', boto3.__version__, 'mlflow', mlflow.__version__)#, 'onnx', onnx.__version__)

MLFLOW_SERVER="http://13.52.243.246:5000"
MLFLOW_EXPERIMENT = 'MHIST ViT (binary classification)'
DYNAMO_MODEL_PATH = "artifacts/MHIST_ViT_v13_dynamo_model.onnx"
MLFLOW_MODEL_PATH = 'onnx_artifacts'

mlflow.set_tracking_uri(MLFLOW_SERVER)
mlflow.set_experiment(MLFLOW_EXPERIMENT)
run_id = '84074e5ab58749f1b609ef5ef90c499f'
# run_name = masked-sheep-165

with mlflow.start_run(run_id) as run: #any active run will be ended
    print('\nLogging metrics and best model with MLflow: run_name =', run.info.run_name, 'run_id =', run.info.run_id) # run_id is a UUID
    print('MLflow server:', mlflow.get_tracking_uri())
    mlflow.log_artifact(DYNAMO_MODEL_PATH, artifact_path=MLFLOW_MODEL_PATH)
    print('Artifacts stored at:', mlflow.get_artifact_uri())
