In [1]:
from google.colab import drive
drive.mount('/content/drive')

!nvidia-smi

Mounted at /content/drive
Thu Aug 19 13:16:00 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.57.02    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   41C    P0    28W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-------------------------------------------------------------

In [2]:
BASE_PATH = "drive/MyDrive/p2p-torch/"

In [4]:
# !pip install -U git+https://github.com/albu/albumentations --no-cache-dir
# !pip install wandb --upgrade

In [5]:
import wandb
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

## Train Model

In [6]:
# !python drive/MyDrive/p2p-torch/code/train.py

## Create sym links for required classes

In [7]:
!ln -s drive/MyDrive/p2p-torch/code/generator_model.py generator_model.py
!ln -s drive/MyDrive/p2p-torch/code/generator_resnet.py generator_resnet.py
!ln -s drive/MyDrive/p2p-torch/code/config.py config.py

## Helpers

In [71]:
import os
import time
import config
import torch
import torchvision
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image
from torch.utils.mobile_optimizer import optimize_for_mobile

class MapDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.list_files = os.listdir(self.root_dir)

    def __len__(self):
        return len(self.list_files)

    def __getitem__(self, index):
        img_file = self.list_files[index]
        img_path = os.path.join(self.root_dir, img_file)
        image = np.array(Image.open(img_path))
        input_image = image[:, :512, :]
        target_image = image[:, 512:, :]

        augmentations = config.both_transform(image=input_image, image0=target_image)
        input_image = augmentations["image"]
        target_image = augmentations["image0"]

        input_image = config.transform_only_input(image=input_image)["image"]
        target_image = config.transform_only_mask(image=target_image)["image"]

        return input_image, target_image

def print_model_size(mdl):
    torch.save(mdl.state_dict(), "tmp.pt")
    print("Model Size: %.2f MB" %(os.path.getsize("tmp.pt")/1e6))
    os.remove('tmp.pt')

def model_speedrun(M, tag="", path_to_data=config.VAL_DIR):
  M.eval()
  dataset = MapDataset(path_to_data)
  loader = DataLoader(dataset, batch_size=1)
  index = 0
  s = time.time()
  for x, y in loader:
      save_image(x * 0.5 + 0.5, f"{BASE_PATH}evaluation/{tag}_in_{index}.png")
      x = M(x)
      x = x * 0.5 + 0.5
      save_image(x, f"{BASE_PATH}evaluation/{tag}_out_{index}.png")
      index += 1
  elapsed = time.time() - s
  print('''\nelapsed time (seconds): {0:.1f}'''.format(elapsed))
  print("Saved Input and Output Images to " + BASE_PATH + "evaluation")
      

def model_trace(M, name="pix2pix_traced"):
  M.eval()
  example = torch.rand(1, 3, 512, 512)
  traced_script_module = torch.jit.trace(M, example)
  traced_script_module_optimized = optimize_for_mobile(traced_script_module)
  traced_script_module_optimized._save_for_lite_interpreter(f"{BASE_PATH}checkpoints/{name}.ptl")

def get_dataset_path(REGION):
  return BASE_PATH + "cherries/" + REGION + "/val"


## Load and test model

In [73]:
REGION = "ethiopia"
EPOCH = "99"
generator_path = BASE_PATH + "models/" + REGION + "_MODEL_" + EPOCH + ".pt"

model = torch.load(generator_path, map_location="cpu")

print_model_size(model)
model_speedrun(model, REGION + EPOCH, get_dataset_path(REGION))

Model Size: 7.98 MB

elapsed time (seconds): 18.1
Saved Input and Output Images to drive/MyDrive/p2p-torch/evaluation


## Testrun all models

In [74]:
REGIONS = ["guatemala", "honduras", "ethiopia"]

# Takes approx 75 minutes to run
# Required output for next cell has already been generated
# Run only when you have retrained models

# for REGION in REGIONS:
#   for EPOCH in range(100):
#     generator_path = BASE_PATH + "models/" + REGION + "_MODEL_" + str(EPOCH) + ".pt"
#     model = torch.load(generator_path, map_location="cpu")
#     print(f"Loading model for {REGION}, epoch: {EPOCH}")
#     model_speedrun(model, REGION + str(EPOCH), get_dataset_path(REGION))

