In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
import numpy as np
import os
from pathlib import Path
import matplotlib.pyplot as plt
import torch.nn.functional as F
# import wandb
from torchvision.models import VisionTransformer

In [None]:
device = "cpu" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

In [None]:
torch.manual_seed(20)

In [None]:
path = "../../preprocessing/icdas_preprocessed"
## Load the testing & training data from data_npy folder# Directory Names
dir_training = '{}/training'.format(path)
dir_testing = '{}/testing'.format(path)
file_name = ''
import numpy

import gc
from torch.utils.data import Dataset, DataLoader


## 4 class

class ToothDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.dataset_path = img_dir
        self.transform = transform

    def __len__(self):
        return len(os.listdir(self.dataset_path))

    def __getitem__(self, idx):
        if idx >= len(os.listdir(self.dataset_path)):
            print("No datafile/image at index: " + str(idx))
            return None
        
        npy_filename = os.listdir(self.dataset_path)[idx]
        # print(npy_flename)
        
        # print(npy_filename)
        # file_name = npy_filename
        label = int(npy_filename[-5]) - 3  # Extract the last digit and convert to class label
        
        numpy_arr = numpy.load(self.dataset_path + '/' + npy_filename)
        # file_name = npy_filename
        
        for i in range(numpy_arr.shape[0] - 70):
            numpy_arr = numpy.delete(numpy_arr, [0], axis=0)
            
        numpy_arr = numpy_arr.reshape(1, 70, 70, 70)
        tensor_arr = torch.from_numpy(numpy_arr).to(torch.float32)

        del numpy_arr 
        gc.collect()
        
        if self.transform:
            tensor_arr = self.transform(tensor_arr)  # Apply transformations
        print(npy_filename, label)
        # pass the filename as well
        return tensor_arr.to(torch.float32), torch.LongTensor([label])
        # return tensor_arr.to(torch.float32), torch.tensor(label)
        

In [None]:
# Hyperparameters
epochs = 500
batch_size = 1
learning_rate = 1e-3
weight_decay = 0.0000000001
momentum = 0.9
training_data = ToothDataset(img_dir=dir_training, transform=None)
validation_data = ToothDataset(img_dir=dir_testing, transform=None)


In [None]:
training_data_loader = DataLoader(training_data, batch_size, shuffle=True)
validation_data_loader = DataLoader(validation_data, batch_size, shuffle=False)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Basic3DCNN(nn.Module):
    def __init__(self, num_classes=4):
        super(Basic3DCNN, self).__init__()
        # Convolutional layers
        self.conv1 = nn.Conv3d(1, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv3d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv3d(32, 64, kernel_size=3, stride=1, padding=1)
        # Max pooling layers
        self.pool = nn.MaxPool3d(2, 2)
        # Fully connected layers
        self.fc1 = nn.Linear(64 * 8 * 8 * 8, 256)  # Adjust input size based on pooling
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        # Convolutional layers with ReLU activation and max pooling
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        # Flatten the output
        x = x.view(-1, 64 * 8 * 8 * 8)  # Adjust output size based on pooling
        # Fully connected layers with ReLU activation
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
# from pytorch_pretrained_vit import ViT
# pretrained_model = ViT('B_16_imagenet1k', pretrained=True)

model = Basic3DCNN().to(device)

In [None]:
# model = NeuralNetwork().to(device)
model.load_state_dict(torch.load('../best_model_epoch_63_accuracy_94.44.pt', map_location=torch.device('cpu')))

In [None]:
import gc
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from skimage import measure
import torch
from captum.attr import Saliency
import meshio

threshold = 0.5


def generateAndSaveSaliencyMap(model, data, target, file_name):
    global threshold
    # global file_name
    
    # set a bool correct to check if the model's prediction is correct
    correct = False
    # Get model's prediction
    output = model(data)
    _, predicted = torch.max(output, 1)
    if predicted == target:
        correct = True
        
    print(f'Predicted: {predicted.item()} Target: {target} Correct: {correct}')
    
    # store the directory path
    if correct:
        directory = file_name[:-4] + "[correct]/"
        
    else:
        directory = file_name[:-4] + "/"
    
    # create the directory if it does not exist
    if not os.path.exists(directory):
        os.makedirs(directory)
    
    #store the path of the file to be saved
    file_name = directory + file_name[:-4] + "_saliency_map.stl"

    # Generate saliency map
    saliency = Saliency(model)
    saliency_map = saliency.attribute(data, target=target)
    
    print(saliency_map.shape)

    # Move saliency map to CPU and convert to numpy array
    saliency_map = saliency_map.cpu().detach().numpy().squeeze()

    # Calculate threshold value as the 70% percentile of the saliency map
    threshold = np.percentile(saliency_map, 99.5)
    # print(threshold)

    # Create a 3D isosurface from the saliency map
    verts, faces, _, _ = measure.marching_cubes(saliency_map, threshold)

    # Create a figure with a 3D plot
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, projection='3d')

    # Create a 3D isosurface from the saliency map
    mesh = Poly3DCollection(verts[faces], alpha=0.4)
    face_color = [1.0, 0.647, 0.0]  # color of the saliency map surface
    mesh.set_facecolor(face_color)
    ax.add_collection3d(mesh)

    # Create a 3D isosurface from the original data
    original_verts, original_faces, _, _ = measure.marching_cubes(
        data.cpu().detach().numpy().squeeze(), 0.5)

    # Create a 3D isosurface from the original STL file
    original_mesh = Poly3DCollection(original_verts[original_faces], alpha=0.4)
    original_face_color = [0.5, 0.5, 1]  # color of the original surface
    original_mesh.set_facecolor(original_face_color)
    ax.add_collection3d(original_mesh)

    # Set limits and labels of the plot
    max_dim = np.max(np.vstack([original_verts, verts]), axis=0)
    ax.set_xlim(0, max_dim[0])
    ax.set_ylim(0, max_dim[1])
    ax.set_zlim(0, max_dim[2])
    ax.set_xlabel("X-axis")
    ax.set_ylabel("Y-axis")
    ax.set_zlabel("Z-axis")
    
    meshio.write(file_name, meshio.Mesh(
        points=np.vstack([original_verts, verts]), cells=[("triangle", np.vstack([original_faces, faces+len(original_verts)]))]))
    
    # remove _saliency.stl and add .stl
    file_name = file_name[:-17] + ".stl"
    meshio.write(file_name, meshio.Mesh(
        points=verts, cells=[("triangle", faces)]))
    #remove the .stl and add _original.stl
    file_name = file_name[:-4] + "_original.stl"
    meshio.write(file_name, meshio.Mesh(
        points=original_verts, cells=[("triangle", original_faces)]))
    
    # Show the plot
    plt.show()

    # Free up memory
    del saliency_map, verts, faces, mesh, original_mesh
    gc.collect()
    
import os

# Get a list of file names in the directory
dir_path = '../../preprocessing/icdas_preprocessed/testing'  # replace with your directory path
file_names = os.listdir(dir_path)

# Iterate over the validation data loader
for i, (data, target) in enumerate(validation_data_loader):
    # print(file_names[i])
    # pass
    # Move data and target to the same device as your model
    data = data.to(device)
    target = target.to(device)
    # Assuming 'data' is your data tensor
    data_min = torch.min(data)
    data_max = torch.max(data)

    print(f'Data range: {data_min} to {data_max}')

    # Use the file name from the directory for each data point
    file_name = file_names[i]
    
    # set a bool correct to check if the model's prediction is correct

    # Call the function
    generateAndSaveSaliencyMap(model, data, target.item(), file_name)
    
