# fine-tune mask-rcnn


In [None]:
from torchvision.datasets import CocoDetection
import torchvision
import torchvision.transforms as transforms
import torch
import matplotlib.pyplot as plt
import tqdm
from tqdm import tqdm
import cv2
# from pycocotools.mask import frPyObjects, decode
import numpy as np

torch.cuda.set_device(7)  # ❗️❗️❗️

In [None]:
from predictor_utils import *
import predictor_utils
from dataset_utils import *
from importlib import reload
import dataset_utils
reload(dataset_utils)

reload(predictor_utils)

In [None]:
batch_size = 32
lr = 3e-4

In [None]:
train_dataset = CocoDetection(root='/workspace/raid/OM_DeepLearning/XMM_OM_code_git/dog-2/train/',
                              annFile='/workspace/raid/OM_DeepLearning/XMM_OM_code_git/dog-2/train/_annotations.coco.json')
val_dataset = CocoDetection(root='/workspace/raid/OM_DeepLearning/XMM_OM_code_git/dog-2/valid/',
                            annFile='/workspace/raid/OM_DeepLearning/XMM_OM_code_git/dog-2/valid/_annotations.coco.json')

In [None]:
def collate_fn(batch):
    images, targets = zip(*batch)
    # new_targets = defaultdict(dict)
    # for i, target in enumerate(targets):
    #     for j, annot in enumerate(target):
    #         new_targets[i['masks'] = annot['segmentation']
    #         new_targets['boxes'] = annot['segmentation']
    #         new_targets['labels'] = annot['segmentation']

    images = [transforms.ToTensor()(image) for image in images]

    return images, targets


def create_mask(points, image_size):
    polygon = [(points[i], points[i+1]) for i in range(0, len(points), 2)]
    mask = np.zeros(image_size, dtype=np.uint8)

    cv2.fillPoly(mask, [np.array(polygon, dtype=np.int32)], 1)
    return mask

train_data_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=collate_fn)
val_data_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=collate_fn)

In [None]:
def get_model_instance_segmentation(num_classes):
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(
        in_features, num_classes)
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 512
    model.roi_heads.mask_predictor = torchvision.models.detection.mask_rcnn.MaskRCNNPredictor(
        in_features_mask, hidden_layer, num_classes)
    return model

model = get_model_instance_segmentation(num_classes=14)
optimizer = torch.optim.Adam(
    params=model.parameters(), lr=lr, betas=(0.9, 0.999))
device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu")
model.to(device)

In [None]:
def adjust_targets(images, targets):
    for target in targets:
        for t in target:
            if isinstance(t['bbox'], list):
                t['bbox'] = torch.as_tensor(t['bbox'], dtype=torch.float32)

    targets_adjusted = []
    for target in targets:
        if isinstance(target, dict):
            target_adjusted = {k: v.to(device) if torch.is_tensor(
                v) else v for k, v in target.items()}
        # If the target is a list of dictionaries
        elif isinstance(target, list) and all(isinstance(t, dict) for t in target):
            target_adjusted = [{k: v.to(device) if torch.is_tensor(
                v) else v for k, v in t.items()} for t in target]
            for t in target:
                for k, v in t.items():
                    if torch.is_tensor(v):
                        v = v.to(device)
        else:
            target_adjusted = target
            print("Target is neither a dictionary nor a list of dictionaries !")
        targets_adjusted.append(target_adjusted)

    for target in targets_adjusted:  # the model expects the target to have the following keys: 'boxes', 'labels', 'masks'
        for tt in target:

            if 'bbox' in tt.keys():
                tt['boxes'] = tt.pop('bbox')
                tt['boxes'] = torch.as_tensor(
                    tt['boxes'], dtype=torch.float32).reshape(-1, 4)

            if 'category_id' in tt.keys():
                tt['labels'] = tt.pop('category_id')
                tt['labels'] = torch.tensor(
                    tt['labels'], dtype=torch.int64, device=device).unsqueeze(0)

            if 'segmentation' in tt.keys():
                tt['masks'] = create_mask(
                    tt['segmentation'][0], (images[0].shape[1], images[0].shape[2]))
                tt['masks'] = torch.tensor(
                    tt['masks'], dtype=torch.float32).unsqueeze(0).unsqueeze(0)

    for target in targets_adjusted:  # XYHW to XYXY
        for t in target:
            boxes = t['boxes']
            boxes[:, 2] += boxes[:, 0]
            boxes[:, 3] += boxes[:, 1]
            t['boxes'] = boxes
    return targets_adjusted

