In [1]:
import math
import pdb

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from torch.autograd import Variable
import numpy as np
import pandas as pd
from tqdm import tqdm 
import pickle
import time
import random

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  
print("Use device:",device)

Use device: cuda:0


In [3]:
batch_size = 8
lr = 0.001
pretrain = False
mode = 'train'
epoch = 20
pid_range = [251,401]
train_range, valid_range, test_range = [251,351], [351,376], [376,401]
num_c = 1
no_nodule_patch_use_ratio = 0.05

rpn_pos_per_batch = []
rpn_neg_per_batch = []
val_rpn_pos_per_batch = []
val_rpn_neg_per_batch = []

# DataLoader

In [4]:

data_path = '/kaggle/input/lidcidri-250-500-continuous/data_continuous/data_continuous' 
pkl_path = '/kaggle/input/lidcidri-250-500-continuous' 
meta = pd.read_csv(f'{data_path}/Meta/meta_info.csv')
bbox_file = open(f'{pkl_path}/bbox.pkl', 'rb')
bbox_list = pickle.load(bbox_file)
all_data = {}

print("loading nodule slices...")
time.sleep(1)
for i in tqdm(range(len(meta.index))):
    
    pid = meta.iloc[i]['patient_id']
    if pid_range[0] <= pid < pid_range[1]:
        if pid not in all_data: all_data[pid] = {}

        slice = int(meta.iloc[i]['original_image'][-3:])
        all_data[pid][slice] = {}
        if meta.iloc[i]['is_clean'] == False: # nodule
            img = np.load(f'{data_path}/Image/LIDC-IDRI-' + meta.iloc[i]['original_image'][:4]\
                           + '/' + meta.iloc[i]['original_image'] + '.npy') 
        img = np.clip(img, -1200,1200)
        img = torch.Tensor(img)
        img = img.to(torch.float32) 
        min_value = img.min()
        max_value = img.max()    

        shifted_tensor = img - min_value   
        img = shifted_tensor / (max_value - min_value)

        all_data[pid][slice]["img"] = img
        all_data[pid][slice]["bbox"] = np.array(bbox_list[i])

print("loading other continuous slices...")
time.sleep(1)
for p in tqdm(os.listdir(f'{data_path}/Image')): # others
    pid = int(p[-4:])
    if pid_range[0] <= pid < pid_range[1]:
        for s in os.listdir(f'{data_path}/Image/{p}'):
            slice = int(s[-7:-4])
            if slice not in all_data[pid] and (slice+1 in all_data[pid] or slice-1 in all_data[pid]):
                all_data[pid][slice] = {}
                img = np.load(f'{data_path}/Image/{p}/{s}') 
                img = np.clip(img, -1200, 1200)
                img = torch.Tensor(img)
                img = img.to(torch.float32) 
                min_value = img.min()
                max_value = img.max()    

                shifted_tensor = img - min_value   
                img = shifted_tensor / (max_value - min_value)
                all_data[pid][slice]["img"] = img
                all_data[pid][slice]["bbox"] = np.array([[0,0,0,0]])

loading nodule slices...


100%|██████████| 6185/6185 [00:26<00:00, 230.82it/s] 


loading other continuous slices...


100%|██████████| 212/212 [00:06<00:00, 30.81it/s]


In [5]:
print("preparing training/validation/testing patches...")
time.sleep(1)
train_img, train_bbox, train_label = [], [], []
valid_img, valid_bbox, valid_label = [], [], []
test_img, test_bbox, test_label = [], [], []

for patient in tqdm(all_data.keys()):
    slice_list = list(sorted(all_data[patient].keys()))[num_c:-num_c]
    for slice in slice_list:
        if (all_data[patient][slice]["bbox"] != [[0,0,0,0]]).all(): # 挑 nodule 的出來切 patch
            slice_concat = range(slice-num_c, slice+num_c+1)
            imgs_num_c = np.stack(\
                [all_data[patient][s]["img"] for s in slice_concat], axis=0)

            # print(patient, slice, all_data[patient][slice]["bbox"])
            for h in range(0, 512-64, 64):
                for w in range(0, 512-64, 64):
                    patch_img = imgs_num_c[:, h:h+128, w:w+128]
                    patch_bbox = all_data[patient][slice]["bbox"] -\
                          np.tile([w,h,w,h], (len(all_data[patient][slice]["bbox"]), 1))
                    for b_idx, box in enumerate(patch_bbox):
                        if ((box[0] < 0 or box[0] >= 127) or (box[1] < 0 or box[1] >= 127))\
                            and ((box[2] <= 0 or box[2] > 127) or (box[3] <= 0 or box[3] > 127)):
                            patch_bbox[b_idx] = [0,0,0,0]
                        else:
                            patch_bbox[b_idx] = np.clip(box, 0, 127)

                    use = True
                    if all((bbox == [0,0,0,0]).all() for bbox in patch_bbox):
                        patch_bbox = [[0,0,0,0]]
                        label = torch.zeros(1).type(torch.int64)
                        if random.random() >= no_nodule_patch_use_ratio: use = False
                    else: 
                        mask = np.any(patch_bbox != [0,0,0,0], axis=1)
                        patch_bbox = patch_bbox[mask]
                        label = torch.ones(len(patch_bbox)).type(torch.int64)

                    if use == True:
                        if train_range[0] <= patient < train_range[1]:
                            train_img.append(patch_img)
                            train_bbox.append(patch_bbox)
                            train_label.append(label)
                    if valid_range[0] <= patient < valid_range[1]:
                        valid_img.append(patch_img)
                        valid_bbox.append(patch_bbox)
                        valid_label.append(label)
                    elif test_range[0] <= patient < test_range[1]:
                        test_img.append(patch_img)
                        test_bbox.append(patch_bbox)
                        test_label.append(label)

train_img = np.array(train_img)
valid_img = np.array(valid_img)
test_img = np.array(test_img)

del all_data

preparing training/validation/testing patches...


100%|██████████| 150/150 [00:04<00:00, 31.76it/s]


In [6]:
print(len(train_img[:-1]), len(valid_img), len(test_img), len(train_bbox[:-2]), len(valid_bbox), len(test_bbox), len(train_label), len(valid_label), len(test_label))

