In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from pathlib import Path
from IPython.core.debugger import set_trace
from fastai import datasets
import pickle, gzip, math, torch, matplotlib as mpl
import matplotlib.pyplot as plt
from torch import tensor, nn, optim
from torch.utils.data import DataLoader, SequentialSampler, RandomSampler
import torch.nn.functional as F

In [6]:
class Callback():
    def begin_fit(self, model, optimizer, loss_func, train_data, valid_data):
        self.model = model
        self.opt = optimizer
        self.loss_function = loss_func
        self.train_dl = train_data
        self.valid_dl = valid_data
        return True
    
    def after_fit(self):
        return True
    
    def begin_epoch(self, epoch):
        self.epoch = epoch
        return True
    
    def begin_validate(self):
        return True
    
    def after_epoch(self):
        return True
    
    def begin_batch(self, x_batch, y_batch):
        self.x_mini_batch = x_batch
        self.y_mini_batch = y_batch
        return True

    def after_loss(self, loss):
        self.loss = loss
        return True
    
    def after_backward(self):
        return True
    def after_step(self):
        return True
    
class CallbackHandler():
    def __init__(self, cbs=None):
        self.callbacks = cbs if cbs else []
        
    def begin_fit(self, model, opt, loss_func, train_dl, valid_dl):
        self.model = model
        self.optimizer = opt
        self.loss_function = loss_func
        self.train_dataloader = train_dl
        self.valid_dataloader = valid_dl
        self.in_train = True #flag for if in training or evaluation
        self.stop = False #flag that gets thrown to stop the model training
        res = True
        for cb in self.callbacks:
            res = res and callbacks.begin_fit(model, opt, loss_func, train_dl ,valid_dl)
        return res
    
    def after_fit(self):
        result = not self.in_train
        for cb in self.callbacks:
            result = result and callbacks.after_fit()
        return result
    
    def begin_epoch(self, epoch):
        self.model.train()
        self.in_train = True
        result = True
        for cb in self.callbacks:
            result = result and callbacks.begin_epoch(epoch)
        return result
    
    def begin_validate(self):
        self.model.eval()
        self.in_train = False
        result = True
        for cb in self.callbacks:
            result = result and callbacks.begin_validate()
        return result
    
    def after_epoch(self):
        result = True
        for cb in self.callbacks:
            result = result and cb.after_epoch()
        return result
    
    def begin_batch(self, x_batch, y_batch):
        result = True
        for cb in self.cbs:
            result = result and cb.begin_batch(xb, yb)
        return result
    
    def after_loss(self, loss):
        result = self.in_train
        for cb in self.callbacks:
            result = result and cb.after_loss(loss)
        return result
            
    def after_backward(self):
        res = True
        for cb in self.cbs: res = res and cb.after_backward()
        return res

    def after_step(self):
        res = True
        for cb in self.cbs: res = res and cb.after_step()
        return res
    
    def do_stop(self):
        try:     return self.learn.stop
        finally: self.learn.stop = False        

In [7]:
#redefining fit()
def one_batch(x_minib, y_minib, callbacks):
    if not callbacks.begin_batch(xb,yb):
        return
    loss = callbacks.loss_function(callbacks.model(x_minib), y_minib)
    if not callbacks.after_loss(loss):
        return
    loss.backward()
    if callbacks.after_backward(): 
        callbacks.opt.step()
    if callbacks.after_step():
        callbacks.opt.zero_grad()
    
def all_batches(dataloader, callbacks):
    for x_minib, y_minib in dataloader:
        one_batch(x_minib, y_minib, callbacks)
        if callbacks.do_stop():
            return
    
def fit(num_epochs, model, optimizer, loss_func, train_dataloader, valid_dataloader):
    if not callbacks.begin_fit(model, optimizer, loss_func, train_dataloader, valid_datalader):
        return
    for epoch in range(num_epochs):
        if not callbacks.begin_epoch(epoch):
            continue
        all_batches(train_dataloader, callbacks)
        
        if callbacks.begin_validate():
            with torch.no_grad():
                all_batches(valid_dataloader, callbacks)
        if callbacks.do_stop() or not callbacks.after_epoch():
            break
        callbacks.after_fit()

In [8]:
class TestCallback(Callback):
    def begin_fit(self,learn):
        super().begin_fit(learn)
        self.n_iters = 0
        return True
        
    def after_step(self):
        self.n_iters += 1
        print(self.n_iters)
        if self.n_iters>=10: self.learn.stop = True
        return True

In [16]:
#create simple 3 layer example model
def get_model(training_data, lr=0.5, nh=50):
    m = training_data.x.shape[1]
    model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,data.c))
    return model, optim.SGD(model.parameters(), lr=lr)


In [17]:
number_hidden = 50
batch_size = 64
loss_func = F.CrossEntropyLoss
x_train,y_train,x_valid,y_valid = get_data()

#setup data
train_ds = Dataset(x_train, y_train)
valid_ds = Dataset(x_valid, y_valid)
train_dl = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True, drop_last=True)
valid_dl = torch.utils.data.DataLoader(valid_dataset, batch_size, shuffle=False)
model, optimizer = get_model(train_data)

AttributeError: module 'torch.nn.functional' has no attribute 'CrossEntropyLoss'

In [14]:
fit(1, model, optimizer, loss_func, train_dl, valid_dl, cb=CallbackHandler([TestCallback()]))

NameError: name 'model' is not defined