In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt
from pprint import pprint

In [None]:
from torchvision.utils import draw_segmentation_masks
import cv2
from torchvision.transforms import functional as F
import torch
import torchvision
from torchvision.io import read_image
from torchvision.ops import masks_to_boxes
from torchvision.utils import draw_bounding_boxes
import pandas as pd
import MyUtils.Dataset
from MyUtils.visualize import visualize_masks
import albumentations as A
from torch.utils.data import DataLoader
from MyUtils import transforms, utils, engine, train as transforms, utils, engine, train
from MyUtils.utils import collate_fn
from MyUtils.engine import train_one_epoch, evaluate
from MyUtils.plot_statistic import plot_stats

from torchvision.models.detection import maskrcnn_resnet50_fpn_v2, maskrcnn_resnet50_fpn, MaskRCNN
from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights, MaskRCNN_ResNet50_FPN_Weights
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

from tqdm import tqdm

import random

import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from collections import defaultdict

In [None]:
random.seed(42)
torch.manual_seed(42)

path = 'C:/Users/User/Petr/Net_4/Dataset/'
visualise_path = 'C:/Users/User/Petr/Net_4/Visualize/'

obj_colors = {
    'Country road': (255, 136, 0),
    'Asphalt road': (0, 0, 0),
    'Water': (0, 0, 255)
}



#explore = sns.barplot(x=list(dataset_stats.keys()), y=list(dataset_stats.values()))
#explore.bar_label(explore.containers[0], fontsize=10)
#explore.set_title('Pillars with different number of visible corners')
#explore.set_ylabel('Number of Pillars')
#plt.show()

classes = list(set([mask.split('-')[-2] for mask in os.listdir(path + 'masks/') if mask.endswith('.npy')]))

classes.sort()

class_index = {cls: index for index, cls in enumerate(classes)}

inv_classes = {v: k for k, v in class_index.items()}


weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT
weights_v2 = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT

print(f'Dataset classes is {classes}')
print(f'Class index is {class_index}')
print(f'Class index inverted is {inv_classes}')

dataset_stats = defaultdict(int)

for mask in os.listdir(path + 'masks/'):
    if mask.endswith('.npy'):
        dataset_stats[mask.split('-')[-2]] += 1

df = pd.DataFrame(data=dataset_stats.items(), columns=['Classes', 'Counts'])

sns.barplot(data=df, x='Classes', y='Counts')

In [None]:
transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    #A.RandomBrightnessContrast(p=0.5),
    A.RandomSizedBBoxSafeCrop(height=3000, width=3000, erosion_rate=0.0, interpolation=1, always_apply=False, p=0.5),
   #A.ToSepia(always_apply=False, p=0.5),
    A.RGBShift(r_shift_limit=(-20, 20), g_shift_limit=(-20, 20), b_shift_limit=(-20, 20), always_apply=False, p=0.5),
    A.RandomGravel(gravel_roi=(0.1, 0.1, 0.9, 0.9), number_of_patches=2, always_apply=False, p=0.5),
    A.RandomShadow(shadow_roi=(0, 0, 1, 1), always_apply=False, p=0.5),
    #A.Solarize(threshold=(200, 200), always_apply=False, p=0.5),
    A.RandomSnow(snow_point_lower=0.1, snow_point_upper=0.3, brightness_coeff=2.5, always_apply=False, p=1.0)
    #A.CropNonEmptyMaskIfExists(height=3000, width=3000, ignore_values=None, ignore_channels=None, always_apply=False, p=0.5),
    ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['class_labels']))

In [None]:
dataset_train_initial = MyUtils.Dataset.RoadsDataset(root=path, class_index=class_index, transform=transform)
dataset_test_initial = MyUtils.Dataset.RoadsDataset(root=path, class_index=class_index, transform=None)

print(len(dataset_train_initial))

In [None]:
image, target = dataset_train_initial[20]

In [None]:
visualize_masks(image, target, inv_classes, obj_colors=obj_colors, show=True, alpha=0.5)
#visualize_masks(image, show=True)

In [None]:
def get_mask_model(weights_path=None):
    model = maskrcnn_resnet50_fpn(weights=weights)
    # Get the number of input features for the classifier
    in_features_box = model.roi_heads.box_predictor.cls_score.in_features
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels

# Get the numbner of output channels for the Mask Predictor
    dim_reduced = model.roi_heads.mask_predictor.conv5_mask.out_channels

# Replace the box predictor
    model.roi_heads.box_predictor = FastRCNNPredictor(in_channels=in_features_box, num_classes=len(classes))

# Replace the mask predictor
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_channels=in_features_mask, dim_reduced=dim_reduced, num_classes=len(classes))

    if weights_path:
        state_dict = torch.load(weights_path)
        model.load_state_dict(state_dict)   

# Set the model's device and data type
    #model.to(device=device)

# Add attributes to store the device and model name for later reference
    #model.device = device
    model.name = 'maskrcnn_resnet50_fpn_v2'
    return model

def get_mask_model_v2(weights_path=None):
    model = maskrcnn_resnet50_fpn_v2(weights=weights_v2)
    # Get the number of input features for the classifier
    in_features_box = model.roi_heads.box_predictor.cls_score.in_features
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels

# Get the numbner of output channels for the Mask Predictor
    dim_reduced = model.roi_heads.mask_predictor.conv5_mask.out_channels

# Replace the box predictor
    model.roi_heads.box_predictor = FastRCNNPredictor(in_channels=in_features_box, num_classes=len(classes))

