In [1]:
# !pip install lightning

In [2]:
import pytorch_lightning as pl
from torch.utils.data import DataLoader
import albumentations as A
import torchvision.transforms as T
import albumentations.pytorch as pytorch
import albumentations as albu
from typing import Union

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchmetrics import Accuracy, JaccardIndex, FBetaScore
from typing import Any, Union

In [4]:
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
import pytorch_lightning as pl
import torch
import segmentation_models_pytorch as smp
import warnings
from typing import Union, Any

In [5]:

from torch.utils.data import Dataset

import pandas as pd
import numpy as np

import os
from typing import Any

from PIL import Image
import segmentation_models_pytorch as smp

from segmentation_models_pytorch.utils import metrics

from segmentation_models_pytorch.losses import FocalLoss, DiceLoss, JaccardLoss

In [6]:
class ThermalDataset(Dataset):
    def __init__(self,
                 stage: str,
                 images_path: str,
                 augmentation: Any,
                 preprocessing: Any,
                 shuffle: bool = True,
                 random_state: int = 42):

        self.__attribute_checking(images_path,
                                  stage, shuffle, random_state)

        self.images_path = images_path

        self.augmentation = augmentation
        self.preprocessing = preprocessing

        self.stage = stage
        self.shuffle = shuffle
        self.random_state = random_state
        self.total_len = None
        self._images, self._masks = self.__create_dataset()

    @staticmethod
    def __type_checking(images_path: str,
                        stage: str, shuffle: bool,
                        random_state: int) -> None:
        
        assert isinstance(images_path, str)
        assert isinstance(stage, str)
        assert isinstance(shuffle, bool)
        assert isinstance(random_state, int)


    @staticmethod
    def __path_checking(images_path: str) -> None:
        assert os.path.isdir(images_path)

    @staticmethod
    def __stage_checking(stage: str) -> None:
        assert stage in ["train", "test", "val"]

    @classmethod
    def __attribute_checking(cls, images_path: str,
                             stage: str,
                             shuffle: bool,
                             random_state: int) -> None:

        cls.__type_checking(images_path=images_path,
                            stage=stage,
                            shuffle=shuffle,
                            random_state=random_state)

        cls.__path_checking(images_path=images_path)

        cls.__stage_checking(stage=stage)

    def __create_dataset(self) -> dict:
        dict_paths = {
            "image": [],
            "mask": []
        }

        images_path = self.__split_data(self.stage)

        for image_name in os.listdir(images_path):
            dict_paths["image"].append(os.path.join(images_path,image_name))
            dict_paths["mask"].append(os.path.join(os.path.dirname(images_path),'masks',image_name.replace('_NIR_SWIR','_mask')))

        dataframe = pd.DataFrame(
            data=dict_paths,
            index=np.arange(0, len(dict_paths["image"]))
        )
        self.total_len = len(dataframe)
        data_dict = {self.stage: (dataframe["image"].values,dataframe["mask"].values)}

        return data_dict[self.stage]

    def __split_data(self, stage: str) -> str:
        return os.path.join(self.images_path,stage,'images')

    def __len__(self) -> int:
        return self.total_len

    def __getitem__(self, idx) -> tuple:

        image = Image.open(self._images[idx])
        mask = Image.open(self._masks[idx])
        
        image = np.array(image)

        ### FOR FOCAL LOSS
        mask = mask.convert('L') # This ensures that the label only have 1 band, which is necessary for binary classification
        mask = np.array(mask)[:,:,np.newaxis]
        
        mask = np.divide(mask,255).astype('float32') #Masks need to be 0-1 values
        
        # # apply augmentation
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        return image, mask


In [7]:
def get_training_augmentation():
    train_transform = [

        albu.HorizontalFlip(p=0.5),
        albu.VerticalFlip(p=0.5),

    ]
    return albu.Compose(train_transform)

