In [1]:
import os
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from dataset import TrainDataset

import torch
import torchvision
from torch.utils.data import DataLoader

In [2]:
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

def get_model(num_classes, model_name):
    if model_name == 'FasterRCNN':
        model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
        in_features = model.roi_heads.box_predictor.cls_score.in_features
        model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
        print(model, file=open('Faster-RCNN.txt', 'w'))

    elif model_name == 'SSD':
        model = torchvision.models.detection.ssd300_vgg16(pretrained=True, num_classes=num_classes)
        print(model, file=open('SSD.txt', 'w'))

    elif model_name == 'RetinaNet':
        # model = torchvision.models.detection.retinanet_resnet50_fpn(pretrained=True)
        model = torchvision.models.detection.retinanet_resnet50_fpn(num_classes=num_classes)
        print(model, file=open('RetinaNet.txt', 'w'))
        
    else:
        raise "Model name error!"

    return model

def get_optimizer(model):
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.Adam(params, lr=1e-4)
    return optimizer

In [3]:
def collate_fn(batch):
    return tuple(zip(*batch))

def load_model():
    # model = torch.load(os.path.join('model', 'faster-rcnn-1580.pkl'))
    model = torch.load(os.path.join('model', 'RetinaNet-2927.pkl'))
    return model

def train():
    model_name = 'RetinaNet'

    mat_path = os.path.join('..', 'data', "train answer", 'digitStruct.mat')
    print(f'mat_path: {mat_path}')
    image_dir = os.path.join('..', 'data', 'train')
    print(f'image_dir: {image_dir}')

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'device: {device}')

    dataset = TrainDataset(mat_path, image_dir)
    # data_loader = torch.utils.data.DataLoader(
    #     dataset, batch_size=2, shuffle=False, num_workers=4,
    #     collate_fn=collate_fn)
    data_loader = DataLoader(dataset, batch_size=4, collate_fn=collate_fn)

    # model = get_model(num_classes=11, model_name=model_name)
    model = load_model()
    model.train()
    model.to(device)

    optimizer = get_optimizer(model)

    best_losses = None

    epochs = 10
    print(f'epochs: {epochs}')
    for epoch in range(epochs):
        print(f'epoch {epoch} strat!')
        epoch_losses = None

        pbar = tqdm(data_loader)

        # print(f'len dataset: {len(dataset)}')
        for i, (images, targets) in enumerate(pbar):
            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            loss_dict = model(images, targets)

            loss_message = []
            for k in loss_dict.keys():
                loss_message.append(f'{k}: {loss_dict[k]:.4f}')

            losses = sum(loss for loss in loss_dict.values())
            if epoch_losses is None:
                epoch_losses = losses
            else:
                epoch_losses += losses

            loss_message.append(f'epoch_losses: {epoch_losses/i+1:.4f}')

            pbar.set_description(', '.join(loss_message))

            optimizer.zero_grad()
            losses.backward()
            optimizer.step()
        
        if best_losses is None or best_losses > epoch_losses:
            best_losses = epoch_losses
            torch.save(model, os.path.join('model', f'{model_name}-{best_losses:.0f}.pkl'))




In [4]:
train()

mat_path: ..\data\train answer\digitStruct.mat
image_dir: ..\data\train
device: cuda


  0%|          | 0/8351 [00:00<?, ?it/s]

epochs: 10
epoch 0 strat!


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
classification: 0.1061, bbox_regression: 0.2393, epoch_losses: 1.3460:  95%|█████████▍| 7897/8351 [56:05<03:19,  2.27it/s]