# UNIT Project in Colab
[on GitHub](https://github.com/mingyuliutw/UNIT)


## Workspace Setup
Bash commands setting up the current session for Juputer Notebook

### Google Drive Folder Mounting
Following code mounts your google.drive to container. Choose **one** of next **two** methods.

In [17]:
#@markdown #### Mounting (built-in)
print('Mounting...')
import os
from google.colab import drive
os.chdir('/content/')
drive.mount('/drive/', force_remount=True)
if os.path.exists('/content/drive'):
    os.unlink('/content/drive')
os.symlink('/drive/My Drive', '/content/drive')
!ls '/content/drive/'
print('Mounted!')

Mounting...
Mounted at /drive/
Colab  datasets  Graduation  ParseData	README.md  UNIT
Mounted!


In [0]:
#@markdown #### Mounting With Fuse Driver (google-drive-ocamlfuse)
print('Mounting Google.Drive with google-drive-ocamlfuse...')
% cd /content/
print('Installing required software')
! apt-get install -y -qq software-properties-common module-init-tools 2>&1 > /dev/null
print('Add apt-repository with Google.Drive Fuse')
! add-apt-repository -y ppa:alessandro-strada/ppa 2>&1 > /dev/null
print('Updating packages...')
! apt-get update -y -qq
print('Installing google-drive-ocamlfuse fuse...')
! apt-get install -y -qq google-drive-ocamlfuse fuse
print('Authenticate Fuse in Google.Drive...')
from google.colab import auth
from oauth2client.client import GoogleCredentials
import getpass
auth.authenticate_user()
creds = GoogleCredentials.get_application_default()
! google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret} < /dev/null 2>&1 | grep URL
vcode = getpass.getpass('Enter auth code here: ')
! echo {vcode} | google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret}
print('Authenticated!')
print('Creating mount directory')
! mkdir /drive2
print('Mounting...')
! google-drive-ocamlfuse /drive2
if os.path.exists('/content/drive'):
    os.unlink('/content/drive')
os.symlink('/drive2', '/content/drive')
!ls '/content/drive/'
print('Mounted!')

### SSH Tunnel

In [0]:
#@markdown ## Connect to Colab session
#@markdown Using ngrok
port = 4040 #@param {type:"integer"}
only_show_credetionals = True #@param {type:"boolean"}
%cd /content/
if not only_show_credetionals:
    print('Generate root password')
    import secrets, string
    password = ''.join(secrets.choice(string.ascii_letters + string.digits) for i in range(20))
    ! echo "Password: $password" > /content/save_pswd
    print('Download ngrok')
    ! wget -q -c -nc https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip
    ! unzip -qq -n ngrok-stable-linux-amd64.zip
    print('Setup sshd')
    ! apt-get install -qq -o=Dpkg::Use-Pty=0 openssh-server pwgen > /dev/null
    print('Set root password')
    ! echo root:$password | chpasswd
    ! mkdir -p /var/run/sshd
    ! echo "PermitRootLogin yes" >> /etc/ssh/sshd_config
    ! echo "PasswordAuthentication yes" >> /etc/ssh/sshd_config
    ! echo "LD_LIBRARY_PATH=/usr/lib64-nvidia" >> /root/.bashrc
    ! echo "export LD_LIBRARY_PATH" >> /root/.bashrc

    print('Run sshd')
    get_ipython().system_raw('/usr/sbin/sshd -D &')

    print("Copy authtoken from https://dashboard.ngrok.com/auth")
    import getpass
    authtoken = getpass.getpass()

    print('Create tunnel')
    get_ipython().system_raw('./ngrok authtoken $authtoken && ./ngrok tcp 22 &')

print('---------')
import sys, json, os
try:
    s = !curl -s http://localhost:$port/api/tunnels
    addr = str(json.loads(s[0])['tunnels'][0]['public_url'])
    print('Use:', end=' ')
    print('ssh root@' + addr[6:addr.find(':', 6)] + ' -p ' + addr[addr.find(':', 6)+1:])
    if os.path.exists('/content/save_pswd'):
        ! cat /content/save_pswd
except:
    print('Tunnel was closed!')


# Project

### Install Dependencies
Session startup installation