3910 6076 9800 3909 6076 9800 3911 6076 9800


In [7]:
import torch
from torchvision.transforms import ToTensor
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class FasterRCNNDataset(Dataset):
    def __init__(self, image_list, bboxes_list, labels_list, transform=None):
        self.image_list = image_list
        self.bboxes_list = bboxes_list
        self.labels_list = labels_list
        self.transform = transform

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        image = self.image_list[idx]
        bboxes = self.bboxes_list[idx]
        labels = self.labels_list[idx]

        image = torch.tensor(image)
#         image = torch.unsqueeze(image, 0)
        if self.transform is not None:
            image = self.transform(image)
        bboxes = torch.tensor(bboxes)
        labels = labels

        return image, bboxes, labels


def collate_fn(batch):
    image_list = []
    bboxes_list = []
    labels_list = []
    for item in batch:
        image_list.append(item[0])
        bboxes_list.append(item[1])
        labels_list.append(item[2])

    # Pad the lists of bounding boxes with -1
    max_num_bboxes = max(len(bboxes) for bboxes in bboxes_list)
    padded_bboxes_list = []
    for bboxes in bboxes_list:
        padded_bboxes = torch.cat((bboxes, torch.tensor([[-1, -1, -1, -1]]).expand((max_num_bboxes - len(bboxes)), -1)), dim=0)
        padded_bboxes_list.append(padded_bboxes)

    # Pad label
    max_num_labels = max_num_bboxes
    padded_labels_list = []
    for labels in labels_list:
        padded_labels = torch.cat((labels, torch.zeros(max_num_labels - len(labels)).type(torch.int64)), dim=0)
        padded_labels_list.append(padded_labels)
    # Convert images, bboxes, and labels to tensors
    image_list = torch.stack(image_list)
    padded_bboxes_list = torch.stack(padded_bboxes_list)
    padded_labels_list = torch.stack(padded_labels_list)

    return image_list, padded_bboxes_list, padded_labels_list

# transform  = transforms.Resize([600,600]) 
train_dataset = FasterRCNNDataset(train_img, train_bbox, train_label)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)

valid_dataset = FasterRCNNDataset(valid_img, valid_bbox, valid_label)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)
valid_iou_dataloader = DataLoader(valid_dataset, batch_size=1, collate_fn=collate_fn, shuffle=False)

test_dataset = FasterRCNNDataset(test_img, test_bbox, test_label)
test_dataloader = DataLoader(test_dataset, batch_size=1, collate_fn=collate_fn, shuffle=False)

for images, bboxes, labels in train_dataloader:
    print("Images shape:", images.shape)
    print("Bounding boxes shape:", bboxes.shape)
    print("Labels shape:", labels.shape, labels)
    break


Images shape: torch.Size([8, 3, 128, 128])
Bounding boxes shape: torch.Size([8, 2, 4])
Labels shape: torch.Size([8, 2]) tensor([[1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 1],
        [1, 0],
        [1, 0],
        [1, 0]])


# Train

In [8]:
import torch
import torch.nn.functional as F
from torch import nn

def custom_rpn_cls_loss(logits, labels, num_hard=2):
        
    classify_loss = nn.BCELoss()
    probs = torch.sigmoid(logits)[:] #.view(-1, 1)
    pos_idcs = labels[:] == 1
    pos_prob = probs[pos_idcs]
    pos_labels = labels[pos_idcs]

    neg_idcs = labels[:] == 0
    neg_prob = probs[neg_idcs]
    neg_labels = labels[neg_idcs]
    p_indices = torch.where(labels == 1)
    n_indices = torch.where(labels == 0)
#     print('NUM', p_indices[0].shape, n_indices[0].shape)
    if mode == 'valid': 
        num_hard=1000000000
        val_rpn_pos_per_batch.append(p_indices[0].shape)
        val_rpn_neg_per_batch.append(n_indices[0].shape) 
    else :
        rpn_pos_per_batch.append(p_indices[0].shape)
        rpn_neg_per_batch.append(n_indices[0].shape)
    
    if num_hard > 0:
        neg_prob, neg_labels = OHEM(neg_prob, neg_labels, num_hard * len(pos_prob))

    pos_correct = 0
    pos_total = 0
    if len(pos_prob) > 0:
        cls_loss = 0.5 * classify_loss(
            pos_prob, pos_labels.float()) + 0.5 * classify_loss(
            neg_prob, neg_labels.float())
        pos_correct = (pos_prob >= 0.5).sum()
        pos_total = len(pos_prob)
    else:
        cls_loss = 0.5 * classify_loss(
            neg_prob, neg_labels.float())


    neg_correct = (neg_prob < 0.5).sum()
    neg_total = len(neg_prob)
    return cls_loss #pos_correct, pos_total, neg_correct, neg_total


def OHEM(neg_output, neg_labels, num_hard):
    _, idcs = torch.topk(neg_output, min(num_hard, len(neg_output)))
    neg_output = torch.index_select(neg_output, 0, idcs)
    neg_labels = torch.index_select(neg_labels, 0, idcs)
    return neg_output, neg_labels

In [9]:
def custom_rcnn_loss(class_logits, box_regression, labels, regression_targets):
    # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
    """
    Computes the loss for Faster R-CNN.

    Args:
        class_logits (Tensor)
        box_regression (Tensor)
        labels (list[BoxList])
        regression_targets (Tensor)

    Returns:
        classification_loss (Tensor)
        box_loss (Tensor)
    """

    labels = torch.cat(labels, dim=0)
    regression_targets = torch.cat(regression_targets, dim=0)
    
    batch_size, num_class = class_logits.shape[:2]
#     print('Samuel')
    weight = torch.ones(num_class)
    if torch.cuda.is_available():
        weight = weight.cuda()
#     print('Samuel222')
    total = len(labels)
#     print('Samuel3333')
    for i in range(num_class):
        num_pos = float((labels == i).sum())
        num_pos = max(num_pos, 1)
        weight[i] = total / num_pos
    
#     print('Samuel444')
    weight = weight / weight.sum()
