### Import most of necessary packages

In [1]:
import cv2
from tqdm import tqdm
import torch
import torch.backends.cudnn as cudnn
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np
import json

from typing import List, Dict

cudnn.benchmark = True

### Create a function for drawing mask polygons from list of x-y pairs

In [2]:
def draw_mask_polygon(mask: np.ndarray, points: List, mask_color) -> np.ndarray:
	"""
		Draws a polygon of a segmented part specified by set of x-y coord pairs

	Args:
		mask (np.ndarray): current mask to draw polygon on
		points (List): list of x-y coord pairs which define the polygon
		mask_color (int, Tuple): per-channel value for the polygon

	Returns:
		np.ndarray: new mask with the polygon drawn
	"""
	# Parse sequence of x-y pairs
	obj_mask_points = []
	for point in points:
		obj_mask_points.append([point["x"], point["y"]])

	# Draw polygon with specified color
	return cv2.fillPoly(mask.copy(), pts=[np.array(obj_mask_points, dtype=np.int32)], color=mask_color)

### Create a dataset for storing images with the corresponding masks and some additional data

In [3]:
class BoltDataset(Dataset):
	def __init__(self, root_dir: str, data_dir: str, annot_file_name: str, transforms = None):
		"""
			BoltDataset default constructor

		Args:
			root_dir (str): root directory where dataset is located 
			data_dir (str): directory with images and annnotations file
			annot_file_name (str): file name of an annotations file
			transforms (_type_, optional): _description_. Defaults to None.
		"""
		# Pathes to all necessary data
		self.root_dir = root_dir
		self.data_dir = os.path.join(self.root_dir, data_dir)
		self.annot_file_path = os.path.join(self.data_dir, annot_file_name)
		self.transforms = transforms
		
		# Buffers for storing image and class info and image-mask pairs
		self.images_info = []
		self.images_with_masks = []
		self.id2class_categories = {0: "background"}
		self.class_categories2id = {"background": 0}

		# Load json annotations file
		with open(self.annot_file_path) as data_json:
			self.annot_data = json.load(data_json)
		
		# Parse class info
		for category_info in self.annot_data["categories"]:
			if(category_info["supercategory"] is not None):
				self.id2class_categories[category_info["id"]] = category_info["name"]
				self.class_categories2id[category_info["name"]] = category_info["id"]

		# Parse image info
		for image_info in self.annot_data["images"]:
			self.images_info.append({"file_name": image_info["file_name"], "image_size": (image_info["height"], image_info["width"])})

		# Parse annotations and apply transforms(if any)
		curr_image_id = 0
		curr_mask = np.zeros(self.images_info[curr_image_id]["image_size"], dtype=np.float32)
		bar = tqdm(enumerate(self.annot_data["annotations"]), total=len(self.annot_data["annotations"]))
		for id, annot_info in bar:
			if annot_info["image_id"] != curr_image_id or id == len(self.annot_data["annotations"]) - 1:
				image = cv2.imread(os.path.join(self.data_dir, self.images_info[curr_image_id]["file_name"]))
				image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
				
				if(transforms is not None):
					transformed = transforms(image=image, mask=curr_mask)
					self.images_with_masks.append((transformed["image"], transformed["mask"]))
				else:
					self.images_with_masks.append((torch.from_numpy(image).to(torch.float32).unsqueeze(0), torch.from_numpy(curr_mask).unsqueeze(0)))

				curr_image_id = (curr_image_id + 1) % len(self.images_info)
				curr_mask = np.zeros(self.images_info[curr_image_id]["image_size"], dtype=np.float32)

				bar.set_postfix_str(f"Total processed images: {len(self.images_with_masks)}")
			
			points = []
			for point_id in range(0, len(annot_info["segmentation"][0]), 2):
				points.append({	"x": annot_info["segmentation"][0][point_id], 
								"y":annot_info["segmentation"][0][point_id + 1]})
			
			curr_mask = draw_mask_polygon(curr_mask, points, annot_info["category_id"])
		
		### Algorithm for processing color images for further usage in dataset creation

		# 	mask = cv2.imread(os.path.join(self.masks_dir, image_name))
		# 	if(mask is None):
		# 		print(f"Warning:{image_name} does not have associated mask. Skipping...")
		# 		continue

		# 	self.image_names.append(image_name)

		# 	image = cv2.bitwise_not(image)
		# 	image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

		# 	image = cv2.multiply(image[:,:,0], (1,1), scale=0.7)

		# 	cv2.imwrite(os.path.join(self.masks_dir, image_name), image)
		
	def __len__(self):
		""" 
			Getting number of dataset elements
		"""
		return len(self.images_with_masks)

	def __getitem__(self, idx: int):
		""" 
			Getting particular element of a dataset via index
		
		Arguments:
			idx (int): index of a needed element
		"""
		return self.images_with_masks[idx]

