# Warp PyTorch Tutorial: Basics


In [None]:
!pip install warp-lang torch

In [None]:
import warp as wp
import numpy as np
import torch

wp.config.quiet = True

# Explicitly initializing Warp is not necessary but
# we do it here to ensure everything is good to go.
wp.init()

# Converting Arrays To/From PyTorch

Warp provides helper functions to convert arrays to/from PyTorch without copying the underlying data (regardless of device). If an associated gradient array exists, this will be converted simultaneously.

In [None]:
"""Warp -> PyTorch"""

# Construct a Warp array, including gradient array
w = wp.array([1.0, 2.0, 3.0], dtype=wp.float32, requires_grad=True, device=wp.get_device())

# Convert to Torch tensor
t = wp.to_torch(w)

print(t)
print(t.grad)

In [None]:
"""PyTorch -> Warp"""

# Construct a Torch tensor, including gradient tensor
t = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32, requires_grad=True, device=torch.device("cuda:0"))

# Convert to Warp array
w = wp.from_torch(t)

print(w)
print(w.grad)

In [None]:
"""PyTorch -> Warp (+ allocate a new gradient array)"""

# Construct a Torch tensor, excluding gradient tensor
t = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32, requires_grad=False, device=torch.device("cuda:0"))

# Convert to Warp array and allocate a gradient array
w = wp.from_torch(t, requires_grad=True)

print(w)
print(w.grad)

In [None]:
"""Map Warp device and dtype to Torch device and dtype"""

# Warp device + dtype to torch
device = wp.device_to_torch(wp.get_device())
dtype = wp.dtype_to_torch(wp.float32)

# Construct a Torch tensor, ensuring we are using the same dtype/device as Warp
t = torch.tensor([1.0, 2.0, 3.0], dtype=dtype, device=device)

print(t)

In [None]:
"""Map Torch device and dtype to Warp device and dtype"""

# Torch device + dtype from torch
device = wp.device_from_torch(torch.device("cuda:0"))
dtype = wp.dtype_from_torch(torch.float32)

# Construct a Warp array, ensuring we are using the same dtype/device as Torch
w = wp.array([1.0, 2.0, 3.0], dtype=dtype, device=device)

print(w)