In [1]:
import os
import onnx
import si4onnx
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset

# set seed
torch.manual_seed(0)

# set number of threads to 1
torch.set_num_threads(1)
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"

In [2]:
device = torch.device("cpu")

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()

        self.conv1 = nn.Conv2d(1, 4, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(4, 8, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(8, 16, kernel_size=3, padding=1)
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(16, 1)

    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = nn.functional.relu(self.conv2(x))
        x = nn.functional.max_pool2d(x, 2)
        features = nn.functional.relu(self.conv3(x))
        x = self.gap(features)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = torch.sigmoid(x)
        return features, x


class CAM(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn = CNN()

    def forward(self, x):
        original_size = x.size()[2:]
        features, output = self.cnn(x)
        cam = torch.sum(
            features * self.cnn.fc.weight.data.view(1, -1, 1, 1), dim=1, keepdim=True
        )
        cam = nn.functional.interpolate(
            cam, size=original_size, mode="bilinear", align_corners=False
        )

        return cam, output # Multi-Output


n_samples = 100
shape = (1, 16, 16)
batch_size = 16
epochs = 16

normal_dataset = si4onnx.data.SyntheticDataset(
    n_samples=n_samples,
    shape=shape,
    local_signal=0,
    seed=42,
)
abnormal_dataset = si4onnx.data.SyntheticDataset(
    n_samples=n_samples,
    shape=shape,
    local_signal=1,
    seed=43,
)
dataset = ConcatDataset([normal_dataset, abnormal_dataset])
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

model = CAM().to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# train
for epoch in range(epochs):
    model.train()
    for images, _, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        _, outputs = model(images)
        loss = criterion(outputs.flatten(), labels.flatten())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")

Epoch [1/16], Loss: 0.7050
Epoch [2/16], Loss: 0.6916
Epoch [3/16], Loss: 0.6930
Epoch [4/16], Loss: 0.7005
Epoch [5/16], Loss: 0.6823
Epoch [6/16], Loss: 0.6719
Epoch [7/16], Loss: 0.7333
Epoch [8/16], Loss: 0.6615
Epoch [9/16], Loss: 0.6613
Epoch [10/16], Loss: 0.6206
Epoch [11/16], Loss: 0.6781
Epoch [12/16], Loss: 0.6369
Epoch [13/16], Loss: 0.6442
Epoch [14/16], Loss: 0.4964
Epoch [15/16], Loss: 0.4718
Epoch [16/16], Loss: 0.4582


In [3]:
# export onnx
model.eval()
dummy_input = torch.randn(1, *shape).to(device)
model_path = "./models/cam.onnx"
torch.onnx.export(model, dummy_input, model_path)

In [4]:
from si4onnx.operators import GaussianFilter
from si4onnx.utils import thresholding
from sicore import SelectiveInferenceNorm

class CustomSIModel(si4onnx.SIModel):
    def __init__(self, model, threshold):
        super().__init__()
        self.si_model = si4onnx.NN(model)
        self.threshold = torch.tensor(threshold, dtype=torch.float64)

    def construct_hypothesis(self, X, var):
        self.shape = X.shape
        input_x = X

        output_x = self.si_model.forward(input_x)
        saliency_map = output_x[0] # saliency map is the first output

        # Apply Gaussian filter
        saliency_map = GaussianFilter().forward(saliency_map)

        # min max norm
        saliency_map = (saliency_map - torch.min(saliency_map)) \
            / (torch.max(saliency_map) - torch.min(saliency_map))
        
        roi = saliency_map > self.threshold
        
        roi_vec = roi.reshape(-1).int()

        input_vec = input_x.reshape(-1).double()
        eta = (
            roi_vec / torch.sum(roi_vec)
            - (1 - roi_vec) / torch.sum(1 - roi_vec)
        ).double()

        self.roi_vec = roi_vec
        self.si_calculator = SelectiveInferenceNorm(input_vec, var, eta, use_torch=True)
        assert not np.isnan(self.si_calculator.stat) # If No Hypothesis


    def algorithm(self, a, b, z):
        x = a + b * z
        input_x = x.reshape(self.shape).double()
        input_a = a.reshape(self.shape)
        input_b = b.reshape(self.shape)
        INF = torch.tensor(torch.inf).double()
        l, u = -INF, INF

        output_x, output_a, output_b, l, u = self.si_model.forward_si(
            input_x, input_a, input_b, l, u, z
        )

        output_x, output_a, output_b, l, u = GaussianFilter().forward_si(
            output_x[0], output_a[0], output_b[0], l[0], u[0], z
        )

        roi_vec, l, u = thresholding(
            self.threshold, output_x, output_a, output_b, l, u, z, use_norm=True
        )

        return roi_vec, [l, u]

    def model_selector(self, roi_vec):
        return torch.all(torch.eq(self.roi_vec, roi_vec))


In [5]:
onnx_model = onnx.load(model_path)
si_model = CustomSIModel(model=onnx_model, threshold=0.8)

x = torch.randn(1, *shape)

p_value = si_model.inference(x, var=1.0).p_value
print(f"p-value: {p_value}")

p-value: 0.9726479099176948