### Display particular sample from the created dataset 

In [4]:
def mask_sample(dataset: BoltDataset, idx: int, class_colors: Dict, pred_masks_dir: str = None, pred_mask: torch.Tensor = None, alpha: float = 0.2) -> List[np.ndarray]:
	""" Merges provided by index from dataset image with either predicted or loaded mask

	Args:
		dataset (BoltDataset): dataset to take image by index from
		idx (int): index of image to merge with mask
		class_colors (Dict): dictionary of RGB colors(tuples) for each category on a mask
		pred_masks_dir (str, optional): path to a predicted mask. Defaults to None.
		pred_mask (torch.Tensor, optional): tensor of predicted mask. Defaults to None.
		alpha (float, optional): transparency of a mask. Defaults to 0.2.

	Returns:
		List[np.ndarray]: original image with applied mask on top and the mask itself
	"""

	mask = None
	if(pred_masks_dir is not None):
		# When the mask is taken on specifed path
		try:
			mask = cv2.imread(os.path.join(os.path.join(dataset.root_dir, pred_masks_dir), dataset.images_info[idx]["file_name"]))
		except IndexError:
			print("There is no image with such index in the dataset.")
			return None
	
		assert mask is not None, "There is no mask for this image. Check mask name(it has to be the same with image) or perform segmentation to obtain the it."
	else:
		if(pred_mask is None):
			# When the mask is taken from dataset
			mask = np.repeat(dataset[idx][1].permute(1, 2, 0).numpy(), 3, axis=2).astype(np.uint8)
		else:
			# When the mask is provided(usually it is postprocessed mask)
			mask = np.repeat(pred_mask.permute(1, 2, 0).numpy(), 3, axis=2).astype(np.uint8)
		
		# Coloring the mask
		for class_id in class_colors.keys():
			for channel_id in range(3):
				mask[:, :, channel_id] = np.where(mask[:, :, channel_id] == class_id, class_colors[class_id][channel_id], mask[:, :, channel_id])

	# Image and mask postprocessing
	image = dataset[idx][0].permute(1, 2, 0).numpy().astype(np.uint8)
	image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
	mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
	return cv2.addWeighted(image, 1.0, mask, alpha, 0), mask

### Define pathes to [the downloaded dataset](https://universe.roboflow.com/arios-workspace/segmtest/dataset/1) parts

In [5]:
root_dir = "roboflow_dataset"
train_images_dir, val_images_dir, test_images_dir = "train", "valid", "test" 
annot_file_path = "_annotations.coco.json"

train_dataset = BoltDataset(root_dir=root_dir, data_dir=train_images_dir, annot_file_name=annot_file_path)
val_dataset = BoltDataset(root_dir=root_dir, data_dir=val_images_dir, annot_file_name=annot_file_path)
test_dataset = BoltDataset(root_dir=root_dir, data_dir=test_images_dir, annot_file_name=annot_file_path)

100%|██████████| 1068/1068 [00:01<00:00, 537.26it/s, Total processed images: 285]
100%|██████████| 108/108 [00:00<00:00, 504.61it/s, Total processed images: 27]
100%|██████████| 57/57 [00:00<00:00, 599.17it/s, Total processed images: 13]


### Display a particular sample from a particular dataset

In [6]:
colors = {1: [0, 250, 125], 2: [250, 0, 125]}
image_IDX = 10

cv2.imshow('Masked', mask_sample(test_dataset, image_IDX, colors)[0])
cv2.waitKey(0)

cv2.destroyAllWindows()

## Custom model with `segmentation_models_pytorch`

### Create functions for training and validating the model(can be merged in one function)

In [None]:
import segmentation_models_pytorch as smp

