# train

In [1]:
# 라이브러리 및 모듈 import
from pycocotools.coco import COCO
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import os
import torch
from torch.utils.data import DataLoader, Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
from effdet import get_efficientdet_config, EfficientDet, DetBenchTrain
from effdet.efficientdet import HeadNet
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

import warnings
warnings.filterwarnings(action='ignore')

In [2]:
CFG = {
    'NUM_CLASS':34,
    'EPOCHS':30,
    'ACCUMULATE':4,
    'LR':3e-4,
    'BATCH_SIZE':8,
    'SEED':41
}
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')


In [3]:
class CustomDataset(Dataset):

    def __init__(self, annotation, data_dir, transforms=None):
        super().__init__()
        self.data_dir = data_dir
        self.coco = COCO(annotation)
        self.predictions = {
            "images": self.coco.dataset["images"].copy(),
            "categories": self.coco.dataset["categories"].copy(),
            "annotations": None
        }
        self.transforms = transforms

    def __getitem__(self, index: int):
        image_id = self.coco.getImgIds(imgIds=index)
        image_info = self.coco.loadImgs(image_id)[0]
        image = np.array(Image.open(os.path.join(self.data_dir, image_info['file_name'])).convert('RGB'))
        image = A.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3, p=0.8)(image=image)['image']
        image = image.astype(np.float32) / 255.
        ann_ids = self.coco.getAnnIds(imgIds=image_info['id'])
        anns = self.coco.loadAnns(ann_ids)
        boxes = np.array([x['bbox'] for x in anns])
        boxes[:, 2] = boxes[:, 0] + boxes[:, 2]
        boxes[:, 3] = boxes[:, 1] + boxes[:, 3]
        labels = np.array([x['category_id'] for x in anns])
        labels = torch.as_tensor(labels, dtype=torch.int64)
        areas = np.array([x['area'] for x in anns])
        areas = torch.as_tensor(areas, dtype=torch.float32)
        is_crowds = np.array([x['iscrowd'] for x in anns])
        is_crowds = torch.as_tensor(is_crowds, dtype=torch.int64)
        target = {'boxes': boxes, 'labels': labels, 'image_id': torch.tensor([index]), 'area': areas,
                  'iscrowd': is_crowds}
        if self.transforms:
            while True:
                sample = self.transforms(**{
                    'image': image,
                    'bboxes': target['boxes'],
                    'labels': labels+1
                })
                if len(sample['bboxes']) > 0:
                    image = sample['image']
                    target['boxes'] = torch.stack(tuple(map(torch.tensor, zip(*sample['bboxes'])))).permute(1, 0)
                    target['boxes'][:,[0,1,2,3]] = target['boxes'][:,[1,0,3,2]]  #yxyx: be warning
                    target['labels'] = torch.tensor(sample['labels'])
                    break
        return image, target, image_id
    
    def __len__(self) -> int:
        return len(self.coco.getImgIds())

In [4]:
def train_transform():
    return A.Compose([
        A.Resize(512, 512),
        #A.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3, p=0.8),
        ToTensorV2(p=1.0)
    ], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})

def valid_transform():
    return A.Compose([
        A.Resize(512, 512),
        ToTensorV2(p=1.0)
    ], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})

In [5]:
# loss 추적
class Averager:
    def __init__(self):
        self.current_total = 0.0
        self.iterations = 0.0

    def send(self, value):
        self.current_total += value
        self.iterations += 1

    @property
    def value(self):
        if self.iterations == 0:
            return 0
        else:
            return 1.0 * self.current_total / self.iterations

    def reset(self):
        self.current_total = 0.0
        self.iterations = 0.0

def collate_fn(batch):
    return tuple(zip(*batch))

