In [1]:
import torch
import time

In [2]:
def try_all_gpus():
    devices = [
        torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())]
    return devices if devices else [torch.device('cpu')]

class Benchmark:
    def __init__(self, description='Done'):
        self.description = description

    def __enter__(self):
        self.timer = Timer()
        return self

    def __exit__(self, *args):
        print(f'{self.description}: {self.timer.stop():.4f} sec')
        
class Timer:
    def __init__(self):
        self.times = []
        self.start()

    def start(self):
        self.tik = time.time()

    def stop(self):
        self.times.append(time.time() - self.tik)
        return self.times[-1]

    def avg(self):
        return sum(self.times) / len(self.times)

    def sum(self):
        return sum(self.times)

    def cumsum(self):
        return np.array(self.times).cumsum().tolist()

In [3]:
devices = try_all_gpus()

def run(x):
    return [x.mm(x) for _ in range(20)]

x_gpu1 = torch.rand(size=(4000, 4000), device=devices[0])
x_gpu2 = torch.rand(size=(4000, 4000), device=devices[0])

In [4]:
run(x_gpu1)
run(x_gpu2)
torch.cuda.synchronize(devices[0])
torch.cuda.synchronize(devices[0])

with Benchmark('GPU1 time'):
    run(x_gpu1)
    torch.cuda.synchronize(devices[0])
    
with Benchmark('GPU2 time'):
    run(x_gpu2)
    torch.cuda.synchronize(devices[0])

GPU1 time: 0.3610 sec
GPU2 time: 0.3690 sec


In [5]:
with Benchmark('GPU1 & GPU2'):
    run(x_gpu1)
    run(x_gpu2)
    torch.cuda.synchronize()

GPU1 & GPU2: 0.7260 sec


In [6]:
def copy_to_cpu(x, non_blocking=False):
    return [y.to('cpu', non_blocking=non_blocking) for y in x]

with Benchmark('在GPU1上运行'):
    y = run(x_gpu1)
    torch.cuda.synchronize()
    
with Benchmark('复制到CPU'):
    y_cpu = copy_to_cpu(y)
    torch.cuda.synchronize()

在GPU1上运行: 0.3644 sec
复制到CPU: 0.2743 sec


In [7]:
with Benchmark('在GPU1上运行并复制到CPU'):
    y = run(x_gpu1)
    y_cpu = copy_to_cpu(y, True)
    torch.cuda.synchronize()

在GPU1上运行并复制到CPU: 0.5100 sec
