In [None]:
!pip install segmentation_models_pytorch

In [None]:
from torchvision.transforms import transforms
from torch.utils.data import Dataset
from PIL import Image
import os

class CustomDataset(Dataset):
    def __init__(self, data_folder, mask_path, input_size=(960,640), mask_size=(960,640)):
        self.input_size = input_size
        self.image_dir = data_folder
        self.mask_size = mask_size
        self.mask_path = mask_path
        self.images = os.listdir(self.image_dir)

        self.transform_image = transforms.Compose([
            transforms.Resize(self.input_size),
            transforms.ToTensor(),
        ])

        self.transform_mask = transforms.Compose([
            transforms.Resize(self.mask_size),
            transforms.Grayscale(num_output_channels=1),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        mask_path = os.path.join(self.mask_path, self.images[idx].replace(".jpg", ".png"))
        image = Image.open(img_path)
        mask = Image.open(mask_path)

        image = self.transform_image(image)
        mask = self.transform_mask(mask)

        return image, mask


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision import transforms
import segmentation_models_pytorch as smp
import os
import glob

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

num_classes = 1
batch_size = 20

model_name = 'unet'
encoder_name = 'resnet18'
#activation_func='sigmoid'
activation_func=None

model = smp.Unet(encoder_name, in_channels=3, classes=num_classes, activation=activation_func).to(device)

criterion = nn.BCEWithLogitsLoss()

checkpoint_path = "/model/checkpoints/"
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)

val_image_path = "/data/validation/images"
val_mask_path = "/data/validation/masks"

val_dataset = CustomDataset(val_image_path, val_mask_path)

val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

model.eval()
val_loss = 0.0
total_iou = 0.0
counter=0

with torch.no_grad():
    with tqdm(val_loader, desc='Validation', leave=True) as val_bar:
        for images, masks in val_bar:
            images, masks = images.to(device), masks.to(device)

            outputs = model(images)
            loss = criterion(outputs, masks)
            val_loss += loss.item()

            #masks = torch.round(torch.sigmoid(masks))

            predicted_labels = torch.round(torch.sigmoid(outputs))

            intersection = torch.logical_and(predicted_labels, masks).sum(dim=(1, 2, 3))
            union = torch.logical_or(predicted_labels, masks).sum(dim=(1, 2, 3))
            iou_batch = (intersection.float() / (union.float() + 1e-8)).mean()

            total_iou += iou_batch.item()

            val_bar.set_postfix(loss=f'{loss.item():.4f}', IoU=f'{iou_batch.item():.4f}')
            counter=counter+1
            if(counter==1): break

average_iou = total_iou / counter
average_val_loss = val_loss / counter
print()
print(f"Validation Loss: {average_val_loss:.4f}")
print(f"Average IoU: {average_iou:.4f}")