In [1]:

import torch
import torch.nn as nn
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from os import walk, path
import flowiz
from skimage import io
import numpy as np
from random import shuffle
from unet_utils import Down, Up, DoubleConv

In [2]:
class FlowDataset(Dataset):
    """Flow dataset."""

    def __init__(self, root_dir='./data', transform=None, copy_from = None, *args, **kwargs):
        """
        Args:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        super(*args, **kwargs)

        self.root_dir = root_dir
        self.transform = transform

        self.images_paths = []
        self.flow_paths = []

        if copy_from:
            self.images_paths = copy_from[0]
            self.flow_paths = copy_from[1]
        else:
            for pardir, dirs, files in walk(path.join(root_dir, 'images')):
                image_name = path.basename(pardir)
                if not (image_name.startswith('IM') and '_' not in image_name):
                    continue

                # We have an image
                for frame_name in files:
                    frame_number = frame_name[6:10]
                    flow_name = 'frameGT_' + str(frame_number) + '.flo'
                    flow_path = path.join(root_dir, 'gt_flow', image_name, flow_name)
                    frame_path = path.join(pardir, frame_name)

                    if not path.exists(flow_path) or not path.exists(frame_path):
                        continue

                    self.images_paths.append(frame_path)
                    self.flow_paths.append(flow_path)

            self.images_paths = np.asarray(self.images_paths)
            self.flow_paths = np.asarray(self.flow_paths)

    def __len__(self):
        return len(self.images_paths)

    def __getitem__(self, idx):
        print('Called __getitem__')
        if torch.is_tensor(idx):
            idx = idx.tolist()

        print('Loading', idx)
        image_path, flow_path = self.images_paths[idx], self.flow_paths[idx]
        print(image_path)
        print(flow_path)
        image = io.imread(image_path)
        flow = flowiz.convert_from_file(flow_path)

        if self.transform:
            image = self.transform(image)

        return image, flow

    def slice(self, limit: int):
        return (
            FlowDataset(root_dir=self.root_dir, transform=self.transform, copy_from=[
                self.images_paths[:limit],
                self.flow_paths[:limit],
            ]),
            FlowDataset(root_dir=self.root_dir, transform=self.transform, copy_from=[
                self.images_paths[limit:],
                self.flow_paths[limit:],
            ])
        )

    def random_split(self, size: int):
        zipped = list(zip(self.images_paths, self.flow_paths))
        shuffle(zipped)
        img_paths, flow_paths = list(zip(*zipped))

        return FlowDataset(root_dir=self.root_dir, transform=self.transform, copy_from=[img_paths, flow_paths]).slice(size)

In [10]:
class LitUNet(pl.LightningModule):
    def __init__(
            self,
            num_classes: int,
            *,
            num_layers: int = 5,
            features_start: int = 64,
            batch_size: int = 10,
            num_workers: int = 10,
            bilinear: bool = False,

    ):
        """
        Paper: `U-Net: Convolutional Networks for Biomedical Image Segmentation
        <https://arxiv.org/abs/1505.04597>`_
        Paper authors: Olaf Ronneberger, Philipp Fischer, Thomas Brox
        Implemented by:
            - `Annika Brundyn <https://github.com/annikabrundyn>`_
            - `Akshay Kulkarni <https://github.com/akshaykvnit>`_
        Args:
            num_classes: Number of output classes required
            num_layers: Number of layers in each side of U-net (default 5)
            features_start: Number of features in first layer (default 64)
            bilinear (bool): Whether to use bilinear interpolation or transposed convolutions (default) for upsampling.
        """
        super().__init__()
        self.num_layers = num_layers

        layers = [DoubleConv(3, features_start)]

        feats = features_start
        for _ in range(num_layers - 1):
            layers.append(Down(feats, feats * 2))
            feats *= 2

        for _ in range(num_layers - 1):
            layers.append(Up(feats, feats // 2, bilinear))
            feats //= 2

        # TODO: change the convolution so that it predicts masks and not classes
        layers.append(nn.Conv2d(feats, num_classes, kernel_size=1))

        self.layers = nn.ModuleList(layers)
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.data = None

    def forward(self, x):
        xi = [self.layers[0](x)]

        # Down path
        for layer in self.layers[1:self.num_layers]:
            xi.append(layer(xi[-1]))

        # Up path
        for i, layer in enumerate(self.layers[self.num_layers:-1]):
            xi[-1] = layer(xi[-1], xi[-2 - i])

        pred = self.layers[-1](xi[-1])

        return pred

    @staticmethod
    def calculate_loss(logits, y):
        # TODO: Somehow, calculate loss using Low Entropy Motion Loss
        loss = ...
        return loss

    def training_step(self, batch, batch_nb):
        x, y = batch
        logits = self(x)

        loss = self.calculate_loss(logits, y)

        self.log('train_loss', loss)
        return loss

    def general_step(self, step_name, batch, batch_nb):
        x, y = batch
        logits = self(x)
        loss = self.calculate_loss(logits, y)

        self.log(step_name + '_loss', loss, prog_bar=True)

        # TODO: Calculate accuracy maybe ?
        # acc = ...
        # self.log(step_name + '_acc', acc, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_nb):
        self.general_step('val', batch, batch_nb)

    def test_step(self, batch, batch_nb):
        self.general_step('test', batch, batch_nb)

    def configure_optimizers(self):
        # TODO: Find some optimizer
        return torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
        pass

    ####################
    # DATA RELATED HOOKS
    ####################
    def setup(self, stage=None):
        print('setup')
        if self.data is None:
            self.data = FlowDataset()

        # Assign train/val datasets for use in dataloaders
        train_size = int(len(self.data) * 0.9)
        train_data, test_data = self.data.slice(train_size)

        if stage == 'fit' or stage is None:
            train_size = int(len(train_data) * 0.9)

            self.data_train, self.data_val = train_data.random_split(train_size)

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.data_test = test_data

        print('setup_end')

    def train_dataloader(self):
        print('train_dl')
        dl = DataLoader(self.data_train, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers,
                          pin_memory=True)
        print('train_dl end')
        return dl

    def val_dataloader(self):
        print('val_dl')
        dl = DataLoader(self.data_val, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers,
                          pin_memory=True)
        print('val_dl end')
        return dl

    def test_dataloader(self):
        return DataLoader(self.data_test, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers,
                          pin_memory=True)

In [11]:
model = LitUNet(5, num_workers=0, batch_size=1)

In [12]:
# Trainer
trainer = pl.Trainer(
    gpus=1, max_epochs=15, progress_bar_refresh_rate=20,
    precision=16,
)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.


In [13]:
trainer.fit(model)


  | Name   | Type       | Params
--------------------------------------
0 | layers | ModuleList | 31.0 M
--------------------------------------
31.0 M    Trainable params
0         Non-trainable params
31.0 M    Total params


setup
setup_end
val_dl
val_dl end
Called __getitem__
Loading 0
./data\images\IM04\frame_0035.png
./data\gt_flow\IM04\frameGT_0035.flo


Validation sanity check: 0it [00:00, ?it/s]

RuntimeError: Given groups=1, weight of size [64, 3, 3, 3], expected input[1, 720, 1280, 3] to have 3 channels, but got 720 channels instead