In [1]:
#!pip3 install torch torchvision torchaudio --upgrade -q
#!pip uninstall nvidia_cublas_cu11 -y -q

In [8]:
# Dataframe manipulation
import pandas as pd
import numpy as np
import math

# PyTorch
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset

# TorchVision
from torchvision.models.detection import maskrcnn_resnet50_fpn_v2
import torchvision as tv
from torchvision.transforms import functional as TF

# Sklearn
from sklearn.model_selection import train_test_split

# Tracking
import wandb

# Images manipulation
from skimage import measure
import nibabel as nib
import pydicom as dicom
import cv2 as cv
from PIL import Image

# Visualization
import matplotlib.pyplot as plt
import matplotlib.patches as patches

# System interaction
import os
import gc
from pathlib import Path
import io

# Parallelization & time tests
from multiprocessing.pool import ThreadPool, Pool
import threading
import time

# Other
import base64
import warnings
warnings.filterwarnings('ignore')
NoneType = type(None)

In [9]:
torch.manual_seed(42)
np.random.seed(42)
torch.cuda.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

In [10]:
import dotenv
dotenv.load_dotenv('.env')

True

In [24]:
wandb.login()

# Initialize new experiment
wandb.init(
    project = 'test-logger',#'rsna-mask-r-cnn',
    notes = """
        Mask RCNN: Semantic
        Optimizer: Adam/0.002/None/0.00001
        Scheduler: CosineLR/0.0002
        Backbone: RN101-FPN from DLV3 (COCO)
        Anchors: (8,), (16,), (32,), (64,), (128,)
        Strategy: Train all
        Augmentation: None
    """,
    tags = ['baseline', 'mask-rcnn', 'kaggle'],
#          resume = 'must',
#          id = '1c6ogt8c',
    save_code = True
)

# Initialize config for this experiment
wandb.config = {
    'batch_size' : 8,
    'val_batch_size' : 8,
    'num_workers' : os.cpu_count(),
    # Only for code test. Training ratio: 0.935/0.033/0.032
    'train_size' : 0.34,
    'val_size' : 0.33,
    'test_size' : 0.33,
    'device' : 'cuda:0' if torch.cuda.is_available() else 'cpu',
    'n_epochs' : 10generate_dataframes
}



In [25]:
print(wandb.config['device'])

cuda:0


## Dataset preparation

In [26]:
available_uids = [uid.replace('.nii', '') for uid in os.listdir('../data/rsna-2022-cervical-spine-fracture-detection/segmentations/')]

In [27]:
meta_segm = pd.read_csv('../data/rsna-2022-spine-fracture-detection-metadata/meta_segmentation_clean.csv')
meta_segm = meta_segm[meta_segm.StudyInstanceUID.isin(available_uids)]
meta_segm.head()

Unnamed: 0,StudyInstanceUID,Slice,ImageHeight,ImageWidth,SliceThickness,ImagePositionPatient_x,ImagePositionPatient_y,ImagePositionPatient_z,C1,C2,C3,C4,C5,C6,C7
0,1.2.826.0.1.3680043.10633,1,512,512,1.0,-68.0,98.0,314.099976,0,0,0,0,0,0,0
1,1.2.826.0.1.3680043.10633,2,512,512,1.0,-68.0,98.0,313.599976,0,0,0,0,0,0,0
2,1.2.826.0.1.3680043.10633,3,512,512,1.0,-68.0,98.0,313.099976,0,0,0,0,0,0,0
3,1.2.826.0.1.3680043.10633,4,512,512,1.0,-68.0,98.0,312.599976,0,0,0,0,0,0,0
4,1.2.826.0.1.3680043.10633,5,512,512,1.0,-68.0,98.0,312.099976,0,0,0,0,0,0,0


In [28]:
meta_segm = meta_segm[(meta_segm.iloc[:, 8:] != 0).any(1)]

In [29]:
# PATH VARS

dicom_path = Path('../data/rsna-2022-cervical-spine-fracture-detection/train_images/')
segm_path = Path('../data/rsna-2022-cervical-spine-fracture-detection/segmentations/')
checkpoint_path = Path('./checkpoints/')

In [30]:
# True means need to flip Z axis, False otherwise
orientation_check = {}
for uid in meta_segm.StudyInstanceUID.unique():
    dcm1 = dicom.dcmread(dicom_path / uid / (str(10) + '.dcm'))
    dcm2 = dicom.dcmread(dicom_path / uid / (str(20) + '.dcm'))
    if (dcm1.ImagePositionPatient[2] - dcm2.ImagePositionPatient[2]) > 0:
        orientation_check[uid] = True
    else:
        orientation_check[uid] = False