def train(train_loader, model, criterion, optimizer, epoch, device: str):
	# Buffer for calculating epoch metrics and total loss
	metrics = {"tp": [], "fp": [], "tn": [], "fn": []}
	total_loss = 0.0

	model.train()
	bar = tqdm(train_loader)
	for _, (images, target) in enumerate(bar, start=1):
		images = images.to(device, non_blocking=True)
		target = target.to(device, non_blocking=True).squeeze(1).long() - 1
		output = model(images)
		
		# Treat low probabilities as 0(value - hyperparameter)
		output = torch.where(output.ge(0.4), output, .0)
		
		loss = criterion(output, target)
		total_loss += loss.item()

		with torch.no_grad():
			tp, fp, fn, tn = smp.metrics.get_stats(output.argmax(dim=1).long(), target, mode='multiclass', num_classes=3, ignore_index=-1)
			metrics["tp"].append(tp)
			metrics["fp"].append(fp)
			metrics["tn"].append(tn)
			metrics["fn"].append(fn)

		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

		bar.set_postfix_str(f"Train / Epoch: {epoch}. Loss {total_loss / _}")

		torch.cuda.empty_cache()

	return 	torch.cat(metrics["tp"], dim=0),\
			torch.cat(metrics["fp"], dim=0),\
			torch.cat(metrics["fn"], dim=0),\
			torch.cat(metrics["tn"], dim=0)
		
def validate(val_loader, model, criterion, epoch, device: str):
	# Buffer for calculating epoch metrics and total loss
	metrics = {"tp": [], "fp": [], "tn": [], "fn": []}
	total_loss = 0.0

	model.eval()
	bar = tqdm(val_loader)
	with torch.no_grad():
		for _, (images, target) in enumerate(bar, start=1):
			images = images.to(device, non_blocking=True)
			target = target.to(device, non_blocking=True).squeeze(1).long() - 1
			output = model(images)

			# Treat low probabilities as 0(value - hyperparameter)
			output = torch.where(output.ge(0.4), output, .0)
			
			loss = criterion(output, target)
			total_loss += loss.item()

			bar.set_postfix_str(f"Validation / Epoch: {epoch}. Loss{total_loss / _}")

			tp, fp, fn, tn = smp.metrics.get_stats(output.argmax(dim=1).long(), target, mode='multiclass', num_classes=3, ignore_index=-1)
			metrics["tp"].append(tp)
			metrics["fp"].append(fp)
			metrics["tn"].append(tn)
			metrics["fn"].append(fn)
			
			torch.cuda.empty_cache()

	return 	torch.cat(metrics["tp"], dim=0),\
			torch.cat(metrics["fp"], dim=0),\
			torch.cat(metrics["fn"], dim=0),\
			torch.cat(metrics["tn"], dim=0)


### Creating and training segmentation model
`Best found backbone options are:` `se_resnext50_32x4d`, `resnetN`

In [8]:
# In case of error about expired ssl certificate
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

# Set random seed for reproducibility
torch.manual_seed(1)

# Creating dataloaders
train_loader = DataLoader(
		train_dataset,
		batch_size = 4,
		shuffle = True,
		pin_memory=True,
	)

val_loader = DataLoader(
	val_dataset,
	batch_size = 4,
	shuffle = False,
	pin_memory = True,
)

# Specify "cuda" as device to train model on(if available)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

# Some hyperparameters
lr=2e-4
num_epochs = 10

# Creating model with paticular architecture and backbone
model = smp.create_model(
			arch="unet",
			encoder_name="se_resnext50_32x4d",
			in_channels=1,
			classes=2,
            activation="sigmoid"
		).to(device)

# Loss criterion and optimizer
criterion = smp.losses.DiceLoss(smp.losses.MULTICLASS_MODE, from_logits=False, log_loss=True, ignore_index=-1)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

Device: cuda


### Start training loop for the defined number of epochs

