In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
#drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os 

os.getcwd()

#os.listdir('/content/drive/MyDrive/SegNet')
#os.listdir('/content/drive/MyDrive/Datasets')

os.chdir('/content/drive/MyDrive/SegNet')
print('Directory Changed')

Directory Changed


In [3]:

import sys

# this next import comes from within our file structure (our datasets script)
! pip install import-ipynb 
import import_ipynb
from datasets import DatasetTrain, DatasetVal # (this needs to be imported before torch, because cv2 needs to be imported before torch for some reason)
# import os

# Mount a file stored on Google Drive:
#from google.colab import drive
#drive.mount('/content/drive/MyDrive/Deeplabv3Flat') 
#Use code block at the top of page


# Dont need these commented out lines on Colab. File structutre is flat
#cwd = '/content/drive/MyDrive/Deeplabv3Flat'

#os.chdir('/content/drive/MyDrive/Deeplabv3Flat/model')
#sys.path.append("/content/drive/MyDrive/Deeplabv3Flat/model")
from segnet import SegNet

#os.chdir('/content/drive/MyDrive/Deeplabv3Flat/utils')
#sys.path.append("/content/drive/MyDrive/Deeplabv3Flat/utils")
from utils import add_weight_decay

#os.chdir(cwd)

import torch
import torch.utils.data
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F

import numpy as np
import pickle
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import cv2

import time
from tqdm import tqdm


# NOTE! NOTE! change this to not overwrite all log data when you train the model:
model_id = "1"

# Constants
NUM_INPUT_CHANNELS = 3  # We may need to change this when doing greyscale baselines later. 
#NUM_OUTPUT_CHANNELS = NUM_CLASSES

num_epochs = 100
batch_size = 8   # changed from 3 to 24 to 16 to 8. 24 and 16 gave OOM errors. 
learning_rate = 0.0001

network = SegNet(model_id, project_dir="/content/drive/MyDrive/SegNet", input_channels=NUM_INPUT_CHANNELS).cuda()

# Amirs suggestion to resume training from last checkpoint #

resume = False #change this flag if starting from scratch 

if resume == True:
     network.load_state_dict(torch.load("/content/drive/MyDrive/SegNet/training_logs/model_3/checkpoints/model_3_epoch_100.pth"))

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~


train_dataset = DatasetTrain(cityscapes_data_path="/content/drive/MyDrive/Datasets/Cityscapes",
                             cityscapes_meta_path="/content/drive/MyDrive/Datasets/Cityscapes/meta")
val_dataset = DatasetVal(cityscapes_data_path="/content/drive/MyDrive/Datasets/Cityscapes",
                         cityscapes_meta_path="/content/drive/MyDrive/Datasets/Cityscapes/meta")

num_train_batches = int(len(train_dataset)/batch_size)
num_val_batches = int(len(val_dataset)/batch_size)
print ("num_train_batches:", num_train_batches)
print ("num_val_batches:", num_val_batches)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size, shuffle=True,
                                           num_workers=2) #changed from 1 to 4 then warning suggested a max of 2 so the dataloaders dont freeze ## as below 
val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                         batch_size=batch_size, shuffle=False,
                                         num_workers=2) # as above

params = add_weight_decay(network, l2_value=0.0001)
optimizer = torch.optim.Adam(params, lr=learning_rate)

with open("/content/drive/MyDrive/Datasets/Cityscapes/meta/class_weights.pkl", "rb") as file: # (needed for python3)
    class_weights = np.array(pickle.load(file))
class_weights = torch.from_numpy(class_weights)
class_weights = Variable(class_weights.type(torch.FloatTensor)).cuda()

# loss function
loss_fn = nn.CrossEntropyLoss(weight=class_weights)