In [31]:
masks = {}
for uid in meta_segm.StudyInstanceUID.unique():
    mask = nib.load(segm_path / (uid + '.nii'))
    mask = np.asarray(mask.get_data())
    if orientation_check[uid]:
        mask = mask[:, :, ::-1]
    mask = np.rot90(mask, k=1, axes=(0, 1))
    masks[uid] = mask

In [33]:
"""
Split patients into train/val/test
"""

train_UIDs, test_val_UIDs = train_test_split(meta_segm.StudyInstanceUID.unique(),
                                                   test_size=wandb.config['val_size'] + wandb.config['test_size'],
                                                   random_state=42)
val_UIDs, test_UIDs = train_test_split(test_val_UIDs, test_size= wandb.config['test_size'] / (wandb.config['test_size'] + wandb.config['val_size']),
                                       random_state=42)

print(f'Number of UIDs in train: {len(train_UIDs)}')
print(f'Number of UIDs in val: {len(val_UIDs)}')
print(f'Number of UIDs in test: {len(test_UIDs)}')

train_UIDs, val_UIDs, test_UIDs = train_UIDs.tolist(), val_UIDs.tolist(), test_UIDs.tolist()

Number of UIDs in train: 1
Number of UIDs in val: 1
Number of UIDs in test: 1


In [34]:
print(val_UIDs)

['1.2.826.0.1.3680043.10633']


In [35]:
"""
Write train/val/test spilts
"""

def write_UIDs_in_txt(UIDs, txt_name):
    with open(txt_name, 'w') as file:
        file.write('\n'.join(UIDs))
        
def read_UIDs_from_txt(txt_name):
    with open(txt_name, 'r') as file:
        return file.read().split('\n')

# When train on kaggle, must be True
write_files = True
if write_files:   # switch to True to rewrite train/val/test split
    write_UIDs_in_txt(train_UIDs, 'train_UIDs.txt')
    write_UIDs_in_txt(val_UIDs, 'val_UIDs.txt')
    write_UIDs_in_txt(test_UIDs, 'test_UIDs.txt')
    write_files = False

train_UIDs = read_UIDs_from_txt('train_UIDs.txt')
val_UIDs = read_UIDs_from_txt('val_UIDs.txt')
test_UIDs = read_UIDs_from_txt('test_UIDs.txt')

# Log these files
"""wandb.save('./train_UIDs.txt')
wandb.save('./val_UIDs.txt')
wandb.save('./test_UIDs.txt')"""

"wandb.save('./train_UIDs.txt')\nwandb.save('./val_UIDs.txt')\nwandb.save('./test_UIDs.txt')"

In [36]:
train_df = meta_segm[meta_segm.StudyInstanceUID.isin(train_UIDs)]
val_df = meta_segm[meta_segm.StudyInstanceUID.isin(val_UIDs)]
test_df = meta_segm[meta_segm.StudyInstanceUID.isin(test_UIDs)]

print(f'Shape of train dataframe: {train_df.shape}')
print(f'Shape of val dataframe: {val_df.shape}')
print(f'Shape of test dataframe: {test_df.shape}')

Shape of train dataframe: (129, 15)
Shape of val dataframe: (222, 15)
Shape of test dataframe: (195, 15)


## Custom PyTorch Dataset

In [37]:
"""
Custom PyTorch Dataset
"""