#     print('Samuel5555')

    classification_loss = F.cross_entropy(class_logits, labels, weight=weight)

    # get indices that correspond to the regression targets for
    # the corresponding ground truth labels, to be used with
    # advanced indexing
    sampled_pos_inds_subset = torch.where(labels > 0)[0]
    labels_pos = labels[sampled_pos_inds_subset]
    N, num_classes = class_logits.shape
    box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)

    box_loss = F.smooth_l1_loss(
        box_regression[sampled_pos_inds_subset, labels_pos],
        regression_targets[sampled_pos_inds_subset],
        beta=1 / 9,
        reduction="sum",
    )
    box_loss = box_loss / labels.numel()

    return classification_loss, box_loss

def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
    # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
    """
    Computes the loss for Faster R-CNN.

    Args:
        class_logits (Tensor)
        box_regression (Tensor)
        labels (list[BoxList])
        regression_targets (Tensor)

    Returns:
        classification_loss (Tensor)
        box_loss (Tensor)
    """

    labels = torch.cat(labels, dim=0)
    regression_targets = torch.cat(regression_targets, dim=0)

    classification_loss = F.cross_entropy(class_logits, labels)

    # get indices that correspond to the regression targets for
    # the corresponding ground truth labels, to be used with
    # advanced indexing
    sampled_pos_inds_subset = torch.where(labels > 0)[0]
    labels_pos = labels[sampled_pos_inds_subset]
    N, num_classes = class_logits.shape
    box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)

    box_loss = F.smooth_l1_loss(
        box_regression[sampled_pos_inds_subset, labels_pos],
        regression_targets[sampled_pos_inds_subset],
        beta=1 / 9,
        reduction="sum",
    )
    box_loss = box_loss / labels.numel()

    return classification_loss, box_loss

In [10]:
from torchvision.models.detection.roi_heads import RoIHeads
from torchvision.models.detection import _utils as det_utils

class CustomRoIHead(RoIHeads):
    __annotations__ = {
        "box_coder": det_utils.BoxCoder,
        "proposal_matcher": det_utils.Matcher,
        "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
    }

    def __init__(
        self,
        box_roi_pool,
        box_head,
        box_predictor,
        # Faster R-CNN training
        fg_iou_thresh,
        bg_iou_thresh,
        batch_size_per_image,
        positive_fraction,
        bbox_reg_weights,
        # Faster R-CNN inference
        score_thresh,
        nms_thresh,
        detections_per_img,
        # Mask
        mask_roi_pool=None,
        mask_head=None,
        mask_predictor=None,
        keypoint_roi_pool=None,
        keypoint_head=None,
        keypoint_predictor=None,
    ):
        super().__init__(box_roi_pool,
        box_head,
        box_predictor,
        # Faster R-CNN training
        fg_iou_thresh,
        bg_iou_thresh,
        batch_size_per_image,
        positive_fraction,
        bbox_reg_weights,
        # Faster R-CNN inference
        score_thresh,
        nms_thresh,
        detections_per_img,
        # Mask
        mask_roi_pool,
        mask_head,
        mask_predictor,
        keypoint_roi_pool,
        keypoint_head,
        keypoint_predictor,
        )
        
    def forward(
        self,
        features,  # type: Dict[str, Tensor]
        proposals,  # type: List[Tensor]
        image_shapes,  # type: List[Tuple[int, int]]
        targets=None,  # type: Optional[List[Dict[str, Tensor]]]
    ):
        # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
        """
        Args:
            features (List[Tensor])
            proposals (List[Tensor[N, 4]])
            image_shapes (List[Tuple[H, W]])
            targets (List[Dict])
        """
        if targets is not None:
            for t in targets:
                # TODO: https://github.com/pytorch/pytorch/issues/26731
                floating_point_types = (torch.float, torch.double, torch.half)
                if not t["boxes"].dtype in floating_point_types:
                    raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}")
                if not t["labels"].dtype == torch.int64:
                    raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
                if self.has_keypoint():
                    if not t["keypoints"].dtype == torch.float32:
                        raise TypeError(f"target keypoints must of float type, instead got {t['keypoints'].dtype}")

        if self.training:
            proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
        else:
            labels = None
            regression_targets = None
            matched_idxs = None

        box_features = self.box_roi_pool(features, proposals, image_shapes)
        box_features = self.box_head(box_features)
        class_logits, box_regression = self.box_predictor(box_features)

        result: List[Dict[str, torch.Tensor]] = []
        losses = {}
        if self.training:
            if labels is None:
                raise ValueError("labels cannot be None")
            if regression_targets is None:
                raise ValueError("regression_targets cannot be None")
            loss_classifier, loss_box_reg = custom_rcnn_loss(class_logits, box_regression, labels, regression_targets)
            losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
        else:
            boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
            num_images = len(boxes)
            for i in range(num_images):
                result.append(
                    {
                        "boxes": boxes[i],
                        "labels": labels[i],
                        "scores": scores[i],
                    }
                )

        if self.has_mask():
            mask_proposals = [p["boxes"] for p in result]
            if self.training:
                if matched_idxs is None:
                    raise ValueError("if in training, matched_idxs should not be None")

                # during training, only focus on positive boxes
                num_images = len(proposals)
                mask_proposals = []
                pos_matched_idxs = []
                for img_id in range(num_images):
                    pos = torch.where(labels[img_id] > 0)[0]
                    mask_proposals.append(proposals[img_id][pos])
                    pos_matched_idxs.append(matched_idxs[img_id][pos])
            else:
                pos_matched_idxs = None

            if self.mask_roi_pool is not None:
                mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
                mask_features = self.mask_head(mask_features)
                mask_logits = self.mask_predictor(mask_features)
            else:
                raise Exception("Expected mask_roi_pool to be not None")

            loss_mask = {}
            if self.training:
                if targets is None or pos_matched_idxs is None or mask_logits is None:
                    raise ValueError("targets, pos_matched_idxs, mask_logits cannot be None when training")

                gt_masks = [t["masks"] for t in targets]
                gt_labels = [t["labels"] for t in targets]
                rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs)
                loss_mask = {"loss_mask": rcnn_loss_mask}
            else:
                labels = [r["labels"] for r in result]
                masks_probs = maskrcnn_inference(mask_logits, labels)
                for mask_prob, r in zip(masks_probs, result):
                    r["masks"] = mask_prob

            losses.update(loss_mask)

        # keep none checks in if conditional so torchscript will conditionally
        # compile each branch
        if (
            self.keypoint_roi_pool is not None
            and self.keypoint_head is not None
            and self.keypoint_predictor is not None
        ):
            keypoint_proposals = [p["boxes"] for p in result]
            if self.training:
                # during training, only focus on positive boxes
                num_images = len(proposals)
                keypoint_proposals = []
                pos_matched_idxs = []
                if matched_idxs is None:
                    raise ValueError("if in trainning, matched_idxs should not be None")

                for img_id in range(num_images):
                    pos = torch.where(labels[img_id] > 0)[0]
                    keypoint_proposals.append(proposals[img_id][pos])
                    pos_matched_idxs.append(matched_idxs[img_id][pos])
            else:
                pos_matched_idxs = None

            keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes)
            keypoint_features = self.keypoint_head(keypoint_features)
            keypoint_logits = self.keypoint_predictor(keypoint_features)

            loss_keypoint = {}
            if self.training:
                if targets is None or pos_matched_idxs is None:
                    raise ValueError("both targets and pos_matched_idxs should not be None when in training mode")

                gt_keypoints = [t["keypoints"] for t in targets]
                rcnn_loss_keypoint = keypointrcnn_loss(
                    keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
                )
                loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
            else:
                if keypoint_logits is None or keypoint_proposals is None:
                    raise ValueError(
                        "both keypoint_logits and keypoint_proposals should not be None when not in training mode"
                    )

                keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals)
                for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result):
                    r["keypoints"] = keypoint_prob
                    r["keypoints_scores"] = kps
            losses.update(loss_keypoint)

        return result, losses

