In [1]:
from utils.masks_bb import masks_bb
from utils.dataset import StrawberryDataset
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torch.utils.data import DataLoader
from torchvision.ops import box_iou
from tqdm.notebook import tqdm 

import torch
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

cuda


In [2]:
#  partition_data('data/Images_resized/', 'data/masks_resized/', 'data/train/Images/', 'data/train/masks/', 'data/test/Images/', 'data/test/masks/')

In [3]:
masks_train, boxes_train = masks_bb('data/train/masks/')
masks_test, boxes_test = masks_bb('data/test/masks/')

In [4]:
dataset_train = StrawberryDataset('data/train/Images/', 'data/train/masks/', boxes_train, masks_train)
dataset_test = StrawberryDataset('data/test/Images/', 'data/test/masks/', boxes_test, masks_test)

In [5]:
model = maskrcnn_resnet50_fpn(progress=True, pretrained_backbone=True, num_classes=2, box_detections_per_img=12)
model.to(device)

optimizer = torch.optim.Adam(model.parameters(),lr=0.001)

def collate_fn(batch):
    return tuple(zip(*batch))

data_loader_train = DataLoader(dataset_train, batch_size=4, collate_fn=collate_fn, shuffle=True)
data_loader_test = DataLoader(dataset_test, batch_size=4, collate_fn=collate_fn, shuffle=True)



In [6]:
batch = next(iter(data_loader_train))
images, targets = batch
images = torch.stack(images).to(device) 
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
print(targets[0])

{'boxes': tensor([[752., 269., 871., 416.],
        [786., 578., 898., 737.],
        [549., 285., 640., 392.],
        [568., 566., 665., 703.],
        [319., 250., 434., 416.],
        [348., 587., 436., 708.],
        [118., 607., 207., 730.]], device='cuda:0'), 'labels': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
   

In [None]:
losses_dict = {'loss_classifier': [], 'loss_box_reg': [], 'loss_mask': [], 'loss_objectness': [], 'loss_rpn_box_reg': [], 'total_loss': []}

for epoch in tqdm(range(10), desc='Epochs', colour='green'):
    losses_avg = {'loss_classifier': [], 'loss_box_reg': [], 'loss_mask': [], 'loss_objectness': [], 'loss_rpn_box_reg': [], 'total_loss': []}
    model.train()
    for batch in tqdm(data_loader_train, desc='data_loader', colour='blue'):
        images, targets = batch
        images = torch.stack(images).to(device)
        targets = [{k: v.to(device) for k, v in target.items()} for target in targets]
        
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        losses.backward()
        optimizer.step()
        optimizer.zero_grad()  

        losses_avg['loss_classifier'].append(loss_dict['loss_classifier'].item())
        losses_avg['loss_box_reg'].append(loss_dict['loss_box_reg'].item())
        losses_avg['loss_mask'].append(loss_dict['loss_mask'].item())   
        losses_avg['loss_objectness'].append(loss_dict['loss_objectness'].item())
        losses_avg['loss_rpn_box_reg'].append(loss_dict['loss_rpn_box_reg'].item())
        losses_avg['total_loss'].append(losses.item())

    losses_dict['loss_classifier'].append(sum(losses_avg['loss_classifier'])/len(losses_avg['loss_classifier']))
    losses_dict['loss_box_reg'].append(sum(losses_avg['loss_box_reg'])/len(losses_avg['loss_box_reg']))
    losses_dict['loss_mask'].append(sum(losses_avg['loss_mask'])/len(losses_avg['loss_mask']))
    losses_dict['loss_objectness'].append(sum(losses_avg['loss_objectness'])/len(losses_avg['loss_objectness']))
    losses_dict['loss_rpn_box_reg'].append(sum(losses_avg['loss_rpn_box_reg'])/len(losses_avg['loss_rpn_box_reg']))
    losses_dict['total_loss'].append(sum(losses_avg['total_loss'])/len(losses_avg['total_loss']))


    model.eval()
    total_iou_bb_avg = []
    total_iou_mask_avg = []
    total_precision = []
    total_recall = []

    for batch in tqdm(data_loader_test, desc='validation', colour='red'):
        images, targets = batch
        images = torch.stack(images).to(device)
        targets = [{k: v.to(device) for k, v in target.items()} for target in targets]
        outputs = model(images)
        iou_bb_avg_list = []
        iou_mask_avg_list = []
        precision_list = []
        recall_list = []
        for output, target in zip(outputs, targets):
            score = output["scores"] > 0.5
            output["boxes"] = output["boxes"][score]
            output["labels"] = output["labels"][score]
            iou = box_iou(output['boxes'], target['boxes'])

            true_positive = []
            false_positive = []
            for i in range(len(iou)):
                if iou[i].max().item() > 0.5:
                    true_positive.append(iou[i].max().item())
                else:
                    false_positive.append(iou[i].max().item())
            bb_iou_avg = sum(true_positive)/len(true_positive) if len(true_positive) > 0 else 0
            intersection = torch.logical_and(output['masks'], target['masks']).sum().float()
            union = torch.logical_or(output['masks'], target['masks']).sum().float()
            iou_masks_score = intersection / (union + 1e-6)
            precision = len(true_positive) / (len(true_positive) + len(false_positive)) if len(true_positive)+len(false_positive) > 0 else 0
            recall = sum(true_positive) / len(target['boxes']) if len(target['boxes']) > 0 else 0
            iou_bb_avg_list.append(bb_iou_avg)
            iou_mask_avg_list.append(iou_masks_score.item())
            precision_list.append(precision)
            recall_list.append(recall)
        total_iou_bb_avg.append(sum(iou_bb_avg_list)/len(iou_bb_avg_list))
        total_iou_mask_avg.append(sum(iou_mask_avg_list)/len(iou_mask_avg_list))
        total_precision.append(sum(precision_list)/len(precision_list))
        total_recall.append(sum(recall_list)/len(recall_list))

    print(f"Epoch {epoch + 1}\nloss_classifier: {losses_dict['loss_classifier']}\nloss_box_reg: {losses_dict['loss_box_reg']}\nloss_mask: {losses_dict['loss_mask']}\nloss_objectness: {losses_dict['loss_objectness']}\nloss_rpn_box_reg: {losses_dict['loss_rpn_box_reg']}\ntotal_loss: {losses_dict['total_loss']}\nAvg iou bb: {sum(total_iou_bb_avg)/len(total_iou_bb_avg)}\nAvg iou masks: {sum(total_iou_mask_avg)/len(total_iou_mask_avg)}\nPrecision: {sum(total_precision)/len(total_precision)}\nRecall: {sum(total_recall)/len(total_recall)}")    


In [7]:
torch.save(model.state_dict(), 'strawberry_model.pth')