In [18]:
#@markdown ## Dependencies
Project = 'google.drive' #@param ['google.drive', 'clone github'] 
copy_pre_trained_model = False #@param {type:"boolean"}

if Project == 'clone github':
    print('Cloning GitHub project...')
    !git clone https://github.com/SoleSensei/UNIT.git
    if copy_pre_trained_model:
        print('Copying PT model gta2city...')
        !mkdir /content/UNIT/models
        !mkdir /content/UNIT/outputs
        !cp -r /content/drive/UNIT/models/ /content/UNIT/
        !cp -r /content/drive/UNIT/output/ /content/UNIT/


print('Installing system packages...')
!apt-get install -y -qq axel imagemagick 2>&1 > /dev/null
print('Installing project dependencies...')
!pip3 install http://download.pytorch.org/whl/cu92/torch-0.4.1-cp36-cp36m-linux_x86_64.whl 2>&1 > /dev/null
!pip3 install torchvision 2>&1 > /dev/null
!pip3 install tensorboard tensorboardX 2>&1 > /dev/null
print('Complete!')

Installing system packages...
Installing project dependencies...
Complete!


## Train

### Day-2-Night Translation

In [0]:
#@title Shift Domains (nexet dataset)
#@markdown Script parsing dataset folder to several domains by states from csv file

import pandas as pd
import os, sys
from shutil import copy, move

# ------------------------ Variables ------------------------ 
datapath = '/content/drive/datasets/nexet/nexet_2017_1/' # path to dataset directory
csvfile = '/content/drive/datasets/nexet/train.csv' # path to csv file
col_name = 'image_filename' # column name with dataset's filenames
col_state = 'lighting' # column name with dataset's states 
domains = {
            'trainA' : 'Day',
            'trainB' : ['Night', 'Twilight'],
            'testA' : 'Day',
            'testB' : ['Night', 'Twilight']
          }  # making domain directories {Domain_Name : States}
mode = 'move' # 'move' | 'copy' all files from dataset folder to domains
domains2data = False # set True to shift all files back to datapath
# -----------------------------------------------------------

# ------------------------ Arguments ------------------------ 
mode = "move" #@param ["move", "copy", "none"]
domains2data = False #@param {type:"boolean"}
show_errors = 5 #@param {type:"slider", min:0, max:100, step:1}
show_log = 10 #@param {type:"slider", min:0, max:100, step:1}
# -----------------------------------------------------------