In [11]:
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.rpn import AnchorGenerator
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.models.detection.transform import GeneralizedRCNNTransform
from torch import Tensor
from typing import Tuple, List, Dict
import torchvision.ops as ops
from torchvision.models.detection.faster_rcnn import TwoMLPHead

# model = models.detection.fasterrcnn_resnet50_fpn(pretrained=pretrain)
model = models.detection.fasterrcnn_resnet50_fpn()

model.transform = GeneralizedRCNNTransform(128,128,[0, 0, 0], [1, 1, 1])
model.backbone.body.conv1 = nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)


# rpn replace 
from torchvision.models.detection.rpn import RegionProposalNetwork
from torchvision.models.detection.rpn import RPNHead
class CustomRPN(RegionProposalNetwork):
    def __init__(
        self,
        anchor_generator: AnchorGenerator,
        head: nn.Module,
        # Faster-RCNN Training
        fg_iou_thresh: float,
        bg_iou_thresh: float,
        batch_size_per_image: int,
        positive_fraction: float,
        # Faster-RCNN Inference
        pre_nms_top_n: Dict[str, int],
        post_nms_top_n: Dict[str, int],
        nms_thresh: float,
        score_thresh: float = 0.0,
    ) -> None:
        super().__init__(anchor_generator, head, fg_iou_thresh, bg_iou_thresh, batch_size_per_image, positive_fraction, pre_nms_top_n, post_nms_top_n, nms_thresh, score_thresh)

    def compute_loss(
        self, objectness: Tensor, pred_bbox_deltas: Tensor, labels: List[Tensor], regression_targets: List[Tensor]
    ) -> Tuple[Tensor, Tensor]:
        """
        Args:
            objectness (Tensor)
            pred_bbox_deltas (Tensor)
            labels (List[Tensor])
            regression_targets (List[Tensor])

        Returns:
            objectness_loss (Tensor)
            box_loss (Tensor)
        """

        sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
#         print('sampled_pos_inds', type(sampled_pos_inds[0]), sampled_pos_inds[0].shape)
#         print('sampled_neg_inds', type(sampled_neg_inds[0]), sampled_neg_inds[0].shape)
        
        sampled_pos_inds = torch.where(torch.cat(sampled_pos_inds, dim=0))[0]
        sampled_neg_inds = torch.where(torch.cat(sampled_neg_inds, dim=0))[0]

        sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)

        objectness = objectness.flatten()

        labels = torch.cat(labels, dim=0)
        regression_targets = torch.cat(regression_targets, dim=0)

        box_loss = F.smooth_l1_loss(
            pred_bbox_deltas[sampled_pos_inds],
            regression_targets[sampled_pos_inds],
            beta=1 / 9,
            reduction="sum",
        ) / (sampled_inds.numel())
        objectness_loss = custom_rpn_cls_loss(objectness[sampled_inds], labels[sampled_inds]) # F.binary_cross_entropy_with_logits

        return objectness_loss, box_loss
aspect_ratios = [(0.5, 1.0, 2.0), (0.5, 1.0, 2.0), (0.5, 1.0, 2.0), (0.5, 1.0, 2.0), (0.5, 1.0, 2.0)]
sizes = ((4,), (8,), (16,), (32,), (64,)) 
model.rpn =  CustomRPN(AnchorGenerator(sizes=sizes, aspect_ratios=aspect_ratios), RPNHead(256, 15), 0.7, 0.3, 256, 0.5, {'training': 12000, 'testing': 3000}, {'training': 600, 'testing': 300}, 0.7)

in_feature = model.roi_heads.box_predictor.cls_score.in_features
# model.roi_heads.box_predictor =FastRCNNPredictor(in_feature, 2)
model.roi_heads = CustomRoIHead(ops.MultiScaleRoIAlign(featmap_names=['0', '1', '2', '3'], output_size=(7, 7), sampling_ratio=2),
                                TwoMLPHead(12544, 1024),
                                FastRCNNPredictor(in_feature, 2),
                                fg_iou_thresh = 0.5,
                                bg_iou_thresh = 0.5,
                                score_thresh = 0.05,
                                nms_thresh = 0.5,
                                detections_per_img = 100,
                                batch_size_per_image = 512,
                                positive_fraction = 0.25,
                                bbox_reg_weights = None,
                                )
