# UNIT Project in Colab
[on GitHub](https://github.com/mingyuliutw/UNIT)


## Workspace Setup
Bash commands setting up the current session for Juputer Notebook

### Google Drive Folder Mounting
Following code mounts your google.drive to container. Choose **one** of next **two** methods.

In [0]:
#@markdown #### Mounting (built-in)
print('Mounting...')
import os
from google.colab import drive
os.chdir('/content/')
drive.mount('/drive/', force_remount=True)
if os.path.exists('/content/drive'):
    os.unlink('/content/drive')
os.symlink('/drive/My Drive', '/content/drive')
!ls '/content/drive/'
print('Mounted!')

In [0]:
#@markdown #### Mounting With Fuse Driver (google-drive-ocamlfuse)
print('Mounting Google.Drive with google-drive-ocamlfuse...')
% cd /content/
print('Installing required software')
! apt-get install -y -qq software-properties-common module-init-tools 2>&1 > /dev/null
print('Add apt-repository with Google.Drive Fuse')
! add-apt-repository -y ppa:alessandro-strada/ppa 2>&1 > /dev/null
print('Updating packages...')
! apt-get update -y -qq
print('Installing google-drive-ocamlfuse fuse...')
! apt-get install -y -qq google-drive-ocamlfuse fuse
print('Authenticate Fuse in Google.Drive...')
from google.colab import auth
from oauth2client.client import GoogleCredentials
import getpass
auth.authenticate_user()
creds = GoogleCredentials.get_application_default()
! google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret} < /dev/null 2>&1 | grep URL
vcode = getpass.getpass('Enter auth code here: ')
! echo {vcode} | google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret}
print('Authenticated!')
print('Creating mount directory')
! mkdir /drive2
print('Mounting...')
! google-drive-ocamlfuse /drive2
if os.path.exists('/content/drive'):
    os.unlink('/content/drive')
os.symlink('/drive2', '/content/drive')
!ls '/content/drive/'
print('Mounted!')

### SSH Tunnel

In [0]:
#@markdown ## Connect to Colab session
#@markdown Using ngrok
port = 4040 #@param {type:"integer"}
only_show_credetionals = True #@param {type:"boolean"}
if not only_show_credetionals:
    print('Generate root password')
    import secrets, string
    password = ''.join(secrets.choice(string.ascii_letters + string.digits) for i in range(20))
    ! echo "Password: $password" > /content/save_pswd
    print('Download ngrok')
    ! wget -q -c -nc https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip
    ! unzip -qq -n ngrok-stable-linux-amd64.zip
    print('Setup sshd')
    ! apt-get install -qq -o=Dpkg::Use-Pty=0 openssh-server pwgen > /dev/null
    print('Set root password')
    ! echo root:$password | chpasswd
    ! mkdir -p /var/run/sshd
    ! echo "PermitRootLogin yes" >> /etc/ssh/sshd_config
    ! echo "PasswordAuthentication yes" >> /etc/ssh/sshd_config
    ! echo "LD_LIBRARY_PATH=/usr/lib64-nvidia" >> /root/.bashrc
    ! echo "export LD_LIBRARY_PATH" >> /root/.bashrc

    print('Run sshd')
    get_ipython().system_raw('/usr/sbin/sshd -D &')

    print("Copy authtoken from https://dashboard.ngrok.com/auth")
    import getpass
    authtoken = getpass.getpass()

    print('Create tunnel')
    get_ipython().system_raw('./ngrok authtoken $authtoken && ./ngrok tcp 22 &')

print('---------')
import sys, json, os
try:
    s = !curl -s http://localhost:$port/api/tunnels
    addr = str(json.loads(s[0])['tunnels'][0]['public_url'])
    print('Use:', end=' ')
    print('ssh root@' + addr[6:addr.find(':', 6)] + ' -p ' + addr[addr.find(':', 6)+1:])
    if os.path.exists('/content/save_pswd'):
        ! cat /content/save_pswd
except:
    print('Tunnel was closed!')


# Project

### Install Dependencies
Session startup installation

In [0]:
#@markdown ## Dependencies
Project = 'google.drive' #@param ['google.drive', 'clone github'] 
copy_pre_trained_model = False #@param {type:"boolean"}

if Project == 'clone github':
    print('Cloning GitHub project...')
    !git clone https://github.com/SoleSensei/UNIT.git
    if copy_pre_trained_model:
        print('Copying PT model gta2city...')
        !mkdir /content/UNIT/models
        !mkdir /content/UNIT/outputs
        !cp -r /content/drive/UNIT/models/ /content/UNIT/
        !cp -r /content/drive/UNIT/output/ /content/UNIT/


print('Installing system packages...')
!apt-get install -y -qq axel imagemagick 2>&1 > /dev/null
print('Installing project dependencies...')
!pip3 install http://download.pytorch.org/whl/cu92/torch-0.4.1-cp36-cp36m-linux_x86_64.whl 2>&1 > /dev/null
!pip3 install torchvision 2>&1 > /dev/null
!pip3 install tensorboard tensorboardX 2>&1 > /dev/null
print('Complete!')

## Train

### Day-2-Night Translation

