-
Notifications
You must be signed in to change notification settings - Fork 0
/
example.py
33 lines (22 loc) · 827 Bytes
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# -*- coding: utf-8 -*-
"""Distributed multiplication usage example."""
import torch
import torch.nn as nn
from distributed_dot_product.utils.comm import get_rank, get_world_size
from distributed_dot_product.module import DistributedDotProductAttn
torch.manual_seed(111)
device = torch.device('cpu')
if torch.cuda.is_available():
torch.cuda.set_device(get_rank())
device = torch.device('cuda')
module = DistributedDotProductAttn(768, num_heads=2, offset=64)
module.to(device)
criterion = nn.MSELoss()
length = 4096
world_size = get_world_size()
x = torch.rand(1, length // world_size, 768, device=device)
y = torch.rand(1, length // world_size, 768, device=device)
mask = torch.zeros(1, length // world_size, length, device=device).bool()
out = module(x, x, x, mask)
loss = criterion(out, y)
loss.backward()