# Replace the mask predictor
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_channels=in_features_mask, dim_reduced=dim_reduced, num_classes=len(classes))

    if weights_path:
        state_dict = torch.load(weights_path)
        model.load_state_dict(state_dict)   

# Set the model's device and data type
    #model.to(device=device)

# Add attributes to store the device and model name for later reference
    #model.device = device
    model.name = 'maskrcnn_resnet50_fpn_v2'
    return model

In [None]:
save_path = 'C:/Users/User/Petr/Net_4/save_model'
log_path = 'C:/Users/User/Petr/Net_4/Metric_log'

indices = torch.randperm(len(dataset_train_initial)).tolist()
thirty_pc = int(len(dataset_train_initial) * 0.10)
dataset_train = torch.utils.data.Subset(dataset_train_initial, indices[:-thirty_pc])
dataset_test = torch.utils.data.Subset(dataset_test_initial, indices[-thirty_pc:])

data_loader_train = DataLoader(dataset_train, batch_size=1, shuffle=True, collate_fn=collate_fn)
data_loader_test = DataLoader(dataset_test, batch_size=1, shuffle=False, collate_fn=collate_fn)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') #torch.device('cuda') if torch.cuda.is_available() else 

model = get_mask_model_v2(weights_path=f'{save_path}/weights_200.pth') #weights_path=f'{save_path}/weights_50.pth'

model.to(device)

params = [p for p in model.parameters() if p.requires_grad]
    #params1 = [p for p in model.roi_heads.keypoint_head.parameters() if p.requires_grad]
    
optimizer = torch.optim.SGD(params, lr=1.2e-7, momentum=0.90)#, weight_decay=0.0001
    
    #optimizer = torch.optim.SGD([{'params': params1},
    #                             {'params': model.roi_heads.keypoint_predictor.parameters(), 'lr': .001},], lr=0.001, momentum=0.90)#, weight_decay=0.0001
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.8)
num_epochs = 301

bbox_stats = []
masks_stats=[]

loss_bb = []
loss_masks = []
loss = []

start_from = 201


for epoch in range(start_from, num_epochs):
    logger = train_one_epoch(model, optimizer, data_loader_train, device, epoch, print_freq=int(len(dataset_train) / 10))
    lr_scheduler.step()
    evaluator = evaluate(model, data_loader_test, device)

    bbox_stats.append(evaluator.coco_eval['bbox'].stats[:6])
    masks_stats.append(evaluator.coco_eval['segm'].stats[:6])

    loss_masks.append(logger.meters['loss_mask'].global_avg)
    loss_bb.append(logger.meters['loss_box_reg'].global_avg)
    loss.append(logger.meters['loss'].global_avg)
        
    if (epoch+1) % 10 == 0:
        torch.save(model.state_dict(), f'{save_path}/weights_{epoch+1}.pth')
        plot_stats(epoch - start_from + 1, bbox_stats, loss_bb, loss, masks_stats=masks_stats, loss_masks=loss_masks, num=epoch+1, log_path=log_path)#, num=num
            
plot_stats(num_epochs - start_from, bbox_stats, loss_bb, loss, masks_stats=masks_stats, loss_masks=loss_masks, num=epoch+1, show=True, log_path=log_path)

In [None]:
torch.save(model.state_dict(), f'{save_path}/weights_50.pth')

In [None]:
test_path = 'C:/Users/User/Petr/Net_4/test_model/images/'
out_path = 'C:/Users/User/Petr/Net_4/test_model/pred/'

save_path = 'C:/Users/User/Petr/Net_4/save_model'
model = get_mask_model_v2(weights_path=f'{save_path}/weights_70.pth')

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model.to(device)
threshold = 0.5


for image in tqdm(os.listdir(test_path)):
    img = cv2.imdecode(np.fromfile(os.path.join(test_path, image), dtype=np.uint8), cv2.IMREAD_UNCHANGED) #cv2.imread(os.path.join(test_path, image))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    img_orig = cv2.imdecode(np.fromfile(os.path.join(test_path, image), dtype=np.uint8), cv2.IMREAD_UNCHANGED)
    img_orig = cv2.cvtColor(img_orig, cv2.COLOR_BGR2RGB)
    
    img = F.to_tensor(img)
    type(img)
    img = img.to(device)
    with torch.no_grad():
        #model_test.to(device)
        model.eval()
        out = model([img,])
        out = out[0]
    #img = (img[0].permute(1,2,0).detach().cpu().numpy() * 255).astype(np.uint8)

    scores_valid = out['scores'] > threshold

    target = {}
    target['masks'] = out['masks'][scores_valid]
    target['boxes'] = out['boxes'][scores_valid]
    target['labels'] = out['labels'][scores_valid]

    visualize_masks(img.detach().cpu(), target=target, inv_classes=inv_classes, obj_colors=obj_colors, alpha=0.5, save_path=out_path + image)

In [None]:
model = get_mask_model()
model.eval()
pred = model([image.to(device)])
output = pred[0]

visualize_masks(image.detach().cpu(), output, inv_classes, obj_colors=obj_colors, show=True)

In [None]:
model_v2 = get_mask_model_v2()
model_v2.eval()
pred_v2 = model_v2([image.to(device)])
output_v2 = pred_v2[0]

visualize_masks(image.detach().cpu(), output_v2, inv_classes, obj_colors=obj_colors, show=True)

In [None]:
loss = []
loss_2 = [1]

windows = 1 + int(bool(loss)) + int(bool(loss_2))

In [None]:
windows