# 'box_roi_pool', 'box_head', 'box_predictor', 'fg_iou_thresh', 
# 'bg_iou_thresh', 'batch_size_per_image', 'positive_fraction', 'bbox_reg_weights', 
# 'score_thresh', 'nms_thresh', and 'detections_per_img'

# print(model)

def bbox_iou(bbox_a, bbox_b):
    if bbox_a.shape[1]!=4 or bbox_b.shape[1]!=4:
        raise IndexError
    tl = np.maximum(bbox_a[:, None, :2], bbox_b[:, :2])
    br = np.minimum(bbox_a[:, None, 2:], bbox_b[:, 2:])

    area_i = np.prod(br-tl, axis=2) * (tl<br).all(axis=2)
    area_a = np.prod(bbox_a[:, 2:] - bbox_a[:, :2], axis=1)
    area_b = np.prod(bbox_b[:, 2:] - bbox_b[:, :2], axis=1)

    return area_i / (area_a[:, None] + area_b - area_i)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 207MB/s]


In [12]:
ious = bbox_iou(np.array([[1,1,6,6], [5,1,10,6]]), np.array([[4,4,8,8],[9,9,10,10]]))
print('ious', ious)
max_ious = np.max(ious, axis=1)
print('max_ious', type(max_ious))

ious [[0.10810811 0.        ]
 [0.17142857 0.        ]]
max_ious <class 'numpy.ndarray'>


In [13]:
# optimizer = torch.optim.SGD(model.parameters(),  lr=0.0001, momentum=0.9)   # optimize all cnn parameters
# torch.optim.Adam(model.parameters(), lr=1e-4, betas = (0.9, 0.999), weight_decay=5e-4)
if torch.cuda.is_available():
    model.cuda()
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas = (0.9, 0.999), weight_decay=5e-4)
optimizer = torch.optim.SGD(model.parameters(),  lr=lr, momentum=0.9, weight_decay=0.0005)   # optimize all cnn parameters
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=1,gamma=0.95)
# Freeze_Epoch = 50
Unfreeze_Epoch = epoch
best_valid_loss_avg = np.inf

# Unfreezed training and validating

print('start training (unfreeze ResNet50)')

for epoch in tqdm(range(Unfreeze_Epoch)):
#     total_loss = np.array([])
    rpn_loc_loss = []
    rpn_cls_loss = []
    roi_loc_loss = []
    roi_cls_loss = []
    val_rpn_loc_loss = []
    val_rpn_cls_loss = []
    val_roi_loc_loss = []
    val_roi_cls_loss = []
#     val_toal_loss = np.array([])

    trainlosslist = []
    validlosslist = []
#     loss_weight = [100,1,100,10]
    
    rpn_pos_per_batch = []
    rpn_neg_per_batch = []
    val_rpn_pos_per_batch = []
    val_rpn_neg_per_batch = []

    # train
    model.train()
    for step, (imgs, bboxes, labels) in enumerate(train_dataloader):
        if torch.cuda.is_available():
            imgs = imgs.cuda()
            bboxes = bboxes.cuda()
            labels = labels.cuda()
        targets = []
        mode = 'train'
        for i in range(imgs.shape[0]):
            bbox = bboxes[i]
            label = labels[i]
            try:
                label_num = torch.nonzero(label == 0)[0][0]
            except:
                label_num = label.shape[0]
            label = label[:label_num]
            bbox = bbox[:label_num]
            d = {}
            d['boxes'] = bbox
            d['labels'] = label
            targets.append(d)
        output = model(imgs, targets)
        optimizer.zero_grad()
        losses = output['loss_rpn_box_reg']*100 + output['loss_objectness']*1 + output['loss_box_reg']*100 + output['loss_classifier']*1
        losses.backward()
        optimizer.step()
        rpn_loc_loss.append(output['loss_rpn_box_reg'].item())
        rpn_cls_loss.append(output['loss_objectness'].item())
        roi_loc_loss.append(output['loss_box_reg'].item())
        roi_cls_loss.append(output['loss_classifier'].item())
        trainlosslist.append(losses.item())

    # validate
    with torch.no_grad():
        for step, (imgs, bboxes, labels) in enumerate(valid_dataloader):
            if torch.cuda.is_available():
                imgs = imgs.cuda()
                bboxes = bboxes.cuda()
                labels = labels.cuda()
            targets = []
            mode = 'valid'
            for i in range(imgs.shape[0]):
                bbox = bboxes[i]
                label = labels[i]
                try:
                    label_num = torch.nonzero(label == 0)[0][0]
                except:
                    label_num = label.shape[0]
                label = label[:label_num]
                bbox = bbox[:label_num]
                d = {}
                d['boxes'] = bbox
                d['labels'] = label
                targets.append(d)    
            output = model(imgs, targets) 
            losses = output['loss_rpn_box_reg']*100 + output['loss_objectness']*1 + output['loss_box_reg']*100 + output['loss_classifier']*1
            val_rpn_loc_loss.append(output['loss_rpn_box_reg'].item())
            val_rpn_cls_loss.append(output['loss_objectness'].item())
            val_roi_loc_loss.append(output['loss_box_reg'].item())
            val_roi_cls_loss.append(output['loss_classifier'].item())
            validlosslist.append(losses.item())
    # validate iou
    model.eval()
    true_pos = 0
    true_neg = 0
    false_pos = 0
    false_neg = 0
    iou_array = np.array([])
    for step, (imgs, bboxes, labels) in enumerate(valid_iou_dataloader):
        if torch.cuda.is_available():
            imgs = imgs.cuda()
            bboxes = bboxes.cuda()
            labels = labels.cuda()
        bbox, label = 0, 0
        mode = 'valid'
        for i in range(imgs.shape[0]):
            bbox = bboxes[i]
            label = labels[i]
            try:
                label_num = torch.nonzero(label == 0)[0][0]
            except:
                label_num = label.shape[0]
            label = label[:label_num]
            bbox = bbox[:label_num]
        output = model(imgs)
        indices = torch.where(output[0]['scores'] > 0.5)[0]
        if indices.shape[0] > 0:
            if label.shape[0] > 0:
                true_pos += 1
                ious = bbox_iou(output[0]['boxes'][indices].cpu().detach().numpy(), bbox.cpu().numpy())
                max_ious = np.max(ious, axis=1)
                iou_array = np.append(iou_array, max_ious)
            else :
                false_pos += 1
        else :
            if label.shape[0] > 0:
                false_neg += 1
            else :
                true_neg += 1

    
    train_loss_avg = np.mean(trainlosslist)
    rpn_loc_loss_avg = np.mean(rpn_loc_loss)
    rpn_cls_loss_avg = np.mean(rpn_cls_loss)
    roi_loc_loss_avg = np.mean(roi_loc_loss)
    roi_cls_loss_avg = np.mean(roi_cls_loss)
    valid_loss_avg = np.mean(validlosslist)
    val_rpn_loc_loss_avg = np.mean(val_rpn_loc_loss)
    val_rpn_cls_loss_avg = np.mean(val_rpn_cls_loss)
    val_roi_loc_loss_avg = np.mean(val_roi_loc_loss)
    val_roi_cls_loss_avg = np.mean(val_roi_cls_loss)

    print('===================RPN========================')
    print('rpn_pos_per_batch', np.mean(rpn_pos_per_batch), len(rpn_pos_per_batch))
    print('rpn_neg_per_batch', np.mean(rpn_neg_per_batch), len(rpn_neg_per_batch))
    print('val_rpn_pos_per_batch', np.mean(val_rpn_pos_per_batch), len(val_rpn_pos_per_batch))
    print('val_rpn_neg_per_batch', np.mean(val_rpn_neg_per_batch), len(val_rpn_neg_per_batch))
    
    print('===================ROI========================')
    print('valid_iou : ', np.mean(iou_array), 'valid_num : ', iou_array.size)
    try:
        print('valid_recall : ', true_pos/(true_pos+false_neg))
    except:
        print('valid_recall : ', 0)
    try:
        print('valid_precision : ', true_pos/(true_pos+false_pos))
    except:
        print('valid_precision : ', 0)
    
    print('===================LOSS========================')
    print('loss : ', rpn_loc_loss_avg, rpn_cls_loss_avg, roi_loc_loss_avg, roi_cls_loss_avg)
    print('val_loss : ', val_rpn_loc_loss_avg, val_rpn_cls_loss_avg, val_roi_loc_loss_avg, val_roi_cls_loss_avg)
    print('epoch : ', epoch, ',train_loss : ',train_loss_avg, ',valid_loss : ', valid_loss_avg)
    # wandb.log({"EPOCHS":epoch, "Train Loss":train_loss_avg, "Valid Loss":valid_loss_avg}) 
    if valid_loss_avg < best_valid_loss_avg :
        best_valid_loss_avg = valid_loss_avg
        print('saving model weight') 
        torch.save(model.state_dict(), '/kaggle/working/weight.pt')
    lr_scheduler.step()