class DomainShifter(object):
    """
        Class creating dataset's domains from csv 
    """

    def get_states(self, column):
        """ Getting states by csv file column """

        print(f'Searching states in {column}...')
        states = set()
        for state in self.csv[column]:
            states.add(state) 
        print("States:", *states)
        return states

    def __init__(self, data, file, domains, col_name, col_state, sep=','):
        
        # Check datasets paths
        if not os.path.exists(data):
            raise FileNotFoundError(f"No dataset '{os.path.abspath(data)}' folder found!")
        if not os.path.exists(file):
            raise FileNotFoundError(f"No csv file '{os.path.abspath(file)}' found!")

        def check_cols(*cols):
            """ Check if columns exist in csv """
            try:
                for col in cols:
                    self.csv[col]
            except:
                raise Exception(f'Column name "{col}" is not found in {self.file}!')
        
        # Initialize class local variables
        self.dataset = data # dataset path
        self.file = file # csv file path
        self.domains = domains # domains to create
        self.csv = pd.read_csv(file, sep=sep, encoding='utf8') # read csv with pandas
        check_cols(col_name, col_state) # check on column names exists
        self.states = self.get_states(col_state) # get all states from csv
    
    def back_data(self, mode='move'):
        """ Backing up data from domain folders to dataset folder """
        if mode == 'copy':
            shift = copy
        elif mode == 'move':
            shift = move
        else:
            raise Exception(f'Shift Domains: no {mode} found!')

        print('Backup shifting starts...')
        print(f'Mode: {shift.__name__}')  

        with open('log.txt', 'a', encoding="utf-8") as log, open('err.txt', 'a', encoding="utf-8") as err:
            print('-------- back data ----------', file=log)
            print('-------- back data ----------', file=err)
            for root, sdir, _ in os.walk(self.dataset):
                for folder in sdir:
                    if folder in self.domains.keys():
                        print(f'Start parsing {folder}')
                        print(f'Start parsing {folder}', file=log)
                        for r, _, files in os.walk(os.path.join(root, folder)):
                            nfile = len(files)
                            print('Files:', nfile)
                            for i, name in enumerate(files):
                                if i % (nfile // 30 + 1) == 0:
                                    print(i, 'files shifted')
                                src = os.path.join(r, name)
                                dst = os.path.join(root, name)
                                if mode == 'move' or not os.path.exists(dst):
                                    shift(src, dst)
                        print(f'Parsed: {folder}') 
                        print(f'Parsed: {folder}', file=log)
                    else:
                        print(f'Not domain folder {folder} found')
                        print(f'Not domain folder {folder} found', file=log)


    def shift_domains(self, mode='move'):
        """ Creating domain folders and parsing dataset folder by csv """
        if mode == 'copy':
            shift = copy
        elif mode == 'move':
            shift = move
        else:
            raise Exception(f'Shift Domains: no {mode} found!')
        print('Shifting domains starts...')
        print(f'Mode: {shift.__name__}')
        # Caclculate splits
        domain_split = {}
        for state in self.states:
            domain_split[state] = sum(state in v for v in self.domains.values())
            if domain_split[state] == 0:
                raise Exception(f'States Error: no {state} found in configuration')
        
        # Creating directories
        print('Creating directories...')
        base = self.dataset
        for ndir in self.domains.keys():
            path = os.path.join(base, ndir)
            if not os.path.isdir(path):
                os.mkdir(path)
                print(f'{path} created!')
        print('Created!')
        
        k = 0
        with open('log.txt', 'a', encoding="utf-8") as log, open('err.txt', 'a', encoding="utf-8") as err:
            print('-------- shift domains ----------', file=err)
            print('-------- shift domains ----------', file=log)
            for i, row in self.csv.iterrows():
                if i % 1000 == 0:
                    print(i, 'files processed')
                name = str(row[col_name])
                src = os.path.join(base, name)
                is_shifted = False
                
                k += 1
                if k % 2 == 0: #TODO: add domain split
                    domain_type = 'test'
                else:
                    domain_type = 'train'

                for item in self.domains.items():
                    if row[col_state] in item[1] and item[0][:-1] == domain_type:
                        dst = os.path.join(base, item[0])
                        dstname = os.path.join(dst, name)
                        if os.path.exists(src) and (mode == 'move' or not os.path.exists(dstname)):
                            shift(src, dst)
                            print(f'{shift.__name__}: {src} → {dst}', file=log)
                            is_shifted = True
                        elif os.path.exists(dstname):
                            is_shifted = True
                        break
                if not is_shifted:
                    print(f'{row[col_name]} file not shifted', file=err)
            for root, sdir, _ in os.walk(self.dataset):
                for folder in sdir:
                    if folder in self.domains.keys():
                        for r, _, files in os.walk(os.path.join(root, folder)):
                            nfile = len(files)
                            print(f'Files in domain {folder}: {nfile}')
                            print(f'Files in domain {folder}: {nfile}', file=log)
        print('Shifiting completed!')

# Main
ds = DomainShifter(datapath, csvfile, domains, col_name, col_state)
if not domains2data:
    ds.shift_domains(mode)
else:
    ds.back_data(mode)


if show_errors:
    print('Error log:')
    !tail -n $show_errors err.txt
if show_log:
    print('Log:')
    !tail -n $show_log log.txt
print('Completed!')


In [0]:
#@markdown ### Training Day 2 Night
Resume_from_last_checkpoint = False #@param {type:"boolean"}
Checkpoint_every_iteration = 1000 #@param {type:"slider", min:100, max:5000, step:100}
# TODO: add changing yaml file

if Resume_from_last_checkpoint:
    rsm = '--resume'
else:
    rsm = ''

import os
if not os.path.isdir('/content/drive'):
    print('=====================================================================')
    print('Session ended! Please remount google.drive and reinstall dependences!')
    print('=====================================================================')
else:
    os.chdir('/content/drive/UNIT')

    print('=====================================================================')
    print('Training a model Day to Night image translation')
    print('=====================================================================')
    !python train.py --config configs/unit_day2night.yaml --trainer UNIT $rsm
    print('=====================================================================')
    print('Fully Trained!')
    print('=====================================================================')

Training a model Day to Night image translation
Elapsed time in update: 6.014165
Elapsed time in update: 1.910728
Elapsed time in update: 1.916571
Elapsed time in update: 1.932878
Elapsed time in update: 1.904431


### Summer-2-Winter Translation

In [0]:
#@markdown ### Training Winter 2 Summer
Resume_from_last_checkpoint = True #@param {type:"boolean"}
Checkpoint_every_iteration = 1000 #@param {type:"slider", min:100, max:5000, step:100}



if Resume_from_last_checkpoint:
    rsm = '--resume'
else:
    rsm = ''

import os
if not os.path.isdir('/content/drive'):
    print('=====================================================================')
    print('Session ended! Please remount google.drive and reinstall dependences!')
    print('=====================================================================')
else:
    os.chdir('/content/drive/UNIT')

    print('=====================================================================')
    print('Training a model Summer to Winter image translation')
    print('=====================================================================')
    !python train.py --config configs/unit_summer2winter_yosemite256_folder.yaml --trainer UNIT $rsm
    print('=====================================================================')
    print('Fully Trained!')
    print('=====================================================================')

In [0]:
#@markdown ### Results
iterations = 5004 #@param {type:"slider", min:10, max:10000, step:1}
image_size = 30 #@param {type:"slider", min:15, max:50, step:1}
iterations = int(iterations)
import os
from IPython.display import Image, display, HTML
import cv2
import matplotlib.pyplot as plt


def getTitle(name, it):
    title = ''
    if name[4] == 'a':
        title += 'Summer 2 Winter'
    else:
        title += 'Winter 2 Summer'
    if name[7:13] == '_test_':
        title += ' - Test '
    else:
        title += ' - Train '
    return title + str(it)

def displayImage(file, title):
    image = cv2.imread(file)
    image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
    plt.figure(figsize=(image_size,image_size))
    plt.title(title);
    plt.imshow(image)
    plt.grid(False)

os.chdir('/content/drive/UNIT/outputs/unit_summer2winter_yosemite256_folder/images/')
# print(os.getcwd())
print(f'Selected iterations: {iterations}')
for r, d, f in os.walk('./'):
    file = f[0]
    num_pit = int(file[-12:-4])
    title = getTitle(file, num_pit)
    displayImage(file, title)
    for file in f[1:]:
        if file[-12:-4] != '_current':
            num_it  = int(file[-12:-4])
            if num_it - num_pit >= iterations or num_it == num_pit:
                title = getTitle(file, num_it)
                displayImage(file, title)
                num_pit = num_it






<!-- ![]()
 ![Google's logo](/content/drive/UNIT/outputs/unit_summer2winter_yosemite256_folder/images/gen_a2b_test_00000010.png) -->


## Test

In [0]:
from IPython.display import Image, display
import os

#@markdown ## Testing on PT models
Model = "gta2city" #@param ['gta2city', 'day2night', 'summer2winter']
backward = False #@param {type:"boolean"}


if not os.path.isdir('/content/drive'):
    print('=====================================================================')
    print('Session ended! Please remount google.drive and reinstall dependences!')
    print('=====================================================================')
    assert(False)
    
os.chdir('/content/drive/UNIT')


if not backward:
    print('=====================================================================')
    print('Testing: gta2city')
    print('=====================================================================')


    ! python test.py --trainer UNIT --config configs/unit_gta2city_list.yaml --input inputs/gta3.jpg --output_folder results/gta2city --checkpoint models/unit_gta2city.pt --a2b 1

    print('=====================================================================')
    print('Input')
    display(Image('/content/drive/UNIT/results/gta2city/input.jpg', width=800))
    print('Output')
    display(Image('/content/drive/UNIT/results/gta2city/output.jpg', width=800))

if backward:
    print('=====================================================================')
    print('Testing: city2gta')
    print('=====================================================================')
   
    ! python test.py --trainer UNIT --config configs/unit_gta2city_list.yaml --input inputs/city_example.jpg --output_folder results/city2gta --checkpoint models/unit_gta2city.pt --a2b 0

    print('=====================================================================')
    print('Input')
    display(Image('/content/drive/UNIT/results/city2gta/input.jpg', width=800))
    print('Output')
    display(Image('/content/drive/UNIT/results/city2gta/output.jpg', width=800))