In [6]:
# https://github.com/rwightman/efficientdet-pytorch/blob/master/effdet/config/model_config.py
def get_net(checkpoint_path=None):
    config = get_efficientdet_config('tf_efficientdet_d0')
    config.num_classes = CFG['NUM_CLASS']
    config.image_size = (512,512)
    
    config.soft_nms = False
    config.max_det_per_image = 25
    
    net = EfficientDet(config, pretrained_backbone=True)
    net.class_net = HeadNet(config, num_outputs=config.num_classes) 
    
    if checkpoint_path:
        checkpoint = torch.load(checkpoint_path)
        net.load_state_dict(checkpoint['model_state_dict'])
    
    return DetBenchTrain(net)
    
# train function
def train_fn(num_epochs, train_loader, optimizer, scheduler, model, device, clip=35):
    model.train()
    step = 0
    for epoch in range(num_epochs):
        with tqdm(train_loader, unit = 'batch') as tepoch:
            for images, targets, _ in tepoch:
                tepoch.set_description(f'epoch {epoch+1}/{num_epochs}')

                images = torch.stack(images) # bs, ch, w, h - 16, 3, 512, 512
                images = images.to(device).float()
                boxes = [target['boxes'].to(device).float() for target in targets]
                labels = [target['labels'].to(device).float() for target in targets]
                target = {"bbox": boxes, "cls": labels}

                # calculate loss
                loss, cls_loss, box_loss = model(images, target).values()
                
                # backward
                (loss / CFG['ACCUMULATE']).backward()
                
                step += 1
                if step % CFG['ACCUMULATE'] : 
                    continue
                # grad clip
                torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
                optimizer.step()
                optimizer.zero_grad()
                scheduler.step()
                tepoch.set_postfix({'LR':round(scheduler.get_lr()[0],6),'loss':float(loss.detach().cpu()), 'loss_bbox':float(box_loss.detach().cpu()), 'loss_cls':float(cls_loss.detach().cpu())})
            
            torch.save(model.state_dict(), f'./ckp/epoch_{epoch+1}.pth')

In [7]:
annotation = './dataset/train.json'
data_dir = './dataset/train'
train_dataset = CustomDataset(annotation, data_dir, train_transform())
train_data_loader = DataLoader(
    train_dataset,
    batch_size=CFG['BATCH_SIZE'],
    shuffle=True,
    num_workers=0,
    collate_fn=collate_fn
)

model = get_net()
model.to(device)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(params, lr=CFG['LR'])
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3, threshold_mode='abs', min_lr=1e-6, verbose=True)
scheduler = CosineAnnealingWarmRestarts(optimizer=optimizer, eta_min=1e-6, T_0=405, T_mult=2)

loading annotations into memory...
Done (t=0.10s)
creating index...
index created!


In [8]:
# labels = []
# for images, targets, _ in tqdm(train_data_loader):
#     labels += [target['labels'].tolist() for target in targets]

In [9]:
# import itertools
# output = list(itertools.chain(*labels))
# np.unique(output,return_counts=True)

In [8]:
train_fn(CFG['EPOCHS'], train_data_loader, optimizer, scheduler, model, device)

epoch 1/30: 100%|██████████| 811/811 [15:25<00:00,  1.14s/batch, LR=0.000151, loss=1.16e+3, loss_bbox=0.00504, loss_cls=1.16e+3]
epoch 2/30: 100%|██████████| 811/811 [14:23<00:00,  1.06s/batch, LR=0.0003, loss=636, loss_bbox=0.00534, loss_cls=635]         
epoch 3/30: 100%|██████████| 811/811 [10:30<00:00,  1.29batch/s, LR=0.000256, loss=4.65, loss_bbox=0.00202, loss_cls=4.55]
epoch 4/30: 100%|██████████| 811/811 [10:34<00:00,  1.28batch/s, LR=0.00015, loss=1.74, loss_bbox=0.000885, loss_cls=1.7]  
epoch 5/30: 100%|██████████| 811/811 [10:32<00:00,  1.28batch/s, LR=4.5e-5, loss=1.53, loss_bbox=0.000527, loss_cls=1.5]   
epoch 6/30: 100%|██████████| 811/811 [11:57<00:00,  1.13batch/s, LR=0.0003, loss=1.69, loss_bbox=0.000446, loss_cls=1.67]
epoch 7/30: 100%|██████████| 811/811 [14:46<00:00,  1.09s/batch, LR=0.000288, loss=1.02, loss_bbox=0.000222, loss_cls=1.01]
epoch 8/30: 100%|██████████| 811/811 [09:49<00:00,  1.37batch/s, LR=0.000256, loss=0.908, loss_bbox=0.000315, loss_cls=0.893]


