In [1]:
import torch
import torchvision
from models import DeepLabV3
from datasets import gta5dataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from torchvision import transforms as T
from torchvision.utils import make_grid
import torch.nn as nn
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import time
import os
from tqdm.notebook import tqdm
from utils import colorize_mask
from PIL import Image

In [2]:
dataset = gta5dataset("./datas/")
# Data
dataloader = DataLoader(dataset,
                                  shuffle=True,
                                  batch_size=4,
                                  num_workers=0,
                                  pin_memory=False)

device = torch.device("cuda:0")
model = DeepLabV3().to(device)      

epochs = 31
# train loop
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr = 0.01, momentum = 0.9 , weight_decay = 1e-4)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 5, gamma = 0.1 )

# loss
def compute_loss(output,target):
    ce = nn.CrossEntropyLoss(ignore_index=255)
    ce_loss = ce(output,target)
    return ce_loss

# output dir
run_time = time.strftime("%Y-%m-%d_%H-%M-%S",time.localtime())
output_dir = "./train-runs"
output_dir = os.path.join(output_dir,run_time)
if os.path.exists(output_dir):
    os.makedir(output_dir)
# Tensorboard
writer = SummaryWriter(os.path.join(output_dir,"log"))

# visualize
def visualize(im,gt,pred):
    vis_transform = T.Compose([T.ToTensor()])
    
    im = im[0].data.cpu()
    im = np.array(im.permute(1,2,0),dtype=np.uint8)

    im = vis_transform(im)
    
    gt = gt[0].data.cpu().numpy()
    gt = colorize_mask(gt).convert('RGB')
    gt = vis_transform(gt)

    pred = pred[0].data.max(0)[1].cpu().numpy()
    pred = colorize_mask(pred).convert('RGB')
    pred = vis_transform(pred)

    grid = torch.stack([im,gt,pred],0)
    grid = make_grid(grid.cpu(), nrow=3, padding=5)
    return grid



In [3]:
# Train loop

interval = 500
for epoch in range(0,epochs):
    # one eopch  
    model.train()
    print(f"epoch {epoch} starts at {time.strftime(r'%Y-%m-%d_%H-%M-%S',time.localtime())}")
    for i,(index,image,label) in tqdm(enumerate(dataloader)):
        output = model(image.to(device))['out']

        loss = compute_loss(output,label.to(device))
        # optimization step:
        optimizer.zero_grad() # (reset gradients)
        loss.backward() # (compute gradients)
        optimizer.step() # (perform optimization step)


        step = epoch*len(dataloader)+i 
        if step % interval ==0:
            writer.add_scalar("total_loss",loss,step)   
            grid = visualize(image,label,output)
            writer.add_image(f'img_gt_pred', grid, step)

    # validate: visualize one image 
    # model.eval()
    # save checkpoint
    lr_scheduler.step()
    
    if epoch % 10 == 0:
        if not os.path.exists(os.path.join(output_dir,"checkpoint")):
            os.mkdir(os.path.join(output_dir,"checkpoint"))
        torch.save(model.state_dict(),
                    os.path.join(output_dir,"checkpoint",f"deeplabv3-{epoch}.pth"))


epoch 0 starts at 2022-09-22_22-44-46


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


epoch 1 starts at 2022-09-22_23-03-28


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


epoch 2 starts at 2022-09-22_23-22-08


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


epoch 3 starts at 2022-09-22_23-40-30


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


epoch 4 starts at 2022-09-22_23-57-58


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


epoch 5 starts at 2022-09-23_00-15-40


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


epoch 6 starts at 2022-09-23_00-33-07


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


epoch 7 starts at 2022-09-23_00-50-36


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


epoch 8 starts at 2022-09-23_01-08-40


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


epoch 9 starts at 2022-09-23_01-26-35


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


epoch 10 starts at 2022-09-23_01-44-31


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


epoch 11 starts at 2022-09-23_02-02-08


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


epoch 12 starts at 2022-09-23_02-19-36


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


epoch 13 starts at 2022-09-23_02-37-05


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


epoch 14 starts at 2022-09-23_02-54-44


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


epoch 15 starts at 2022-09-23_03-12-22


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


epoch 16 starts at 2022-09-23_03-30-12


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


epoch 17 starts at 2022-09-23_03-47-56


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


epoch 18 starts at 2022-09-23_04-05-25


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


epoch 19 starts at 2022-09-23_04-22-55


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


epoch 20 starts at 2022-09-23_04-40-24


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


epoch 21 starts at 2022-09-23_04-57-56


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


epoch 22 starts at 2022-09-23_05-15-26


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


epoch 23 starts at 2022-09-23_05-32-55


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


epoch 24 starts at 2022-09-23_05-50-40


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


epoch 25 starts at 2022-09-23_06-08-09


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


epoch 26 starts at 2022-09-23_06-25-40


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


epoch 27 starts at 2022-09-23_06-43-21


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


epoch 28 starts at 2022-09-23_07-00-56


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


epoch 29 starts at 2022-09-23_07-18-26


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


epoch 30 starts at 2022-09-23_07-36-03


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…