class RSNADataset(Dataset):
    def __init__(self, dataframe, dicom_path, segm_path, batch_size, num_workers, device, masks, transform = None) -> None:
        super(RSNADataset, self).__init__()
        self.dataframe = dataframe
        self.dicom_path = dicom_path
        self.segm_path = segm_path
        self.masks = masks
        
        self.batch_index = -1
        
        self.batch_size = batch_size
        self.device = device
        self.num_workers = num_workers
        self.transform = transform
        
        self.timer = 0
        
        # These variables are used to keep data of mask and save computational resources.
        self.current_mask_UID = None
        self.current_mask_data = None
        
    def load_dicom(self, path):
        # Source: https://www.kaggle.com/code/vslaykovsky/pytorch-effnetv2-vertebrae-detection-acc-0-95
        img=dicom.dcmread(path)
        img.PhotometricInterpretation = 'YBR_FULL'
        data = img.pixel_array
        data = data - np.min(data)
        if np.max(data) != 0:
            data = data / np.max(data)
        data=(data * 255).astype(np.uint8)
        return cv.cvtColor(data, cv.COLOR_GRAY2RGB), img
        
    def generate_data_batch(self, ind):
        
        # Targets - y
        # Targets is array of dictionaries. Each dictionary contains masks, bounding boxes and labels for 1 record of patient data
        self.targets = []
        
        # Imgs - X
        # Imgs contains DICOM images. The shape is (batch_size, 3, 512, 512)
        self.imgs = torch.empty((0, 3, 512, 512), dtype=torch.float32)#np.array([], dtype=np.uint8).reshape(0, 3, 512, 512)
        
        # Get records from dataframe about what data should be loaded in the current batch
        batches = self.dataframe.iloc[ind*self.batch_size:(ind+1)*self.batch_size, :].to_numpy()
        
        # Multithreading for loading data
        
        for batch in batches:
            self.extract_data_from_batch(batch)
        
        # Convert images to tensors
        #self.imgs = torch.as_tensor(self.imgs, dtype=torch.float32)
        
        if self.imgs.shape[0] == 0:
            print('no imgs')
            return self.__getitem__()
        
        # Return batch data, e.g. X and y
        return self.imgs.to(self.device, non_blocking=True), self.targets
    
    def transform_function(self, img, mask):
        transform_config = self.transform
        
        img, mask = TF.to_pil_image(img), TF.to_pil_image(mask.astype(np.uint8))
        
        if np.random.random() < transform_config['p_original']:
            pass
        else:
            # Horizontal flip
            if transform_config['p_hflip'] < np.random.random():
                img, mask = TF.hflip(img), TF.hflip(mask)
            
            # Affine
            if transform_config['p_affine'] < np.random.random():
                affine_params = tv.transforms.RandomAffine(30).get_params((-15,15), (0.1, 0.1), (1, 1), (-15, 15), (512, 512))
                img = TF.affine(img, *affine_params)
                mask = TF.affine(mask, *affine_params)
                
            if transform_config['p_cjitter'] < np.random.random():
                img = tv.transforms.ColorJitter(*(0.3, 0.3, 0.1, 0.1))(img)
        
        return np.asarray(img), np.asarray(mask)
    
    def extract_data_from_batch(self, batch):
        # Define initial structure of y
        target = {
            'boxes':[],
            'labels':None,
            'masks':[]
        }
        
        # Get mask from array of masks
        mask = self.masks[batch[0]][:, :, batch[1] - 1]
        
        # Get vertebrae numbers then is presented on mask. C1 - 1, C2 - 2 ...
        labels = np.unique(mask)[1:]
        
        labels_to_dict = []
        
        # if there is no vertebrae on mask
        if len(labels) < 2:
            return None
        
        # Load Dicom image
        img = self.load_dicom(self.dicom_path / batch[0] / (str(batch[1]) + '.dcm'))[0]
        
        # Apply transformations
        if self.transform:
            img, mask = self.transform_function(img, mask)
        
        for label in labels:
            label_mask = (mask == label).astype(np.uint8)
            
            # Here if we have some parts of vertebra on the mask are placed separately, they should be processed as different object (due to architecture of Mask R-CNN)
#             segmented_mask = measure.label(label_mask)
            
#             for e_object in np.unique(segmented_mask)[1:]:
#                 object_mask = (segmented_mask == e_object).astype(np.uint8)
            
            x, y, w, h = cv.boundingRect(label_mask)
                
            if sum([x, y, w, h]) == 0:
                print('cv.boundingRect not found mask')
                continue
            elif w <  5 and h < 5:
                continue
                
            if label.item() < 8:
                labels_to_dict.append(label.item())
            else:
                labels_to_dict.append(0) # 0 is the class for other vertebraes that can be faced on the mask with C1-C7
            target['masks'].append(label_mask)
            target['boxes'].append(np.array([x, y, x+w, y+h]))
                
        # Convert to tensor
        target['labels'] = torch.from_numpy(np.array(labels_to_dict).astype(np.int64)).to(self.device, non_blocking=True)
        target['boxes'] = torch.from_numpy(np.array(target['boxes']).astype(np.float32)).to(self.device, non_blocking=True)
        target['masks'] = torch.from_numpy(np.array(target['masks']).astype(np.float32)).to(self.device, non_blocking=True)
        
        img = tv.transforms.ToTensor()(img).unsqueeze(dim=0)
        
        self.imgs = torch.cat([self.imgs, img], dim=0)
        self.targets.append(target)
        
    def __len__(self):
        return int(len(self.dataframe) / self.batch_size)
    
    # the internal counter of batches
    def reset_batch(self):
        self.dataframe = self.dataframe.sample(frac=1).reset_index(drop=True)
        self.batch_index = -1
        
    # Check are there batches or it's end of dataframe and we have to do next iteration
    def is_end(self):
        return False if self.batch_index < self.__len__() - 1 else True
    
    def __getitem__(self) -> None:
        if self.is_end():
            return None, None
        
        self.batch_index += 1
        return self.generate_data_batch(self.batch_index)

In [38]:
train_df = train_df.sample(frac=1).reset_index(drop = True)
val_df = val_df.sample(frac=1).reset_index(drop = True)