In [9]:
for epoch in range(1, num_epochs + 1):
	train_tp, train_fp, train_fn, train_tn = train(train_loader, model, criterion, optimizer, epoch, device)
	val_tp, val_fp, val_fn, val_tn = validate(val_loader, model, criterion, epoch, device)

	train_imagewise_iou = smp.metrics.iou_score(train_tp, train_fp, train_fn, train_tn, reduction="micro-imagewise")
	train_dataset_iou = smp.metrics.iou_score(train_tp, train_fp, train_fn, train_tn, reduction="micro")
	train_imagewise_prec = smp.metrics.precision(train_tp, train_fp, train_fn, train_tn, reduction="micro-imagewise")
	train_dataset_prec = smp.metrics.precision(train_tp, train_fp, train_fn, train_tn, reduction="micro")
	train_imagewise_rec = smp.metrics.recall(train_tp, train_fp, train_fn, train_tn, reduction="micro-imagewise")
	train_dataset_rec = smp.metrics.recall(train_tp, train_fp, train_fn, train_tn, reduction="micro")

	val_imagewise_iou = smp.metrics.iou_score(val_tp, val_fp, val_fn, val_tn, reduction="micro-imagewise")
	val_dataset_iou = smp.metrics.iou_score(val_tp, val_fp, val_fn, val_tn, reduction="micro")
	val_imagewise_prec = smp.metrics.precision(val_tp, val_fp, val_fn, val_tn, reduction="micro-imagewise")
	val_dataset_prec = smp.metrics.precision(val_tp, val_fp, val_fn, val_tn, reduction="micro")
	val_imagewise_rec = smp.metrics.recall(val_tp, val_fp, val_fn, val_tn, reduction="micro-imagewise")
	val_dataset_rec = smp.metrics.recall(val_tp, val_fp, val_fn, val_tn, reduction="micro")

	print(f"Train IoU(imagewise): {train_imagewise_iou:.3f}, IoU(dataset): {train_dataset_iou:.3f}, Precision(imagewise): {train_imagewise_prec:.2f}, Precision(dataset): {train_dataset_prec:.2f}, Recall(imagewise): {train_imagewise_rec:.2f}, Recall(dataset): {train_dataset_rec:.2f}")
	print(f"Validation IoU(imagewise): {val_imagewise_iou:.3f}, IoU(dataset): {val_dataset_iou:.3f}, Precision(imagewise): {val_imagewise_prec:.2f}, Precision(dataset): {val_dataset_prec:.2f}, Recall(imagewise): {val_imagewise_rec:.2f}, Recall(dataset): {val_dataset_rec:.2f}")

100%|██████████| 72/72 [01:01<00:00,  1.17it/s, Train / Epoch: 1. Loss 1.07827548806866]  
100%|██████████| 7/7 [00:03<00:00,  1.91it/s, Validation / Epoch: 1. Loss0.5456941212926593]


Train IoU(imagewise): 0.866, IoU(dataset): 0.852, Precision(imagewise): 0.92, Precision(dataset): 0.92, Recall(imagewise): 0.92, Recall(dataset): 0.92
Validation IoU(imagewise): 0.971, IoU(dataset): 0.973, Precision(imagewise): 0.98, Precision(dataset): 0.99, Recall(imagewise): 0.98, Recall(dataset): 0.99


100%|██████████| 72/72 [00:38<00:00,  1.85it/s, Train / Epoch: 2. Loss 0.3926733715666665] 
100%|██████████| 7/7 [00:01<00:00,  6.83it/s, Validation / Epoch: 2. Loss0.5621372929641179]


Train IoU(imagewise): 0.972, IoU(dataset): 0.973, Precision(imagewise): 0.99, Precision(dataset): 0.99, Recall(imagewise): 0.99, Recall(dataset): 0.99
Validation IoU(imagewise): 0.962, IoU(dataset): 0.965, Precision(imagewise): 0.98, Precision(dataset): 0.98, Recall(imagewise): 0.98, Recall(dataset): 0.98


100%|██████████| 72/72 [00:38<00:00,  1.85it/s, Train / Epoch: 3. Loss 0.3095161273247666] 
100%|██████████| 7/7 [00:01<00:00,  6.89it/s, Validation / Epoch: 3. Loss0.36857698219163076]


Train IoU(imagewise): 0.973, IoU(dataset): 0.974, Precision(imagewise): 0.99, Precision(dataset): 0.99, Recall(imagewise): 0.99, Recall(dataset): 0.99
Validation IoU(imagewise): 0.973, IoU(dataset): 0.976, Precision(imagewise): 0.99, Precision(dataset): 0.99, Recall(imagewise): 0.99, Recall(dataset): 0.99


100%|██████████| 72/72 [00:39<00:00,  1.84it/s, Train / Epoch: 4. Loss 0.24172176989830202]
100%|██████████| 7/7 [00:01<00:00,  6.80it/s, Validation / Epoch: 4. Loss0.3562841628279005] 


Train IoU(imagewise): 0.977, IoU(dataset): 0.979, Precision(imagewise): 0.99, Precision(dataset): 0.99, Recall(imagewise): 0.99, Recall(dataset): 0.99
Validation IoU(imagewise): 0.974, IoU(dataset): 0.976, Precision(imagewise): 0.99, Precision(dataset): 0.99, Recall(imagewise): 0.99, Recall(dataset): 0.99


