In [None]:
import fastai
from fastai.vision.all import *
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.cm as cm
import cv2

In [None]:
root_path = Path("mvtec_anomaly_detection/")
task_path = root_path/"hazelnut"
train_path = task_path/"train/good"

In [None]:
btfms = aug_transforms()+[Normalize.from_stats(*imagenet_stats)]
btfms

In [None]:
#init_feature=8
class UpsampleBlock(Module):
    def __init__(self, up_in_c:int, final_div:bool=True, blur:bool=True, leaky:float=None, **kwargs):
        self.shuf = PixelShuffle_ICNR(up_in_c, up_in_c//2, blur=blur, **kwargs)
        ni = up_in_c//2
        nf = ni if final_div else ni//2
        self.conv1 = ConvLayer(ni, nf, **kwargs)
        self.conv2 = ConvLayer(nf, nf, **kwargs)
        self.relu = nn.ReLU()

    def forward(self, up_in:Tensor) -> Tensor:
        up_out = self.shuf(up_in)
        cat_x = self.relu(up_out)
        return self.conv2(self.conv1(cat_x))
    
def decoder_resnet(y_range, n_out=3):
    return nn.Sequential(UpsampleBlock(512), 
                         UpsampleBlock(256),
                         UpsampleBlock(128),
                         UpsampleBlock(64),
                         UpsampleBlock(32),
                         nn.Conv2d(16, n_out, 1),
                         SigmoidRange(*y_range)
                        )
                         
def autoencoder(encoder, y_range): return nn.Sequential(encoder, decoder_resnet(y_range))    

In [None]:
def image2image(x):
    return x

In [None]:
block = DataBlock(blocks=(ImageBlock(cls=PILImage), ImageBlock(cls=PILImage)),
                  get_items = get_image_files,
                  get_y = image2image,
                  splitter=RandomSplitter(0.1),                  
                  item_tfms=Resize(256),
                  batch_tfms = btfms,
)

In [None]:
dls = block.dataloaders(train_path, batch_size=16)

In [None]:
dls.show_batch()

In [None]:
arch = create_body(resnet18(), n_in=3).cuda()
y_range = (-3.,3.)
ac_resnet = autoencoder(arch, y_range).cuda()

In [None]:
def gram_matrix(x):
    n,c,h,w = x.size()
    x = x.view(n, c, -1)
    return (x @ x.transpose(1,2))/(c*h*w)

In [None]:
base_loss = F.l1_loss

In [None]:
classificator = vgg16_bn(True).features.cuda().eval().requires_grad_(False)
requires_grad(classificator)

In [None]:
blocks = [i-1 for i,o in enumerate(classificator.children()) if isinstance(o,nn.MaxPool2d)]
blocks, [classificator[i] for i in blocks]

In [None]:
class FeatureLoss(nn.Module):
    def __init__(self, m_feat, layer_ids, layer_wgts):
        super().__init__()
        self.m_feat = m_feat
        self.loss_features = [self.m_feat[i] for i in layer_ids]
        self.hooks = hook_outputs(self.loss_features, detach=False)
        self.wgts = layer_wgts
        
        self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))
              ] + [f'gram_{i}' for i in range(len(layer_ids))]
        

    def make_features(self, x, clone=False):
        self.m_feat(x)
        return [(o.clone() if clone else o) for o in self.hooks.stored]
    
    def forward(self, input, target):
        out_feat = self.make_features(target, clone=True)
        in_feat = self.make_features(input)
        self.feat_losses = [base_loss(input,target)]
        self.feat_losses += [base_loss(f_in, f_out)*w
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.feat_losses += [base_loss(gram_matrix(f_in), gram_matrix(f_out))*w**2 * 5e3
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.metrics = dict(zip(self.metric_names, self.feat_losses))
        #print(self.metrics)
        return sum(self.feat_losses)
    
    def __del__(self): self.hooks.remove()

In [None]:
feat_loss = FeatureLoss(classificator, blocks[2:5], [5,15,2])

In [None]:
learn = Learner(dls, ac_resnet, loss_func=feat_loss, wd = 1e-3)

In [None]:
learn.unfreeze()

In [None]:
learn.fit_one_cycle(10)

In [None]:
learn.save("feature_loss_stage_VGG_32_1")

In [None]:
learn.show_results()

In [None]:
learn.unfreeze()

In [None]:
learn.lr_find()

In [None]:
learn.fit_one_cycle(100, slice(1e-3,1e-3))

In [None]:
learn.save("feature_loss_stage_VGG_32_2")

In [None]:
learn.show_results()

In [None]:
learn.lr_find()

In [None]:
learn.fit_one_cycle(200)

In [None]:
learn.fit_one_cycle(100, slice(1e-4,1e-3))

In [None]:
learn.show_results()

In [None]:
learn.save("feature_loss_stage_VGG_32_3")

In [None]:
x = TensorImage(PILImage.create(np.array(Image.open(task_path/"test"/"hole"/"001.png").resize((256,256))))).permute(2,0,1)
prediction = learn.predict(x)
y_hat = prediction[0]

In [None]:
plt.imshow(x.permute(1,2,0))

In [None]:
plt.imshow(y_hat.permute(1,2,0))

In [None]:
diff = np.absolute(x-y_hat)

In [None]:
diff_all = np.sum(np.array(np.absolute(diff)),axis=0)

In [None]:
plt.imshow(diff_all)

In [None]:
diff = np.absolute(x-y_hat)

In [None]:
x = TensorImage(PILImage.create(np.array(Image.open(task_path/"test"/"hole"/"001.png").resize((256,256))))).permute(2,0,1)
prediction = learn.predict(x)
y_hat = prediction[0]

In [None]:
plt.imshow(x.permute(1,2,0))

In [None]:
plt.imshow(y_hat.permute(1,2,0))

In [None]:
diff = np.absolute(x-y_hat)

In [None]:
plt.imshow(diff[1])

In [None]:
learn.save("feature_loss_stage_3")

In [None]:
learn.lr_find()

In [None]:
learn.fit_one_cycle(400, slice(1e-4,1e-3))

In [None]:
learn.show_results()

In [None]:
learn.save("feature_loss_stage_VGG_32_4")

In [None]:
learn.lr_find()

In [None]:
x = TensorImage(PILImage.create(np.array(Image.open(task_path/"test"/"hole"/"003.png").resize((256,256))))).permute(2,0,1)
prediction = learn.predict(x)
y_hat = prediction[0]

In [None]:
plt.imshow(x.permute(1,2,0))

In [None]:
plt.imshow(y_hat.permute(1,2,0))

In [None]:
diff = np.absolute(x-y_hat)

In [None]:
diff_all = np.sum(np.array(np.absolute(diff)),axis=0)

In [None]:
plt.imshow(diff_all)

In [None]:
learn.fit_one_cycle(100, slice(1e-4,1e-3))

In [None]:
learn.show_results()

In [None]:
learn.save("feature_loss_stage_5")

In [None]:
learn.lr_find()

In [None]:
learn.fit_one_cycle(100, slice(1e-5,1e-3))

In [None]:
learn.show_results()

In [None]:
learn.save("feature_loss_stage_6")

In [None]:
learn.lr_find()

In [None]:
learn.fit_one_cycle(300, slice(1e-5,1e-3))

In [None]:
learn.show_results()

In [None]:
learn.save("feature_loss_stage_7")

In [None]:
x = TensorImage(PILImage.create(np.array(Image.open(task_path/"test"/"hole"/"000.png").resize((512,512))))).permute(2,0,1)
y_hat = learn.predict(x)[0]


In [None]:
plt.imshow(x.permute(1,2,0))

In [None]:
plt.imshow(y_hat.permute(1,2,0))

In [None]:
learn.dls.valid

In [None]:
learn.show_results(dl=learn.dls.train)

In [None]:
learn.lr_find()

In [None]:
learn.fit_one_cycle(500, slice(1e-5,1e-4))

In [None]:
learn.show_results(dl=learn.dls.train)

In [None]:
learn.save("feature_loss_stage_8")

In [None]:
learn.show_results(dl=learn.dls.valid)

In [None]:
v

In [None]:
x = TensorImage(PILImage.create(np.array(Image.open(task_path/"test"/"crack"/"000.png").resize((512,512))))).permute(2,0,1)
y_hat = learn.predict(x)[0]


In [None]:
plt.imshow(x.permute(1,2,0))

In [None]:
plt.imshow(y_hat.permute(1,2,0))

In [None]:
diff = np.absolute(x-y_hat)

In [None]:
plt.imshow(diff[0])

In [None]:
learn.lr_find()

In [None]:
5e-4*2, 1e-3

In [None]:
learn.fit_one_cycle(500, slice(1e-5,5e-4))

In [None]:
learn.save("feature_loss_stage_9")

In [None]:
learn.show_results(dl=learn.dls.train)

In [None]:
good_test_images = get_image_files(task_path/"test", folders=["good"])
losses_good = []
for img_file in good_test_images:
    x = TensorImage(PILImage.create(np.array(Image.open(img_file).resize((512,512))))).permute(2,0,1)    
    y_hat = learn.predict(x)[0]    
    losses_good.append(np.absolute(np.array(y_hat - x)).mean())

In [None]:
anomalie_images = get_image_files(task_path/"test", folders=["crack","cut","hole","print"])
losses_anomaly = []
for img_file in anomalie_images:
    x = TensorImage(PILImage.create(np.array(Image.open(img_file).resize((512,512))))).permute(2,0,1)    
    y_hat = learn.predict(x)[0]    
    losses_anomaly.append(np.absolute(np.array(y_hat - x)).mean())

In [None]:

losses = []
for idx in range(len(dls.train_ds)):
    x,y = dls.do_item(idx)
    y_hat = learn.predict(x)[0]    
    losses.append(np.absolute(np.array(y_hat - y)).mean())

In [None]:
plt.hist(losses, bins=100)
plt.hist(losses_anomaly, bins=100)
#plt.hist(losses_good, bins=50)
plt.show()

In [None]:
x = TensorImage(PILImage.create(np.array(Image.open(task_path/"test"/"hole"/"001.png").resize((512,512))))).permute(2,0,1)
y_hat = learn.predict(x)[0]
#learn.loss_func(y_hat,x)
diff = np.absolute(np.array(y_hat - x))
diff.shape

In [None]:
plt.imshow(x.permute(1,2,0))

In [None]:
plt.imshow(y_hat.permute(1,2,0))

In [None]:
heatmap = np.sum(np.absolute(np.array(y_hat - x)),axis=0)
kernel = np.ones((5, 5), 'uint8')
mask = heatmap > heatmap.max()*0.5
mask = cv2.dilate(mask.astype(np.uint8), kernel, iterations=6)
mask = cv2.erode(mask, kernel, iterations=6).astype(bool)
mask = np.ma.masked_where(mask==False, mask)
plt.imshow(mask)

In [None]:
contours, hierarchy = cv2.findContours(image=mask.astype(np.uint8), mode=cv2.RETR_TREE, method=cv2.CHAIN_APPROX_NONE)    

f,axes = plt.subplots(ncols=3,nrows=1,figsize=(20,8))
axes[0].imshow(x.permute(1,2,0))
axes[1].imshow(y_hat.permute(1,2,0))
axes[2].imshow(x.permute(1,2,0))
for cont in contours:
    if len(cont) == 1:
        continue
    cont = np.squeeze(cont)    
    axes[2].plot(cont[:,0], cont[:,1], "r-")
plt.show()

In [None]:
plt.imshow(heatmap)

In [None]:
plt.imshow(np.absolute(np.array(y_hat - x))[2])

In [None]:
dilated = cv2.dilate(heatmap.astype(np.uint16), kernel, iterations=2)


In [None]:
plt.imshow()

In [None]:
blur = cv2.GaussianBlur(dilated,(11,11),0)

In [None]:
plt.imshow(blur)

In [None]:
blur.max()

In [None]:
learn.fit_flat_cos?

In [None]:
learn.fit_flat_cos(500, slice(1e-6,1e-4))

In [None]:
learn.show_results(dl=learn.dls.valid)