In [8]:
from tqdm import tqdm
import network
import utils
import os
import random
import argparse
import numpy as np

from torch.utils import data
#from datasets import VOCSegmentation, Cityscapes
#from utils import ext_transforms as et
#from metrics import StreamSegMetrics

import torch
import torch.nn as nn
#from utils.visualizer import Visualizer

from PIL import Image
import matplotlib
import matplotlib.pyplot as plt
import glob
import yaml
from datetime import datetime, timezone, timedelta


In [6]:
def get_harbor_idx(data_path, train = True, is_label = False, label_num = 5):
    if train:
        if is_label:
            classes = ['ship', 'container_truck', 'forklift', 'reach_stacker']
            image_path = glob(os.path.join(root, 'train', 'labeled_images', '*.jpg'))
            image_idx_list = list(map(lambda x : x.split('/')[-1].split('.')[0], image_path))
            train_idx = []
            valid_idx = []
            for c in classes:
                matched_idx = [i for i in image_idx_list if c in i]
                train_idx.extend(matched_idx[label_num:])
                valid_idx.extend(matched_idx[:label_num])
            return train_idx, valid_idx
        else:
            image_path = glob(os.path.join(root, 'train', 'unlabeled_images', '*.jpg'))
            train_idx = list(map(lambda x: x.split('/')[-1].split('.')[0], image_path))
            return train_idx
    else:
        image_path = glob(os.path.join(root, 'test', 'images', '*.jpg'))
        test_idx = list(map(lambda x: x.split('/')[-1].split('.')[0], image_path))
        return test_idx



class BuildDataLoader:
    def __init__(self, num_labels, dataset_path, batch_size):
        self.data_path = dataset_path
        self.im_size = [513, 513]
        self.crop_size = [430, 430]
        self.num_segments = 5
        self.scale_size = (0.5, 1.5)
        self.batch_size = batch_size
        self.train_l_idx, self.valid_l_idx = get_harbor_idx(self.data_path, train=True, is_label=True, label_num=num_labels)
        self.train_u_idx = get_harbor_idx(self.data_path, train=True, is_label=False)
        self.test_idx = get_harbor_idx(self.data_path, train=False)

        if num_labels == 0:  # using all data
            self.train_l_idx = self.train_u_idx

    def build(self, supervised=False):
        train_l_dataset = BuildDataset(self.data_path, self.train_l_idx,
                                       crop_size=self.crop_size, scale_size=self.scale_size,
                                       augmentation=True, train=True, is_label=True)
        train_u_dataset = BuildDataset(self.data_path, self.train_u_idx,
                                       crop_size=self.crop_size, scale_size=(1.0, 1.0),
                                       augmentation=False, train=True, is_label=False)
        valid_l_dataset = BuildDataset(self.data_path, self.valid_l_idx,
                                       crop_size=self.crop_size, scale_size=self.scale_size,
                                       augmentation=False, train=True, is_label=True)
        test_dataset    = BuildDataset(self.data_path, self.test_idx,
                                       crop_size=self.im_size, scale_size=(1.0, 1.0),
                                       augmentation=False, train=False, is_label=True)

        if supervised:  # no unlabelled dataset needed, double batch-size to match the same number of training samples
            self.batch_size = self.batch_size * 2

        num_samples = self.batch_size * 200  # for total 40k iterations with 200 epochs
        # num_samples = self.batch_size * 2
        train_l_loader = torch.utils.data.DataLoader(
            train_l_dataset,
            batch_size=self.batch_size,
            sampler=sampler.RandomSampler(data_source=train_l_dataset,
                                          replacement=True,
                                          num_samples=num_samples),
            drop_last=True,)


        valid_l_loader = torch.utils.data.DataLoader(
            valid_l_dataset,
            batch_size=self.batch_size,
            sampler=sampler.RandomSampler(data_source=valid_l_dataset,
                                          replacement=True,
                                          num_samples=num_samples),
            drop_last=True,)

        if not supervised:
            train_u_loader = torch.utils.data.DataLoader(
                train_u_dataset,
                batch_size=self.batch_size,
                sampler=sampler.RandomSampler(data_source=train_u_dataset,
                                              replacement=True,
                                              num_samples=num_samples),
                drop_last=True,)

        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
        )
        if supervised:
            return train_l_loader, valid_l_loader, test_loader
        else:
            return train_l_loader, train_u_loader, valid_l_loader, test_loader

