In [1]:
%load_ext autoreload
%autoreload 1

In [2]:
import os
from time import sleep

import matplotlib.pyplot as plt
import cv2
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import PIL.Image as Image

from dataset import *
from utilities import *
from resnet import *
from config import *



In [3]:
color_to_gray_map, gray_to_color_map = None, None

In [4]:
np.random.seed(0)
train_imageset_path = '../trainval/DAVIS/ImageSets/2017/train.txt'
val_imageset_path = '../trainval/DAVIS/ImageSets/2017/val.txt'
testd_imageset_path = '../testd/DAVIS/ImageSets/2017/test-dev.txt'
trainval_image_root = '../trainval/DAVIS/JPEGImages/480p/'
trainval_mask_root = '../trainval/DAVIS/Annotations/480p/'
testd_image_root = '../testd/DAVIS/JPEGImages/480p/'
testd_mask_root = '../testd/DAVIS/Annotations/480p/'
models_root = '../models/'

train_list = []
val_list = []
test_list = []

with open(train_imageset_path, 'r') as f:
    for line in f:
        train_list.append(line.strip())
with open(val_imageset_path, 'r') as f:
    for line in f:
        val_list.append(line.strip())
with open(testd_imageset_path, 'r') as f:
    for line in f:
        test_list.append(line.strip())


In [5]:
def train(image_root, mask_root, target_list):

    for t in range(len(target_list)):
        if t != 0:
            continue
        print(target_list[t])
        image_path = os.path.join(image_root, target_list[t] + '/00000.jpg')
        mask_path = os.path.join(mask_root, target_list[t] + '/00000.png')
        model_save_path = os.path.join(models_root, target_list[t] + '.pt')
    
        image = cv2.imread(image_path)
        mask = cv2.imread(mask_path)
        PIL_mask = Image.open(mask_path)
        color_to_gray_map, gray_to_color_map = get_map(mask, PIL_mask)
        del PIL_mask
        
        image = cv2.resize(image, Resize, interpolation=cv2.INTER_NEAREST)
        mask = cv2.resize(mask, Resize, interpolation=cv2.INTER_NEAREST)
        mask = convert_to_gray_mask(mask, color_to_gray_map)
        print('type_cnt:', len(color_to_gray_map))
    
        model = MyResNet(len(color_to_gray_map)).to(device)
        train_dataset = CustomDataset(image_path, mask_path, image_transform=train_image_transforms, mask_transform=train_mask_transforms, num_samples=augmentation_num)
        train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
        val_dataset = CustomDataset(image_path, mask_path, image_transform=val_image_transforms, mask_transform=val_mask_transforms, num_samples=1)
        val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False)
    
        opt = torch.optim.Adam(model.parameters(), lr=learning_rate)
        sch = torch.optim.lr_scheduler.StepLR(opt, step_size=step_size, gamma=gamma)
    
        for i in range(train_epoch):
            print('epoch:', i)
            model.train()
            for image, mask in train_dataloader:
                mask = (mask * 255).long()
                input = torch.cat((image, mask), dim=1).to(device)
                output_mask = torch.argmax(model(input), dim=1)
                output = model(input)
                loss = F.cross_entropy(output, input[:, 3, :, :].long())
                # print(loss.item(), torch.sum(output_mask == torch.tensor(input[:, 3, :, :]).to(device)).item() / (224 * 224) / batch_size)
    
                opt.zero_grad()
                loss.backward()
                opt.step()
                sch.step()
    
            model.eval()
            for image, mask in val_dataloader:

                mask = (mask * 255).long()
                input = torch.cat((image, mask), dim=1).to(device)
                output_mask = torch.argmax(model(input), dim=1)
                output = model(input)
                loss = F.cross_entropy(output, input[:, 3, :, :].long())
                print(loss.item(), torch.sum(output_mask == torch.tensor(input[:, 3, :, :]).to(device)).item() / (224 * 224))

                # mask_np = mask.squeeze(0).cpu().numpy().transpose(1, 2, 0)
                # output_mask_np = output_mask.cpu().numpy().transpose(1, 2, 0)
                # image_np = image.squeeze(0).numpy().transpose(1, 2, 0)
                # print_images([image_np, mask_np, np.where(output_mask_np == mask_np, 1, 0)])
    
        if not os.path.exists(models_root):
            os.makedirs(models_root)
        torch.save(model.state_dict(), model_save_path)


In [6]:
# train(trainval_image_root, trainval_mask_root, train_list)
train(trainval_image_root, trainval_mask_root, val_list)
# train(testd_image_root, testd_mask_root, test_list)

bike-packing
Element 0 has count 363087
Element 1 has count 30321
Element 2 has count 43392
type_cnt: 3
epoch: 0


  return self.activation(x)
  print(loss.item(), torch.sum(output_mask == torch.tensor(input[:, 3, :, :]).to(device)).item() / (224 * 224))


0.6776339411735535 0.9025430484693877
epoch: 1
0.6766858696937561 0.9039979272959183
epoch: 2
0.6753642559051514 0.9046356823979592
epoch: 3
0.6757019758224487 0.9044164540816326
epoch: 4
0.6750971078872681 0.9060108418367347
epoch: 5
0.6755126714706421 0.9056122448979592
epoch: 6
0.6765565872192383 0.9039580676020408
epoch: 7
0.6755678653717041 0.9058314732142857
epoch: 8
0.6764638423919678 0.9041573660714286
epoch: 9
0.6733667850494385 0.9086615114795918
epoch: 10
0.6749585270881653 0.9064891581632653
epoch: 11
0.6751775741577148 0.9057318239795918
epoch: 12
0.6781514286994934 0.9023636798469388
epoch: 13
0.6743523478507996 0.9069076849489796
epoch: 14
0.674736499786377 0.9067083864795918
epoch: 15
0.6753599047660828 0.9055125956632653
epoch: 16
0.6752889156341553 0.9055923150510204
epoch: 17
0.6765099167823792 0.9037388392857143
epoch: 18
0.6757300496101379 0.9045758928571429
epoch: 19
0.6751980781555176 0.9057517538265306
