In [None]:
import sys
sys.path.append('..')
from pycocotools.coco import COCO

In [None]:
coco = COCO(path_to_annotations)

In [None]:
cls = [coco.cats[i]['name'] if i in coco.cats.keys() else 'No class' for i in range(1, max(coco.cats.keys() ) + 1)]

In [None]:
cls = ['background'] + cls

In [None]:
conf = {
    'images_path': path_to_images,
    'annotations_path': path_to_annotations,
    'remove_zeros': True,
    'train_batch_size': 16,
    'test_batch_size': 4,
    'class_names' : cls,
    'num_classes': 91,
    'nms_thresh': 0.01,
    'num_epochs': 9,
    'path_to_model': path_to_checkpoint
} 

In [None]:
from utils.prepare_dataset import clean_zero
import os
if conf['remove_zeros']:
    clean_zero(os.path.join(conf['annotations_path'], 'train.json'))
    clean_zero(os.path.join(conf['annotations_path'], 'val.json'))

In [None]:
from loaders.dl import COCODataset, collate_fn 
import os

In [None]:
if conf['remove_zeros']:
    train_annotations = os.path.join(conf['annotations_path'], 'train_clean.json')
    val_annotations = os.path.join(conf['annotations_path'], 'val_clean.json')
else:
    train_annotations = os.path.join(conf['annotations_path'], 'train.json')
    val_annotations = os.path.join(conf['annotations_path'], 'val.json')

In [None]:
train_ds = COCODataset(root=conf['images_path'],
                       annotation=train_annotations)

test_ds = COCODataset(root=conf['images_path'],
                      annotation=val_annotations)


In [None]:
import torch
data_loader_train = torch.utils.data.DataLoader(train_ds,
                                                batch_size=conf['train_batch_size'],
                                                shuffle=True,
                                                num_workers=0,
                                                collate_fn=collate_fn)

data_loader_test = torch.utils.data.DataLoader(test_ds,
                                               batch_size=conf['test_batch_size'],
                                               shuffle=True,
                                               num_workers=4,
                                               collate_fn=collate_fn)

data_loader_infer = torch.utils.data.DataLoader(test_ds,
                                                batch_size=1,
                                                shuffle=True,
                                                num_workers=4,
                                                collate_fn=collate_fn)

In [None]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
import torchvision
from models.train_utils import *


### Модель и обучение

In [None]:
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(box_nms_thresh=conf['nms_thresh'], 
                                                             num_classes=conf['num_classes'])
model.to(device)

In [None]:
best_metric = 0
print_freq = 25

In [None]:
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.25)

In [None]:
for epoch in range(1, conf['num_epochs']):
    
    train_one_epoch(model, device, optimizer, data_loader_train, print_freq, epoch)
    current_metric = evaluate(model=model, dl=data_loader_test, device=device, iou_thresholds=[0.5], max_detection_thresholds=[50, 100, 200], filter_fn=None)
    print(current_metric['map'].item())
    if current_metric['map'].item() > best_metric:
        torch.save(model.state_dict(), conf['path_to_model'])
        best_metric = current_metric['map'].item()
        
    scheduler.step()

In [None]:
current_metric = evaluate(model=model, dl=data_loader_test, device=device, iou_thresholds=[0.5], max_detection_thresholds=[50, 100, 200], filter_fn=None)

In [None]:
current_metric

### Инференс

In [None]:
model.load_state_dict(torch.load(conf['path_to_model']))
model.to(device)

In [None]:
from utils.visualisation import *

In [None]:
inference(model, device, data_loader_infer, threshold=0.5, class_names=conf['class_names'])