transform = {
    'p_original' : 1,
    'p_hflip' : 0,
    'p_affine' : 0,
    'p_cjitter' : 0
}

train_dataset = RSNADataset(train_df, dicom_path, segm_path, wandb.config['batch_size'], wandb.config['num_workers'], wandb.config['device'], masks, transform)
val_dataset = RSNADataset(val_df, dicom_path, segm_path, wandb.config['val_batch_size'], wandb.config['num_workers'], wandb.config['device'], masks, transform)

In [39]:
print(train_df.shape)
print(val_df.shape)

(129, 15)
(222, 15)


## Visualizer

In [40]:
plt.ioff()

class Visualizer:
    def __init__(self, dataframe, masks, confidence_threshold = 0.15):
        self.dataframe = dataframe.sample(frac=1).reset_index(drop=True)
        self.confidence_threshold = confidence_threshold
        self.dataset = RSNADataset(self.dataframe, dicom_path, segm_path, 1, wandb.config['num_workers'], wandb.config['device'], masks, None)
        self.ids_random = self.dataframe[(self.dataframe.iloc[:, 8:] != 0).any(axis=1)].sample(16).index
        self.gt_color = [0, 255, 0, 128]
        self.pr_color = [0, 0, 255, 128]
    
    def visualize(self, model):
        i = 0
        
        model.eval()
        
        fig, axs = plt.subplots(4, 4, figsize = (20, 20))
        
        while i < 16:
            
            batch_index = self.ids_random[i]
            self.dataset.batch_index = batch_index

            X, y = self.dataset.__getitem__()
            if isinstance(X, NoneType):
                self.ids_random = self.dataframe[(self.dataframe.iloc[:, 8:] != 0).any(axis=1)].sample(1).index
                continue
            
            with torch.cuda.amp.autocast():
                preds = model(X)

            if len(preds) == 0:
                continue

            preds = preds[0]
            # RGBA image with alpha channel for masks
            rgba_img = np.zeros((X.shape[2], X.shape[3], 4))

            if len(y) == 0:
                continue

            y = y[0]

            ax = axs[i//4, i%4]

            # Orig image
            img = np.transpose((X.squeeze(dim = 0).cpu().numpy() * 255.0).astype(np.uint8), (1, 2, 0)).copy()
            
            # Scores
            scores = preds['scores'].detach().cpu().numpy().astype(np.float32)
            score_inds = np.where(scores > self.confidence_threshold)[0]
            
            # GT Bounding boxes
            gt_bounding_boxes = y['boxes'].cpu().numpy().astype(np.int32).tolist()
            pr_bounding_boxes = preds['boxes'].detach().cpu().numpy().astype(np.int32)[score_inds, :].tolist()

            for contour in gt_bounding_boxes:
                if len(contour) == 0:
                    continue
                cv.rectangle(img, (contour[0], contour[1]), (contour[2], contour[3]), self.gt_color[:3], 2)

            for contour in pr_bounding_boxes:
                if len(contour) == 0:
                    continue
                cv.rectangle(img, (contour[0], contour[1]), (contour[2], contour[3]), self.pr_color[:3], 2)

            # Get GT Masks
            gt_masks = y['masks'].sum(axis=0).cpu().numpy().astype(np.uint8)
            pr_masks = preds['masks'][score_inds, :, :, :].sum(axis=0).squeeze().detach().cpu().numpy().astype(np.uint8)

            # Convert numpy array of image into PIL Image object with alpha channel
            img_pil = Image.fromarray(cv.cvtColor(img, cv.COLOR_RGB2RGBA))

            # Add to image GT masks
            gt_mask_pil = cv.cvtColor(gt_masks, cv.COLOR_GRAY2RGB)
            pr_mask_pil = cv.cvtColor(pr_masks, cv.COLOR_GRAY2RGB)

            gt_mask_pil = np.dstack([gt_mask_pil, gt_mask_pil[:, :, 0]])
            pr_mask_pil = np.dstack([pr_mask_pil, pr_mask_pil[:, :, 0]])

            gt_mask_pil = gt_mask_pil * np.array(self.gt_color)
            pr_mask_pil = pr_mask_pil * np.array(self.pr_color)

            gt_mask_pil = Image.fromarray(gt_mask_pil.astype(np.uint8))
            pr_mask_pil = Image.fromarray(pr_mask_pil.astype(np.uint8))

            img_pil.paste(gt_mask_pil, mask = gt_mask_pil)
            img_pil.paste(pr_mask_pil, mask = pr_mask_pil)

            # Plot image
            ax.imshow(img_pil)

            # Labels
            gt_labels = y['labels'].cpu().numpy().tolist()
            pr_labels = preds['labels'][score_inds].detach().cpu().numpy().tolist()

            #pred_labels = ...
            ax.set_title(f'{gt_labels}|{pr_labels if len(pr_labels) < 5 else str(pr_labels)[:16] + "..."}')

            i += 1
            
        wandb.log({"my_plot": fig})
        model.train()

## Evaluator

In [41]:
class Evaluator:
    def __init__(self, prefix, metrics_to_log):
        self.metrics_to_log = metrics_to_log
        self.prefix = prefix
        self.reset()
        
    def reset(self):
        self.update_counter = 0
        self.metrics = {key: 0 for key in self.metrics_to_log}
        
    def log(self, metrics):
        self.update_counter += 1
        for key in metrics.keys():
            self.metrics[key] += metrics[key]
            
    def send_logs(self):
        if self.update_counter != 0:
            
            self.metrics = {self.prefix + key : (value / self.update_counter) for key, value in self.metrics.items()}
            wandb.log(self.metrics)
            
            self.reset()
        
    # Code source: https://github.com/matterport/Mask_RCNN/blob/master/samples/shapes/train_shapes.ipynb
    def compute_overlaps_masks(self, masks1, masks2): #
        """Computes IoU overlaps between two sets of masks.
        masks1, masks2: [Height, Width, instances]
        """

        # If either set of masks is empty return empty result
        if masks1.shape[-1] == 0 or masks2.shape[-1] == 0:
            return np.zeros((masks1.shape[-1], masks2.shape[-1]))
        # flatten masks and compute their areas
        masks1 = np.reshape(masks1 > .5, (-1, masks1.shape[-1])).astype(np.float32)
        masks2 = np.reshape(masks2 > .5, (-1, masks2.shape[-1])).astype(np.float32)
        area1 = np.sum(masks1, axis=0)
        area2 = np.sum(masks2, axis=0)

        # intersections and union
        intersections = np.dot(masks1.T, masks2)
        union = area1[:, None] + area2[None, :] - intersections
        overlaps = intersections / union

        return overlaps
    
    def trim_zeros(self, x): #
        """It's common to have tensors larger than the available data and
        pad with zeros. This function removes rows that are all zeros.

        x: [rows, columns].
        """
        assert len(x.shape) == 2
        return x[~np.all(x == 0, axis=1)]
    
    def compute_matches(self, gt_boxes, gt_class_ids, gt_masks, #
                    pred_boxes, pred_class_ids, pred_scores, pred_masks,
                    iou_threshold=0.5, score_threshold=0.0):
        """Finds matches between prediction and ground truth instances.

        Returns:
            gt_match: 1-D array. For each GT box it has the index of the matched
                      predicted box.
            pred_match: 1-D array. For each predicted box, it has the index of
                        the matched ground truth box.
            overlaps: [pred_boxes, gt_boxes] IoU overlaps.
        """
        # Trim zero padding
        # TODO: cleaner to do zero unpadding upstream
        gt_boxes = trim_zeros(gt_boxes)
        gt_masks = gt_masks[..., :gt_boxes.shape[0]]
        pred_boxes = trim_zeros(pred_boxes)
        pred_scores = pred_scores[:pred_boxes.shape[0]]
        # Sort predictions by score from high to low
        indices = np.argsort(pred_scores)[::-1]
        pred_boxes = pred_boxes[indices]
        pred_class_ids = pred_class_ids[indices]
        pred_scores = pred_scores[indices]
        pred_masks = pred_masks[..., indices]

        # Compute IoU overlaps [pred_masks, gt_masks]
        overlaps = compute_overlaps_masks(pred_masks, gt_masks)

        # Loop through predictions and find matching ground truth boxes
        match_count = 0
        pred_match = -1 * np.ones([pred_boxes.shape[0]])
        gt_match = -1 * np.ones([gt_boxes.shape[0]])
        for i in range(len(pred_boxes)):
            # Find best matching ground truth box
            # 1. Sort matches by score
            sorted_ixs = np.argsort(overlaps[i])[::-1]
            # 2. Remove low scores
            low_score_idx = np.where(overlaps[i, sorted_ixs] < score_threshold)[0]
            if low_score_idx.size > 0:
                sorted_ixs = sorted_ixs[:low_score_idx[0]]
            # 3. Find the match
            for j in sorted_ixs:
                # If ground truth box is already matched, go to next one
                if gt_match[j] > -1:
                    continue
                # If we reach IoU smaller than the threshold, end the loop
                iou = overlaps[i, j]
                if iou < iou_threshold:
                    break
                # Do we have a match?
                if pred_class_ids[i] == gt_class_ids[j]:
                    match_count += 1
                    gt_match[j] = i
                    pred_match[i] = j
                    break

        return gt_match, pred_match, overlaps
    
    def compute_ap(self, gt_boxes, gt_class_ids, gt_masks,
               pred_boxes, pred_class_ids, pred_scores, pred_masks,
               iou_threshold=0.5):
        """Compute Average Precision at a set IoU threshold (default 0.5).

        Returns:
        mAP: Mean Average Precision
        precisions: List of precisions at different class score thresholds.
        recalls: List of recall values at different class score thresholds.
        overlaps: [pred_boxes, gt_boxes] IoU overlaps.
        """
        # Get matches and overlaps
        gt_match, pred_match, overlaps = compute_matches(
            gt_boxes, gt_class_ids, gt_masks,
            pred_boxes, pred_class_ids, pred_scores, pred_masks,
            iou_threshold)

        # Compute precision and recall at each prediction box step
        precisions = np.cumsum(pred_match > -1) / (np.arange(len(pred_match)) + 1)
        recalls = np.cumsum(pred_match > -1).astype(np.float32) / len(gt_match)

        # Pad with start and end values to simplify the math
        precisions = np.concatenate([[0], precisions, [0]])
        recalls = np.concatenate([[0], recalls, [1]])

        # Ensure precision values decrease but don't increase. This way, the
        # precision value at each recall threshold is the maximum it can be
        # for all following recall thresholds, as specified by the VOC paper.
        for i in range(len(precisions) - 2, -1, -1):
            precisions[i] = np.maximum(precisions[i], precisions[i + 1])

        # Compute mean AP over recall range
        indices = np.where(recalls[:-1] != recalls[1:])[0] + 1
        mAP = np.sum((recalls[indices] - recalls[indices - 1]) *
                     precisions[indices])

        return mAP, precisions, recalls, overlaps
    
    def compute_ap_range(self, gt_box, gt_class_id, gt_mask,
                     pred_box, pred_class_id, pred_score, pred_mask,
                     iou_thresholds=None, verbose=1):
        """Compute AP over a range or IoU thresholds. Default range is 0.5-0.95."""
        # Default is 0.5 to 0.95 with increments of 0.05
        iou_thresholds = iou_thresholds or np.arange(0.5, 1.0, 0.05)

        # Compute AP over range of IoU thresholds
        AP = []
        for iou_threshold in iou_thresholds:
            ap, precisions, recalls, overlaps =\
                compute_ap(gt_box, gt_class_id, gt_mask,
                            pred_box, pred_class_id, pred_score, pred_mask,
                            iou_threshold=iou_threshold)
            AP.append(ap)
        mAP = np.array(AP).mean()
        return AP, mAP

## Model implementation

In [42]:
dlv3 = tv.models.segmentation.deeplabv3.deeplabv3_resnet101(weights = tv.models.segmentation.DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1)
dlv3_keys = list(dlv3.get_submodule('backbone').state_dict().keys())

rn101_original = tv.models.resnet.resnet101()
del rn101_original.fc
rn101_keys = list(rn101_original.state_dict().keys())

rn101_original.load_state_dict(dlv3.get_submodule('backbone').state_dict())

Downloading: "https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth" to /home/agavrilko/.cache/torch/hub/checkpoints/deeplabv3_resnet101_coco-586e9e4e.pth
100.0%


<All keys matched successfully>

In [43]:
from torchvision.models.detection.backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
from torchvision.models.detection.anchor_utils import AnchorGenerator
from torchvision.ops import MultiScaleRoIAlign
from torchvision.models.resnet import resnet101
from torchvision.models.detection.mask_rcnn import MaskRCNN, MaskRCNNHeads
from torchvision.models.detection.faster_rcnn import _default_anchorgen, RPNHead, FastRCNNConvFCHead

trainable_backbone_layers = None
num_classes = 8
is_trained = False

trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)