# validation

In [3]:
from effdet import DetBenchPredict
import gc

# Effdet config를 통해 모델 불러오기 + ckpt load
def load_net(checkpoint_path, device):
    config = get_efficientdet_config('tf_efficientdet_d0')
    config.num_classes = 34
    config.image_size = (512,512)
    
    config.soft_nms = False
    config.max_det_per_image = 25
    
    net = EfficientDet(config, pretrained_backbone=False)
    net.class_net = HeadNet(config, num_outputs=config.num_classes)
    
    checkpoint = torch.load(checkpoint_path, map_location='cpu')

    net = DetBenchPredict(net)
    net.load_state_dict(checkpoint)
    net.eval()

    return net.to(device)

In [4]:
def denormalize_box(box):
    x1 = list(box)[0] * (1920/512)
    y1 = list(box)[1] * (1080/512)
    x2 = list(box)[2] * (1920/512)
    y2 = list(box)[3] * (1080/512)
    return x1, y1, x2, y2

In [13]:
from utils import nms
from glob import glob
class ValidDataset(Dataset):
    def __init__(self, img_list, transform):
        super().__init__()
        self.img_list = img_list
        self.transform = transform

    def __getitem__(self, idx):
        file_name = self.img_list[idx]
        img = Image.open(file_name).convert('RGB')
        #img_size = torch.tensor(np.array(img).shape[:-1]).unsqueeze(0)
        img = np.array(img).astype(np.float32) / 255.0
        img = self.transform(image=np.array(img))['image']
        return file_name, img
    
    def __len__(self):
        return len(self.img_list)
# Albumentation을 이용, augmentation 선언
def get_train_transform():
    return A.Compose([
        A.Resize(512, 512),
        A.Flip(p=0.5),
        ToTensorV2(p=1.0)
    ])


def get_valid_transform():
    return A.Compose([
        A.Resize(512, 512),
        ToTensorV2(p=1.0)
    ])
img_list = glob('./dataset/train/*.png')
valid_dataset = ValidDataset(img_list, get_valid_transform())

In [19]:
checkpoint_path = './ckp/epoch_30.pth'

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = load_net(checkpoint_path, device)
model = model.to(device)

file_name, image = valid_dataset[0]
image = image.to(device).float()
with torch.no_grad():
    output = model(image.unsqueeze(0))
outputs = []
for out in output:
    outputs.append({'boxes': out.detach().cpu().numpy()[:,:4], 
                    'scores': out.detach().cpu().numpy()[:,4], 
                    'labels': out.detach().cpu().numpy()[:,-1]})

final_box = []
final_score = []
final_label = []
for output in outputs:
    boxes = []
    scores = []
    labels = []
    for box, score, label  in zip(output['boxes'],output['scores'],output['labels']):
        x1, y1, x2, y2 = denormalize_box(box)
        score = score
        label = label
        boxes.append([x1, y1, x2, y2])
        scores.append(score)
        labels.append(label)
    picked_boxes, picked_score, picked_labels = nms(boxes, scores, labels, 0.5)

    for box, score, label in zip(picked_boxes, picked_score, picked_labels):
        if score < 0.5:
            break
        final_box.append(box)
        final_score.append(score)
        final_label.append(label)

In [25]:
from PIL import Image, ImageDraw

index = 1162
img = Image.open(file_name)
draw = ImageDraw.Draw(img, "RGBA")

for i,j,k in zip(final_box, final_score, final_label):
    draw.rectangle(tuple(i), outline='red', width=1)
    draw.text((i[0],i[1]),text=str(k))
img.show()