## Setup


In [0]:
! [[ -d box-unet ]] || git clone --quiet https://github.com/sdll/box-unet.git

In [2]:
%cd box-unet

/content/box-unet


In [6]:
! [[ -f data.zip ]] || wget https://www.dropbox.com/s/m1ie2zq8nkburar/data.zip?raw=1 -O data.zip && unzip data.zip

Archive:  data.zip
replace __MACOSX/._data? [y]es, [n]o, [A]ll, [N]one, [r]ename: A
  inflating: __MACOSX/._data         
  inflating: __MACOSX/data/._train   
  inflating: __MACOSX/data/train/._ground_truth  
  inflating: data/val/.DS_Store      
  inflating: __MACOSX/data/val/._.DS_Store  
  inflating: data/train/ground_truth/normalized_data.tensor  
  inflating: __MACOSX/data/train/ground_truth/._normalized_data.tensor  
  inflating: data/train/ground_truth/normed_crops.33.tensor  
  inflating: data/train/noisy/normalized_data.tensor  
  inflating: __MACOSX/data/train/noisy/._normalized_data.tensor  
  inflating: data/train/noisy/normed_crops.33.tensor  
  inflating: __MACOSX/data/train/noisy/._normed_crops.33.tensor  
  inflating: data/val/ground_truth/normalized_data.tensor  
  inflating: __MACOSX/data/val/ground_truth/._normalized_data.tensor  
  inflating: data/val/ground_truth/normed_crops.33.tensor  
  inflating: __MACOSX/data/val/ground_truth/._normed_crops.33.tensor  
  infl

In [7]:
! pip install -q gsheet-keyring ipython-secrets comet_ml tqdm

