In [1]:
from model.rnn import GRUDecoder
import torch        

In [2]:
def get_device():
    # Check if CUDA is available
    if torch.cuda.is_available():
        # If CUDA is available, select the first CUDA device
        device = torch.device("cuda:0")
        print("Using CUDA device:", torch.cuda.get_device_name(0))
    # Check for MPS availability on supported macOS devices (requires PyTorch 1.12 or newer)
    elif torch.backends.mps.is_available():
        # If MPS is available, use MPS device
        device = torch.device("mps")
        print("Using MPS (Metal Performance Shaders) device")
    else:
        # Fallback to CPU if neither CUDA nor MPS is available
        device = torch.device("cpu")
        print("Using CPU")
    return device


In [3]:
model = GRUDecoder(input_size=12, hidden_size=10).to(get_device())

Using MPS (Metal Performance Shaders) device


In [4]:
x = torch.randn(4,10,12).to(get_device())

Using MPS (Metal Performance Shaders) device


In [5]:
model(x)

tensor([[-0.2743],
        [-0.2173],
        [-0.2639],
        [-0.2543]], device='mps:0', grad_fn=<LinearBackward0>)

In [10]:
xx, _ = model.gru(x)
xx = model.post_gru(xx)

In [11]:
model.attention(xx)

tensor([[-0.0596, -0.0882,  0.0289, -0.2966, -0.1535,  0.0709,  0.1236,  0.0680,
          0.2359,  0.1787],
        [-0.1526, -0.0519,  0.0643, -0.2133, -0.1781,  0.0261,  0.1416, -0.2112,
          0.1317,  0.1871],
        [-0.0015, -0.1295, -0.1343, -0.1155, -0.0357, -0.1870,  0.2212, -0.1039,
          0.2096,  0.1106],
        [-0.1767,  0.1607,  0.1776, -0.0506, -0.0305, -0.0209, -0.0028, -0.0815,
          0.0192,  0.2910]], device='mps:0', grad_fn=<SumBackward1>)

In [16]:
scores = torch.softmax(torch.einsum("bij, j -> bi", xx, model.attention.context), dim=1)
scores.shape

torch.Size([4, 10])

In [18]:
torch.einsum("bij, bi -> bj", xx, scores)

tensor([[-0.0596, -0.0882,  0.0289, -0.2966, -0.1535,  0.0709,  0.1236,  0.0680,
          0.2359,  0.1787],
        [-0.1526, -0.0519,  0.0643, -0.2133, -0.1781,  0.0261,  0.1416, -0.2112,
          0.1317,  0.1871],
        [-0.0015, -0.1295, -0.1343, -0.1155, -0.0357, -0.1870,  0.2212, -0.1039,
          0.2096,  0.1106],
        [-0.1767,  0.1607,  0.1776, -0.0506, -0.0305, -0.0209, -0.0028, -0.0815,
          0.0192,  0.2910]], device='mps:0', grad_fn=<ViewBackward0>)