forked from fastai/fastai
-
Notifications
You must be signed in to change notification settings - Fork 0
/
learner.py
132 lines (111 loc) · 6.08 KB
/
learner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"`Learner` support for computer vision"
from ..torch_core import *
from ..basic_train import *
from ..basic_data import *
from .image import *
from . import models
from ..callback import *
from ..layers import *
__all__ = ['create_cnn', 'create_body', 'create_head', 'ClassificationInterpretation']
# By default split models between first and second layer
def _default_split(m:nn.Module): return (m[1],)
# Split a resnet style model
def _resnet_split(m:nn.Module): return (m[0][6],m[1])
_default_meta = {'cut':-1, 'split':_default_split}
_resnet_meta = {'cut':-2, 'split':_resnet_split }
model_meta = {
models.resnet18 :{**_resnet_meta}, models.resnet34: {**_resnet_meta},
models.resnet50 :{**_resnet_meta}, models.resnet101:{**_resnet_meta},
models.resnet152:{**_resnet_meta}}
def cnn_config(arch):
torch.backends.cudnn.benchmark = True
return model_meta.get(arch, _default_meta)
def create_body(model:nn.Module, cut:Optional[int]=None, body_fn:Callable[[nn.Module],nn.Module]=None):
"Cut off the body of a typically pretrained `model` at `cut` or as specified by `body_fn`."
return (nn.Sequential(*list(model.children())[:cut]) if cut
else body_fn(model) if body_fn else model)
def create_head(nf:int, nc:int, lin_ftrs:Optional[Collection[int]]=None, ps:Floats=0.5):
"""Model head that takes `nf` features, runs through `lin_ftrs`, and about `nc` classes.
:param ps: dropout, can be a single float or a list for each layer."""
lin_ftrs = [nf, 512, nc] if lin_ftrs is None else [nf] + lin_ftrs + [nc]
ps = listify(ps)
if len(ps)==1: ps = [ps[0]/2] * (len(lin_ftrs)-2) + ps
actns = [nn.ReLU(inplace=True)] * (len(lin_ftrs)-2) + [None]
layers = [AdaptiveConcatPool2d(), Flatten()]
for ni,no,p,actn in zip(lin_ftrs[:-1],lin_ftrs[1:],ps,actns):
layers += bn_drop_lin(ni,no,True,p,actn)
return nn.Sequential(*layers)
def create_cnn(data:DataBunch, arch:Callable, cut:Union[int,Callable]=None, pretrained:bool=True,
lin_ftrs:Optional[Collection[int]]=None, ps:Floats=0.5,
custom_head:Optional[nn.Module]=None, split_on:Optional[SplitFuncOrIdxList]=None, **kwargs:Any)->None:
"Build convnet style learners."
meta = cnn_config(arch)
body = create_body(arch(pretrained), ifnone(cut,meta['cut']))
nf = num_features_model(body) * 2
head = custom_head or create_head(nf, data.c, lin_ftrs, ps)
model = nn.Sequential(body, head)
learn = Learner(data, model, **kwargs)
learn.split(ifnone(split_on,meta['split']))
if pretrained: learn.freeze()
apply_init(model[1], nn.init.kaiming_normal_)
return learn
class ClassificationInterpretation():
"Interpretation methods for classification models."
def __init__(self, data:DataBunch, probs:Tensor, y_true:Tensor, losses:Tensor, sigmoid:bool=None):
if sigmoid is not None: warnings.warn("`sigmoid` argument is deprecated, the learner now always return the probabilities")
self.data,self.probs,self.y_true,self.losses = data,probs,y_true,losses
self.pred_class = self.probs.argmax(dim=1)
@classmethod
def from_learner(cls, learn:Learner, sigmoid:bool=None, tta=False):
"Factory method to create from a Learner."
preds = learn.TTA(with_loss=True) if tta else learn.get_preds(with_loss=True)
return cls(learn.data, *preds, sigmoid=sigmoid)
def top_losses(self, k, largest=True):
"`k` largest(/smallest) losses."
return self.losses.topk(k, largest=largest)
def plot_top_losses(self, k, largest=True, figsize=(12,12)):
"Show images in `top_losses` along with their prediction, actual, loss, and probability of actual class."
tl_val,tl_idx = self.top_losses(k,largest)
classes = self.data.classes
rows = math.ceil(math.sqrt(k))
fig,axes = plt.subplots(rows,rows,figsize=figsize)
for i,idx in enumerate(tl_idx):
t=self.data.valid_ds[idx]
t[0].show(ax=axes.flat[i], title=
f'{classes[self.pred_class[idx]]}/{classes[t[1]]} / {self.losses[idx]:.2f} / {self.probs[idx][t[1]]:.2f}')
def confusion_matrix(self):
"Confusion matrix as an `np.ndarray`."
x=torch.arange(0,self.data.c)
cm = ((self.pred_class==x[:,None]) & (self.y_true==x[:,None,None])).sum(2)
return to_np(cm)
def plot_confusion_matrix(self, normalize:bool=False, title:str='Confusion matrix', cmap:Any="Blues", **kwargs)->None:
"Plot the confusion matrix, passing `kawrgs` to `plt.figure`."
# This function is mainly copied from the sklearn docs
cm = self.confusion_matrix()
plt.figure(**kwargs)
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
tick_marks = arange_of(self.data.classes)
plt.xticks(tick_marks, self.data.classes, rotation=90)
plt.yticks(tick_marks, self.data.classes, rotation=0)
if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, cm[i, j], horizontalalignment="center", color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('Actual')
plt.xlabel('Predicted')
def most_confused(self, min_val:int=1)->Collection[Tuple[str,str,int]]:
"Sorted descending list of largest non-diagonal entries of confusion matrix"
cm = self.confusion_matrix()
np.fill_diagonal(cm, 0)
res = [(self.data.classes[i],self.data.classes[j],cm[i,j])
for i,j in zip(*np.where(cm>min_val))]
return sorted(res, key=itemgetter(2), reverse=True)
def Image_predict(img, learn):
img = apply_tfms(learn.data.valid_ds.tfms, img, **learn.data.valid_ds.kwargs)
ds = TensorDataset(img.data[None], torch.zeros(1))
dl = DeviceDataLoader.create(ds, bs=1, shuffle=False, device=learn.data.device, tfms=learn.data.valid_dl.tfms,
num_workers=0)
return get_preds(learn.model, dl, cb_handler=CallbackHandler(learn.callbacks, []))[0][0]
Image.predict = Image_predict