def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """
    
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

In [8]:
class ThermalDataModule(pl.LightningDataModule):
    def __init__(self,images_path: str,
                 augmentation: Union[T.Compose, A.Compose],
                 preprocessing: Any,
                 batch_size: int = 5,
                 num_workers: int = os.cpu_count(),
                 seed: int = 42):
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.images_path = images_path
        self.data_train = None
        self.data_val = None
        self.data_test = None
        self.data_predict = None
        self.seed = seed

        self.train_augmentation = augmentation
        self.eval_augmentation = augmentation
        self.preprocessing = preprocessing


    def setup(self, stage: str = None) -> None:
        self.data_train = ThermalDataset(
            images_path=self.images_path,
            augmentation=self.train_augmentation,
            preprocessing=self.preprocessing,
            stage="train",
            shuffle=True,
            random_state=self.seed
            )

        self.data_val = ThermalDataset(
            images_path=self.images_path,
            augmentation=self.eval_augmentation,
            preprocessing=self.preprocessing,
            stage="val",
            shuffle=True,
            random_state=self.seed
            )

        self.data_test = ThermalDataset(
            images_path=self.images_path,
            augmentation=self.eval_augmentation,
            preprocessing=self.preprocessing,
            stage="test",
            shuffle=True,
            random_state=self.seed
            )

        self.data_predict = ThermalDataset(
            images_path=self.images_path,
            augmentation=self.eval_augmentation,
            preprocessing=self.preprocessing,
            stage="test",
            shuffle=True,
            random_state=self.seed
            )

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            dataset=self.data_train,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            dataset=self.data_val,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False
        )

    def test_dataloader(self) -> DataLoader:
        return DataLoader(
            dataset=self.data_test,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False
        )

    def predict_dataloader(self) -> DataLoader:
        return DataLoader(
            dataset=self.data_predict,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False
        )

In [9]:
class ThermalModel(pl.LightningModule):
    def __init__(self,
                 model: nn.Module,
                 loss_fn: Any,
                 optim_dict: dict = None,
                 lr: float = None,
                 num_classes: int = 1):
        super().__init__()
        self.save_hyperparameters(ignore=['model','loss_fn'])

        self.num_classes = num_classes
        self.model = model
        # self.criterion = nn.CrossEntropyLoss()
        self.criterion = loss_fn
        self.optim_dict = optim_dict
        self._device = "cuda" if torch.cuda.is_available else "cpu"

        self.step_outputs = {
            "loss": [],
            "accuracy": [],
            "jaccard_index": [],
            "fbeta_score": [],
            "IoU": []
        }

        self.metrics = {
            "accuracy": Accuracy(task="binary",
                                 threshold=0.5,
                                 num_classes=num_classes,
                                 validate_args=True,
                                 ignore_index=None,
                                 average="micro").to(self._device),

            "jaccard_index": JaccardIndex(task="binary",
                                          threshold=0.5,
                                          num_classes=num_classes,
                                          validate_args=True,
                                          ignore_index=None,
                                          average="macro").to(self._device),

            "fbeta_score": FBetaScore(task="binary",
                                      beta=1.0,
                                      threshold=0.5,
                                      num_classes=num_classes,
                                      average="micro",
                                      ignore_index=None,
                                      validate_args=True).to(self._device),

            "IoU": metrics.IoU()
        }

    def forward(self, x):
        return self.model(x)

    def shared_step(self, batch, stage: str) -> torch.Tensor:
        x, y = batch
        x, y = x.to(self._device),y.to(self._device)

        assert x.ndim == 4
        assert x.max() <= 3 and x.min() >= -3 
        assert y.ndim == 4
        assert y.max() <= 1 and y.min() >= 0

        logits = self.forward(x.to(torch.float32))
        

        # activated = F.softmax(input=logits, dim=1)
        # predictions = torch.argmax(activated, dim=1)

        predictions = torch.round(torch.sigmoid(logits))
        # predictions = logits
        
        loss = self.criterion(logits, y)

        accuracy = self.metrics["accuracy"](predictions, y)
        jaccard_index = self.metrics["jaccard_index"](predictions, y)
        fbeta_score = self.metrics["fbeta_score"](predictions, y)
        IoU_score = self.metrics["IoU"](predictions, y)

        self.step_outputs["loss"].append(loss)
        self.step_outputs["accuracy"].append(accuracy)
        self.step_outputs["jaccard_index"].append(jaccard_index)
        self.step_outputs["fbeta_score"].append(fbeta_score)
        self.step_outputs["IoU"].append(IoU_score)

        self.log(f'{stage}_loss'   , loss          , prog_bar=True , on_step=False , on_epoch=True)
        # self.log(f'{stage}_acc'    , accuracy      , prog_bar=True , on_step=False , on_epoch=True)
        # self.log(f'{stage}_jaccard', jaccard_index , prog_bar=True , on_step=False , on_epoch=True)
        self.log(f'{stage}_fbeta'  , fbeta_score   , prog_bar=True , on_step=False , on_epoch=True)
        self.log(f'{stage}_IoU'    , IoU_score     , prog_bar=True , on_step=False , on_epoch=True)
        
        return loss

    # def shared_epoch_end(self, stage: Any):
    #     loss = torch.mean(torch.tensor([
    #         loss for loss in self.step_outputs["loss"]
    #     ]))

    #     accuracy = torch.mean(torch.tensor([
    #         accuracy for accuracy in self.step_outputs["accuracy"]
    #     ]))

    #     jaccard_index = torch.mean(torch.tensor([
    #         jaccard_index for jaccard_index in self.step_outputs["jaccard_index"]
    #     ]))

    #     fbeta_score = torch.mean(torch.tensor(
    #         [fbeta_score for fbeta_score in self.step_outputs["fbeta_score"]
    #          ]))

    #     for key in self.step_outputs.keys():
    #         self.step_outputs[key].clear()

    #     metrics = {
    #         f"{stage}_loss": loss,
    #         f"{stage}_accuracy": accuracy,
    #         f"{stage}_jaccard_index": jaccard_index,
    #         f"{stage}_fbeta_score": fbeta_score
    #     }
    #     self.log_dict(metrics, prog_bar=True)

    def training_step(self, batch: Any, batch_idx: Any):
        return self.shared_step(batch=batch, stage="train")

    # def on_train_epoch_end(self) -> None:
    #     return self.shared_epoch_end(stage="train")

    def validation_step(self, batch: Any, batch_idx: Any):
        return self.shared_step(batch=batch, stage="val")

    # def on_validation_epoch_end(self) -> None:
    #     return self.shared_epoch_end(stage="val")

    def test_step(self, batch: Any, batch_idx: Any):
        return self.shared_step(batch=batch, stage="test")

    # def on_test_epoch_end(self) -> None:
    #     return self.shared_epoch_end(stage="test")

    def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):
        x, y = batch

        assert x.ndim == 4
        assert x.max() <= 3 and x.min() >= -3
        assert y.ndim == 4
        assert y.max() <= 1 and y.min() >= 0

        logits = self.forward(x.to(torch.float32))
        # predictions = logits
        predictions = torch.round(torch.sigmoid(logits))

        # activated = F.softmax(input=logits, dim=1)
        # predictions = torch.argmax(activated, dim=1)

        return predictions

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            params=self.parameters(),
            lr=self.hparams.lr
        )

        scheduler_dict = {
            "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer=optimizer,
                patience=5
            ),
            "interval": "epoch",
            "monitor": "val_loss"
        }

        optimization_dictionary = {"optimizer": optimizer, "lr_scheduler": scheduler_dict}
        return self.optim_dict if self.optim_dict else optimization_dictionary


In [10]:
def main(callbacks: list,
         model: Union[list, tuple],
         loss_fn: Any,
         augmentation: Any,
         preprocessing: Any,
         logger: TensorBoardLogger,
         images_path: str,
         optim_dict: dict,
         ) -> None:

    # Trainer
    trainer = pl.Trainer(
        fast_dev_run=False,
        accelerator="auto",
        strategy="auto",
        devices="auto",
        num_nodes=1,
        logger=logger,
        callbacks=callbacks,
        max_epochs=200,
        min_epochs=150
    )

    # Datamodule
    datamodule = ThermalDataModule(
        images_path=images_path,
        augmentation=augmentation,
        preprocessing=preprocessing,
        batch_size=5,
        num_workers=os.cpu_count()
    )

    # LightningModule
    lightning_model = ThermalModel(
        model=model,
        loss_fn=loss_fn,
        optim_dict=optim_dict,
        lr=3e-4
    )

    # Start training
    trainer.fit(model=lightning_model, datamodule=datamodule)

In [11]:
# Run Constants
SEED: int = 42
ACTION: str = "ignore"
DATA_PATH: str = os.path.join(os.getcwd(),'train_dataset')
CHECKPOINT: Any = None
    
# Model Constants
CLASSES = 1
IN_CHANNELS = 3

optim_dict = None

# ENCODER = 'se_resnext50_32x4d'
ENCODER = 'mobilenet_v2'
ENCODER_WEIGHTS = 'imagenet'
    
CLASSES = ['Background','Thermal Event']
#### Removed the activation function for testing
ACTIVATION = None
# ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multiclass segmentation
DEVICE = 'cuda'

n_cpu = os.cpu_count()

# model_name = 'Unet'
model_name = 'DeepLabV3Plus'


model = smp.DeepLabV3Plus(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=1, 
    activation=ACTIVATION,
)

# model = smp.Unet(
#     encoder_name=ENCODER, 
#     encoder_weights=ENCODER_WEIGHTS, 
#     in_channels = 3,
#     classes=1, 
#     activation=ACTIVATION,
# )

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)



loss = FocalLoss(mode= 'binary')
loss.__name__ = 'focal_loss'

# # loss = DiceLoss(mode= 'binary')
# # loss.__name__ = 'dice_loss'

# # loss = JaccardLoss(mode= 'binary')
# # loss.__name__ = 'jaccard_loss'

# # loss = losses.DiceLoss()
# # loss = losses.JaccardLoss()

# metrics = [
#     metrics.IoU(),
# ]

# optimizer = torch.optim.Adam([ 
#     dict(params=model.parameters(), lr=1e-3),
# ])

augmentation=get_training_augmentation()
preprocessing=get_preprocessing(preprocessing_fn)

In [12]:
# Callbacks
callbacks = [
    ModelCheckpoint(
        dirpath=f"models/{model_name}",
        filename="Unet_MobileNetV2_{epoch}_{val_loss:.2f}_{val_accuracy:.2f}",
        save_top_k=10,
        monitor="val_loss",
        mode="min"
    ),

    EarlyStopping(
        monitor="val_loss",
        min_delta=2e-4,
        patience=8,
        verbose=False,
        mode="min"
    ),

    LearningRateMonitor(
        logging_interval="step"
    )
]

In [13]:
from lightning.pytorch.loggers import TensorBoardLogger
logger = TensorBoardLogger(save_dir="./logs", name=model_name)

In [None]:
main(
    callbacks=callbacks,
    model=model,
    loss_fn=loss,
    augmentation=augmentation,
    preprocessing=preprocessing,
    logger=logger,
    images_path=DATA_PATH,
    optim_dict=optim_dict
)

In [23]:
%load_ext tensorboard
# %tensorboard --logdir models/current_best_model/version_33

# %tensorboard --logdir logs/DeepLabV3Plus/version_0

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [19]:

# checkpoint_path = os.path.join(os.getcwd(),'models','current_best_model','Unet_MobileNetV2_epoch=146.ckpt')
checkpoint_path = os.path.join(os.getcwd(),'models',model_name,'DeepLabv3PLus_MobileNetV2_epoch=144_val_loss=0.00_val_accuracy=0.00.ckpt')
trained_model = ThermalModel.load_from_checkpoint(checkpoint_path=checkpoint_path,model=model,loss_fn=loss)
trained_model.eval();



In [20]:
### Perform the testing, c  NEED TO create a function

trainer = pl.Trainer(
    fast_dev_run=False,
    accelerator="auto",
    strategy="auto",
    devices="auto",
    num_nodes=1,
    logger=logger,
    callbacks=callbacks,
    max_epochs=200,
    min_epochs=150
)

# Datamodule
datamodule = ThermalDataModule(
    images_path=DATA_PATH,
    augmentation=augmentation,
    preprocessing=preprocessing,
    batch_size=5,
    num_workers=os.cpu_count()
)

loss_2 = FocalLoss(mode= 'binary')
loss_2.__name__ = 'focal_loss'

# LightningModule
lightning_model = ThermalModel(
    model=model,
    loss_fn=loss_2,
    optim_dict=optim_dict,
    lr=3e-4
)

trainer.test(model=trained_model,datamodule=datamodule)
# trainer.predict(model=trained_model,datamodule=datamodule)

Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_IoU             0.652067244052887
       test_fbeta           0.5769277215003967
        test_loss         0.00044822978088632226
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.00044822978088632226,
  'test_fbeta': 0.5769277215003967,
  'test_IoU': 0.652067244052887}]

In [None]:
# torch.save(trained_model,'trained_model')

In [None]:
import pandas as pd
%matplotlib inline
# %matplotlib widget
import os

from matplotlib.ticker import AutoMinorLocator, MultipleLocator

import matplotlib.pyplot as plt

model_path = os.path.join(os.getcwd(),'models','current_best_model')

metric_id_list = []

for file_name in os.listdir(model_path):
    if file_name[-4:]=='.csv':
        fig, ax = plt.subplots(figsize=(7,5))
        
        csv_file = pd.read_csv(os.path.join(model_path,file_name))
        metric_name = file_name.replace('_evolution.csv','')
        

        ax.yaxis.set_minor_locator(AutoMinorLocator(4))

        if metric_name[:5]=='train':
            ax.set_xlabel('# epochs')

            metric_id = metric_name.replace('train_','')    
            if metric_id not in metric_id_list:
                for inner_file_name in os.listdir(model_path):
                    if inner_file_name[-4:]=='.csv' and inner_file_name.replace('_evolution.csv','')==f'val_{metric_id}':
                        csv_file_inner = pd.read_csv(os.path.join(model_path,inner_file_name))
                        
                        csv_values = csv_file['Value']
                        csv_inner_values = csv_file_inner['Value']
                        
                        ax.plot(csv_values,color='r',label='Training')
                        ax.plot(csv_inner_values,color='b',label='Validation')
                        ax.set_title(f'{metric_id} evolution')
                        ax.set_ylabel(f'{metric_id}')
                        plt.annotate('%0.2f' % csv_values[len(csv_values)-1], xy=(len(csv_values)-1, csv_values[len(csv_values)-1]))
                        plt.annotate('%0.2f' % csv_inner_values[len(csv_inner_values)-1], xy=(len(csv_inner_values)-1, csv_inner_values[len(csv_inner_values)-1]))
                        metric_id_list.append(metric_id)
                        break
            else:
                metric_id = None
                plt.close(fig)
        elif metric_name[:3]=='val':
            ax.set_xlabel('# epochs')
            
            metric_id = metric_name.replace('val_','')
            if metric_id not in metric_id_list:
                for inner_file_name in os.listdir(model_path):
                    if inner_file_name[-4:]=='.csv' and inner_file_name.replace('_evolution.csv','')==f'train_{metric_id}':
                        csv_file_inner = pd.read_csv(os.path.join(model_path,inner_file_name))
                        
                        csv_values = csv_file['Value']
                        csv_inner_values = csv_file_inner['Value']

                        ax.plot(csv_values,color='b',label='Validation')
                        ax.plot(csv_inner_values,color='r',label='Training')
                        ax.set_title(f'{metric_id} evolution')
                        ax.set_ylabel(f'{metric_id}')

                        plt.annotate('%0.2f' % csv_values[len(csv_values)-1], xy=(len(csv_values)-1, csv_values[len(csv_values)-1]))
                        plt.annotate('%0.2f' % csv_inner_values[len(csv_inner_values)-1], xy=(len(csv_inner_values)-1, csv_inner_values[len(csv_inner_values)-1]))
                        
                        metric_id_list.append(metric_id)
            else:
                metric_id = None
                plt.close(fig)

        else:
            csv_values = csv_file['Value']
            ax.set_title(f'{metric_name} evolution')
            ax.set_ylabel(f'{metric_name}')
            ax.set_xlabel('Time')
            ax.plot(csv_values,color='g',label='learning rate')
            plt.annotate('%.0E' % csv_values[0], xy=(0, csv_values[0]))
            plt.annotate('%.0E' % csv_values[len(csv_values)-1], xy=(len(csv_values)-1, csv_values[len(csv_values)-1]))
            metric_id = 'lr'
        # ax.plot(csv_file['Value'])
        if metric_id:
            ax.legend()
            plt.savefig(os.path.join(model_path,metric_id+'_evolution.png'))
        # break
        # break
        # plt.show()
        # break
        # plt.savefig(os.path.join(model_path,file_name.replace('.csv','.png')))
        


In [None]:
x = torch.randn(5, 3, 256, 256).cpu()
model_onnx = trained_model.cpu()
model_onnx.eval()
torch_out = model_onnx(x)

# Export the model
torch.onnx.export(model_onnx,               # model being run
                  x,                         # model input (or a tuple for multiple inputs)
                  "Unet_MobileNetV2_epoch146.onnx",   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=11,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                'output' : {0 : 'batch_size'}})

  if h % output_stride != 0 or w % output_stride != 0:


In [None]:
import onnx

onnx_model = onnx.load("Unet_MobileNetV2_epoch146.onnx")
onnx.checker.check_model(onnx_model)

In [None]:
# !pip install onnxruntime
import onnxruntime

ort_session = onnxruntime.InferenceSession("Unet_MobileNetV2_epoch146.onnx", providers=["CPUExecutionProvider"])

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)

# compare ONNX Runtime and PyTorch results
print(np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05))

print("Exported model has been tested with ONNXRuntime, and the result looks good!")

None
Exported model has been tested with ONNXRuntime, and the result looks good!
