# Tiramisu

In [None]:
# coding: utf-8
# ## Dependencies
# In[1]:

import time
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

from models import tiramisu
from datasets import camvid
from datasets import joint_transforms
import utils.imgs
import utils.training as train_utils
import datetime
from IPython.display import Image, display
import os

In [None]:
DATA_PATH = Path('./data/')
RESULTS_PATH = Path('./output/')
WEIGHTS_PATH = Path('./weights/')
RESULTS_PATH.mkdir(exist_ok=True)
WEIGHTS_PATH.mkdir(exist_ok=True)
batch_size=20
LR = 1e-4

### Load Model

In [None]:
normalize = transforms.Normalize(mean=camvid.mean, std=camvid.std)

test_dset = camvid.CamVid(
    DATA_PATH, 'test', joint_transform=None,
    transform=transforms.Compose([
        transforms.Resize([132, 132]),
        transforms.ToTensor(),
        normalize
    ]),
    target_transform=transforms.Compose([          
          camvid.LabelToLongTensor(),
    ]))
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=batch_size, shuffle=False)
torch.cuda.manual_seed(0)

In [None]:
model = tiramisu.FCDenseNet67(n_classes=4).cpu()
model.apply(train_utils.weights_init)
optimizer = torch.optim.RMSprop(model.parameters(), lr=LR, weight_decay=1e-4)
criterion = nn.NLLLoss(weight=camvid.class_weight.cpu()).cpu()

In [None]:
train_utils.load_weights(model, str(WEIGHTS_PATH)+'/latest.th')
train_utils.test(model, test_loader, criterion, epoch=1)  

### Compute segmentation for the images in "output folder"

In [None]:
train_utils.view_sample_predictions(model, test_loader, 0, n=10)

### Show the results

In [None]:

listOfImageNames = ["output/"+f for f in os.listdir('./output') if(".png" in f)]

for imageName in listOfImageNames:
    display(Image(filename=imageName))