In [30]:
#| default_exp preprocessing.lesson15

: 

# Lesson15 Scripts
> Reproducing lesson 15

In [1]:
#| hide
%load_ext autoreload
%autoreload 2

In [2]:
#| export
from cv_tools.core import *


## learner[ 1:16:01]

In [3]:
import torch
import torchvision.transforms.functional as TF
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from collections.abc import Mapping

In [4]:
from datasets import load_dataset, load_dataset_builder

In [5]:
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
torch.manual_seed(1)

import logging
logging.disable(logging.WARNING)

In [6]:
name = 'fashion_mnist'
dsd = load_dataset(name)
x, y = 'image', 'label'


In [7]:
from miniai.datasets import *
from operator import itemgetter
from torch.utils.data import default_collate
from torch.utils.data import DataLoader
from fastcore.all import *
from miniai.training import *

In [8]:
def inplace(f):
    def _f(b):
        f(b)
        return b
    return _f

In [9]:
@inplace
def transformi(b):
    b[x] = [torch.flatten(TF.to_tensor(o)) for o in b[x]]

In [10]:
bs = 1024
tsd = dsd.with_transform(transformi)

In [41]:
dsd['train'][0][x].size

(28, 28)

In [42]:
tsd['train']['image'][1].shape


torch.Size([784])

In [11]:
def get_dls(train_ds, valid_ds, bs, **kwargs):
    return (DataLoader(train_ds, batch_size=bs, shuffle=True, **kwargs),
            DataLoader(valid_ds, batch_size=bs*2, **kwargs))

In [12]:
tsd

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 60000
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 10000
    })
})

In [13]:
dls = get_dls(tsd['train'], tsd['test'], bs)

In [14]:
def collate_dict(ds):
    get = itemgetter(*ds.features)
    def _f(b): return get(default_collate(b))
    return _f

In [15]:
class DataLoaders:
    def __init__(self, *dls): self.train,self.valid = dls[:2]

    @classmethod
    def from_dd(cls, dd, batch_size, as_tuple=True, **kwargs):
        f = collate_dict(dd['train'])
        return cls(*get_dls(*dd.values(), bs=batch_size, collate_fn=f, **kwargs))

In [16]:
from miniai.conv import *

In [17]:
def_device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'

def to_device(x, device=def_device):
    if isinstance(x, torch.Tensor): return x.to(device)
    if isinstance(x, Mapping): return {k:v.to(device) for k,v in x.items()}
    return type(x)(to_device(o, device) for o in x)

In [18]:
tds = dsd.with_transform(transformi)

In [29]:
class Learner:
	def __init__(self, model, dls, loss_func, opt_func=optim.SGD, lr=1e-1):
		store_attr()

	def one_batch(self):
		self.xb, self.yb = to_device(self.batch)
		self.preds = self.model(self.xb)
		self.loss = self.loss_func(self.preds, self.yb)
		if self.model.training:
			self.loss.backward()
			self.opt.step()
			self.opt.zero_grad()
		with torch.no_grad():
			self.calc_stats()
	
	def calc_stats(self):
		acc = (self.preds.argmax(dim=1).eq(self.yb).float().sum())
		self.accs.append(acc)
		n = len(self.yb)
		self.losses.append(self.loss*n)
		self.ns.append(n)

	def one_epoch(self, train):
		self.model.training = train
		dl = self.dls.train if train else self.dls.valid
		for self.num, self.batch in enumerate(dl):
			self.one_batch()
		n = sum(self.ns)
		#print(f'epoch {self.epoch} , {self.model.training}, {(self.losses).item()/n:.2f} , {sum(self.accs).item()/n:.2f}')
		print(self.epoch, self.model.training, sum(self.losses).item()/n, sum(self.accs).item()/n)
	

	def fit(self, n_epochs):
		self.accs, self.losses, self.ns = [], [], []

		self.model.to(def_device)
		self.opt = self.opt_func(self.model.parameters(), self.lr)
		self.n_epochs = n_epochs
		for self.epoch in range(n_epochs):
			self.one_epoch(True)
			with torch.no_grad():
				self.one_epoch(False)
	


In [20]:
dls = DataLoaders.from_dd(tsd, bs)

In [21]:
m, nh =28*28,50
model = nn.Sequential(nn.Linear(m, nh), nn.ReLU(), nn.Linear(nh, 10))

learner = Learner(model=model, dls=dls, loss_func=F.cross_entropy, opt_func=optim.SGD, lr=0.2)


In [58]:
dls.valid

<torch.utils.data.dataloader.DataLoader at 0x7f09f153a080>

In [56]:
dls.train#

<torch.utils.data.dataloader.DataLoader at 0x7f09f1539e10>

In [22]:

learner.fit(1)


0 True 1.1903846354166667 0.6001833333333333
0 False 1.1394975446428572 0.6130857142857142


> More Flexible learner we need to make it more flexible

- let's see how we can change it if we want to use different Metric