100%|██████████| 72/72 [00:39<00:00,  1.83it/s, Train / Epoch: 5. Loss 0.18873227566170195]
100%|██████████| 7/7 [00:01<00:00,  6.75it/s, Validation / Epoch: 5. Loss0.33380840931619915]


Train IoU(imagewise): 0.979, IoU(dataset): 0.981, Precision(imagewise): 0.99, Precision(dataset): 0.99, Recall(imagewise): 0.99, Recall(dataset): 0.99
Validation IoU(imagewise): 0.975, IoU(dataset): 0.976, Precision(imagewise): 0.99, Precision(dataset): 0.99, Recall(imagewise): 0.99, Recall(dataset): 0.99


100%|██████████| 72/72 [00:39<00:00,  1.82it/s, Train / Epoch: 6. Loss 0.16455889493227005]
100%|██████████| 7/7 [00:01<00:00,  6.72it/s, Validation / Epoch: 6. Loss0.3440100465502058] 


Train IoU(imagewise): 0.982, IoU(dataset): 0.983, Precision(imagewise): 0.99, Precision(dataset): 0.99, Recall(imagewise): 0.99, Recall(dataset): 0.99
Validation IoU(imagewise): 0.975, IoU(dataset): 0.977, Precision(imagewise): 0.99, Precision(dataset): 0.99, Recall(imagewise): 0.99, Recall(dataset): 0.99


100%|██████████| 72/72 [00:39<00:00,  1.83it/s, Train / Epoch: 7. Loss 0.16276814591967398]
100%|██████████| 7/7 [00:01<00:00,  6.72it/s, Validation / Epoch: 7. Loss0.3044664242437908] 


Train IoU(imagewise): 0.983, IoU(dataset): 0.984, Precision(imagewise): 0.99, Precision(dataset): 0.99, Recall(imagewise): 0.99, Recall(dataset): 0.99
Validation IoU(imagewise): 0.976, IoU(dataset): 0.978, Precision(imagewise): 0.99, Precision(dataset): 0.99, Recall(imagewise): 0.99, Recall(dataset): 0.99


100%|██████████| 72/72 [00:39<00:00,  1.81it/s, Train / Epoch: 8. Loss 0.13869249375743997]
100%|██████████| 7/7 [00:01<00:00,  6.72it/s, Validation / Epoch: 8. Loss0.3531170295817511] 


Train IoU(imagewise): 0.985, IoU(dataset): 0.986, Precision(imagewise): 0.99, Precision(dataset): 0.99, Recall(imagewise): 0.99, Recall(dataset): 0.99
Validation IoU(imagewise): 0.975, IoU(dataset): 0.977, Precision(imagewise): 0.99, Precision(dataset): 0.99, Recall(imagewise): 0.99, Recall(dataset): 0.99


100%|██████████| 72/72 [00:39<00:00,  1.83it/s, Train / Epoch: 9. Loss 0.13497281726449728]
100%|██████████| 7/7 [00:01<00:00,  6.84it/s, Validation / Epoch: 9. Loss0.3164951375552586] 


Train IoU(imagewise): 0.985, IoU(dataset): 0.986, Precision(imagewise): 0.99, Precision(dataset): 0.99, Recall(imagewise): 0.99, Recall(dataset): 0.99
Validation IoU(imagewise): 0.975, IoU(dataset): 0.978, Precision(imagewise): 0.99, Precision(dataset): 0.99, Recall(imagewise): 0.99, Recall(dataset): 0.99


100%|██████████| 72/72 [00:39<00:00,  1.82it/s, Train / Epoch: 10. Loss 0.11836361600500014]
100%|██████████| 7/7 [00:01<00:00,  6.73it/s, Validation / Epoch: 10. Loss0.30814724947725025]

Train IoU(imagewise): 0.986, IoU(dataset): 0.987, Precision(imagewise): 0.99, Precision(dataset): 0.99, Recall(imagewise): 0.99, Recall(dataset): 0.99
Validation IoU(imagewise): 0.976, IoU(dataset): 0.978, Precision(imagewise): 0.99, Precision(dataset): 0.99, Recall(imagewise): 0.99, Recall(dataset): 0.99





### Segment a particular image from the test dataset with a curtain level of confidence for the classes 