# wandb.finish() 

start training (unfreeze ResNet50)


  5%|▌         | 1/20 [12:28<3:56:55, 748.21s/it]

rpn_pos_per_batch 7.9079754601226995 489
rpn_neg_per_batch 2039.5685071574642 489
val_rpn_pos_per_batch 0.9592105263157895 760
val_rpn_neg_per_batch 2045.6934210526315 760
valid_iou :  0.027928247863623454 valid_num :  11168
valid_recall :  0.9923954372623575
valid_precision :  0.09563942836203737
loss :  0.0029273256861270503 0.7010446590636161 0.001132923849785517 0.26184874890345494
val_loss :  0.00030208105449913236 nan 0.00011791440192909202 0.13376822728590157
epoch :  0 ,train_loss :  1.368918360010978 ,valid_loss :  nan


 10%|█         | 2/20 [24:50<3:43:26, 744.80s/it]

rpn_pos_per_batch 7.9079754601226995 489
rpn_neg_per_batch 2039.5685071574642 489
val_rpn_pos_per_batch 0.9592105263157895 760
val_rpn_neg_per_batch 2045.6934210526315 760
valid_iou :  0.05479522822959775 valid_num :  4233
valid_recall :  0.9486692015209125
valid_precision :  0.11980792316926771
loss :  0.002908095868990246 0.6968342795937583 0.0013662579110573154 0.06545698025848541
val_loss :  0.00029985850486283244 nan 8.350736346340987e-05 0.14029982559762796
epoch :  1 ,train_loss :  1.1897266381601126 ,valid_loss :  nan


 15%|█▌        | 3/20 [37:20<3:31:40, 747.06s/it]

rpn_pos_per_batch 7.9079754601226995 489
rpn_neg_per_batch 2039.5685071574642 489
val_rpn_pos_per_batch 0.9592105263157895 760
val_rpn_neg_per_batch 2045.6934210526315 760
valid_iou :  0.04672638686821599 valid_num :  3643
valid_recall :  0.9562737642585551
valid_precision :  0.11405895691609977
loss :  0.0029044591232868584 0.6959340912432759 0.001418562167503897 0.05715084321614058
val_loss :  0.0002994337758106045 nan 6.504105550525789e-05 0.10801522049525948
epoch :  2 ,train_loss :  1.1853870648304135 ,valid_loss :  nan


 20%|██        | 4/20 [49:42<3:18:39, 744.96s/it]

rpn_pos_per_batch 7.9079754601226995 489
rpn_neg_per_batch 2039.5685071574642 489
val_rpn_pos_per_batch 0.9592105263157895 760
val_rpn_neg_per_batch 2045.6934210526315 760
valid_iou :  0.07559498949885131 valid_num :  2367
valid_recall :  0.9011406844106464
valid_precision :  0.12562947256824808
loss :  0.0029083583919985645 0.6953530506609895 0.001418531447677349 0.03126670048973982
val_loss :  0.00030001181273148274 nan 0.0001303345607933998 0.1476314768805184
epoch :  3 ,train_loss :  1.1593087391131738 ,valid_loss :  nan


 25%|██▌       | 5/20 [1:02:05<3:06:07, 744.47s/it]

