In [1]:
from DataLoader import MyOwnDataloader
from pycocotools.coco import COCO

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import torchvision

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
dataDir='/media/gamedisk/COCO_dataset/'
val='val2017'
train = 'train2017'

val_annFile='{}/annotations/instances_{}.json'.format(dataDir,val)
train_annFile='{}/annotations/instances_{}.json'.format(dataDir,train) 
# Batch size
batch_size = 8



classes = {
    "bird": 0,
    "cat": 1,
    "dog": 2,
    "horse": 3,
    "sheep": 4,
    "cow": 5,
    "elephant": 6,
    "bear": 7,
    "zebra": 8,
    "giraffe": 9
}


coco = COCO(val_annFile)
val_loader = MyOwnDataloader(dataDir = dataDir, dataType = val,
                     annFile = val_annFile, classes = classes, train_batch_size=batch_size)
data_loader = val_loader.concat_datasets()


loading annotations into memory...
Done (t=0.38s)
creating index...
index created!
loading annotations into memory...
Done (t=0.33s)
creating index...
index created!
bird 0 [16] 55299


In [4]:
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

def get_model_instance_segmentation(num_classes):
    # load an instance segmentation model pre-trained pre-trained on COCO
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False)
    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    return model
    

# 2 classes; Only target class or background

num_epochs = 2
model = get_model_instance_segmentation(10)

# move model to the right device
model.to(device)
    
# parameters
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

len_dataloader = len(data_loader)

for epoch in range(num_epochs):
    model.train()
    i = 0 
    print(f'epoch # {epoch}')   
    for i, (imgs, annotations) in tqdm(enumerate(data_loader)):
        i += 1
        imgs = list(img.to(device) for img in imgs)
        annotations = [{k: v.to(device) for k, v in t.items()} for t in annotations]
        loss_dict = model(imgs, annotations)
        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        if i%5:
            print(f'Iteration: {i}/{len_dataloader}, Loss: {losses}')

epoch # 0


1it [00:00,  1.07it/s]

Iteration: 1/137, Loss: 10.653388977050781


2it [00:01,  1.55it/s]

Iteration: 2/137, Loss: 13.098834991455078


3it [00:01,  1.83it/s]

Iteration: 3/137, Loss: 47.30316925048828


4it [00:02,  1.99it/s]

Iteration: 4/137, Loss: 9.815535545349121


6it [00:03,  2.11it/s]

Iteration: 6/137, Loss: 12.52155876159668


7it [00:03,  2.20it/s]

Iteration: 7/137, Loss: 3.575284957885742


8it [00:03,  2.27it/s]

Iteration: 8/137, Loss: 5.781394958496094


9it [00:04,  2.29it/s]

Iteration: 9/137, Loss: 12.71387767791748


11it [00:05,  2.24it/s]

Iteration: 11/137, Loss: 7.662105560302734


12it [00:05,  2.31it/s]

Iteration: 12/137, Loss: 27.923887252807617


13it [00:05,  2.36it/s]

Iteration: 13/137, Loss: 8.684164047241211


14it [00:06,  2.39it/s]

Iteration: 14/137, Loss: 9.958576202392578


16it [00:07,  2.34it/s]

Iteration: 16/137, Loss: 27.396018981933594


17it [00:07,  2.37it/s]

Iteration: 17/137, Loss: 13.241992950439453


18it [00:08,  2.38it/s]

Iteration: 18/137, Loss: 15.754155158996582


20it [00:08,  2.86it/s]

Iteration: 19/137, Loss: 14.33704948425293


21it [00:09,  2.45it/s]

Iteration: 21/137, Loss: 184458656.0


22it [00:09,  2.47it/s]

Iteration: 22/137, Loss: nan


23it [00:09,  2.56it/s]

Iteration: 23/137, Loss: nan


25it [00:10,  3.13it/s]

Iteration: 24/137, Loss: nan


26it [00:10,  2.60it/s]

Iteration: 26/137, Loss: nan


27it [00:11,  2.65it/s]

Iteration: 27/137, Loss: nan


28it [00:11,  2.67it/s]

Iteration: 28/137, Loss: nan


30it [00:12,  3.19it/s]

Iteration: 29/137, Loss: nan


31it [00:12,  2.66it/s]

Iteration: 31/137, Loss: nan


32it [00:13,  2.60it/s]

Iteration: 32/137, Loss: nan


33it [00:13,  2.65it/s]

Iteration: 33/137, Loss: nan


35it [00:14,  3.18it/s]

Iteration: 34/137, Loss: nan


36it [00:14,  2.65it/s]

Iteration: 36/137, Loss: nan


37it [00:14,  2.68it/s]

Iteration: 37/137, Loss: nan


38it [00:15,  2.71it/s]

Iteration: 38/137, Loss: nan


39it [00:15,  2.68it/s]

Iteration: 39/137, Loss: nan


41it [00:16,  2.52it/s]

Iteration: 41/137, Loss: nan


42it [00:16,  2.58it/s]

Iteration: 42/137, Loss: nan


43it [00:17,  2.63it/s]

Iteration: 43/137, Loss: nan


45it [00:17,  3.15it/s]

Iteration: 44/137, Loss: nan


46it [00:18,  2.62it/s]

Iteration: 46/137, Loss: nan


47it [00:18,  2.64it/s]

Iteration: 47/137, Loss: nan


48it [00:19,  2.68it/s]

Iteration: 48/137, Loss: nan


50it [00:19,  3.18it/s]

Iteration: 49/137, Loss: nan


51it [00:20,  2.64it/s]

Iteration: 51/137, Loss: nan


52it [00:20,  2.67it/s]

Iteration: 52/137, Loss: nan


53it [00:20,  2.70it/s]

Iteration: 53/137, Loss: nan


55it [00:21,  3.18it/s]

Iteration: 54/137, Loss: nan


56it [00:21,  2.61it/s]

Iteration: 56/137, Loss: nan


57it [00:22,  2.54it/s]

Iteration: 57/137, Loss: nan





KeyboardInterrupt: 