[?25l[K     |▏                               | 10kB 33.2MB/s eta 0:00:01[K     |▎                               | 20kB 6.5MB/s eta 0:00:01[K     |▍                               | 30kB 9.2MB/s eta 0:00:01[K     |▋                               | 40kB 5.8MB/s eta 0:00:01[K     |▊                               | 51kB 7.1MB/s eta 0:00:01[K     |▉                               | 61kB 8.4MB/s eta 0:00:01[K     |█                               | 71kB 9.6MB/s eta 0:00:01[K     |█▏                              | 81kB 10.7MB/s eta 0:00:01[K     |█▎                              | 92kB 11.8MB/s eta 0:00:01[K     |█▍                              | 102kB 9.4MB/s eta 0:00:01[K     |█▋                              | 112kB 9.4MB/s eta 0:00:01[K     |█▊                              | 122kB 9.4MB/s eta 0:00:01[K     |█▉                              | 133kB 9.4MB/s eta 0:00:01[K     |██                              | 143kB 9.4MB/s eta 0:00:01[K     |██▏                     

In [8]:
! python3 -m pip install -q git+https://github.com/shrubb/box-convolutions.git

  Building wheel for box-convolution (setup.py) ... [?25l[?25hdone


## Imports

In [0]:
from comet_ml import Experiment

import argparse
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm import tqdm as tqdm_base

from box_unet import BoxUNet as Model
from ipython_secrets import get_secret
from pytorch_ssim import ssim

sns.set()


def tqdm(*args, **kwargs):
    if hasattr(tqdm_base, "_instances"):
        for instance in list(tqdm_base._instances):
            tqdm_base._decr_instances(instance)
    return tqdm_base(*args, **kwargs)

## Environment

In [0]:
DATA_PATH = "data"
GROUND_TRUTH_LABEL = "ground_truth"
NOISY_IMAGES_LABEL = "noisy"
TRAIN_LABEL = "train"
TEST_LABEL = "val"
TRAIN_POSTFIX = "normed_crops.33.tensor"
TEST_POSTFIX = TRAIN_POSTFIX
TRAIN_GT_DATA = Path(DATA_PATH) / TRAIN_LABEL / GROUND_TRUTH_LABEL / TRAIN_POSTFIX
TRAIN_NOISY_DATA = Path(DATA_PATH) / TRAIN_LABEL / NOISY_IMAGES_LABEL / TRAIN_POSTFIX
TEST_GT_DATA = Path(DATA_PATH) / TEST_LABEL / GROUND_TRUTH_LABEL / TEST_POSTFIX
TEST_NOISY_DATA = Path(DATA_PATH) / TEST_LABEL / NOISY_IMAGES_LABEL / TEST_POSTFIX

DEVICE = "cuda"

PROJECT = "fastrino"
COMET_ML_API_KEY = get_secret("comet-{}".format(PROJECT))

## Utilities

In [0]:
def get_arg_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--max-input-h", type=int, default=64)
    parser.add_argument("--max-input-w", type=int, default=64)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--batch-size", type=int, default=32)
    parser.add_argument("--num-epochs", type=int, default=5)
    parser.add_argument("--seed", type=int, default=42)

    return parser


def get_criterion():
    return nn.MSELoss()


def get_optimizer(model, lr=0.001):
    return optim.Adam(model.parameters(), lr)


def psnr(prediction, target, max_pixel=255.0):
    return 10.0 * ((max_pixel ** 2) / ((prediction - target) ** 2).mean()).log10()


class PlaneLoader(torch.utils.data.Dataset):
    def __init__(self, gt_data, noisy_data):
        self.gt_data = torch.load(gt_data)
        self.noisy_data = torch.load(noisy_data)

    def __len__(self):
        return len(self.noisy_data)

    def __getitem__(self, index):
        return (
            self.gt_data[index][:, :],
            self.noisy_data[index][:, :],
        )


def train():
    parser = get_arg_parser()
    args = parser.parse_args(args=[])
    train_loader = torch.utils.data.DataLoader(
        PlaneLoader(TRAIN_GT_DATA, TRAIN_NOISY_DATA),
        batch_size=args.batch_size,
        shuffle=True,
    )

    test_loader = torch.utils.data.DataLoader(
        PlaneLoader(TEST_GT_DATA, TEST_NOISY_DATA),
        batch_size=args.batch_size,
        shuffle=False,
    )
    image, noise = next(iter(train_loader))

    args.in_channels = 1 if len(image.shape) == 3 else image.shape[1]
    experiment = Experiment(
        api_key=COMET_ML_API_KEY,
        project_name=PROJECT,
        workspace=PROJECT,
        auto_output_logging=None,
    )
    experiment.log_parameters(vars(args))

    model = Model(
        args.in_channels, args.in_channels, args.max_input_h, args.max_input_w,
    ).to(DEVICE)

    criterion = get_criterion()
    optimizer = get_optimizer(model, args.lr)

    for epoch in tqdm(range(args.num_epochs), desc="Epoch", unit="epochs"):
        with experiment.train():
            model.train()
            train_psnr = []
            train_ssim = []

            for image, noise in tqdm(train_loader, desc="Train images", unit="images"):
                image = image.to(DEVICE)
                noise = noise.to(DEVICE)

                prediction = model(image)
                loss = criterion(prediction, noise)

                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                current_psnr = psnr(image - prediction, image - noise).data.item()
                current_ssim = ssim(image - prediction, image - noise).data.item()
                train_psnr.append(current_psnr)
                train_ssim.append(current_ssim)
                experiment.log_metric("psnr", current_psnr)
                experiment.log_metric("ssim", current_ssim)
                experiment.log_metric("loss", loss.data.item())

            experiment.log_metric("mean_psnr", np.mean(train_psnr))
            experiment.log_metric("mean_ssim", np.mean(train_ssim))

    with experiment.test():
        model.eval()
        test_losses = []
        test_psnr = []
        test_ssim = []
        for image, noise in test_loader:
            image = image.to(DEVICE)
            noise = noise.to(DEVICE)
            prediction = model(image)
            current_psnr = psnr(image - prediction, image - noise).data.item()
            current_ssim = ssim(image - prediction, image - noise).data.item()
            test_psnr.append(current_psnr)
            test_ssim.append(current_ssim)
            test_losses.append(criterion(prediction, noise).data.item())

        test_psnr = np.mean(test_psnr)
        test_ssim = np.mean(test_ssim)
        test_loss = np.mean(test_losses)

        experiment.log_metric("mean_psnr", test_psnr)
        experiment.log_metric("mean_ssim", test_ssim)
        experiment.log_metric("mean_loss", test_loss)

    return test_psnr, test_ssim, test_loss

In [13]:
test_psnr, test_ssim, test_loss = train()

COMET INFO: ----------------------------
COMET INFO: Comet.ml Experiment Summary:
COMET INFO:   Data:
COMET INFO:     url: https://www.comet.ml/fastrino/fastrino/90886c615c8f4643ab70a01aa5e22457
COMET INFO:   Metrics [count] (min, max):
COMET INFO:     sys.cpu.percent.01 [3]             : (0.9, 49.6)
COMET INFO:     sys.cpu.percent.02 [3]             : (0.9, 40.5)
COMET INFO:     sys.cpu.percent.avg [3]            : (0.9, 45.05)
COMET INFO:     sys.gpu.0.free_memory [3]          : (14455537664.0, 15812198400.0)
COMET INFO:     sys.gpu.0.gpu_utilization [3]      : (0.0, 0.0)
COMET INFO:     sys.gpu.0.total_memory             : (15812263936.0, 15812263936.0)
COMET INFO:     sys.gpu.0.used_memory [3]          : (65536.0, 1356726272.0)
COMET INFO:     sys.ram.total [3]                  : (13655232512.0, 13655232512.0)
COMET INFO:     sys.ram.used [3]                   : (1343336448.0, 4130967552.0)
COMET INFO:     test_mean_loss                     : (0.018405649555882756, 0.01840564955588

In [14]:
print("Mean Test PSNR: {}\nMean Test SSIM: {}\nMean Test Loss: {}".format(test_psnr, test_ssim, test_loss))

Mean Test PSNR: 70.6452029292195
Mean Test SSIM: 0.4837297156405525
Mean Test Loss: 0.02190944241949717