rpn_pos_per_batch 7.9079754601226995 489
rpn_neg_per_batch 2039.5685071574642 489
val_rpn_pos_per_batch 0.9592105263157895 760
val_rpn_neg_per_batch 2045.6934210526315 760
valid_iou :  0.056236421118250904 valid_num :  2384
valid_recall :  0.903041825095057
valid_precision :  0.11585365853658537
loss :  0.002901593346860791 0.6951632749571147 0.0011017122127620343 0.03197273692679887
val_loss :  0.00030028642272315523 nan 0.00011178094330770555 0.14060233186679205
epoch :  4 ,train_loss :  1.1274665659922032 ,valid_loss :  nan


 30%|███       | 6/20 [1:14:26<2:53:22, 743.07s/it]

rpn_pos_per_batch 7.9079754601226995 489
rpn_neg_per_batch 2039.5685071574642 489
val_rpn_pos_per_batch 0.9592105263157895 760
val_rpn_neg_per_batch 2045.6934210526315 760
valid_iou :  0.06056689707278333 valid_num :  1925
valid_recall :  0.8574144486692015
valid_precision :  0.1174785100286533
loss :  0.002900739431934496 0.6947257923446062 0.0010455321631564353 0.02076391029459232
val_loss :  0.0003000105797151551 nan 4.869215593686623e-05 0.14906837604388495
epoch :  5 ,train_loss :  1.110116861836798 ,valid_loss :  nan


 35%|███▌      | 7/20 [1:26:58<2:41:38, 746.01s/it]

rpn_pos_per_batch 7.9079754601226995 489
rpn_neg_per_batch 2039.5685071574642 489
val_rpn_pos_per_batch 0.9592105263157895 760
val_rpn_neg_per_batch 2045.6934210526315 760
valid_iou :  0.06850859400471572 valid_num :  1696
valid_recall :  0.8136882129277566
valid_precision :  0.1296969696969697
loss :  0.0028992444172357295 0.6945296632000274 0.0009289623753002143 0.0156265162400509
val_loss :  0.00029851796706632165 nan 5.701274909776992e-05 0.18858272139652016
epoch :  6 ,train_loss :  1.092976862058074 ,valid_loss :  nan


 40%|████      | 8/20 [1:39:27<2:29:24, 747.07s/it]

rpn_pos_per_batch 7.9079754601226995 489
rpn_neg_per_batch 2039.5685071574642 489
val_rpn_pos_per_batch 0.9592105263157895 760
val_rpn_neg_per_batch 2045.6934210526315 760
valid_iou :  0.10505408668506218 valid_num :  654
valid_recall :  0.5627376425855514
valid_precision :  0.15171706817016914
loss :  0.00289664376712288 0.6943861543522778 0.0008090499337694195 0.0141074667943752
val_loss :  0.0002999851066832841 nan 8.401928468508027e-05 0.28871850838916685
epoch :  7 ,train_loss :  1.0790629901281408 ,valid_loss :  nan


 45%|████▌     | 9/20 [1:51:53<2:16:53, 746.72s/it]

rpn_pos_per_batch 7.9079754601226995 489
rpn_neg_per_batch 2039.5685071574642 489
val_rpn_pos_per_batch 0.9592105263157895 760
val_rpn_neg_per_batch 2045.6934210526315 760
valid_iou :  0.12082660468501637 valid_num :  571
valid_recall :  0.5342205323193916
valid_precision :  0.1341288782816229
loss :  0.0028975678611154417 0.6942793978016313 0.0007798456516873721 0.01015897222974572
val_loss :  0.0002994782118022206 nan 8.997159295833662e-05 0.3057670125065335
epoch :  8 ,train_loss :  1.072179715448118 ,valid_loss :  nan


 50%|█████     | 10/20 [2:04:16<2:04:14, 745.47s/it]

rpn_pos_per_batch 7.9079754601226995 489
rpn_neg_per_batch 2039.5685071574642 489
val_rpn_pos_per_batch 0.9592105263157895 760
val_rpn_neg_per_batch 2045.6934210526315 760
valid_iou :  0.13012715597694346 valid_num :  898
valid_recall :  0.7129277566539924
valid_precision :  0.15611990008326396
loss :  0.0028933105076823177 0.6942090615173059 0.0006566636097198492 0.008797190074141353
val_loss :  0.0003010709109012647 nan 0.00012828121216840242 0.30909243367213524
epoch :  9 ,train_loss :  1.0580036601162153 ,valid_loss :  nan


 55%|█████▌    | 11/20 [2:16:31<1:51:21, 742.41s/it]

rpn_pos_per_batch 7.9079754601226995 489
rpn_neg_per_batch 2039.5685071574642 489
val_rpn_pos_per_batch 0.9592105263157895 760
val_rpn_neg_per_batch 2045.6934210526315 760
valid_iou :  0.0922899824041976 valid_num :  1405
valid_recall :  0.8212927756653993
valid_precision :  0.1390408754425491
loss :  0.0028956187340232844 0.694144741767267 0.0006963298517046213 0.010657988763166733
val_loss :  0.00029843421324537636 nan 9.573840309856582e-05 0.21993066734005545
epoch :  10 ,train_loss :  1.0639975904687051 ,valid_loss :  nan


 60%|██████    | 12/20 [2:28:52<1:38:56, 742.11s/it]

rpn_pos_per_batch 7.9079754601226995 489
rpn_neg_per_batch 2039.5685071574642 489
val_rpn_pos_per_batch 0.9592105263157895 760
val_rpn_neg_per_batch 2045.6934210526315 760
valid_iou :  0.13140606854090398 valid_num :  760
valid_recall :  0.6045627376425855
valid_precision :  0.1515007146260124
loss :  0.002891180679549183 0.6941157719596519 0.0006239134707939303 0.007130925277684165
val_loss :  0.0002998354549354668 nan 0.00011105996114181575 0.25189588693942583
epoch :  11 ,train_loss :  1.052756115337091 ,valid_loss :  nan


 65%|██████▌   | 13/20 [2:41:10<1:26:24, 740.64s/it]