Loading model for guatemala, epoch: 0

elapsed time (seconds): 17.7
Saved Input and Output Images to drive/MyDrive/p2p-torch/evaluation
Loading model for guatemala, epoch: 1

elapsed time (seconds): 17.5
Saved Input and Output Images to drive/MyDrive/p2p-torch/evaluation
Loading model for guatemala, epoch: 2

elapsed time (seconds): 17.4
Saved Input and Output Images to drive/MyDrive/p2p-torch/evaluation
Loading model for guatemala, epoch: 3

elapsed time (seconds): 17.1
Saved Input and Output Images to drive/MyDrive/p2p-torch/evaluation
Loading model for guatemala, epoch: 4

elapsed time (seconds): 16.7
Saved Input and Output Images to drive/MyDrive/p2p-torch/evaluation
Loading model for guatemala, epoch: 5

elapsed time (seconds): 16.9
Saved Input and Output Images to drive/MyDrive/p2p-torch/evaluation
Loading model for guatemala, epoch: 6

elapsed time (seconds): 16.7
Saved Input and Output Images to drive/MyDrive/p2p-torch/evaluation
Loading model for guatemala, epoch: 7

elapsed t

## Plot Inputs and outputs for all epochs

In [75]:
from matplotlib import pyplot as plt
import matplotlib.image as mpimg

REGIONS = ["guatemala", "honduras", "ethiopia"]

def plot_graphs_for_region(REGION):
  for EPOCH in range(100):
    # print(f"{REGION}, epoch: {EPOCH}\n\n")
    fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(14, 7))
    fig.subplots_adjust(wspace=0, hspace=0)
    ax = axes.ravel()

    for a in ax:
      a.set_xticklabels([])
      a.set_yticklabels([])

    tag = REGION + str(EPOCH)

    for index in range(4):
      input_path = f"{BASE_PATH}evaluation/{tag}_in_{index}.png"
      output_path = f"{BASE_PATH}evaluation/{tag}_out_{index}.png"
      in_img = mpimg.imread(input_path)
      out_img = mpimg.imread(output_path)
      ax[index].imshow(in_img)
      ax[index + 4].imshow(out_img)
    
    fig.suptitle(f"{REGION} - Epoch: {EPOCH}")

In [76]:
plot_graphs_for_region(REGIONS[0])

Output hidden; open in https://colab.research.google.com to view.

In [77]:
plot_graphs_for_region(REGIONS[1])

Output hidden; open in https://colab.research.google.com to view.

In [78]:
plot_graphs_for_region(REGIONS[2])

Output hidden; open in https://colab.research.google.com to view.

## Trace Model and Optimize for Android

After observing the results, the following models were observed to have provided decent results

In [80]:
decent_models = {
    "guatemala": [52, 54, 65],
    "ethiopia": [32,36,44,56,67,85],
    "honduras": [40]
}
decent_models

{'ethiopia': [32, 36, 44, 56, 67, 85],
 'guatemala': [52, 54, 65],
 'honduras': [40]}

In [85]:
for item in decent_models.items():
  generator_path = BASE_PATH + "models/" + item[0] + "_MODEL_" + str(item[1][-1]) + ".pt"
  model = torch.load(generator_path, map_location="cpu")
  print(f"Tracing model at {generator_path}")
  model_trace(model, item[0])

Tracing model at drive/MyDrive/p2p-torch/models/guatemala_MODEL_65.pt
Tracing model at drive/MyDrive/p2p-torch/models/ethiopia_MODEL_85.pt
Tracing model at drive/MyDrive/p2p-torch/models/honduras_MODEL_40.pt


## Test Traced Model

In [86]:
REGION = "honduras"
traced_model = torch.jit.load(BASE_PATH + 'checkpoints/' + REGION + '.ptl')
model_speedrun(traced_model)


elapsed time (seconds): 23.0
Saved Input and Output Images to drive/MyDrive/p2p-torch/evaluation


## Quantization

## Load Model

In [None]:
MODEL_PATH = BASE_PATH + ""
model = torch.load(MODEL_PATH, map_location="cpu")
print_model_size(model)