In [1]:
import torch
import torch.nn as nn
import torch.distributions as dist
import torch.nn.functional as F

from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from utils import ARDataset, AutoRegressiveNetwork, device
from data_rnn import load_toy
from tqdm import tqdm

def sample(lnprobs, temperature=1.0):
    if temperature == 0.0:
        return lnprobs.argmax()
    p = F.softmax(lnprobs / temperature, dim=0)
    cd = dist.Categorical(p)
    return cd.sample()

In [2]:
x_train, (i2w, w2i) = load_toy(n=150_000)

In [3]:
model = AutoRegressiveNetwork(w2i, emb=64, h=128).to(device)
optimizer = Adam([p for p in model.parameters() if p.requires_grad], lr=3e-4, weight_decay=1e-4)
dl = ARDataset(x_train, w2i, bs=8, maxsize=300)
criterion = nn.CrossEntropyLoss()
sw = SummaryWriter('runs/lang')

In [4]:
def norm(model: nn.Module):
    total_norm = 0
    for p in model.parameters():
        param_norm = p.grad.detach().data.norm(2)
        total_norm += param_norm.item() ** 2
    return total_norm ** 0.5

for epoch in range(10):
    model.train()
    dl.shuffle()
    total_loss = 0
    c = 0
    for x, y in tqdm(dl.dataloader()):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        loss = criterion(model(x), y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        sw.add_scalar(f'Norm/epoch{epoch}', norm(model), c)
        c += 1
    sw.add_scalar('Loss/train', total_loss / c, epoch)
    print(f'Epoch {epoch}, Train Loss: {total_loss / c:.2f}')
sw.flush()

2742it [00:43, 62.35it/s]


Epoch 0, Train Loss: 0.55


2745it [00:44, 61.94it/s]


Epoch 1, Train Loss: 0.35


2746it [00:47, 58.01it/s]


Epoch 2, Train Loss: 0.34


2745it [00:47, 57.68it/s]


Epoch 3, Train Loss: 0.34


2743it [00:47, 58.13it/s]


Epoch 4, Train Loss: 0.34


2742it [00:50, 54.13it/s]


Epoch 5, Train Loss: 0.34


2743it [00:47, 57.72it/s]


Epoch 6, Train Loss: 0.34


2743it [00:47, 58.36it/s]


Epoch 7, Train Loss: 0.34


2742it [00:46, 58.43it/s]


Epoch 8, Train Loss: 0.34


2744it [00:51, 53.00it/s]

Epoch 9, Train Loss: 0.34