epoch_losses_train = []
epoch_losses_val = []
for epoch in range(num_epochs):
    print ("###########################")
    print ("######## NEW EPOCH ########")
    print ("###########################")
    print ("epoch: %d/%d" % (epoch+1, num_epochs))

    ############################################################################
    # train:
    ############################################################################
    network.train() # (set in training mode, this affects BatchNorm and dropout)
    batch_losses = []
    
    # found this from aladdin person youtube 'add progress bar to pytorch'
    loop = tqdm(enumerate(train_loader), total=len(train_loader), leave=False)
    
    for step, (imgs, label_imgs) in loop:
        current_time = time.time()

        imgs = Variable(imgs).cuda() # (shape: (batch_size, 3, img_h, img_w))
        label_imgs = Variable(label_imgs.type(torch.LongTensor)).cuda() # (shape: (batch_size, img_h, img_w))

        #outputs = network(imgs) # (shape: (batch_size, num_classes, img_h, img_w))
        #predicted_tensor, softmaxed_tensor = model(input_tensor)
        predicted_tensor, softmaxed_tensor = network(imgs)  # Because the SegNet class gives two outputs
        
        # compute the loss:
        loss = loss_fn(predicted_tensor, label_imgs)
        loss_value = loss.data.cpu().numpy()
        batch_losses.append(loss_value)

        # optimization step:
        optimizer.zero_grad() # (reset gradients)
        loss.backward() # (compute gradients)
        optimizer.step() # (perform optimization step)

        # Update Progress Bar
        loop.set_description(f"Epoch [{epoch+1}/{num_epochs}]")
        loop.set_postfix(loss = loss.item())
        print (time.time() - current_time)

    epoch_loss = np.mean(batch_losses)
    epoch_losses_train.append(epoch_loss)
    with open("%s/epoch_losses_train.pkl" % network.model_dir, "wb") as file:
        pickle.dump(epoch_losses_train, file)
    print ("train loss: %g" % epoch_loss)
    plt.figure(1)
    plt.plot(epoch_losses_train, "k^")
    plt.plot(epoch_losses_train, "k")
    plt.ylabel("loss")
    plt.xlabel("epoch")
    plt.title("train loss per epoch")
    plt.savefig("%s/epoch_losses_train.png" % network.model_dir)
    plt.close(1)

    print ("####")

    
    
    ############################################################################
    # val:
    ############################################################################
    network.eval() # (set in evaluation mode, this affects BatchNorm and dropout)
    batch_losses = []
    for step, (imgs, label_imgs, img_ids) in enumerate(val_loader):
        with torch.no_grad(): # (corresponds to setting volatile=True in all variables, this is done during inference to reduce memory consumption)
            imgs = Variable(imgs).cuda() # (shape: (batch_size, 3, img_h, img_w))
            label_imgs = Variable(label_imgs.type(torch.LongTensor)).cuda() # (shape: (batch_size, img_h, img_w))

            #outputs = network(imgs) # (shape: (batch_size, num_classes, img_h, img_w))
            predicted_tensor, softmaxed_tensor = network(imgs)

            # compute the loss:
            loss = loss_fn(predicted_tensor, label_imgs)
            loss_value = loss.data.cpu().numpy()
            batch_losses.append(loss_value)

    epoch_loss = np.mean(batch_losses)
    epoch_losses_val.append(epoch_loss)
    with open("%s/epoch_losses_val.pkl" % network.model_dir, "wb") as file:
        pickle.dump(epoch_losses_val, file)
    print ("val loss: %g" % epoch_loss)
    plt.figure(1)
    plt.plot(epoch_losses_val, "k^")
    plt.plot(epoch_losses_val, "k")
    plt.ylabel("loss")
    plt.xlabel("epoch")
    plt.title("val loss per epoch")
    plt.savefig("%s/epoch_losses_val.png" % network.model_dir)
    plt.close(1)

    # save the model weights to disk:
    checkpoint_path = network.checkpoints_dir + "/model_" + model_id +"_epoch_" + str(epoch+1) + ".pth"
    torch.save(network.state_dict(), checkpoint_path)
    
    

Output hidden; open in https://colab.research.google.com to view.