In [0]:
#@title Shift Domains (nexet dataset)
processing = "google.drive" #@param ["google.drive", "copy nexet → localy"]
shift_domain_mode = "move" #@param ["move", "copy", "none"]
domains2data = True #@param {type:"boolean"}
show_errors = 5 #@param {type:"slider", min:0, max:100, step:1}
show_log = 10 #@param {type:"slider", min:0, max:100, step:1}


datapath = ''
csvfile = ''
if processing != 'google.drive':
    print('Coping nexet from g.drive to local space...')
    !cp -r /content/drive/datasets/nexet/ /content/ 
    print('Copied!')
    datapath = '/content/nexet/nexet_2017_1/'
    csvfile = '/content/nexet/train.cvs'

%cd /content/drive/ParseData 
if shift_domain_mode != 'none':
    print('python data2domains.py', shift_domain_mode, domains2data, datapath, csvfile)
    !python data2domains.py $shift_domain_mode $domains2data $datapath $csvfile
if show_errors:
    print('Error log:')
    !tail -n $show_errors err.txt
if show_log:
    print('Log:')
    !tail -n $show_log log.txt
print('Completed!')


### Summer-2-Winter Translation

In [0]:
#@markdown ### Training Winter 2 Summer
Resume_from_last_checkpoint = True #@param {type:"boolean"}
Checkpoint_every_iteration = 1000 #@param {type:"slider", min:100, max:5000, step:100}



if Resume_from_last_checkpoint:
    rsm = '--resume'
else:
    rsm = ''

import os
if not os.path.isdir('/content/drive'):
    print('=====================================================================')
    print('Session ended! Please remount google.drive and reinstall dependences!')
    print('=====================================================================')
else:
    os.chdir('/content/drive/UNIT')

    print('=====================================================================')
    print('Training a model Summer to Winter image translation')
    print('=====================================================================')
    !python train.py --config configs/unit_summer2winter_yosemite256_folder.yaml --trainer UNIT $rsm
    print('=====================================================================')
    print('Fully Trained!')
    print('=====================================================================')

In [0]:
#@markdown ### Results
iterations = 5004 #@param {type:"slider", min:10, max:10000, step:1}
image_size = 30 #@param {type:"slider", min:15, max:50, step:1}
iterations = int(iterations)
import os
from IPython.display import Image, display, HTML
import cv2
import matplotlib.pyplot as plt


def getTitle(name, it):
    title = ''
    if name[4] == 'a':
        title += 'Summer 2 Winter'
    else:
        title += 'Winter 2 Summer'
    if name[7:13] == '_test_':
        title += ' - Test '
    else:
        title += ' - Train '
    return title + str(it)

def displayImage(file, title):
    image = cv2.imread(file)
    image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
    plt.figure(figsize=(image_size,image_size))
    plt.title(title);
    plt.imshow(image)
    plt.grid(False)

os.chdir('/content/drive/UNIT/outputs/unit_summer2winter_yosemite256_folder/images/')
# print(os.getcwd())
print(f'Selected iterations: {iterations}')
for r, d, f in os.walk('./'):
    file = f[0]
    num_pit = int(file[-12:-4])
    title = getTitle(file, num_pit)
    displayImage(file, title)
    for file in f[1:]:
        if file[-12:-4] != '_current':
            num_it  = int(file[-12:-4])
            if num_it - num_pit >= iterations or num_it == num_pit:
                title = getTitle(file, num_it)
                displayImage(file, title)
                num_pit = num_it






<!-- ![]()
 ![Google's logo](/content/drive/UNIT/outputs/unit_summer2winter_yosemite256_folder/images/gen_a2b_test_00000010.png) -->


## Test

In [0]:
from IPython.display import Image, display
import os

#@markdown ## Testing on PT models
Model = "gta2city" #@param ['gta2city', 'day2night', 'summer2winter']
backward = False #@param {type:"boolean"}


if not os.path.isdir('/content/drive'):
    print('=====================================================================')
    print('Session ended! Please remount google.drive and reinstall dependences!')
    print('=====================================================================')
    assert(False)
    
os.chdir('/content/drive/UNIT')


if not backward:
    print('=====================================================================')
    print('Testing: gta2city')
    print('=====================================================================')


    ! python test.py --trainer UNIT --config configs/unit_gta2city_list.yaml --input inputs/gta3.jpg --output_folder results/gta2city --checkpoint models/unit_gta2city.pt --a2b 1

    print('=====================================================================')
    print('Input')
    display(Image('/content/drive/UNIT/results/gta2city/input.jpg', width=800))
    print('Output')
    display(Image('/content/drive/UNIT/results/gta2city/output.jpg', width=800))

if backward:
    print('=====================================================================')
    print('Testing: city2gta')
    print('=====================================================================')
   
    ! python test.py --trainer UNIT --config configs/unit_gta2city_list.yaml --input inputs/city_example.jpg --output_folder results/city2gta --checkpoint models/unit_gta2city.pt --a2b 0

    print('=====================================================================')
    print('Input')
    display(Image('/content/drive/UNIT/results/city2gta/input.jpg', width=800))
    print('Output')
    display(Image('/content/drive/UNIT/results/city2gta/output.jpg', width=800))