In [20]:
import copy
import datetime
import errno
import hashlib
import os
import time
from collections import defaultdict, deque, OrderedDict
from typing import List, Optional, Tuple

import torch
import torch.distributed as dist
from torch import nn


class SmoothedValue:
    """Track a series of values and provide access to smoothed values over a
    window or the global series average.
    """

    def __init__(self, window_size=20, fmt=None):
        if fmt is None:
            fmt = "{median:.4f} ({global_avg:.4f})"
        self.deque = deque(maxlen=window_size)
        self.total = 0.0
        self.count = 0
        self.fmt = fmt

    def update(self, value, n=1):
        self.deque.append(value)
        self.count += n
        self.total += value * n

    def synchronize_between_processes(self):
        """
        Warning: does not synchronize the deque!
        """
        t = reduce_across_processes([self.count, self.total])
        t = t.tolist()
        self.count = int(t[0])
        self.total = t[1]

    @property
    def median(self):
        d = torch.tensor(list(self.deque))
        return d.median().item()

    @property
    def avg(self):
        d = torch.tensor(list(self.deque), dtype=torch.float32)
        return d.mean().item()

    @property
    def global_avg(self):
        return self.total / self.count

    @property
    def max(self):
        return max(self.deque)

    @property
    def value(self):
        return self.deque[-1]

    def __str__(self):
        return self.fmt.format(
            median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
        )

In [16]:
s = SmoothedValue()

In [17]:
s.update(1)

In [19]:
s.update(1, 2, 2)

TypeError: update() takes from 2 to 3 positional arguments but 4 were given

In [8]:
s.deque

deque([1, 2, 2])

In [9]:
print(s)

2.0000 (1.7500)


In [10]:
s = dict()

In [13]:
s['1'] = []

In [14]:
s['1'].append(1)

In [15]:
s

{'1': [1]}

In [None]:
torch.tesno

In [None]:
loss_fn()

In [21]:
loss = nn.CrossEntropyLoss()

In [None]:
torch.Tensor([1,1,1])

In [22]:
loss([1,2,3], [1,2,3])

TypeError: cross_entropy_loss(): argument 'input' (position 1) must be Tensor, not list

In [25]:
s = SmoothedValue()

In [26]:
s.update(10, n=32)

In [28]:
print(s)

10.0000 (10.0000)


In [29]:
s.deque

deque([10])

In [30]:
s.total

320.0

In [None]:
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)