In [1]:
import os
import time

import cv2
import matplotlib.pyplot as plt
import numpy as np
import png
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from colormap.colors import Color, hex2rgb
from sklearn.metrics import average_precision_score as ap_score
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
from tqdm import tqdm
import torch.nn.functional as F

from dataset import FacadeDataset

from train import Net
from train import train, test, get_result


In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# TODO change data_range to include all train/evaluation/test data.
# TODO adjust batch_size.
train_data = FacadeDataset(flag='train', data_range=(0,800), onehot=False)
train_loader = DataLoader(train_data, batch_size=10)
val_data = FacadeDataset(flag='train', data_range=(800,906), onehot=False)
val_loader = DataLoader(val_data, batch_size=10)
test_data = FacadeDataset(flag='test_dev', data_range=(0,114), onehot=False)
test_loader = DataLoader(test_data, batch_size=5)

load train dataset start
    from: ./starter_set/
    range: [0, 800)
load dataset done
load train dataset start
    from: ./starter_set/
    range: [800, 906)
load dataset done
load test_dev dataset start
    from: ./starter_set/
    range: [0, 114)
load dataset done


In [7]:
N_CLASS=5
#torch.nn.ConvTranspose2d

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.n_class = N_CLASS
        self.layers = nn.Sequential(
            #########################################
            ###        TODO: Add more layers      ###
            #########################################
            nn.Conv2d(3, self.n_class, 1, padding=0),
            nn.ReLU(inplace=True)
        )
        self.encoder = nn.Sequential(
        	# conv layer 1
        	nn.Conv2d(3, 64, 3, padding=1),
        	nn.ReLU(inplace=True),
        	nn.Conv2d(64, 64, 3, padding=1),
        	nn.ReLU(inplace=True),
        	nn.MaxPool2d(2, stride=2, ceil_mode=True),
        	# conv layer 2
        	nn.Conv2d(64, 128, 3, padding=1),
        	nn.ReLU(inplace=True),
        	nn.Conv2d(128, 128, 3, padding=1),
        	nn.ReLU(inplace=True),
        	nn.MaxPool2d(2, stride=2, ceil_mode=True)
        	)
        self.decoder = nn.Sequential(
        	nn.Conv2d(128, 256, 5),
            nn.ReLU(inplace=True),
            nn.Dropout2d(),
            nn.Conv2d(256, 256, 1),
            nn.ReLU(inplace=True),
            nn.Dropout2d(),
            nn.Conv2d(256, self.n_class, 1)
        	)


    def forward(self, x):
        temp = self.encoder(x)
        output = self.decoder(temp)
        output = F.interpolate(output, (256,256), mode='bicubic')
        return output

In [10]:
def cross_entropy2d(input, target, weight=None, size_average=True):
    # input: (n, c, h, w), target: (n, h, w) or might be (n, c, h, w) but c1 != c2 idk
    n, c, h, w = input.size()
    # log_p: (n, c, h, w)
    log_p = F.log_softmax(input, dim=1)
    # log_p: (n*h*w, c)
    log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous()
    log_p = log_p[target.view(n, h, w, 1).repeat(1, 1, 1, c) >= 0]
    log_p = log_p.view(-1, c)
    # target: (n*h*w,)
    mask = target >= 0
    target = target[mask]
    loss = F.nll_loss(log_p, target, weight=weight, reduction='sum')
    if size_average:
        loss /= mask.data.sum()
    return loss

name = 'starter_net'
net = Net().to(device)
criterion = cross_entropy2d #nn.CrossEntropyLoss() #TODO decide loss
optimizer = torch.optim.Adam(net.parameters(), 1e-3, weight_decay=1e-5)


