In [1]:
import torch
import torch.nn as nn
import torchvision

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import time
import os

from torch.utils.data import Dataset, DataLoader
# from torch.utils.data.sampler import Sampler
import torch.optim as optim
import sys
sys.path.append('../')
sys.path.append('../../')

from dataset import CocoDetection, train_transforms, val_transforms, test_transforms
from visualize import visualize
# from rcnn_model import fasterrcnn_resnet201_fpn, FastRCNNPredictor
from engine import evaluate
import utils
from models.swin import *

In [2]:
from models.detection.backbone_utils import swin_fpn_backbone, _validate_trainable_layers
from ops.feature_pyramid_network import LastLevelP6P7, LastLevelMaxPool
from models.detection.retinanet import RetinaNet
from torch.hub import load_state_dict_from_url
from models.detection.anchor_utils import AnchorGenerator
# from models.detection.backbone_utils import mobilenet_backbone

In [3]:
def retinanet_swin_t_fpn(pretrained=False, progress=True,
                           num_classes=91, pretrained_backbone=False, trainable_backbone_layers=None, **kwargs):
    trainable_backbone_layers = _validate_trainable_layers(
        pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3)

    if pretrained:
        # no need to download the backbone if pretrained is set
        pretrained_backbone = False
        
    anchor_sizes = ((32, 64, 128, 256, 512), ) * 5
    aspect_ratios = ((0.5, 0.75, 1.0, 1.5, 2.0),) * len(anchor_sizes)
    rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios)
    
    # skip P2 because it generates too many anchors (according to their paper)
    backbone = swin_fpn_backbone('swin_t', pretrained_backbone, returned_layers=[2, 3, 4],
                                   extra_blocks=LastLevelP6P7(256,256), trainable_layers=trainable_backbone_layers)
    
    model = RetinaNet(backbone, num_classes, anchor_generator=rpn_anchor_generator, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls['retinanet_resnet50_fpn_coco'],
                                              progress=progress)
        model.load_state_dict(state_dict)
        overwrite_eps(model, 0.0)
    return model

In [4]:
NUM_CLASS = 91
IMG_SIZE = 448*2
model = retinanet_swin_t_fpn(pretrained=False, min_size=IMG_SIZE, max_size=IMG_SIZE, num_classes=NUM_CLASS)

device = torch.device('cuda')
model.to(device)
print('model is loaded to gpu')

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


return_layers {'layer2': '0', 'layer3': '1', 'layer4': '2'}
model is loaded to gpu


In [5]:
# model

In [6]:
import easydict 
args = easydict.EasyDict({ "batch_size": 2, 
                          "epochs": 90, 
                          "data": 0, 
                          'lr':0.002,
                         'momentum':0.9,
                         'weight_decay':1e-4,
                         'start_epoch':0,
                         'gpu':0,
                          'workers':12,
                         'print_freq':1000,
                         'output_dir':'../trained_model/retinanet_swin_v2_t_fpn/'})

In [7]:
from pathlib import Path
path = Path(args.output_dir.split('checkpoint')[0])
path.mkdir(parents=True, exist_ok=True)  

In [8]:
ngpus_per_node = torch.cuda.device_count()
print(ngpus_per_node)
GPU_NUM = args.gpu # 원하는 GPU 번호 입력
device = torch.device(f'cuda:{GPU_NUM}' if torch.cuda.is_available() else 'cpu')
torch.cuda.set_device(device)
print(device)

1
cuda:0


In [9]:
from dataset import CocoDetection, train_transforms, val_transforms, test_transforms
train_dataset = CocoDetection(root='/home/Dataset/scl/', annFile='../../data/train.json', 
                              transforms=train_transforms)
test_dataset = CocoDetection(root='/home/Dataset/scl/', annFile='../../data/test.json', 
                              transforms=val_transforms)

loading annotations into memory...
Done (t=0.02s)
creating index...
index created!
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!


In [10]:
image, target = next(iter(train_dataset))
target
                    

{'boxes': tensor([[667.1875, 472.5000, 743.3125, 542.9375]]),
 'category_id': tensor([1]),
 'labels': tensor([1]),
 'image_id': tensor([1]),
 'area': tensor([5362.0547]),
 'iscrowd': tensor([0])}

In [11]:
train_sampler = torch.utils.data.RandomSampler(train_dataset)
test_sampler = torch.utils.data.SequentialSampler(test_dataset)

train_loader = DataLoader(
    train_dataset, batch_size=args.batch_size,
    sampler=train_sampler, num_workers=args.workers,
    collate_fn=utils.collate_fn)

test_loader = DataLoader(
    test_dataset, batch_size=args.batch_size,
    sampler=test_sampler, num_workers=args.workers,
    collate_fn=utils.collate_fn)

In [12]:
params = [p for p in model.parameters() if p.requires_grad]
# optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay)
optimizer = torch.optim.SGD(
       params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[15, 30, 45, 60, 75], 
                                                    gamma=0.5)

In [None]:
from engine import train_one_epoch

start_time = time.time()
for epoch in range(args.epochs):
    train_one_epoch(model, optimizer, train_loader, device, epoch, args.print_freq)
    lr_scheduler.step()
    
    if epoch > 60 and epoch % 5 == 0 :
        if args.output_dir:
            checkpoint = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'args': args,
                'epoch': epoch
            }
            utils.save_on_master(
                checkpoint,
                os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
            utils.save_on_master(
                checkpoint,
                os.path.join(args.output_dir, 'checkpoint.pth'))

    if epoch > 5 and epoch % 5 == 0 :
        # evaluate after every epoch
        evaluate(model, test_loader, device=device)    
print('total time is {}'.format(time.time() - start_time))    

Epoch: [0]  [   0/3098]  eta: 1:39:47  lr: 0.000004  loss: 3.0876 (3.0876)  classification: 2.3525 (2.3525)  bbox_regression: 0.7351 (0.7351)  time: 1.9328  data: 1.0108  max mem: 7460
Epoch: [0]  [1000/3098]  eta: 0:18:23  lr: 0.002000  loss: 1.6415 (2.1396)  classification: 1.0110 (1.5201)  bbox_regression: 0.5783 (0.6195)  time: 0.5292  data: 0.0047  max mem: 7789
Epoch: [0]  [2000/3098]  eta: 0:09:39  lr: 0.002000  loss: 1.4683 (1.9623)  classification: 0.9360 (1.3566)  bbox_regression: 0.5172 (0.6057)  time: 0.5285  data: 0.0048  max mem: 7789


In [None]:
model