In [1]:
import torch
import torch.nn.parallel
import torch.utils
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.nn.functional as F
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torchvision.transforms import ToPILImage
import time
from PIL import Image
import dataGenerator
import math
import os
import argparse
import sys
from datetime import datetime

# Define DataLoader

In [2]:
# default values
NUM_CLASSES = 1
EPOCHS = 200
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batchSize = 3
iteration = "1"
newTraining = True
imageSize = 416
cwd = os.getcwd()
SAVE_LOCATION = cwd + "/data/models/model_test"
LOAD_LOCATION = cwd + "/data/input/"
#data_dir = cwd + "/data/photos/"
data_dir = "./data/"

from torchvision import datasets

class ImageFolderWithPaths(datasets.ImageFolder):
    """Custom dataset that includes image file paths. Extends
    torchvision.datasets.ImageFolder
    """

    # override the __getitem__ method. this is the method dataloader calls
    def __getitem__(self, index):
        # this is what ImageFolder normally returns 
        original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)
        # the image file path
        path = self.imgs[index][0]
        # make a new tuple that includes original and the path
        tuple_with_path = (original_tuple + (path,))
        return tuple_with_path

# Load Data

In [3]:
data_transforms = transforms.Compose([transforms.Resize([imageSize,imageSize]),
                                      transforms.ToTensor()
                                     ])

# instantiate the dataset and dataloader
dataset = ImageFolderWithPaths(data_dir, transform=data_transforms) # our custom dataset

#loads only photos
dataloaders = torch.utils.data.DataLoader(dataset, batch_size = batchSize, shuffle=True)

new_road_factory = dataGenerator.dataGenerator(IMAGE_SIZE=imageSize)

# Training W/O Val Steps

