-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
85 lines (71 loc) · 2.59 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
import logging
import os
from torch.distributed import init_process_group
def get_logger(filename, local_rank):
formatter = logging.Formatter(fmt='[%(asctime)s %(levelname)s] %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p')
logger = logging.getLogger()
logger.handlers = []
logger.setLevel(logging.INFO)
logger.propagate = False
if filename is not None and local_rank <=0: # only log to file for first GPU
f_handler = logging.FileHandler(filename, 'a')
f_handler.setLevel(logging.INFO)
f_handler.setFormatter(formatter)
logger.addHandler(f_handler)
stdout_handler = logging.StreamHandler()
stdout_handler.setFormatter(formatter)
stdout_handler.setLevel(logging.INFO)
logger.addHandler(stdout_handler)
else: # null handlers for other GPUs
null_handler = logging.NullHandler()
null_handler.setLevel(logging.INFO)
logger.addHandler(null_handler)
return logger
def ddp_setup(rank: int, world_size: int, runid=None):
"""
Args:
rank: Unique identifier of each process
world_size: Total number of processes
"""
os.environ["MASTER_ADDR"] = "localhost"
if runid:
os.environ["MASTER_PORT"] = "1235" + str(runid)
else:
os.environ["MASTER_PORT"] = "1235"
torch.backends.cudnn.benchmark = True
init_process_group(backend="nccl", rank=rank, world_size=world_size)
def tocuda(rank, data):
if type(data) is list:
if len(data) == 1:
return data[0].to(rank)
else:
return [x.to(rank) for x in data]
else:
return data.to(rank)
def evaluate(rank, model, loader):
''' Evaluate some model on some data '''
ncorrect = 0
nsamples = 0
model.eval()
for *data, target in loader:
data, target = tocuda(rank, data), tocuda(rank, target)
with torch.no_grad():
output = model.module(data)
pred = output.data.max(1)[1]
ncorrect += pred.eq(target.data).sum().cpu().item()
nsamples += len(target)
acc = ncorrect / nsamples
return acc
class DataIterator(object):
def __init__(self, dataloader):
assert isinstance(dataloader, torch.utils.data.DataLoader), 'Wrong loader type'
self.loader = dataloader
self.iterator = iter(self.loader)
def __next__(self):
try:
x, y = next(self.iterator)
except StopIteration:
self.iterator = iter(self.loader)
x, y = next(self.iterator)
return x, y