In [1]:
# all_slow

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

from generative_models.layers import scale, unscale

In [3]:
from torchvision.datasets import MNIST
import torchvision.transforms as T

In [4]:
dir = "/media/arto/work/data/mnist"

In [5]:
from pathlib import Path

In [6]:
path = Path(dir)
path.exists()

True

In [7]:
# train_dl = DataLoader(MNIST(dir, transform=T.ToTensor()), batch_size=128, shuffle=True, drop_last=True)
valid_dl = DataLoader(MNIST(dir, transform=T.ToTensor(), train=False), batch_size=32, shuffle=False, drop_last=False)

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [8]:
b = next(iter(valid_dl))

In [9]:
b[0].shape, b[1].shape

(torch.Size([32, 1, 28, 28]), torch.Size([32]))

In [10]:
from generative_models.layers import ConvNet

In [11]:
model = nn.Sequential(
    ConvNet(1),
    nn.Flatten(),
    nn.BatchNorm1d(64*2*2),
    nn.ReLU(),
    nn.Linear(64*2*2, 10)
)

In [12]:
out = model(b[0])
out.shape

torch.Size([32, 10])

In [13]:
from tqdm.auto import tqdm, trange

In [14]:
def train_step(batch, model, optimizer, loss_func, scheduler=None, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    xb, yb = batch[0].to(device), batch[1].to(device)
    preds = model(xb)

    loss = loss_func(preds, yb)
    
    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), 1.)
    optimizer.step()
    if scheduler is not None:
        scheduler.step()
    optimizer.zero_grad()
    return loss.item()

In [15]:
def accuracy(pred, targ):
    return (pred.argmax(-1) == targ).float().mean()
def eval_step(batch, model, loss_func, device=None):
    if device == None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    xb, yb = batch[0].to(device), batch[1].to(device)
    with torch.no_grad():
        preds = model(xb)
        loss = loss_func(preds, yb)
    return loss.item(), accuracy(preds, yb).item()

In [30]:
def fit(n_epoch, model, train_dl, valid_dl, train_step, eval_step, optimizer, loss_func, scheduler=None, device=None):
    
    if device == None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    steps_per_epoch = len(train_dl)
    total_steps = n_epoch * steps_per_epoch
    train_losses = []
    valid_losses = []
    valid_accs = []
    for e in trange(n_epoch):
        
        model.train()
        train_pbar = tqdm(train_dl, leave=False)
        for step, batch in enumerate(train_pbar):
            total_step = (e*steps_per_epoch)+step
            loss = train_step(batch, model, optimizer, loss_func, scheduler, device=device)
            train_losses.append(loss)
            train_pbar.set_description(f"{loss:.2f}")

        model.eval()
        avg_valid_loss = 0
        for step, batch in enumerate(valid_dl):
            loss = eval_step(batch, model, loss_func, device=device)
            avg_valid_loss += (loss-avg_valid_loss) / (step+1)
        valid_losses.append(avg_valid_loss)
    return train_losses, valid_losses, valid_accs

In [24]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

tl, vl, va = fit(2, model, train_dl, valid_dl, train_step, eval_step, optimizer, nn.CrossEntropyLoss())

100%|██████████| 2/2 [00:30<00:00, 15.36s/it]


In [25]:
va

[0.9798259493670884, 0.9860561708860759]

In [17]:
from generative_models.pixelcnn import SimplePixelCNN

In [18]:
model = SimplePixelCNN(ks=5)

SimplePixelCNN(
  (layers): ModuleList(
    (0): MaskedConv2d()
    (1): ChanLayerNorm(
      (ln): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    )
    (2): ReLU()
    (3): MaskedConv2d()
    (4): ChanLayerNorm(
      (ln): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    )
    (5): ReLU()
    (6): MaskedConv2d()
    (7): ChanLayerNorm(
      (ln): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    )
    (8): ReLU()
    (9): MaskedConv2d()
    (10): ChanLayerNorm(
      (ln): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    )
    (11): ReLU()
    (12): MaskedConv2d()
    (13): ChanLayerNorm(
      (ln): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    )
    (14): ReLU()
    (15): MaskedConv2d()
    (16): ChanLayerNorm(
      (ln): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    )
    (17): ReLU()
    (18): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
    (19): ReLU()
    (20): Conv2d(32, 1, kernel_size=(1, 1), stride=(1, 1))
  )
)

In [27]:
def train_step(batch, model, optimizer, loss_func, scheduler=None, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    xb = scale(batch[0].to(device))
    preds = model(xb)

    loss = loss_func(preds, xb)
    
    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), 1.)
    optimizer.step()
    if scheduler is not None:
        scheduler.step()
    optimizer.zero_grad()
    return loss.item()

In [28]:
def eval_step(batch, model, loss_func, device=None):
    if device == None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    xb = scale(batch[0].to(device))
    with torch.no_grad():
        preds = model(xb)
        loss = loss_func(preds, xb)
    return loss.item()

In [22]:
optimizer = torch.optim.Adam(model.parameters())

In [31]:
fit(1, model, valid_dl, valid_dl, train_step, eval_step, optimizer, nn.MSELoss())

100%|██████████| 1/1 [01:13<00:00, 73.23s/it]


([0.025165459141135216,
  0.0241972878575325,
  0.026837032288312912,
  0.02695108950138092,
  0.024686306715011597,
  0.024144113063812256,
  0.024485789239406586,
  0.025522911921143532,
  0.023379918187856674,
  0.02602148987352848,
  0.022684287279844284,
  0.023620959371328354,
  0.024264581501483917,
  0.025816742330789566,
  0.03121105395257473,
  0.02405347116291523,
  0.02727814018726349,
  0.025518158450722694,
  0.024460557848215103,
  0.023682404309511185,
  0.02358895353972912,
  0.022790253162384033,
  0.02382190153002739,
  0.020986245945096016,
  0.02433934435248375,
  0.027637461200356483,
  0.021911630406975746,
  0.027573779225349426,
  0.02243075519800186,
  0.02319721318781376,
  0.021680839359760284,
  0.024634309113025665,
  0.024297423660755157,
  0.02423711121082306,
  0.024424802511930466,
  0.02504887990653515,
  0.02586372010409832,
  0.024227755144238472,
  0.024355586618185043,
  0.025468679144978523,
  0.022759485989809036,
  0.024881746619939804,
  0.022