**target** is a list of annotations. each annotation is a list of dictionaries containing the fields:
* id
* image_id
* area
* segmentation : (N, 1, H, W)
* iscrowd
* bbox (N, 4)
* category_id : int

## Fine-tuning

In [None]:
num_epochs = 2

epoch_train_losses = []
epoch_val_losses = []

for epoch in range(num_epochs):
    model.train()
    for images, targets in tqdm(train_data_loader):
        images = torch.stack(images, dim=0).to(device)

        targets_adjusted = adjust_targets(images, targets)
        batch_loss = []
        for i in range(len(targets_adjusted)):
            # mask_copy = targets_adjusted[i][0]['masks'][0][0].detach().cpu().numpy()
            # mask_copy = cv2.cvtColor(mask_copy, cv2.COLOR_GRAY2BGR)
            # mask_copy = mask_copy * 255
            # x_min, y_min, x_max, y_max = targets_adjusted[i][0]['boxes'][0].detach().cpu().numpy()
            # x_min, y_min, x_max, y_max = int(x_min), int(y_min), int(x_max), int(y_max)
            
            # print(x_min, y_min, x_max, y_max)
            # cv2.rectangle(mask_copy, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
            # plt.imshow(mask_copy)
            # plt.show()
        
            # plt.imshow(images[i].permute(1, 2, 0).cpu().numpy())
            # plt.show()
            # plt.close()
            # for j in targets_adjusted[i]:
            #     print(targets_adjusted[i][0]['masks'][0][0].shape)
            #     plt.imshow(targets_adjusted[i][0]['masks'][0][0].detach().cpu().numpy())
            #     plt.show()
            #     plt.close()

            loss_dict = model(images[i].unsqueeze(0), targets_adjusted[i])
            losses = sum(loss for loss in loss_dict.values())
            batch_loss.append(losses)
        
        del targets_adjusted
        torch.cuda.empty_cache()
        
        batch_loss = torch.stack(batch_loss)
        optimizer.zero_grad()
        batch_loss.sum().backward()
        for name, parameter in model.named_parameters():
            if parameter.grad is not None: 
                grad_norm = parameter.grad.norm()
                if grad_norm < 1e-8:  # threshold
                    print(f'❗️Layer {name} has vanishing gradients: {grad_norm}')
        optimizer.step()

    epoch_train_losses.append(batch_loss.mean().detach().cpu().numpy())

    model.eval()
    
    # val_loss_accumulator = []
    # with torch.no_grad():
    #     for images, targets in val_data_loader:
    #         images = torch.stack(images, dim=0).to(device)
    #         targets_adjusted = adjust_targets(images, targets)

    #         batch_loss = []
    #         for i in range(len(targets_adjusted)):
    #             loss_dict = model(images[i].unsqueeze(0), targets_adjusted[i])
    #             print(loss_dict[0].keys())
    #             losses = sum(loss for loss in loss_dict.values())
    #             batch_loss.append(losses)

    #         batch_loss = torch.stack(batch_loss).mean()  # Compute mean loss for the batch
    #         val_loss_accumulator.append(batch_loss.item())

    # epoch_val_loss = sum(val_loss_accumulator) / len(val_loss_accumulator)
    # epoch_val_losses.append(epoch_val_loss)
    # print(f' Epoch {epoch}. Train loss: {batch_loss.mean()}. Validation loss: {epoch_val_loss}')
    
    with torch.no_grad():
        prediction = model(images[len(images)-1].unsqueeze(0))
        predicted_masks = []
        for pred in prediction:
            predicted_masks.append(pred['masks'][0][0].detach().cpu().numpy())
    
        plt.figure(figsize=(10, 10))
        plt.imshow(images[len(images)-1].permute(1, 2, 0).detach().cpu().numpy())
        show_masks(predicted_masks, plt.gca(), random_color=False)
        plt.show()
        plt.close()
        
        # for mask in predicted_masks:
        # plt.imshow(mask[0].detach().cpu().numpy())
        # plt.show()
        # plt.close()
        # predicted_labels = prediction[0]['labels']
        # predicted_scores = prediction[0]['scores']
    
        # predicted_boxes = prediction[0]['boxes']

In [None]:
torch.save(model.state_dict(), f'mask_rcnn_resent_checkpoint.pth')

In [None]:
plt.plot(list(range(len(epoch_train_losses))),
         epoch_train_losses, label='Training Loss')
# plt.plot(list(range(len(valid_bboxes_losses))), valid_bboxes_losses, label='Validation Loss')
plt.title('Mean epoch loss')
plt.xlabel('Epoch Number')
plt.ylabel('Loss')
plt.legend()
plt.savefig('./plots/mask_rcnn.png')
plt.show()