In [None]:
import matplotlib.pyplot as plt
from torchvision import transforms
import torchvision.transforms.functional as F
import skimage.io as skio
import numpy as np
import skimage
import os
import torch
import torchvision
import torchvision.transforms as T
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.models.detection import MaskRCNN_ResNet50_FPN_Weights
from torch.utils.data import DataLoader, Dataset
from pycocotools.coco import COCO
import numpy
from PIL import Image

torch.set_num_threads(4)
torch.set_num_interop_threads(4)

# Root directory of the project
ROOT_DIR = os.getcwd()
print(ROOT_DIR)
# Directory of images to run detection on
DATA_DIR = os.path.join(ROOT_DIR, "Reinforcement")
print(DATA_DIR)
# Directory to save logs and trained model
MODEL_DIR = os.path.join(DATA_DIR, "logs")
print(MODEL_DIR)

In [None]:
%load_ext autoreload
%autoreload 2

### Load Trained Model

In [None]:
from inspect_model import load_trained_model
from coco_json import process_masks
from train_pytorch import ForamPoreDataset, Compose, RandomHorizontalFlip
from train_pytorch import get_transform
from inspect_model import visualize_prediction, visualize_dataset

In [None]:
# Load the model from a specific training checkpoint
loop = 3 # adjust if needed
epoch = 10 # adjust if needed

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_path = os.path.join(MODEL_DIR, f"loop_{loop}/model_epoch_{epoch}.pt")

model, optimizer, lr_scheduler = load_trained_model(model_path, device)
model.eval()

### Load the Train and Test Dataset

In [None]:
train_root = os.path.join(ROOT_DIR, 'train')
train_annotation = os.path.join(train_root, 'via_region_data.json')
dataset_train = ForamPoreDataset(train_root, train_annotation, get_transform(train=True))

test_root = os.path.join(ROOT_DIR, 'test')
test_annotation = os.path.join(test_root, 'via_region_data.json')
dataset_test = ForamPoreDataset(test_root, test_annotation, get_transform(train=False))

### Visualize Predictions on Train Images

In [None]:
idx = 3
# Load an image
image, _ = dataset_train[idx] # torch.Tensor, float32, [C, H, W]
# Normalize the image to [0,1] float32, for prediction
image = ((image-image.min())/(image.max()-image.min())).float()

# Add a batch dim for prediction
image_tensor = image.unsqueeze(0).to(device)  # Add batch dimension

# Prediction on the image
model.eval()
with torch.no_grad():
    predictions = model(image_tensor)

visualize_prediction(image, predictions)

#### Compare with the original good pores mask

In [None]:
visualize_dataset(dataset_train, idx = idx)

### Visualize Predictions on Test Images

In [None]:
idx = 1
# Load an image
image, _ = dataset_test[idx] # torch.Tensor, float32, [C, H, W]

# Normalize the image to [0,1] float32, for prediction
image = ((image-image.min())/(image.max()-image.min())).float() 

# Add a batch dim for prediction
image_tensor = image.unsqueeze(0).to(device)  # Add batch dimension

# Prediction on the image
model.eval()
with torch.no_grad():
    predictions = model(image_tensor)

visualize_prediction(image, predictions)

#### Compare with the ground truth

In [None]:
from inspect_model import visualize_dataset
visualize_dataset(dataset_test, idx = idx)