forked from fastai/fastai
-
Notifications
You must be signed in to change notification settings - Fork 0
/
learner.py
177 lines (157 loc) · 8.66 KB
/
learner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
"`Learner` support for computer vision"
from ..torch_core import *
from ..basic_train import *
from ..basic_data import *
from .image import *
from . import models
from ..callback import *
from ..layers import *
from ..callbacks.hooks import num_features_model
__all__ = ['ClassificationLearner', 'create_cnn', 'create_body', 'create_head', 'ClassificationInterpretation', 'ImageLearner']
# By default split models between first and second layer
def _default_split(m:nn.Module): return (m[1],)
# Split a resnet style model
def _resnet_split(m:nn.Module): return (m[0][6],m[1])
_default_meta = {'cut':-1, 'split':_default_split}
_resnet_meta = {'cut':-2, 'split':_resnet_split }
model_meta = {
models.resnet18 :{**_resnet_meta}, models.resnet34: {**_resnet_meta},
models.resnet50 :{**_resnet_meta}, models.resnet101:{**_resnet_meta},
models.resnet152:{**_resnet_meta}}
def cnn_config(arch):
torch.backends.cudnn.benchmark = True
return model_meta.get(arch, _default_meta)
def create_body(model:nn.Module, cut:Optional[int]=None, body_fn:Callable[[nn.Module],nn.Module]=None):
"Cut off the body of a typically pretrained `model` at `cut` or as specified by `body_fn`."
return (nn.Sequential(*list(model.children())[:cut]) if cut
else body_fn(model) if body_fn else model)
def create_head(nf:int, nc:int, lin_ftrs:Optional[Collection[int]]=None, ps:Floats=0.5):
"""Model head that takes `nf` features, runs through `lin_ftrs`, and about `nc` classes.
:param ps: dropout, can be a single float or a list for each layer."""
lin_ftrs = [nf, 512, nc] if lin_ftrs is None else [nf] + lin_ftrs + [nc]
ps = listify(ps)
if len(ps)==1: ps = [ps[0]/2] * (len(lin_ftrs)-2) + ps
actns = [nn.ReLU(inplace=True)] * (len(lin_ftrs)-2) + [None]
layers = [AdaptiveConcatPool2d(), Flatten()]
for ni,no,p,actn in zip(lin_ftrs[:-1],lin_ftrs[1:],ps,actns):
layers += bn_drop_lin(ni,no,True,p,actn)
return nn.Sequential(*layers)
class ImageLearner(Learner):
def show_results(self, ds_type=DatasetType.Valid, rows:int=3, figsize:Tuple[int,int]=None):
dl = self.dl(ds_type)
preds = self.pred_batch()
figsize = ifnone(figsize, (8,3*rows))
_,axs = plt.subplots(rows, 2, figsize=figsize)
axs[0,0].set_title('Predictions')
axs[0,1].set_title('Ground truth')
for i in range(rows):
x,y = dl.dataset[i]
x.show(ax=axs[i,1], y=y) #Doing that first will update x before we pass it to reconstruct_output
pred = dl.reconstruct_output(preds[i], x)
x.show(ax=axs[i,0], y=pred)
plt.tight_layout()
class ClassificationLearner(ImageLearner):
def predict(self, img:Image):
"Return prect class, label and probabilities for `img`."
ds = self.data.valid_ds
ds.set_item(img)
res = self.pred_batch()[0]
ds.clear_item()
pred_max = res.argmax()
return self.data.classes[pred_max],pred_max,res
def create_cnn(data:DataBunch, arch:Callable, cut:Union[int,Callable]=None, pretrained:bool=True,
lin_ftrs:Optional[Collection[int]]=None, ps:Floats=0.5,
custom_head:Optional[nn.Module]=None, split_on:Optional[SplitFuncOrIdxList]=None,
classification:bool=True, **kwargs:Any)->Learner:
"Build convnet style learners."
assert classification, 'Regression CNN not implemented yet, bug us on the forums if you want this!'
meta = cnn_config(arch)
body = create_body(arch(pretrained), ifnone(cut,meta['cut']))
nf = num_features_model(body) * 2
head = custom_head or create_head(nf, data.c, lin_ftrs, ps)
model = nn.Sequential(body, head)
learner_cls = ifnone(data.learner_type(), ClassificationLearner)
learn = learner_cls(data, model, **kwargs)
learn.split(ifnone(split_on,meta['split']))
if pretrained: learn.freeze()
apply_init(model[1], nn.init.kaiming_normal_)
return learn
@classmethod
def Learner_create_unet(cls, data:DataBunch, arch:Callable, pretrained:bool=True,
split_on:Optional[SplitFuncOrIdxList]=None, **kwargs:Any)->None:
"Build Unet learners."
meta = cnn_config(arch)
body = create_body(arch(pretrained), meta['cut'])
model = to_device(models.unet.DynamicUnet(body, n_classes=data.c), data.device)
learner_cls = ifnone(data.learner_type(), Learner)
learn = learner_cls(data, model, **kwargs)
learn.split(ifnone(split_on,meta['split']))
if pretrained: learn.freeze()
apply_init(model[2], nn.init.kaiming_normal_)
return learn
Learner.create_unet = Learner_create_unet
class ClassificationInterpretation():
"Interpretation methods for classification models."
def __init__(self, data:DataBunch, probs:Tensor, y_true:Tensor, losses:Tensor, sigmoid:bool=None):
if sigmoid is not None: warnings.warn("`sigmoid` argument is deprecated, the learner now always return the probabilities")
self.data,self.probs,self.y_true,self.losses = data,probs,y_true,losses
self.pred_class = self.probs.argmax(dim=1)
@classmethod
def from_learner(cls, learn:Learner, ds_type:DatasetType=DatasetType.Valid, sigmoid:bool=None, tta=False):
"Create an instance of `ClassificationInterpretation`. `tta` indicates if we want to use Test Time Augmentation."
preds = learn.TTA(with_loss=True) if tta else learn.get_preds(ds_type=ds_type, with_loss=True)
return cls(learn.data, *preds, sigmoid=sigmoid)
def top_losses(self, k:int=None, largest=True):
"`k` largest(/smallest) losses and indexes, defaulting to all losses (sorted by `largest`)."
return self.losses.topk(ifnone(k, len(self.losses)), largest=largest)
def plot_top_losses(self, k, largest=True, figsize=(12,12)):
"Show images in `top_losses` along with their prediction, actual, loss, and probability of predicted class."
tl_val,tl_idx = self.top_losses(k,largest)
classes = self.data.classes
rows = math.ceil(math.sqrt(k))
fig,axes = plt.subplots(rows,rows,figsize=figsize)
fig.suptitle('prediction/actual/loss/probability', weight='bold', size=14)
for i,idx in enumerate(tl_idx):
t=self.data.valid_ds[idx]
t[0].show(ax=axes.flat[i], title=
f'{classes[self.pred_class[idx]]}/{classes[t[1]]} / {self.losses[idx]:.2f} / {self.probs[idx][t[1]]:.2f}')
def confusion_matrix(self, slice_size:int=None):
"Confusion matrix as an `np.ndarray`."
x=torch.arange(0,self.data.c)
if slice_size is None: cm = ((self.pred_class==x[:,None]) & (self.y_true==x[:,None,None])).sum(2)
else:
cm = torch.zeros(self.data.c, self.data.c, dtype=x.dtype)
for i in range(0, self.y_true.shape[0], slice_size):
cm_slice = ((self.pred_class[i:i+slice_size]==x[:,None])
& (self.y_true[i:i+slice_size]==x[:,None,None])).sum(2)
torch.add(cm, cm_slice, out=cm)
return to_np(cm)
def plot_confusion_matrix(self, normalize:bool=False, title:str='Confusion matrix', cmap:Any="Blues", norm_dec:int=2,
slice_size:int=None, **kwargs)->None:
"""Plot the confusion matrix, with `title` and using `cmap`. If `normalize`, plots the percentages with
`norm_dec` digits. `slice_size` can be used to avoid out of memory error if your set is too big.
`kawrgs` are passed to `plt.figure`.
"""
# This function is mainly copied from the sklearn docs
cm = self.confusion_matrix(slice_size=slice_size)
plt.figure(**kwargs)
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
tick_marks = arange_of(self.data.classes)
plt.xticks(tick_marks, self.data.classes, rotation=90)
plt.yticks(tick_marks, self.data.classes, rotation=0)
if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
coeff = f'{cm[i, j]:.{norm_dec}f}' if normalize else f'{cm[i, j]}'
plt.text(j, i, coeff, horizontalalignment="center", color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('Actual')
plt.xlabel('Predicted')
def most_confused(self, min_val:int=1, slice_size:int=None)->Collection[Tuple[str,str,int]]:
"Sorted descending list of largest non-diagonal entries of confusion matrix"
cm = self.confusion_matrix(slice_size=slice_size)
np.fill_diagonal(cm, 0)
res = [(self.data.classes[i],self.data.classes[j],cm[i,j])
for i,j in zip(*np.where(cm>min_val))]
return sorted(res, key=itemgetter(2), reverse=True)