In [None]:
import torch.utils.data as data
from collections import namedtuple

class Comp_dataset(data.Dataset):
    DatasetClass = namedtuple('DataClass', ['name', 'id'])
    classes = [
        DatasetClass('unlabeled', 0),
        DatasetClass('container_truck', 1),
        DatasetClass('forklift', 2),
        DatasetClass('reach_stacker', 3),
        DatasetClass('ship', 4)
    ]

    id_to_train_id = np.array([c.train_id for c in classes])

    def __init__(self, root, split = 'train', transform = None):
        self.root = os.path.expanduser(root)
        self.image_dir = os.path.join(self.root, 'data', split)

        self.transform = transform

        self.split = split

In [None]:
def get_dataset():
    train_transform = extcompose([])

    train_dataset = Comp_dataset(root=, split=, transform=)
    return train_dataset, val_dataset

In [9]:

# Root directory
PROJECT_DIR = os.path.dirname(__file__)

# Load config
config_path = os.path.join(PROJECT_DIR, 'config', 'train_config.yml')
config = load_yaml(config_path)

# Train Serial
kst = timezone(timedelta(hours=9))
train_serial = datetime.now(tz=kst).strftime("%Y%m%d_%H%M%S")

# Recorder directory
RECORDER_DIR = os.path.join(PROJECT_DIR, 'results', 'train', train_serial)
os.makedirs(RECORDER_DIR, exist_ok=True)

# Data directory
DATA_DIR = os.path.join(PROJECT_DIR, 'data', config['DIRECTORY']['dataset'])

# GPU
os.environ['CUDA_VISIBLE_DEVICES'] = config['gpu_id']
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device: {}".format(device))

num_classes = 5

# Setup random seed
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Setup dataloader
batch_size = 5

train_dataset, valid_dataset = get_dataset(config)
data_loader = BuildDataLoader(num_labels = 5, dataset_path = "", batch_size = 5)


train_l_loader, train_u_loader, valid_l_loader, _ = data_loader.build()
print("Train set: {}, train unsup set: {}, Val set: {}".format(len(train_l_loader), len(train_u_loader), len(valid_l_loader)))


####
#set up model (all models are 'constructed at network.modeling)
model = network.modeling.__dict__[config['model'](num_classes = num_classes, output_stride = config['output_stride'])]

#####
#set up metrics
metrics = 
    
#####
#set up optimizer
optimizer = torch.optim.SGD()  
if config['lr_policy'] == 'poly':
    scheduler = PolyLR
elif config['lr_policy'] == 'step':
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=config['step_size'], gamma=0.1)


######
#set up criterion
# criterion = utils.get_loss(config['loss_type'])
if config['loss_type'] == 'focal_loss':
    criterion = utils.FocalLoss(ignore_index = 255, size_average = True)
elif config['loss_type'] == 'cross_entropy':
    criterion = nn.CrossEntropyLoss(ignore_index=255, reduction = 'mean')

def save_ckpt(path):
    """save current model"""
    torch.save()
    print("Model saved as {}".format(path))

mkdir('checkpoints')

# Restore
best_score == 0.0
cur_itrs = 0
cur_epochs = 0
if config['ckpt'] not None and os.path.isfile(config['ckpt']): # if checkpoint is available
    checkpoint = torch.load(config['ckpt'], map_location = torch.device['cpu'])
    model.load_state_dict(checkpoint['model_state'])
    model = nn.DataParallel(model)
    model.to(device)
else:
    print("[!] Retrain")
    model = nn.DataParallel(model)
    model.to(device)


# train Loop



SyntaxError: invalid syntax (477171462.py, line 44)