In [1]:
import time
import os

import torch

import pytorch_mask_rcnn as pmr
    
# ------------------ adjustable parameters ---------------------
use_cuda = True                              # choose to use GPU or not
train_num_samples = 1463                     # number of samples during train, betweem 1 to 1463
ckpt_path = '../checkpoint.pth'               # path where to save the checkpoint.pth
data_dir = 'E:/PyTorch/data/VOC2012'         # dataset directory
# ------------------ adjustable parameters ---------------------

device = torch.device('cuda' if torch.cuda.is_available() and use_cuda else 'cpu')
print('cuda: {}\nuse_cuda: {}\n{} GPU(s) available'.format(torch.cuda.is_available(), use_cuda, torch.cuda.device_count()))
print('\ndevice: {}'.format(device))

trainset = pmr.VOCDataset(data_dir, 'train', True, device=device)
indices = torch.randperm(len(trainset)).tolist()
trainset = torch.utils.data.Subset(trainset, indices[:train_num_samples])

torch.manual_seed(3)
model = pmr.maskrcnn_resnet50(False, 21).to(device)
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.001, momentum=0.9, weight_decay=0.0005)

cuda: True
use_cuda: True
1 GPU(s) available

device: cuda


In [2]:
if os.path.exists(ckpt_path):
    checkpoint = torch.load(ckpt_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    del checkpoint['model_state_dict']
    del checkpoint['optimizer_state_dict']
    torch.cuda.empty_cache()
else:
    checkpoint = dict(num_batches=0)
print('already trained: {}'.format(checkpoint['num_batches']))

since = time.time()
# ------------------train---------------------
model.train()
for i, data in enumerate(trainset):
    optimizer.zero_grad()
    losses = model(*data)
    if i % 100 == 0:
        print(' '.join(str(round(l.item(), 3)) for l in losses.values()))
    loss = sum(losses.values())
    loss.backward()
    optimizer.step()
# ------------------train---------------------
print('total time: {:.2f} s'.format(time.time() - since))

checkpoint['model_state_dict'] = model.state_dict()
checkpoint['optimizer_state_dict'] = optimizer.state_dict()
checkpoint['num_batches'] += i + 1
torch.save(checkpoint, ckpt_path)

num_batches = checkpoint['num_batches']
del checkpoint
torch.cuda.empty_cache()

print('\nalready trained: {}'.format(num_batches))

already trained: 0
total time: 764.59 s

already trained: 1463