In [23]:
class Metric:
    def __init__(self):
        self.reset()

    def reset(self):
        self.vals, self.ns = [], []

    def add(self, inp, targ=None, n=1):
        self.last = self.calc(inp, targ)
        self.vals.append(self.last)
        self.ns.append(n)

    @property
    def value(self):
        ns = torch.tensor(self.ns)
        return (torch.tensor(self.vals)*ns).sum()/ns.sum()

    def calc(self, inp, targ):
        return inp


> Now do subclassing

In [24]:
class Accuracy(Metric):
    def calc(self, inp, targ):
        return (inp == targ).float().mean()



In [25]:
acc = Accuracy()
acc.add(torch.tensor([0,1,2,3]), torch.tensor([0,1,2,3]))
acc.value


tensor(1.)

In [26]:
loss = Metric()
loss.add(0.6, n=10) # we can do batching and calculate loss for each batch
loss.add(0.5, n=20)
loss.value


tensor(0.53)

> Now with that we are able to use any metric we want

# Callback Learner


In [39]:
class with_cbs:
	"Decorator to add callbacks to a function."
	def __init__(self, nm): self.nm = nm
	def __call__(self, f):
		"Call the function with the callbacks."
		def _f(o, *args, **kwargs):
			"Try to run the function with the callbacks."
			try:
				# First, we call the callback function with the name 'before_' followed by the name of the function we're decorating.
				# This allows us to execute any necessary setup or preprocessing before the function is called.
				o.callback(f'before_{self.nm}')
				# Next, we call the original function with the provided arguments and keyword arguments.
				# This is where the actual work of the function is done.
				f(o, *args, **kwargs)
				# Finally, we call the callback function again, this time with the name 'after_' followed by the name of the function.
				# This allows us to execute any necessary cleanup or postprocessing after the function has been called.
				o.callback(f'after_{self.nm}')
			except globals()[f'Cancel{self.nm.title()}Exception']: pass
			# This line catches a specific exception that is dynamically generated based on the name of the function being decorated.
			# The exception name is constructed by concatenating 'Cancel' with the capitalized first letter of the function name.
			# If this exception is caught, the code simply passes, effectively ignoring the exception and not propagating it further.
		return _f


In [49]:
#| export
def identity(*args): 
	# This line checks if there are any arguments passed to the function. If not, it returns immediately.
	if not args: return 
	# This line unpacks the arguments into a first element 'x' and the rest of the arguments in 'args'.
	x, *args = args
	# This line returns a tuple containing 'x' and the rest of the arguments if there are any. 
	# If there are no additional arguments, it returns 'x' itself.
	return (x,) + tuple(args) if args else x

In [51]:
print(identity(1))
print(identity(1,2))
print(identity(1,'q',3))
print(identity())


1
(1, 2)
(1, 'q', 3)
None


In [52]:
#| export
class CallbackLearner:
    def __init__(self, model, dls, loss_func, cbs, opt_func=optim.SGD, lr=1e-1):
        store_attr()

        # Here we're looping through each callback in self.cbs and setting its learn attribute to self, 
        # which is the current CallbackLearner instance. This allows each callback to access the learner.
        for callback in self.cbs: callback.learn = self

    def one_batch(self):
        self.xb, self.yb = to_device(self.batch)
        self.preds = self.model(self.xb)
        self.loss = self.loss_func(self.preds, self.yb)
        if self.model.training:
            self.loss.backward()
            self.opt.step()
            self.opt.zero_grad()
        with torch.no_grad():
            self.calc_stats()
    
    def calc_stats(self):
        acc = (self.preds.argmax(dim=1).eq(self.yb).float().sum())
        self.accs.append(acc)
        n = len(self.yb)
        self.losses.append(self.loss*n)
        self.ns.append(n)

    def one_epoch(self, train):
        self.model.training = train
        dl = self.dls.train if train else self.dls.valid
        for self.num, self.batch in enumerate(dl):
            self.one_batch()
        n = sum(self.ns)
        #print(f'epoch {self.epoch} , {self.model.training}, {(self.losses).item()/n:.2f} , {sum(self.accs).item()/n:.2f}')
        print(self.epoch, self.model.training, sum(self.losses).item()/n, sum(self.accs).item()/n)
    

    def fit(self, n_epochs):
        self.n_epochs = n_epochs
        self.epochs = range(n_epochs)
        self.opt = self.opt_func(self.model.parameters(), self.lr)
        # This is not the bottom _fit function, but the with_cbs function, 
        # which will run the callback functions of the learner
        self._fit()
    
    @with_cbs('fit')
    def _fit(self):
        for self.epoch in self.epochs:
            self.one_epoch(True)
            with torch.no_grad():
                self.one_epoch(False)

    def callback(self, method_nm):
        # Sort the callbacks based on their order attribute
        for cb in sorted(self.cbs, key=attrgetter('order')):
            # Attempt to call the method specified by method_nm on the callback
            # If the method does not exist, call the identity method instead
            getattr(cb, method_nm, self.identity)()


In [53]:
class CallBack:
	order = 0

In [31]:
#| hide
import nbdev; nbdev.nbdev_export('01_preprocessing.lesson_15.ipynb')