In [1]:
import subprocess, re, os, sys

# GPU picking
# http://stackoverflow.com/a/41638727/419116
# Nvidia-smi GPU memory parsing.
# Tested on nvidia-smi 370.23

def run_command(cmd):
    """Run command, return output as string."""
    
    output = subprocess.Popen(cmd, stdout=subprocess.PIPE, shell=True).communicate()[0]
    return output.decode("ascii")

def list_available_gpus():
    """Returns list of available GPU ids."""
    
    output = run_command("nvidia-smi -L")
    # lines of the form GPU 0: TITAN X
    gpu_regex = re.compile(r"GPU (?P<gpu_id>\d+):")
    result = []
    for line in output.strip().split("\n"):
        m = gpu_regex.match(line)
        assert m, "Couldnt parse "+line
        result.append(int(m.group("gpu_id")))
    return result

def gpu_memory_map():
    """Returns map of GPU id to memory allocated on that GPU."""

    output = run_command("nvidia-smi")
    gpu_output = output[output.find("GPU Memory"):]
    # lines of the form
    # |    0      8734    C   python                                       11705MiB |
    memory_regex = re.compile(r"[|]\s+?(?P<gpu_id>\d+)\D+?(?P<pid>\d+).+[ ](?P<gpu_memory>\d+)MiB")
    rows = gpu_output.split("\n")
    result = {gpu_id: 0 for gpu_id in list_available_gpus()}
    for row in gpu_output.split("\n"):
        m = memory_regex.search(row)
        if not m:
            continue
        gpu_id = int(m.group("gpu_id"))
        gpu_memory = int(m.group("gpu_memory"))
        result[gpu_id] += gpu_memory
    return result

def pick_gpu_lowest_memory():
    """Returns GPU with the least allocated memory"""

    memory_gpu_map = [(memory, gpu_id) for (gpu_id, memory) in gpu_memory_map().items()]
    best_memory, best_gpu = sorted(memory_gpu_map)[0]
    return best_gpu

def setup_one_gpu():
    assert not 'tensorflow' in sys.modules, "GPU setup must happen before importing TensorFlow"
    gpu_id = pick_gpu_lowest_memory()
    print("Picking GPU "+str(gpu_id))
    os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)

def setup_no_gpu():
    if 'tensorflow' in sys.modules:
        print("Warning, GPU setup must happen before importing TensorFlow")
    os.environ["CUDA_VISIBLE_DEVICES"] = ''

In [2]:
pick_gpu_lowest_memory()

0

In [3]:
list_available_gpus()

[0, 1, 2]

In [4]:
gpu_memory_map()

{0: 4, 1: 11883, 2: 301}

In [5]:
import tensorflow as tf

In [9]:
tf.config.experimental.list_physical_devices('GPU')

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'),
 PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU'),
 PhysicalDevice(name='/physical_device:GPU:2', device_type='GPU')]

In [10]:
import subprocess as sp
import os

def get_gpu_memory():
    command = "nvidia-smi --query-gpu=memory.free --format=csv"
    memory_free_info = sp.check_output(command.split()).decode('ascii').split('\n')[:-1][1:]
    memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
    return memory_free_values

get_gpu_memory()

[12187, 308, 7390]

In [11]:
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    print("Name:", gpu.name, "  Type:", gpu.device_type)

Name: /physical_device:GPU:0   Type: GPU
Name: /physical_device:GPU:1   Type: GPU
Name: /physical_device:GPU:2   Type: GPU


In [14]:
gpu0=gpus[0]

<function PhysicalDevice.index(value, start=0, stop=9223372036854775807, /)>

In [3]:
import jax.numpy as jnp

In [4]:
jnp.linspace(1,200)



DeviceArray([  1.       ,   5.0612245,   9.122449 ,  13.183674 ,
              17.244898 ,  21.306122 ,  25.367346 ,  29.428572 ,
              33.489796 ,  37.55102  ,  41.612244 ,  45.673466 ,
              49.73469  ,  53.79592  ,  57.857143 ,  61.918365 ,
              65.97959  ,  70.04082  ,  74.10204  ,  78.16326  ,
              82.22449  ,  86.28571  ,  90.34693  ,  94.40816  ,
              98.46938  , 102.53062  , 106.59184  , 110.65306  ,
             114.71429  , 118.77551  , 122.83673  , 126.89796  ,
             130.95918  , 135.02042  , 139.08163  , 143.14287  ,
             147.20409  , 151.2653   , 155.32652  , 159.38776  ,
             163.44897  , 167.51021  , 171.57143  , 175.63266  ,
             179.69386  , 183.7551   , 187.81631  , 191.87755  ,
             195.93877  , 200.       ], dtype=float32)