Skip to content

Commit 95c630a

Browse files
committed
add huggingface/transformer
1 parent a927117 commit 95c630a

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

Dockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,7 @@ RUN pip install flashtext && \
481481
pip install qgrid && \
482482
pip install bqplot && \
483483
pip install earthengine-api && \
484+
pip install transformers && \
484485
/tmp/clean-layer.sh
485486

486487
# Tesseract and some associated utility packages

tests/test_transformers.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import unittest
2+
3+
import torch
4+
from transformers import AdamW
5+
6+
7+
class TestTransformers(unittest.TestCase):
8+
def assertListAlmostEqual(self, list1, list2, tol):
9+
self.assertEqual(len(list1), len(list2))
10+
for a, b in zip(list1, list2):
11+
self.assertAlmostEqual(a, b, delta=tol)
12+
13+
def test_adam_w(self):
14+
w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True)
15+
target = torch.tensor([0.4, 0.2, -0.5])
16+
criterion = torch.nn.MSELoss()
17+
# No warmup, constant schedule, no gradient clipping
18+
optimizer = AdamW(params=[w], lr=2e-1, weight_decay=0.0)
19+
for _ in range(100):
20+
loss = criterion(w, target)
21+
loss.backward()
22+
optimizer.step()
23+
w.grad.detach_() # No zero_grad() function on simple tensors. we do it ourselves.
24+
w.grad.zero_()
25+
self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)

0 commit comments

Comments
 (0)