In [4]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=4):
    since = time.time()
    best_model = None
    best_loss = math.inf
    model.train()  
    for epoch in range(1,num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1), flush=True)
        epoch_loss = 0
   
        #BATCH TUPLE
        inputs, labels, paths = next(iter(dataloaders))
        inputs.to(device)
                
        #build ground-truth batch tensor
        for locations in paths:
            i = 0
            #dtype=torch.int64
            labels = torch.zeros(batchSize,NUM_CLASSES,imageSize,imageSize, dtype = torch.float32).to(device)
            labels[i] = torch.load(locations.replace(".png", ".pt").replace("photos", "tensors")) 
            i += 1
            
        # forward

        with torch.set_grad_enabled(True):
            #build input-truth batch tensor
            outputs = model(inputs.to(device)).to(device)
            loss = criterion(outputs, labels) #ground truth comparison

            # zero the parameter gradients
            optimizer.zero_grad()
            
            # backward + optimize 
            loss.backward()
            optimizer.step()
            
            # statistics
            epoch_loss += loss.item()

        print("loss: {}".format(epoch_loss), flush=True)
        print('---------------', flush=True)
        
        #save best copy of  model
        if epoch_loss < best_loss:
            best_loss = epoch_loss
            torch.save(model.state_dict(), SAVE_LOCATION.replace("test", "best") + "-" + str(datetime.now().date()))
            #torch.save(model, SAVE_LOCATION)
            #torch.save(model, SAVE_LOCATION.replace("model", "model_best"))
        
    time_elapsed = time.time() - since
    
    print('Training complete in {:.1f}m {:.1f}s'.format(time_elapsed // 60, time_elapsed % 60), flush=True)

    #completed model
    torch.save(model,SAVE_LOCATION)
    return model

# Load Pretrained Model Weights

In [5]:
from unet_models import *

#imports related to UNet
if newTraining:
    model = UNet16(num_classes=1, num_filters=32,pretrained=False, is_deconv=True)
    
    print('initializing model with random weights', flush=True)
    torch.nn.init.xavier_uniform_(next(model.center.children())[1].weight)
    
    torch.nn.init.xavier_uniform_(next(model.dec5.children())[1].weight)
    
    torch.nn.init.xavier_uniform_(next(model.dec4.children())[1].weight)

    torch.nn.init.xavier_uniform_(next(model.dec3.children())[1].weight)

    torch.nn.init.xavier_uniform_(next(model.dec2.children())[1].weight)

    torch.nn.init.xavier_uniform_(next(model.dec1.children()).weight)

    torch.nn.init.xavier_uniform_(model.final.weight)
                           
else:
    print("loading weights from", SAVE_LOCATION, flush=True)
    model = torch.load(SAVE_LOCATION)
    #model = torch.load(LOAD_LOCATION)

model = model.to(device)

initializing model with random weights


# Training and Results

In [6]:
criterion = torch.nn.BCEWithLogitsLoss()
#criterion = DICELossMultiClass()
#criterion = IOU_BCELoss()

#Observe adjustments in learning rate
optimizer_ft = optim.Adam(model.parameters(), lr=0.05,weight_decay=0, amsgrad=False, eps=0.1)

# Osscilate between high and low learning rates over time
exp_lr_scheduler = lr_scheduler.CosineAnnealingLR(optimizer_ft, T_max=EPOCHS,eta_min=0.001)

try:
    model = train_model(model, criterion, optimizer_ft, exp_lr_scheduler,num_epochs=EPOCHS)

except KeyboardInterrupt:
    torch.save(model, SAVE_LOCATION.replace("model",'INTERRUPTED'))
    print('Saved interrupt', flush=True)

Epoch 1/199
loss: 0.6917615532875061
---------------
Epoch 2/199
loss: 0.6713278889656067
---------------
Epoch 3/199
loss: 0.6325186491012573
---------------
Epoch 4/199
loss: 0.6196605563163757
---------------
Epoch 5/199
loss: 0.5703428983688354
---------------
Epoch 6/199
loss: 0.5609607696533203
---------------
Epoch 7/199
loss: 0.5559383034706116
---------------
Epoch 8/199
loss: 0.5312397480010986
---------------
Epoch 9/199
loss: 0.5365355014801025
---------------
Epoch 10/199
loss: 0.47386443614959717
---------------
Epoch 11/199
loss: 0.516312301158905
---------------
Epoch 12/199
loss: 0.4716586768627167
---------------
Epoch 13/199
loss: 0.5389310717582703
---------------
Epoch 14/199
loss: 0.42445486783981323
---------------
Epoch 15/199
loss: 0.4824327230453491
---------------
Epoch 16/199
loss: 0.4066139757633209
---------------
Epoch 17/199
loss: 0.4568275809288025
---------------
Epoch 18/199
loss: 0.4975232183933258
---------------
Epoch 19/199
loss: 0.519110262393951

loss: 0.4478919804096222
---------------
Epoch 152/199
loss: 0.43676286935806274
---------------
Epoch 153/199
loss: 0.41269710659980774
---------------
Epoch 154/199
loss: 0.4429226815700531
---------------
Epoch 155/199
loss: 0.4784122109413147
---------------
Epoch 156/199
loss: 0.47655564546585083
---------------
Epoch 157/199
loss: 0.3828109800815582
---------------
Epoch 158/199
loss: 0.5004972219467163
---------------
Epoch 159/199
loss: 0.42439398169517517
---------------
Epoch 160/199
loss: 0.49712568521499634
---------------
Epoch 161/199
loss: 0.5008724331855774
---------------
Epoch 162/199
loss: 0.4684484302997589
---------------
Epoch 163/199
loss: 0.41618672013282776
---------------
Epoch 164/199
loss: 0.47222277522087097
---------------
Epoch 165/199
loss: 0.4366042912006378
---------------
Epoch 166/199
loss: 0.4745142161846161
---------------
Epoch 167/199
loss: 0.39068400859832764
---------------
Epoch 168/199
loss: 0.4390754699707031
---------------
Epoch 169/199
lo