# Backbone
backbone = rn101_original
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers, norm_layer = nn.BatchNorm2d)

# Anchor generator
anchor_sizes = ((8,), (16,), (32,), (64,), (128,))
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
rpn_anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)

# RPN Module
rpn_head = RPNHead(backbone.out_channels, rpn_anchor_generator.num_anchors_per_location()[0], conv_depth=2)

# Box Module
box_roi_pool = MultiScaleRoIAlign(featmap_names=["0"], output_size=7, sampling_ratio=2)
box_head = FastRCNNConvFCHead(
        (backbone.out_channels, 7, 7), [256, 256, 256, 256], [1024], norm_layer=nn.BatchNorm2d
)

# Mask Module
mask_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
mask_head = MaskRCNNHeads(backbone.out_channels, [256, 256, 256, 256], 1, norm_layer=nn.BatchNorm2d)

In [44]:
model = MaskRCNN(
    backbone,
    num_classes=num_classes,
    
    # transform parameters ------------------
    min_size=500,
    max_size=600,
    image_mean=None,
    image_std=None,
    
    # RPN parameters -------------------------
    rpn_anchor_generator=rpn_anchor_generator,
    rpn_head=rpn_head,
    rpn_pre_nms_top_n_train=500 * 4,
    rpn_pre_nms_top_n_test=250 * 4,
    rpn_post_nms_top_n_train=500 * 4,
    rpn_post_nms_top_n_test=250 * 4,
    rpn_nms_thresh=0.7,
    rpn_fg_iou_thresh=0.7,
    rpn_bg_iou_thresh=0.3,
    rpn_batch_size_per_image=256,
    rpn_positive_fraction=0.5,
    rpn_score_thresh=0.0,
    
    # Box parameters ---------------
    box_roi_pool=box_roi_pool,
    box_head=box_head,
    
    # must be None when num_classes specified
    box_predictor=None,
    
    box_score_thresh=0.05,
    box_nms_thresh=0.5,
    box_detections_per_img=100,
    box_fg_iou_thresh=0.5,
    box_bg_iou_thresh=0.5,
    box_batch_size_per_image=512,
    box_positive_fraction=0.25,
    bbox_reg_weights=None,
    
    # Mask parameters --------------------------
    mask_roi_pool=mask_roi_pool,
    mask_head=mask_head,
    
    # must be None when num_classes specified
    mask_predictor=None
)