In [10]:
def mask_postprocess_with_conf(mask: torch.Tensor, confidence: float = 0.9) -> np.ndarray:
	""" Postprocess a mask with per class confidence level. If the confidence for both classes of a pixel are less
		than specified level then it is classified as "background" 

	Args:
		mask (torch.Tensor): predicted mask
		confidence (float, optional): confidence level. Defaults to 0.9.

	Returns:
		np.ndarray: postprocessed mask
	"""
	neg_conf_mask = torch.where(mask.max(dim=1)[0].lt(confidence), -1, 0)
	conf_class_mask = torch.where(mask.ge(confidence), mask, 0).argmax(dim=1)
	conf_mask = conf_class_mask + neg_conf_mask + 1

	return conf_mask

### Segment a particular image from the test dataset

In [12]:
image_IDX = 0	# Index of image from test dataset to segment
conf = 0.95		# Confidence level

# Processing image
model.eval()
with torch.no_grad():
    mask = model(test_dataset[image_IDX][0].unsqueeze(0).to(device)).cpu()

# Preprocessing the mask
conf_mask = mask_postprocess_with_conf(mask, conf)

# Debug info(can be commented out)
unique, counts = np.unique(conf_mask.numpy(), return_counts=True)
print(dict(zip(unique, counts)))

# Show masked image
cv2.imshow('Masked', mask_sample(test_dataset, image_IDX, colors, pred_mask=conf_mask)[0])
cv2.waitKey(0)

cv2.destroyAllWindows()

{0: 169573, 1: 90279, 2: 2292}


### Save segmentation masks and masked images for all samples from test dataset

In [192]:
masked_images_dir = "masked"	# Directory where to save masked images
pred_masks_dir = "pred_masks"	# Directory where to save masks
conf = 0.95						# Confidence level for classes

model.eval()
with torch.no_grad():
	bar = tqdm(range(len(test_dataset)))
	for image_ID in bar:
		# Segmenting each image from test dataset
		pred_mask = model(test_dataset[image_ID][0].unsqueeze(0).to(device)).cpu()
		
		# Preprocessing each mask
		conf_mask = mask_postprocess_with_conf(pred_mask, conf)

		# Getting masked image and colored mask
		masked_image, conf_mask = mask_sample(test_dataset, image_ID, colors, pred_mask=conf_mask)	

		# And finally save everything
		cv2.imwrite(os.path.join(os.path.join(root_dir, masked_images_dir), test_dataset.images_info[image_ID]["file_name"]), masked_image)
		cv2.imwrite(os.path.join(os.path.join(root_dir, pred_masks_dir), test_dataset.images_info[image_ID]["file_name"]), conf_mask)

100%|██████████| 13/13 [00:01<00:00,  8.08it/s]


## Inference with Roboflow API 

### Connect to a curtain project version 

In [144]:
from roboflow import Roboflow

# Private API key
API_KEY = # YOUR_API_KEY

# Connecting to a particular version of the project
rf = Roboflow(api_key=API_KEY)
project = rf.workspace().project("segmtest")
rf_model = project.version(1).model

loading Roboflow workspace...
loading Roboflow project...


### Segment all images from the test dataset and save them with the results

In [191]:
masked_images_dir = "masked"	# Directory where to save masked images
pred_masks_dir = "pred_masks"	# Directory where to save masks

bar = tqdm(range(len(test_dataset)))
for image_ID in bar:
	# Segmenting each image from test dataset with curtain level of confidence
	pred_results = rf_model.predict(os.path.join(os.path.join(test_dataset.data_dir, test_dataset.images_info[image_ID]["file_name"])), confidence=40).json()["predictions"]

	# Creating empty mask with the shape of [IMAGE_HEIGHT, IMAGE_WIDTH, 3]
	image_mask = np.zeros(test_dataset[image_ID][0].permute(1, 2, 0).numpy().shape, dtype=np.uint8)
	
	for pred in pred_results:
		image_mask = draw_mask_polygon(image_mask, pred["points"], test_dataset.class_categories2id[pred["class"]])

	masked_image, mask = mask_sample(test_dataset, image_ID, colors, pred_mask=torch.from_numpy(image_mask).permute(2, 0, 1))

	cv2.imwrite(os.path.join(os.path.join(root_dir, masked_images_dir), test_dataset.images_info[image_ID]["file_name"]), masked_image)
	cv2.imwrite(os.path.join(os.path.join(root_dir, pred_masks_dir), test_dataset.images_info[image_ID]["file_name"]), mask)

100%|██████████| 13/13 [00:12<00:00,  1.05it/s]
