In [None]:
from z import *
from trainer import *
from torchvision import datasets, transforms
from torchsummary import summary

In [None]:
transform_train = transforms.Compose([
    transforms.ColorJitter(0.15, 0.15, 0.15),
    transforms.RandomRotation(15),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.2, 0.2, 0.2)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.2, 0.2, 0.2)),
])

In [None]:
train_dataset = datasets.CIFAR10(root='data/', download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root='data/', train=False, transform=transform_test)

In [None]:
def DenseConv(in_channels, mid_channels, out_channels, layers=3, routes=4, kernel_size=3, padding=1):
  return DenseBase(lambda i, o, l: ZConv2d(i, o, routes, kernel_size, padding=padding), in_channels, mid_channels, out_channels, layers)

In [None]:
model = nn.Sequential(
    DenseConv(3, 16, 16),
    SoftPool2d(16, 2, 2),
    DenseConv(16, 16, 32),
    SoftPool2d(32, 2,  2),
    DenseConv(32, 32, 64),
    SoftPool2d(64, 2, 2),
    nn.Flatten(),
    nn.Dropout(0.5),
    ZLinear(64 * 4 * 4, 10, 4),
).to(0)

In [None]:
summary(model, (3, 32, 32))

In [None]:
def loss_func(model, batch, scope):
  x, y = batch
  _y = model(x)
  loss = F.cross_entropy(_y, y)
  scope["metrics"]["Accuracy"] = float(torch.sum((torch.argmax(_y, dim=1) == y).type(torch.float))) / len(scope["dataset"])
  return loss, _y

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)

In [None]:
acc = 0
def on_epoch(scope):
  global acc
  _acc = scope["val_metrics"]["Accuracy"]
  if _acc > acc:
    acc = _acc
    torch.save(model.state_dict(), "cifar.model")
    print("Model saved!")

In [None]:
train(model, loss_func, train_dataset, test_dataset, optimizer, device=0, epochs=150, batch_size=256, on_val_epoch=on_epoch)

In [None]:
print("Best Accuracy = " + str(acc))