In [6]:
import random
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import numpy as np
import torch.utils.data
import cv2
import torchvision.models.segmentation
import torchvision.models.detection
import torch
import os
batchSize=1
imageSize=[512,512]

In [7]:
print(torchvision.__version__)
print(torch.__version__)

0.14.0
1.13.0


In [8]:
gpu = torch.device("cuda")
cpu = torch.device("cpu")

In [9]:
def loadData(path, batch_size):
    batch_Images = []
    batch_Data = []

    for i in range(batch_size):
        idImg = random.randint(0, len(path)-1)
        img = cv2.imread(path[idImg]["image"])
        maskImg = cv2.imread(path[idImg]["label"])
        maskImg = cv2.bitwise_not(maskImg)
        mask = [(maskImg > 0).astype(np.uint8)]

        gray = cv2.cvtColor(maskImg,cv2.COLOR_BGR2GRAY)
        thresh = cv2.threshold(gray,128,255,cv2.THRESH_BINARY)[1]
        contours = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        contours = contours[0] if len(contours) == 2 else contours[1]

        boxes = torch.zeros([len(contours),4], dtype=torch.float32)
        for i,contour in enumerate(contours) :
            x,y,w,h = cv2.boundingRect(contour)
            boxes[i] = torch.tensor([x,y, x+w, y+h])
            
        mask = torch.as_tensor(mask, dtype=torch.uint8)
        img = torch.as_tensor(img, dtype=torch.float32)
        data = {
            "boxes" : boxes,
            "labels": torch.ones((len(contours),), dtype=torch.int64),
            "masks" : mask
        }
        batch_Data.append(data)
        batch_Images.append(img)

    batch_Images= torch.stack([torch.as_tensor(d) for d in batch_Images], 0)
    batch_Images = batch_Images.swapaxes(1, 3).swapaxes(2, 3)
    return batch_Images, batch_Data


In [11]:
import shutil
def train_mrcnn_v1(path, batch_size):
    if os.path.exists(path+"/result/model"):
        shutil.rmtree(path+"/result/model")

    os.mkdir(path+"/result/model")
    
    imageDir = os.path.join(path, "image")
    lableDir = os.path.join(path, "mask")
    imgs = [{"image":imageDir + '/' + imgPath, "label":lableDir+ '/' + labelPath} for imgPath, labelPath in zip(os.listdir(imageDir), os.listdir(lableDir))]
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features 
    model.roi_heads.box_predictor=FastRCNNPredictor(in_features,num_classes=2)
    model.to(gpu)# move model to the right devic
    optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-5)
    model.train()
    for i in range(10001):
            images, targets = loadData(imgs, batch_size)
            images = list(image.to(gpu) for image in images)
            targets=[{k: v.to(gpu) for k,v in t.items()} for t in targets]

            optimizer.zero_grad()
            loss_dict = model(images, targets)

            losses = sum(loss for loss in loss_dict.values())
            losses.backward()
            optimizer.step()
            print(i,'loss:', losses.item())
            if i%500==0:
                torch.save(model.state_dict(), path +"/result/model/"+str(i)+".torch")


In [12]:
from torchvision.models.detection import maskrcnn_resnet50_fpn_v2, MaskRCNN_ResNet50_FPN_Weights

def train_mrcnn_v2(path,batch_size):
    if os.path.exists(path+"/result/modelv2"):
        shutil.rmtree(path+"/result/modelv2")

    os.mkdir(path+"/result/modelv2")
    
    imageDir = os.path.join(path, "image")
    lableDir = os.path.join(path, "mask")
    imgs = [{"image":imageDir + '/' + imgPath, "label":lableDir+ '/' + labelPath} for imgPath, labelPath in zip(os.listdir(imageDir), os.listdir(lableDir))]
    model = torchvision.models.detection.maskrcnn_resnet50_fpn_v2(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features 
    model.roi_heads.box_predictor=FastRCNNPredictor(in_features,num_classes=2)
    model.to(gpu)# move model to the right devic
    optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-5)
    model.train()
    for i in range(10001):
            images, targets = loadData(imgs, batch_size)
            images = list(image.to(gpu) for image in images)
            targets=[{k: v.to(gpu) for k,v in t.items()} for t in targets]

            optimizer.zero_grad()
            loss_dict = model(images, targets)

            losses = sum(loss for loss in loss_dict.values())
            losses.backward()
            optimizer.step()
            print(i,'loss:', losses.item())
            if i%500==0:
                torch.save(model.state_dict(), path +"/result/model/v2"+str(i)+".torch")


In [13]:
data_dir = ["crop_128", "resize_128"]
data_dir = ["crop_512"]
data_dir = ["resize_512"]

In [14]:

for path in data_dir:
    train_mrcnn_v1(path,2)

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


0 loss: 52.660133361816406
1 loss: 46.917808532714844
2 loss: 65.55193328857422
3 loss: 10.041972160339355
4 loss: 7.807111740112305
5 loss: 63.10396194458008
6 loss: 54.26588439941406
7 loss: 12.629693984985352
8 loss: 1.495792031288147
9 loss: 13.54576587677002
10 loss: 6.883968830108643
11 loss: 8.825187683105469
12 loss: 6.5233964920043945
13 loss: 11.971923828125
14 loss: 11.8583984375
15 loss: 11.942085266113281
16 loss: 7.584370136260986
17 loss: 4.6838459968566895
18 loss: 4.145400047302246
19 loss: 2.2261130809783936
20 loss: 5.36493444442749
21 loss: 1.2286041975021362
22 loss: 1.4529154300689697
23 loss: 1.7220481634140015
24 loss: 3.0095102787017822
25 loss: 2.741943836212158
26 loss: 2.951002597808838
27 loss: 0.9469656944274902
28 loss: 52.873870849609375
29 loss: 1.6723248958587646
30 loss: 1.7549488544464111
31 loss: 6.715487957000732
32 loss: 4.286092281341553
33 loss: 25.39372444152832
34 loss: 3.143489360809326
35 loss: 4.215130805969238
36 loss: 4.017560958862305
37

KeyboardInterrupt: 

In [None]:
for path in data_dir:
    train_mrcnn_v2(path,1)



0 loss: 2.586671829223633
1 loss: 1.3121163845062256
2 loss: 2.850853443145752
3 loss: 2.7258691787719727
4 loss: 2.5906190872192383
5 loss: 2.5761191844940186
6 loss: 2.7949576377868652
7 loss: 2.529034376144409
8 loss: 2.6956546306610107
9 loss: 2.206521511077881
10 loss: 2.2663960456848145
11 loss: 2.1497433185577393
12 loss: 2.3240256309509277
13 loss: 2.2314984798431396
14 loss: 6.211216449737549
15 loss: 2.0799548625946045
16 loss: 2.1932783126831055
17 loss: 0.8766852021217346
18 loss: 1.8942251205444336
19 loss: 2.190601110458374
20 loss: 2.122345447540283
21 loss: 1.9436993598937988
22 loss: 1.837454915046692
23 loss: 1.9542869329452515
24 loss: 1.5335601568222046
25 loss: 1.7408965826034546
26 loss: 1.6883593797683716
27 loss: 1.0469282865524292
28 loss: 1.545470118522644
29 loss: 1.7902435064315796
30 loss: 1.889362096786499
31 loss: 0.8759512901306152
32 loss: 0.7824288606643677
33 loss: 1.4204832315444946
34 loss: 1.8499600887298584
35 loss: 2.230243682861328
36 loss: 1.80

KeyboardInterrupt: 