In [1]:
import os
import random

from collections import OrderedDict
import numpy as np
import pandas as pd

import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['figure.figsize'] = [5, 5]
matplotlib.rcParams['figure.dpi'] = 200

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import torchvision.models as models

import cv2
#from tqdm import tqdm
from itertools import chain
from skimage.io import imread, imshow, imread_collection, concatenate_images
from skimage.transform import resize
from skimage.morphology import label


from data_helper import UnlabeledDataset, LabeledDataset
from helper import collate_fn, draw_box, compute_ts_road_map
from darknet import *
from util import *

In [3]:
########### helper functions ##########


def sew_images(sing_samp):
        # sing_samp is [6, 3, 256, 306], one item is batch
        # output is the image object of all 6 pictures 'sown' together
        #############
        # A | B | C #
        # D | E | F #
        #############
        
        # return [3, 768, 612]
        
        A1 = sing_samp[0][0]
        A2 = sing_samp[0][1]
        A3 = sing_samp[0][2]

        B1 = sing_samp[1][0]
        B2 = sing_samp[1][1]
        B3 = sing_samp[1][2]

        C1 = sing_samp[2][0]
        C2 = sing_samp[2][1]
        C3 = sing_samp[2][1]

        D1 = sing_samp[3][0]
        D2 = sing_samp[3][1]
        D3 = sing_samp[3][2]

        E1 = sing_samp[4][0]
        E2 = sing_samp[4][1]
        E3 = sing_samp[4][2]

        F1 = sing_samp[5][0]
        F2 = sing_samp[5][1]
        F3 = sing_samp[5][2]

        #print("F shape {}".format(F1.shape))

        T1 = torch.cat([A1, B1, C1], 1)
        T2 = torch.cat([A2, B2, C2], 1)
        T3 = torch.cat([A3, B3, C3], 1)

        B1 = torch.cat([D1, E1, F1], 1)
        B2 = torch.cat([D2, E2, F2], 1)
        B3 = torch.cat([D3, E3, F3], 1)
        #print("T1 shape {}".format(T1.shape))

        comb1 = torch.cat([T1,B1], 0)
        comb2 = torch.cat([T2,B2], 0)
        comb3 = torch.cat([T3,B3], 0)

        #print("comb1 shape {}".format(comb1.shape)) #should be 768, 612
        comb = torch.stack([comb1, comb2, comb3])
        # TODO: maybe should flip and face right
        toImg = transforms.ToPILImage()
        result = toImg(comb) # image object [3, 768, 612]
        return result
    
############
## usage:
# image_tensor = torch.stack(images) #should be [6, 3, 256, 306]
# m = transforms.Resize((800,800))
# comb_img = sew_images(image_tensor) 
# img =m(comb_img) #should be [3, 800, 800]
############

In [5]:
random.seed(0)
np.random.seed(0)
torch.manual_seed(0);

image_folder = '../data'
annotation_csv = '../data/annotation.csv'

labeled_scene_index = np.arange(106, 134)

train_index = np.arange(106,108)
val_index = np.arange(128,130)
transform = torchvision.transforms.ToTensor()

labeled_trainset = LabeledDataset(
    image_folder=image_folder,
    annotation_file=annotation_csv,
    scene_index=train_index,
    transform=transform,
    extra_info=False
    )

labeled_valset = LabeledDataset(
    image_folder=image_folder,
    annotation_file=annotation_csv,
    scene_index=val_index,
    transform=transform,
    extra_info=False
    )

trainloader = torch.utils.data.DataLoader(labeled_trainset, batch_size=2, shuffle=True, num_workers=2, collate_fn=collate_fn)
valloader = torch.utils.data.DataLoader(labeled_valset, batch_size=2, shuffle=True, num_workers=2, collate_fn=collate_fn)

In [11]:
#### initialize model ####
batch_size = 2
confidence = 0.5
nms_thesh = 0.4
start = 0
CUDA = torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Darknet('yolov3.cfg').to(device)
model.net_info["height"] = 800


criterion = torch.nn.MSELoss()
#param_list = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
    [{'params': filter(lambda p: p.requires_grad, model.parameters()),
    'lr': 0.0001}],
    lr=0.0001,
    momentum=0.9,
    weight_decay=0.0001,
    nesterov=False,
    )
best_val_loss = 100

epochs = 10