In [45]:
!wget https://download.pytorch.org/models/maskrcnn_resnet50_fpn_v2_coco-73cbd019.pth

--2023-05-06 23:46:31--  https://download.pytorch.org/models/maskrcnn_resnet50_fpn_v2_coco-73cbd019.pth
Resolving download.pytorch.org (download.pytorch.org)... 13.33.243.80, 13.33.243.23, 13.33.243.91, ...
Connecting to download.pytorch.org (download.pytorch.org)|13.33.243.80|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 185828065 (177M) [application/x-www-form-urlencoded]
Saving to: ‘maskrcnn_resnet50_fpn_v2_coco-73cbd019.pth’


2023-05-06 23:48:03 (1.93 MB/s) - ‘maskrcnn_resnet50_fpn_v2_coco-73cbd019.pth’ saved [185828065/185828065]



In [46]:
weights = torch.load('maskrcnn_resnet50_fpn_v2_coco-73cbd019.pth')

In [47]:
model_keys = set(model.state_dict().keys())
weight_keys = set(weights.keys())

model_shapes = {key : value.shape for key, value in model.state_dict().items()}
weight_shapes = {key : value.shape for key, value in weights.items()}

mapping = []

for key in model_keys:
    if not (key.split('.')[0] == 'backbone' and key.split('.')[1] == 'body') and key in weight_keys and weight_shapes[key] == model_shapes[key]:
        mapping.append(key)

