In [1]:
import torch
import torch.nn as nn
import torchvision
import time
import os

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from torch.utils.data import Dataset, DataLoader
# from torch.utils.data.sampler import Sampler
import torch.optim as optim
import sys
sys.path.append('../')

from dataset import LbpDataset, train_transforms, val_transforms, test_transforms, collate_fn, get_data
from visualize import visualize
from rcnn_model import fasterrcnn_resnet18_fpn, fasterrcnn_resnet201_fpn, FastRCNNPredictor
from engine import evaluate
import utils

In [2]:
from train_lbp import get_train_test_list

In [3]:
import easydict 
args = easydict.EasyDict({ "batch_size": 2, 
                          "epochs": 50, 
                          "data": 0, 
                          'lr':0.1,
                         'momentum':0.9,
                         'weight_decay':1e-4,
                         'start_epoch':0,
                         'gpu':5,
                          'workers':2,
                         'output_dir' :'../trained_model/'})

In [4]:
df = pd.read_csv('../../data/df.csv')
df.head()
#df.insert(0, 'ID', range(0, len(df)))
#df.to_csv('../../data/df.csv', index=None)

Unnamed: 0,ID,file_name,task,bbox,xmin,ymin,w,h,label,occluded,des,cell_type
0,0,patch_images/2021.01.12/LBC305-20210108(1)/LBC...,[ASCUS] LBC305,"[56, 35, 1980, 1985]",56,35,1980,1985,판독불가,0,,
1,1,patch_images/2021.01.12/LBC305-20210108(1)/LBC...,[ASCUS] LBC305,"[56, 30, 1912, 1937]",56,30,1912,1937,판독불가,0,,
2,2,patch_images/2021.01.12/LBC305-20210108(1)/LBC...,[ASCUS] LBC305,"[21, 12, 2010, 2027]",21,12,2010,2027,판독불가,0,,
3,3,patch_images/2021.01.06/LBC37-20210102(1)/LBC3...,[ASCUS] LBC37,"[1349, 420, 100, 113]",1349,420,100,113,ASC-US,0,,Atypical squamous cells of undetermined signif...
4,4,patch_images/2021.01.06/LBC37-20210102(1)/LBC3...,[ASCUS] LBC37,"[1575, 720, 163, 213]",1575,720,163,213,ASC-US,0,,Atypical squamous cells of undetermined signif...


In [5]:
# Data loading code
data_dir = '../../data/df.csv'
train_list, test_list = get_train_test_list(data_dir)
train_dataset = LbpDataset(train_list, transform=train_transforms)
test_dataset = LbpDataset(test_list, transform=val_transforms)  

total 4019 train 3014 test 1005
3014
1005


In [6]:
train_sampler = torch.utils.data.RandomSampler(train_dataset)
test_sampler = torch.utils.data.SequentialSampler(test_dataset)

train_loader = DataLoader(
    train_dataset, batch_size=args.batch_size,
    sampler=train_sampler, num_workers=args.workers,
    collate_fn=utils.collate_fn)

test_loader = DataLoader(
    test_dataset, batch_size=args.batch_size,
    sampler=test_sampler, num_workers=args.workers,
    collate_fn=utils.collate_fn)

In [7]:
# next(iter(train_loader))

In [8]:
num_classes = 2
model = fasterrcnn_resnet18_fpn(pretrained=False, min_size=2048, max_size=2048)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
device = torch.device('cuda')
model.to(device)
print('model is loaded to gpu')

model is loaded to gpu


In [9]:
params = [p for p in model.parameters() if p.requires_grad]
# optimizer = torch.optim.Adam(params, lr=0.0001)
optimizer = torch.optim.SGD(
       params, lr=0.001, momentum=0.9, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 15, 20, 25], 
                                                    gamma=0.5)

In [10]:
# checkpoint = torch.load('../trained_model/model.pt')
# model.load_state_dict(checkpoint['model'])
# optimizer.load_state_dict(checkpoint['optimizer'])
# lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
# epoch = checkpoint['epoch']
# print(epoch)

In [11]:
# evaluate(model, test_loader, device=device) 

In [12]:
from train_lbp import train_one_epoch
if not args.output_dir:
    os.mkdir(args.output_dir)
    print('{} was made. '.format(args.output_dir))

start_time = time.time()
for epoch in range(30):
    train_one_epoch(model, optimizer, train_loader, device, epoch, 400)
    lr_scheduler.step()
    
    if epoch > 1 and epoch % 2 == 0 :
        if args.output_dir:
            checkpoint = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'args': args,
                'epoch': epoch
            }
            utils.save_on_master(
                checkpoint,
                os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
            utils.save_on_master(
                checkpoint,
                os.path.join(args.output_dir, 'checkpoint.pth'))

        # evaluate after every epoch
        evaluate(model, test_loader, device=device)    
print('total time is {}'.format(time.time() - start_time))    

Epoch: [0]  [   0/1507]  eta: 0:33:03  lr: 0.001000  loss: 1.3549 (1.3549)  loss_classifier: 0.6799 (0.6799)  loss_box_reg: 0.0008 (0.0008)  loss_objectness: 0.6646 (0.6646)  loss_rpn_box_reg: 0.0095 (0.0095)  time: 1.3165  data: 0.8182  max mem: 6289
Epoch: [0]  [ 400/1507]  eta: 0:17:26  lr: 0.001000  loss: 0.1034 (0.1650)  loss_classifier: 0.0320 (0.0349)  loss_box_reg: 0.0177 (0.0113)  loss_objectness: 0.0331 (0.0586)  loss_rpn_box_reg: 0.0062 (0.0602)  time: 0.8884  data: 0.4096  max mem: 6451
Epoch: [0]  [ 800/1507]  eta: 0:10:53  lr: 0.001000  loss: 0.0755 (0.1430)  loss_classifier: 0.0298 (0.0346)  loss_box_reg: 0.0102 (0.0125)  loss_objectness: 0.0347 (0.0521)  loss_rpn_box_reg: 0.0045 (0.0437)  time: 0.9518  data: 0.4698  max mem: 6451
Epoch: [0]  [1200/1507]  eta: 0:04:50  lr: 0.001000  loss: 0.1203 (0.1363)  loss_classifier: 0.0357 (0.0356)  loss_box_reg: 0.0143 (0.0138)  loss_objectness: 0.0277 (0.0480)  loss_rpn_box_reg: 0.0056 (0.0390)  time: 1.3601  data: 0.8686  max me

In [13]:
checkpoint = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'lr_scheduler': lr_scheduler.state_dict(),
    'args': args,
    'epoch': epoch
}

In [14]:
# torch.save(checkpoint, '../trained_model/model.pt')

In [15]:
# checkpoint = torch.load('../trained_model/model.pt')
# # checkpoint['model']

In [16]:
# model.load_state_dict(checkpoint['model'])