rpn_pos_per_batch 7.9079754601226995 489
rpn_neg_per_batch 2039.5685071574642 489
val_rpn_pos_per_batch 0.9592105263157895 760
val_rpn_neg_per_batch 2045.6934210526315 760
valid_iou :  0.13122621672417104 valid_num :  480
valid_recall :  0.48098859315589354
valid_precision :  0.14988151658767773
loss :  0.0028805008453976675 0.6941033382357263 0.0005758459994345734 0.00847507777745485
val_loss :  0.00029894384105749565 nan 9.641681714576951e-05 0.3583967070951279
epoch :  12 ,train_loss :  1.0482131035537563 ,valid_loss :  nan


 70%|███████   | 14/20 [2:53:28<1:13:58, 739.81s/it]

rpn_pos_per_batch 7.9079754601226995 489
rpn_neg_per_batch 2039.5685071574642 489
val_rpn_pos_per_batch 0.9592105263157895 760
val_rpn_neg_per_batch 2045.6934210526315 760
valid_iou :  0.11891018647961532 valid_num :  1069
valid_recall :  0.7224334600760456
valid_precision :  0.149547422274695
loss :  0.0028831390902942337 0.6940489930609253 0.0005206402767515934 0.005562303203433287
val_loss :  0.000299741820055332 nan 9.804112639685806e-05 0.30375134906784007
epoch :  13 ,train_loss :  1.0399892324075133 ,valid_loss :  nan


 75%|███████▌  | 15/20 [3:06:01<1:01:59, 743.92s/it]

rpn_pos_per_batch 7.9079754601226995 489
rpn_neg_per_batch 2039.5685071574642 489
val_rpn_pos_per_batch 0.9592105263157895 760
val_rpn_neg_per_batch 2045.6934210526315 760
valid_iou :  0.1780497255930727 valid_num :  491
valid_recall :  0.5038022813688213
valid_precision :  0.16327788046826863
loss :  0.002880874860285189 0.6940559308465517 0.00048720897575735305 0.006662459453795932
val_loss :  0.0002972784228739712 nan 0.00012795506922959465 0.35041787760598564
epoch :  14 ,train_loss :  1.0375267747477277 ,valid_loss :  nan


 80%|████████  | 16/20 [3:18:20<49:29, 742.48s/it]  

rpn_pos_per_batch 7.9079754601226995 489
rpn_neg_per_batch 2039.5685071574642 489
val_rpn_pos_per_batch 0.9592105263157895 760
val_rpn_neg_per_batch 2045.6934210526315 760
valid_iou :  0.16497948975951593 valid_num :  532
valid_recall :  0.5342205323193916
valid_precision :  0.15938740782756664
loss :  0.002892682594087657 0.693967496203011 0.0004107993349453047 0.004242853537362101
val_loss :  0.00029966747838204214 nan 0.00013105087512644187 0.3348370882481534
epoch :  15 ,train_loss :  1.0285585410512055 ,valid_loss :  nan


 85%|████████▌ | 17/20 [3:31:00<37:22, 747.61s/it]

rpn_pos_per_batch 7.9079754601226995 489
rpn_neg_per_batch 2039.5685071574642 489
val_rpn_pos_per_batch 0.9592105263157895 760
val_rpn_neg_per_batch 2045.6934210526315 760
valid_iou :  0.1901388319624528 valid_num :  393
valid_recall :  0.44106463878326996
valid_precision :  0.16453900709219857
loss :  0.0028864560778069943 0.693928326809577 0.00042891421207503984 0.004803207180821546
val_loss :  0.00029454295947384945 nan 0.00013673720172040611 0.3973947381827172
epoch :  16 ,train_loss :  1.0302685638146898 ,valid_loss :  nan


 90%|█████████ | 18/20 [3:43:48<25:07, 753.75s/it]

rpn_pos_per_batch 7.9079754601226995 489
rpn_neg_per_batch 2039.5685071574642 489
val_rpn_pos_per_batch 0.9592105263157895 760
val_rpn_neg_per_batch 2045.6934210526315 760
valid_iou :  0.17287682214577305 valid_num :  634
valid_recall :  0.6140684410646388
valid_precision :  0.1650485436893204
loss :  0.0028848646891066246 0.6938906058455782 0.0003517764945218563 0.0044779622873028545
val_loss :  0.0002984844295832246 nan 0.0001636785890372095 0.3540373016297029
epoch :  17 ,train_loss :  1.0220326904382686 ,valid_loss :  nan


 95%|█████████▌| 19/20 [3:56:05<12:28, 748.72s/it]

rpn_pos_per_batch 7.9079754601226995 489
rpn_neg_per_batch 2039.5685071574642 489
val_rpn_pos_per_batch 0.9592105263157895 760
val_rpn_neg_per_batch 2045.6934210526315 760
valid_iou :  0.17808608486683308 valid_num :  422
valid_recall :  0.49049429657794674
valid_precision :  0.15674362089914945
loss :  0.0028780959167442493 0.6939147407291857 0.0003630553539903781 0.004320601947159649
val_loss :  0.0002958312536257405 nan 0.00016718015739711298 0.3635833885101623
epoch :  18 ,train_loss :  1.0223504731503976 ,valid_loss :  nan


100%|██████████| 20/20 [4:08:22<00:00, 745.13s/it]

rpn_pos_per_batch 7.9079754601226995 489
rpn_neg_per_batch 2039.5685071574642 489
val_rpn_pos_per_batch 0.9592105263157895 760
val_rpn_neg_per_batch 2045.6934210526315 760
valid_iou :  0.23603939132368273 valid_num :  310
valid_recall :  0.3916349809885932
valid_precision :  0.17195325542570952
loss :  0.002895982982686022 0.6938687122916395 0.000354015469853366 0.003798156544546633
val_loss :  0.00029846280557122603 nan 0.00016984277765771675 0.4438335094337052
epoch :  19 ,train_loss :  1.0226667140646701 ,valid_loss :  nan





## 拜託好運貓貓讓我 train 成功 
![](https://hips.hearstapps.com/hmg-prod/images/beautiful-smooth-haired-red-cat-lies-on-the-sofa-royalty-free-image-1678488026.jpg?crop=1.00xw:0.752xh;0,0.0457xh&resize=640:*)