In [48]:
model_dict = model.state_dict()
for key in mapping:
    model_dict[key] = weights[key]

model.load_state_dict(model_dict)

<All keys matched successfully>

In [49]:
# api = wandb.Api()
# model_weights = api.artifact('hitogamiag/test-logger/chkp_mrcnn_10:v10')
# model_weights.download()

# model_weights = torch.load('artifacts/' + model_weights.name + '/chkp_mrcnn_10.pth')
# model_weights = {key.replace('module.', '') : value for key, value in model_weights.items()}

# model.load_state_dict(model_weights)

# Utils

In [50]:
def checkpoint(checkpoint_path, model_name):
    if not checkpoint_path.is_dir():
        checkpoint_path.mkdir()
    
    # Save trained model weights
    save_path = checkpoint_path / (model_name + '.pth')
    torch.save(model.state_dict(), save_path)

    # Upload them on wandb
    artifact = wandb.Artifact(model_name, type='checkpoint')
    artifact.add_file(save_path)
    wandb.log_artifact(artifact)
    print(f'Logged {model_name}')

# Training

In [51]:
#optimizer = torch.optim.SGD(params=model.parameters(), lr=0.0002, momentum=0.9, weight_decay=0.00001)

max_num_of_iters = int(wandb.config['n_epochs'] * train_dataset.dataframe.shape[0] / wandb.config['batch_size'])

#scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_num_of_iters, eta_min=0)
optimizer = torch.optim.AdamW(params=model.parameters(), lr=0.002, weight_decay=0.00001)
#scheduler = torch.optim.lr_scheduler.StepLAR(optimizer, step_size=3, gamma=0.1, verbose=False)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_num_of_iters, eta_min=0.0002)

scaler = torch.cuda.amp.GradScaler()

In [52]:
# Initialize optimizer aaggregatend lr scheduler