In [None]:
train(train_loader, net, criterion, optimizer, device, 3) #why isnt this printing loss


  0%|          | 0/80 [00:00<?, ?it/s]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


  1%|▏         | 1/80 [00:10<14:04, 10.69s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


  2%|▎         | 2/80 [00:23<14:37, 11.24s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


  4%|▍         | 3/80 [00:36<15:04, 11.75s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


  5%|▌         | 4/80 [00:48<14:59, 11.84s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


  6%|▋         | 5/80 [01:01<15:11, 12.15s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


  8%|▊         | 6/80 [01:14<15:23, 12.47s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


  9%|▉         | 7/80 [01:27<15:36, 12.83s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 10%|█         | 8/80 [01:42<16:10, 13.49s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 11%|█▏        | 9/80 [01:58<16:42, 14.12s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 12%|█▎        | 10/80 [02:36<24:45, 21.22s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 14%|█▍        | 11/80 [02:56<24:08, 20.99s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 15%|█▌        | 12/80 [03:10<21:24, 18.89s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 16%|█▋        | 13/80 [03:28<20:33, 18.42s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 18%|█▊        | 14/80 [03:44<19:38, 17.85s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 19%|█▉        | 15/80 [04:11<22:11, 20.49s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 20%|██        | 16/80 [04:45<26:18, 24.66s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 21%|██▏       | 17/80 [05:10<25:55, 24.69s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 22%|██▎       | 18/80 [05:24<22:08, 21.43s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 24%|██▍       | 19/80 [05:36<18:58, 18.66s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 25%|██▌       | 20/80 [05:48<16:45, 16.75s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 26%|██▋       | 21/80 [06:01<15:10, 15.42s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 28%|██▊       | 22/80 [06:13<14:07, 14.62s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 29%|██▉       | 23/80 [06:26<13:17, 13.99s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 30%|███       | 24/80 [06:37<12:15, 13.14s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 31%|███▏      | 25/80 [06:48<11:22, 12.41s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 32%|███▎      | 26/80 [06:58<10:42, 11.90s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 34%|███▍      | 27/80 [07:11<10:48, 12.23s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 35%|███▌      | 28/80 [07:26<11:13, 12.94s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 36%|███▋      | 29/80 [07:38<10:42, 12.59s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 38%|███▊      | 30/80 [07:49<10:04, 12.08s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 39%|███▉      | 31/80 [08:00<09:43, 11.91s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 40%|████      | 32/80 [08:12<09:23, 11.73s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 41%|████▏     | 33/80 [08:23<09:06, 11.62s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 42%|████▎     | 34/80 [08:48<11:58, 15.62s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 44%|████▍     | 35/80 [09:11<13:21, 17.81s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 45%|████▌     | 36/80 [09:29<13:02, 17.79s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 46%|████▋     | 37/80 [09:51<13:43, 19.15s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 48%|████▊     | 38/80 [10:32<17:58, 25.67s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 49%|████▉     | 39/80 [11:04<18:55, 27.68s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 50%|█████     | 40/80 [11:36<19:22, 29.06s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 51%|█████▏    | 41/80 [11:57<17:19, 26.66s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 52%|█████▎    | 42/80 [12:20<16:04, 25.37s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 54%|█████▍    | 43/80 [12:47<16:02, 26.01s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 55%|█████▌    | 44/80 [13:13<15:36, 26.02s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 56%|█████▋    | 45/80 [13:30<13:32, 23.20s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 57%|█████▊    | 46/80 [13:49<12:25, 21.92s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 59%|█████▉    | 47/80 [14:05<11:08, 20.26s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 60%|██████    | 48/80 [14:24<10:29, 19.68s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 61%|██████▏   | 49/80 [14:40<09:42, 18.80s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 62%|██████▎   | 50/80 [14:56<08:56, 17.87s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 64%|██████▍   | 51/80 [15:10<08:03, 16.66s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 65%|██████▌   | 52/80 [15:24<07:23, 15.84s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 66%|██████▋   | 53/80 [15:41<07:21, 16.36s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 68%|██████▊   | 54/80 [15:56<06:48, 15.70s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])


 69%|██████▉   | 55/80 [16:10<06:26, 15.46s/it]

torch.Size([10, 3, 256, 256])
torch.Size([10, 128, 64, 64])
torch.Size([10, 5, 60, 60])
torch.Size([10, 5, 256, 256])
