Skip to content

Commit

Permalink
Merge pull request #20 from TezRomacH/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
TezRomacH committed Sep 17, 2020
2 parents 54ebff9 + 148663d commit 8ee55fd
Show file tree
Hide file tree
Showing 10 changed files with 11,818 additions and 11,508 deletions.
45 changes: 35 additions & 10 deletions README.md
Expand Up @@ -19,7 +19,9 @@ PyTorch implementation of L2L execution algorithm from paper [Training Large Neu

You need to define a torch model where all layers are specified in ModuleList.

for example
See [examples folder](examples)

### Basic usage

```python
import torch
Expand Down Expand Up @@ -55,15 +57,15 @@ class M(nn.Module):

return x


model = M(depth=5, dim=40).train() # on CPU
```

Then, you can use the L2L wrapper over this model.

```python
from layer_to_layer_pytorch.l2l import Layer2Layer

model = M(depth=5, dim=40).train() # on CPU

l2l_model = Layer2Layer(
model,
layers_attr="layers", # attribute with ModuleList
Expand All @@ -81,23 +83,46 @@ x = torch.rand(1_000, 40) # on CPU
y = torch.rand(1_000, 40) # on CPU

losses = []
loss_fn = nn.MSELoss(reduction="sum") # since L2L calcs average losses itself, we just need to save them
criterion = nn.MSELoss()

optimizer = optim.AdamW(l2l_model.main_model.parameters(), lr=0.001) # optimizer works with the main model on CPU
optimizer = optim.AdamW(l2l_model.main_params) # optimizer works with the main model on CPU

for i in trange(5000):
for i in trange(2000):
l2l_model.zero_grad()
l2l_model.forward(x)
_ = l2l_model.forward(x)

loss_value = l2l_model.backward(x, y, loss_fn)
loss_value: float = l2l_model.compute_loss(y, criterion)

if i % 50 == 0:
tqdm.write(f"[{i}] loss = {loss_value.item()}")
losses.append(loss_value.item())
tqdm.write(f"[{i}] loss = {loss_value}")
losses.append(loss_value)


l2l_model.backward()
optimizer.step()
l2l_model.update_main_model_params() # Sync params with CPU
```

### FP-16 usage

Cross-mixes-precision available in init params

```python
from layer_to_layer_pytorch.l2l import Layer2Layer

l2l_model = Layer2Layer(
model,
layers_attr="layers",
microbatch_size=100,

# fp-16
mixed_precision=True,
loss_scale = 128.0
)
```

And then train the same way 😉

## Installation

```bash
Expand Down

0 comments on commit 8ee55fd

Please sign in to comment.