In [1]:
import numpy as np
import torch
import torchvision
from torch.nn.parallel import DataParallel
from torchvision.transforms import ToTensor, Compose
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt

from segment_anything.utils.transforms import ResizeLongestSide

from datasets import Embedding_Dataset
from utils import SAMPreprocess, PILToNumpy, NumpyToTensor, SamplePoint, embedding_collate, is_valid_file
from models import SAM_Baseline
from test import evaluate_model

In [2]:
batch_size = 8
folder_path = '/pfs/work7/workspace/scratch/ul_xto11-FSSAM/Liebherr/dataset'
load_weights = 'fine_tune_liebherr_02.pth'

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SAM_Baseline()
model.load_state_dict(torch.load(load_weights))
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = DataParallel(model)
model.to(device)
print("SAM loaded")

SAM loaded


In [4]:
sam_transform = ResizeLongestSide(model.img_size)
target_transform = Compose([
    sam_transform.apply_image_torch, # rescale
    SAMPreprocess(model.img_size, normalize=False), # padding
    SamplePoint(),
])
transform = Compose([
    PILToNumpy(),
    sam_transform.apply_image, # rescale
    NumpyToTensor(),
    SAMPreprocess(model.img_size) # padding
])

dataset = Embedding_Dataset(root=folder_path, transform=transform, target_transform=target_transform, is_valid_file=is_valid_file)

dataset_size = len(dataset)
train_size = int(0.7 * dataset_size)
val_size = int(0.15 * dataset_size)
test_size = dataset_size - train_size - val_size
generator = torch.Generator().manual_seed(42)
_, _, test_set = random_split(dataset, [train_size, val_size, test_size], generator)

test_loader = DataLoader(test_set, batch_size=batch_size, collate_fn=embedding_collate)

In [5]:
iou_score, oracle_iou_score = evaluate_model(model, test_loader, device)
print('IoU: {:.4f} Oracle IoU: {:.4f}'.format(iou_score, oracle_iou_score), flush=True)

100%|██████████| 16/16 [00:59<00:00,  3.70s/it, IoU=0.627, oracle IoU=0.763]

IoU: 0.6274 Oracle IoU: 0.7628



