In [18]:
from dataclasses import dataclass
from PIL import Image
import numpy as np

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import (
    DataLoader, 
    Dataset, 
    ConcatDataset, 
    random_split
)

from segmentation import UNet
from segmentation.eval import eval_segmentation_net
from utils.manager import RenderManager

### Params Dataclass

In [19]:
@dataclass
class Params:
    
    #Dataset Params
    dir_num: list = []
    
    #Image Params
    img_size: tuple = (120, 120)
    
    #Training Params
    optimizer = optim.RMSprop #optim.Adam
    batch_size: int = 1
    epochs: int = 5
    learning_rate: float = .0001
    val_split: float = .1

### Dataset

In [24]:
class EvMaskDataset(Dataset):
    
    def __init__(self, 
                 dir_num: int
                 image_params: Params,
                 transforms: list = []):
        
        self.img_size = image_params.img_size
        self.transforms = transforms
        self.render_manager = RenderManager.from_directory(
            dir_num = dir_num
        )
 
    def __len__(self):
        return len(self.render_manager)
    
    def preprocess(self, img: Image):
        """Resize and normalize the images to range 0, 1
        """
        img = img.resize(self.img_size)
        
        img_np = np.array(img)
        
        if len(img_np.shape) == 2:
            img_nd = np.expand_dims(img_nd, axis=2)

        # HWC to CHW
        img_trans = img_np.transpose((2, 0, 1))
        if img_trans.max() > 1:
            img_trans = img_trans / 255

        return img_trans
    
    def __getitem__(self, index: int):
        
        mask = self.render_manager.get_image("silhouette", index)
        
        event_frame = self.render_manager.get_event_frame(index)
        
        assert mask.size == event_frame.size, \
            "Mask and event frame must be same size"
        
        mask = self.preprocess(mask)
        event_frame = self.preprocess(event_frame)
        
        return event_frame, mask
           

### Evaluation

In [25]:
from losses import DiceCoeffLoss

def eval_seg_net(net, loader):
    loss = 0
    for X, y in loader:
        X = Variable(X).cuda()
        y = Variable(y).cuda()
        out = net(X)
        loss += DiceCoeffLoss().forward(out, y)
    loss = loss / len(loader)
    return loss

In [2]:
#Start by first loading the net
unet = UNet(
    n_channels = 1, 
    n_classes = 1, 
    bilinear=True
)

#Create Train and Val DataLoaders
datasets = [EvMaskDataset(dir_num, Params) for dir_num in Params.dir_num]
dataset = ConcatDataset(datasets)

val_size = int(len(dataset) * Params.val_split)
train_size = len(dataset) - val_size
train, val = random_split(dataset, (train_size, val_size))

train_loader = DataLoader(
    train,
    batch_size = Params.batch_size,
    shuffle = True,
    num_workers = 8
)
val_loader = DataLoader(
    val,
    batch_size = Params.batch_size,
    shuffle = False,
    num_workers = 8,
    drop_last = True
)

optimizer = Params.optimizer(
    unet.parameters(),
    lr = Params.learning_rate,
    weight_decay = 1e-8,
    momentum = 0.9
)

#Or Maybe just use a cross entropy loss - need to eval this
if unet.n_classes > 1:
    criterion = nn.CrossEntropyLoss()
else:
    criterion = nn.BCEWithLogitsLoss()


NameError: name 'UNet' is not defined

### Training Loop

In [27]:
iters = []
train_losses = []
val_losses = []

step = 0
min_loss = np.inf

unet.train()
for epoch in range(Params.epochs):
    for i, (X, y) in enumerate(train_loader):
        X = Variable(X).cuda()
        y = Variable(y).cuda()
        
        output = model(X)
        loss = critetion(output, y)

        optim.zero_grad()
        loss.backward()
        optim.step()
        
        if i % 100 == 0:
            iters.append(i)
            train_losses.append(loss)
            
            unet.eval()
            val_loss = eval_seg_net(net, val_loader)
            unet.train()
            val_losses.append(val_loss)
             

Exception: Caught Exception in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/alexis/anaconda3/envs/pytorch3d/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/alexis/anaconda3/envs/pytorch3d/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/alexis/anaconda3/envs/pytorch3d/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/alexis/anaconda3/envs/pytorch3d/lib/python3.7/site-packages/torch/utils/data/dataset.py", line 257, in __getitem__
    return self.dataset[self.indices[idx]]
  File "<ipython-input-24-56d0e06591fd>", line 43, in __getitem__
    mask = self.render_manager.get_image("silhouette", index)
  File "/home/alexis/Desktop/e3d/e3d/utils/manager.py", line 217, in get_image
    return img_manager._load
  File "/home/alexis/Desktop/e3d/e3d/utils/manager.py", line 91, in _load
    raise Exception(f"Image path {self.image_path} does not exist")
Exception: Image path data/renders/001-teapot_2020-08-25T00:45:24/silhouette/4_silhouette.png does not exist


In [None]:
plt.plot(train_losses)
plt.plot(val_losses)
plt.show()