# Head - 5 epochs; resnet4+ 10 epochs; all - 15 epochs; With lr decresed by 10 every stage
# Another is try to train only Head for 15 epochs
model = nn.DataParallel(model).to(f'{wandb.config["device"]}', non_blocking=True)

send_train_logs_each_n_batches = 10

# Create new loggers
# Cannot calculate mAPs for train because of dropout layers & batchnorm
train_logger = Evaluator(prefix='train_', metrics_to_log={
            'loss_classifier' : 0,
            'loss_box_reg' : 0,
            'loss_mask' : 0,
            'loss_objectness' : 0,
            'loss_rpn_box_reg' : 0,
            'total_loss' : 0
        })

val_logger = Evaluator(prefix='val_', metrics_to_log={
            'loss_classifier' : 0,
            'loss_box_reg' : 0,
            'loss_mask' : 0,
            'loss_objectness' : 0,
            'loss_rpn_box_reg' : 0,
            'total_loss' : 0
        })

# Visualizer
visualizer = Visualizer(val_df, masks)

start_time = time.time()

model.train()

# Note: we don't use dataloader because it's incorrectly working with dictionaries (and our "y" is list of dictionaries)
# During the epoch, the model will be evaluated approx "validate_n_times_per_epoch" times
train_send_logs_each_n_iters = 5

validate_n_times_per_epoch = 5
validate_each_n_iters = round(train_dataset.dataframe.shape[0] / train_dataset.batch_size / validate_n_times_per_epoch)

# Visualize n times per epoch
vizualize_n_times_per_epoch = 3
vizualize_each_n_iters = round(train_dataset.dataframe.shape[0] / train_dataset.batch_size / vizualize_n_times_per_epoch)

for i in range(wandb.config['n_epochs']):
    wandb.log({'epoch' : i+1})
    train_dataset.reset_batch()
    while not train_dataset.is_end():
        X, y = train_dataset.__getitem__()
        if isinstance(X, NoneType): break;
        
        # TODO: call func to get predictions and losses at once
        with torch.cuda.amp.autocast():
            loss_dict = model(X, y)
        
        optimizer.zero_grad()
        losses = sum(loss for loss in loss_dict.values()).sum()
        scaler.scale(losses).backward()
        scaler.step(optimizer)
        scaler.update()
        
        scheduler.step()
        
        #warmup.step()
        # TODO: Calculate APs and mAP
        # Log iteration results
        train_logger.log({
            'loss_classifier' : loss_dict['loss_classifier'].sum().item(),
            'loss_box_reg' : loss_dict['loss_box_reg'].sum().item(),
            'loss_mask' : loss_dict['loss_mask'].sum().item(),
            'loss_objectness' : loss_dict['loss_objectness'].sum().item(),
            'loss_rpn_box_reg' : loss_dict['loss_rpn_box_reg'].sum().item(),
            'total_loss' : losses.sum().item()
        })
        
        if train_logger.update_counter == train_send_logs_each_n_iters:
            wandb.log({'lr' : optimizer.param_groups[0]['lr']})
            train_logger.send_logs()
        
        # Visualize model each N iterations
        if (train_dataset.batch_index + 1) % vizualize_each_n_iters == 0:
            visualizer.visualize(model)
        
        # Validate model each N iterations
        if (train_dataset.batch_index + 1) % validate_each_n_iters == 0:
                
            val_dataset.reset_batch()
            while not val_dataset.is_end():
                X, y = val_dataset.__getitem__()
                if isinstance(X, NoneType): break;
                
                # TODO: call func to get predictions and losses at once
                with torch.cuda.amp.autocast():
                    loss_dict = model(X, y)
                
                losses = sum(loss for loss in loss_dict.values()).sum()

                # TODO: Calculate APs and mAP
                # Log validation results
                val_logger.log({
                    'loss_classifier' : loss_dict['loss_classifier'].sum().item(),
                    'loss_box_reg' : loss_dict['loss_box_reg'].sum().item(),
                    'loss_mask' : loss_dict['loss_mask'].sum().item(),
                    'loss_objectness' : loss_dict['loss_objectness'].sum().item(),
                    'loss_rpn_box_reg' : loss_dict['loss_rpn_box_reg'].sum().item(),
                    'total_loss' : losses.sum().item()
                })
            val_logger.send_logs()
    
    train_logger.send_logs()
    visualizer.visualize(model)
    
    checkpoint(checkpoint_path, f'chkp_mrcnn_{i+1}')
    
finish_time = time.time()

OutOfMemoryError: CUDA out of memory. Tried to allocate 172.00 MiB (GPU 0; 3.95 GiB total capacity; 3.15 GiB already allocated; 19.69 MiB free; 3.23 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [53]:
wandb.finish()

0,1
epoch,▁

0,1
epoch,1
