In [1]:
%load_ext autoreload
%autoreload 1

In [2]:
import os
from time import sleep

import matplotlib.pyplot as plt
import cv2
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import PIL.Image as Image

from dataset import *
from utilities import *
from resnet import *
from config import *



In [3]:
color_to_gray_map, gray_to_color_map = None, None

In [4]:
np.random.seed(0)
train_imageset_path = '../trainval/DAVIS/ImageSets/2017/train.txt'
val_imageset_path = '../trainval/DAVIS/ImageSets/2017/val.txt'
testd_imageset_path = '../testd/DAVIS/ImageSets/2017/test-dev.txt'
trainval_image_root = '../trainval/DAVIS/JPEGImages/480p/'
trainval_mask_root = '../trainval/DAVIS/Annotations/480p/'
testd_image_root = '../testd/DAVIS/JPEGImages/480p/'
testd_mask_root = '../testd/DAVIS/Annotations/480p/'
models_root = '../models/'

train_list = []
val_list = []
test_list = []

with open(train_imageset_path, 'r') as f:
    for line in f:
        train_list.append(line.strip())
with open(val_imageset_path, 'r') as f:
    for line in f:
        val_list.append(line.strip())
with open(testd_imageset_path, 'r') as f:
    for line in f:
        test_list.append(line.strip())


In [5]:
def train(image_root, mask_root, target_list):

    for t in range(len(target_list)):
        if t != 29:
            continue
        print(target_list[t])
        image_path = os.path.join(image_root, target_list[t] + '/00000.jpg')
        mask_path = os.path.join(mask_root, target_list[t] + '/00000.png')
        model_save_path = os.path.join(models_root, target_list[t] + '.pt')
    
        image = cv2.imread(image_path)
        mask = cv2.imread(mask_path)
        PIL_mask = Image.open(mask_path)
        color_to_gray_map, gray_to_color_map = get_map(mask, PIL_mask)
        del PIL_mask
        
        image = cv2.resize(image, Resize, interpolation=cv2.INTER_NEAREST)
        mask = cv2.resize(mask, Resize, interpolation=cv2.INTER_NEAREST)
        mask = convert_to_gray_mask(mask, color_to_gray_map)
        print('type_cnt:', len(color_to_gray_map))
    
        model = MyResNet(len(color_to_gray_map)).to(device)
        train_dataset = CustomDataset(image_path, mask_path, image_transform=train_image_transforms, mask_transform=train_mask_transforms, num_samples=augmentation_num)
        train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
        val_dataset = CustomDataset(image_path, mask_path, image_transform=val_image_transforms, mask_transform=val_mask_transforms, num_samples=1)
        val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False)
    
        opt = torch.optim.Adam(model.parameters(), lr=learning_rate)
        sch = torch.optim.lr_scheduler.StepLR(opt, step_size=step_size, gamma=gamma)
        
    
        for i in range(train_epoch):
            print('epoch:', i)
            model.train()
            for image, mask in train_dataloader:
                mask = (mask * 255).long()
                input = torch.cat((image, mask), dim=1).to(device)
                output_mask = torch.argmax(model(input), dim=1)
                output = model(input)
                loss = F.cross_entropy(output, input[:, 3, :, :].long())
                # print(loss.item(), torch.sum(output_mask == torch.tensor(input[:, 3, :, :]).to(device)).item() / (224 * 224) / batch_size)
                
                opt.zero_grad()
                loss.backward()
                opt.step()
                sch.step()
    
            model.eval()
            for image, mask in val_dataloader:

                mask = (mask * 255).long()
                input = torch.cat((image, mask), dim=1).to(device)
                output_mask = torch.argmax(model(input), dim=1)
                output = model(input)
                loss = F.cross_entropy(output, input[:, 3, :, :].long())
                print(loss.item(), torch.sum(output_mask == torch.tensor(input[:, 3, :, :]).to(device)).item() / (Resize[0] * Resize[1]))

                # mask_np = mask.squeeze(0).cpu().numpy().transpose(1, 2, 0)
                # output_mask_np = output_mask.cpu().numpy().transpose(1, 2, 0)
                # image_np = image.squeeze(0).numpy().transpose(1, 2, 0)
                # print_images([image_np, mask_np, np.where(output_mask_np == mask_np, 1, 0)])
    
        if not os.path.exists(models_root):
            os.makedirs(models_root)
        torch.save(model.state_dict(), model_save_path)


In [6]:
# train(trainval_image_root, trainval_mask_root, train_list)
train(trainval_image_root, trainval_mask_root, val_list)
# train(testd_image_root, testd_mask_root, test_list)

soapbox
type_cnt: 4
epoch: 0


  return self.activation(x)
  print(loss.item(), torch.sum(output_mask == torch.tensor(input[:, 3, :, :]).to(device)).item() / (Resize[0] * Resize[1]))


1.2383451461791992 0.8644172512755102
epoch: 1
0.8579269647598267 0.9835180165816326
epoch: 2
0.8062392473220825 0.9890186543367347
epoch: 3
0.9365745186805725 0.803352200255102
epoch: 4
0.7676920890808105 0.9930644132653061
epoch: 5
0.7638428211212158 0.9933035714285714
epoch: 6
0.7599223256111145 0.9935028698979592
epoch: 7
0.759382963180542 0.9935427295918368
epoch: 8
0.7566217184066772 0.9939413265306123
epoch: 9
0.7555509805679321 0.9943598533163265
epoch: 10
0.7550023794174194 0.9949378188775511
epoch: 11
0.7540321946144104 0.9948182397959183
epoch: 12
0.7533095479011536 0.9949577487244898
epoch: 13
0.7527804374694824 0.9948780293367347
epoch: 14
0.7523694634437561 0.9950374681122449
epoch: 15
0.7520291209220886 0.9953563456632653
epoch: 16
0.7518365979194641 0.9952367665816326
epoch: 17
0.7514084577560425 0.9954559948979592
epoch: 18
0.7511677145957947 0.9954161352040817
epoch: 19
0.7509124875068665 0.9955357142857143
epoch: 20
0.7509933710098267 0.9955357142857143
epoch: 21
0.7