In [2]:
from __future__ import division

import os
import socket
import timeit
from datetime import datetime
from tensorboardX import SummaryWriter

# PyTorch includes
import torch
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader

# Custom includes
from dataloaders import davis_2016 as db
from dataloaders import custom_transforms as tr
from util import visualize as viz
import scipy.misc as sm
import networks.vgg_osvos as vo
from layers.osvos_layers import class_balanced_cross_entropy_loss
from dataloaders.helpers import *
from mypath import Path
import imageio

In [3]:
# Setting of parameters
if 'SEQ_NAME' not in os.environ.keys():
    seq_name = 'blackswan'
else:
    seq_name = str(os.environ['SEQ_NAME'])
db_root_dir = Path.db_root_dir()
save_dir = Path.save_root_dir()

if not os.path.exists(save_dir):
    os.makedirs(os.path.join(save_dir))

gpu_id = 0
device = torch.device("cuda:"+str(gpu_id) if torch.cuda.is_available() else "cpu")

net = vo.OSVOS(pretrained=0)
net.load_state_dict(torch.load('models/blackswan_epoch-499.pth',
                               map_location=lambda storage, loc: storage))

composed_transforms = transforms.Compose([tr.RandomHorizontalFlip(),
                                          tr.ScaleNRotate(rots=(-30, 30), scales=(.75, 1.25)),
                                          tr.ToTensor()])

# Training dataset and its iterator
db_train = db.DAVIS2016(train=True, db_root_dir=db_root_dir, transform=composed_transforms)
trainloader = DataLoader(db_train, batch_size=1, shuffle=True, num_workers=1)

# Testing dataset and its iterator
db_test = db.DAVIS2016(train=False, db_root_dir=db_root_dir, transform=tr.ToTensor(), seq_name=seq_name)
testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1)

Constructing OSVOS architecture..
Initializing weights..
Done initializing ImageSets/480p/train Dataset
Done initializing ImageSets/480p/val Dataset


  net.load_state_dict(torch.load('models/blackswan_epoch-499.pth',


In [4]:
sample = db_train[99]
img = sample['image']
gt = sample['gt']


JPEGImages/480p/bmx-bumps/00017.jpg
./DAVIS/JPEGImages/480p/bmx-bumps/00017.jpg
