<a href="https://colab.research.google.com/github/Joonqi/GCAD/blob/main/train_GCAD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Global Context Anomaly Detection
Method to solve MVTec LOCO dataset -> by extracting logically constrained features

In [None]:
# '''
# https://github.com/denguir/student-teacher-anomaly-detection
# https://github.com/erezposner/Fast_Dense_Feature_Extraction
# https://discuss.pytorch.org/t/unet-implementation/426
# '''

In [None]:
# !pip install einops
# !pip install torchsummary

In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms
import torchvision.models as models
import numpy as np
import pandas as pd
import torchsummary
# import pytorch_lightning as pl
from tqdm import tqdm
import PIL
from PIL import Image
torch.manual_seed(42)

import warnings
warnings.filterwarnings(action='ignore')

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import time
time_stamp = str(int(time.time()))[2:]
print(time_stamp)

target_dataset = 'juice_bottle'

d_glo = 10
d_loc = 128

epochs = 125
early_stopping = 30

os.makedirs(f'/content/drive/MyDrive/LOCO_AD/gcad/{target_dataset}_{time_stamp}', exist_ok=True)
os.makedirs(f'/content/drive/MyDrive/LOCO_AD/gcad/{target_dataset}_{time_stamp}/results', exist_ok=True)
for mode in ['both', 'global', 'local']:
    os.makedirs(f'/content/drive/MyDrive/LOCO_AD/gcad/{target_dataset}_{time_stamp}/{mode}/results', exist_ok=True)
    
mode = 'both'
assert (mode in ['both', 'local', 'global'])

69601607


In [None]:
print(torch.cuda.is_available())

True


In [None]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= "0,1"
os.environ["CUDA_LAUNCH_BLOCKING"]= '1'

In [None]:
# Source:
# https://github.com/erezposner/Fast_Dense_Feature_Extraction

from torch import nn
import torch
import numpy as np
import torch.nn.functional as F

# (N,C,H,W)


class multiPoolPrepare(nn.Module):
    def __init__(self, patchY, patchX):
        super(multiPoolPrepare, self).__init__()
        pady = patchY - 1
        padx = patchX - 1

        self.pad_top = np.ceil(pady / 2).astype(int)
        self.pad_bottom = np.floor(pady / 2).astype(int)
        self.pad_left = np.ceil(padx / 2).astype(int)
        self.pad_right = np.floor(padx / 2).astype(int)

    def forward(self, x):
        y = F.pad(x, [self.pad_left, self.pad_right, self.pad_top, self.pad_bottom], value=0)
        return y


class unwrapPrepare(nn.Module):
    def __init__(self):
        super(unwrapPrepare, self).__init__()

    def forward(self, x):
        x_ = F.pad(x, [0, -1, 0, -1], value=0)
        y = x_.contiguous().view(x_.shape[0], -1)
        y = y.transpose(0, 1)
        return y.contiguous()


class unwrapPool(nn.Module):
    def __init__(self, outChans, curImgW, curImgH, dW, dH):
        super(unwrapPool, self).__init__()
        self.outChans = int(outChans)
        self.curImgW = int(curImgW)
        self.curImgH = int(curImgH)
        self.dW = int(dW)
        self.dH = int(dH)

    def forward(self, x, ):
        y = x.view((self.outChans, self.curImgW, self.curImgH, self.dH, self.dW, -1))
        y = y.transpose(2, 3)

        return y.contiguous()


class multiMaxPooling(nn.Module):
    def __init__(self, kW, kH, dW, dH):
        super(multiMaxPooling, self).__init__()
        layers = []
        self.padd = []
        for i in range(0, dH):
            for j in range(0, dW):
                self.padd.append((-j, -i))
                layers.append(nn.MaxPool2d(kernel_size=(kW, kH), stride=(dW, dH)))
        self.max_layers = nn.ModuleList(layers)
        self.s = dH

    def forward(self, x):

        hh = []
        ww = []
        res = []

        for i in range(0, len(self.max_layers)):
            pad_left, pad_top = self.padd[i]
            _x = F.pad(x, [pad_left, pad_left, pad_top, pad_top], value=0)
            _x = self.max_layers[i](_x)
            h, w = _x.size()[2], _x.size()[3]
            hh.append(h)
            ww.append(w)
            res.append(_x)
        max_h, max_w = np.max(hh), np.max(ww)
        for i in range(0, len(self.max_layers)):
            _x = res[i]
            h, w = _x.size()[2], _x.size()[3]
            pad_top = np.floor((max_h - h) / 2).astype(int)
            pad_bottom = np.ceil((max_h - h) / 2).astype(int)
            pad_left = np.floor((max_w - w) / 2).astype(int)
            pad_right = np.ceil((max_w - w) / 2).astype(int)
            _x = F.pad(_x, [pad_left, pad_right, pad_top, pad_bottom], value=0)
            res[i] = _x
        return torch.cat(res, 0)


class multiConv(nn.Module):
    def __init__(self, nInputPlane, nOutputPlane, kW, kH, dW, dH):
        super(multiConv, self).__init__()
        layers = []
        self.padd = []
        for i in range(0, dH):
            for j in range(0, dW):
                self.padd.append((-j, -i))
                torch.manual_seed(10)
                torch.cuda.manual_seed(10)
                a = nn.Conv2d(nInputPlane, nOutputPlane, kernel_size=(kW, kH), stride=(dW, dH), padding=0)
                layers.append(a)
        self.max_layers = nn.ModuleList(layers)
        self.s = dW

    def forward(self, x):
        hh = []
        ww = []
        res = []

        for i in range(0, len(self.max_layers)):
            pad_left, pad_top = self.padd[i]
            _x = F.pad(x, [pad_left, pad_left, pad_top, pad_top], value=0)
            _x = self.max_layers[i](_x)
            h, w = _x.size()[2], _x.size()[3]
            hh.append(h)
            ww.append(w)
            res.append(_x)
        max_h, max_w = np.max(hh), np.max(ww)
        for i in range(0, len(self.max_layers)):
            _x = res[i]
            h, w = _x.size()[2], _x.size()[3]
            pad_top = np.ceil((max_h - h) / 2).astype(int)
            pad_bottom = np.floor((max_h - h) / 2).astype(int)
            pad_left = np.ceil((max_w - w) / 2).astype(int)
            pad_right = np.floor((max_w - w) / 2).astype(int)
            _x = F.pad(_x, [pad_left, pad_right, pad_top, pad_bottom], value=0)
            res[i] = _x
        return torch.cat(res, 0)

In [None]:
class LocalBranchEncoder(nn.Module):
    def __init__(self, fdfe=False):
        super(LocalBranchEncoder, self).__init__()
        self.pH = 33
        self.pW = 33
        self.bool_fdfe = fdfe
        self.multiPoolPrepare = multiPoolPrepare(self.pH, self.pW)

        self.conv1 = nn.Conv2d(3, 128, 5, 1)
        self.conv2 = nn.Conv2d(128, 256, 5, 1)
        self.conv3 = nn.Conv2d(256, 256, 2, 1)
        self.conv4 = nn.Conv2d(256, 128, 4, 1)
        self.outChans = self.conv4.out_channels
        self.decode = nn.Linear(128, 512)

        self.dropout_2d = nn.Dropout2d(0.2)
        self.dropout = nn.Dropout(0.2)
        
        self.max_pool = nn.MaxPool2d(2, 2)
        self.multiMaxPooling = multiMaxPooling(2, 2, 2, 2)
        self.unwrapPrepare = unwrapPrepare()

        self.l_relu = nn.LeakyReLU(5e-3)

    def fdfe(self, x):
        '''Use Fast Dense Feature Extraction to efficiently apply 
        the patch-based CNN AnomalyNet33 on a whole image.'''

        imH = x.size(2)
        imW = x.size(3)

        unwrapPool2 = unwrapPool(self.outChans, imH / (2 * 2), imW / (2 * 2), 2, 2)
        unwrapPool1 = unwrapPool(self.outChans, imH / 2, imW / 2, 2, 2)

        x = self.multiPoolPrepare(x)

        x = self.l_relu(self.conv1(x))
        x = self.multiMaxPooling(x)

        x = self.l_relu(self.conv2(x))
        x = self.multiMaxPooling(x)

        x = self.l_relu(self.conv3(x))
        x = self.l_relu(self.conv4(x))

        x = self.unwrapPrepare(x)
        x = unwrapPool2(x)
        x = unwrapPool1(x)

        y = x.view(self.outChans, imH, imW, -1)
        y = y.permute(3, 1, 2, 0)
        y = self.l_relu(self.decode(y))
        return y

    def forward(self, x, fdfe=False):
        if (fdfe) or (self.bool_fdfe):
            return self.fdfe(x)
        else:
            assert x.size(2) == self.pH and x.size(3) == self.pW, \
                f"This patch extractor only accepts input of size (b, 3, {self.pH}, {self.pW})"
            x = self.l_relu(self.conv1(x))
            x = self.max_pool(x)
            x = self.l_relu(self.conv2(x))
            x = self.max_pool(x)
            x = self.l_relu(self.conv3(x))
            x = self.l_relu(self.conv4(x))
            x = self.dropout_2d(x)
            x = x.view(-1, self.outChans)
            x = self.l_relu(self.decode(x))
            x = self.dropout(x)
            return x
        
print("LocalBranchEncoder on Patches (33x33)")
model = LocalBranchEncoder().cuda()
torchsummary.summary(model, (3, 33, 33))
del model

print("LocalBranchEncoder on whole image (256x256)")
model = LocalBranchEncoder(fdfe=True).cuda()
torchsummary.summary(model, (3, 256, 256))
del model

LocalBranchEncoder on Patches (33x33)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 128, 29, 29]           9,728
         LeakyReLU-2          [-1, 128, 29, 29]               0
         MaxPool2d-3          [-1, 128, 14, 14]               0
            Conv2d-4          [-1, 256, 10, 10]         819,456
         LeakyReLU-5          [-1, 256, 10, 10]               0
         MaxPool2d-6            [-1, 256, 5, 5]               0
            Conv2d-7            [-1, 256, 4, 4]         262,400
         LeakyReLU-8            [-1, 256, 4, 4]               0
            Conv2d-9            [-1, 128, 1, 1]         524,416
        LeakyReLU-10            [-1, 128, 1, 1]               0
        Dropout2d-11            [-1, 128, 1, 1]               0
           Linear-12                  [-1, 512]          66,048
        LeakyReLU-13                  [-1, 512]               0
 

In [None]:
# class GlobalBranchEncoder(pl.LightningModule):

class GlobalBranchEncoder(nn.Module):
    def __init__(self, decode=True, upsample=False):
        super(GlobalBranchEncoder, self).__init__()
        self.g_dim = 32
        self.decode = decode
        self.bool_upsample = upsample

        self.conv1 = nn.Conv2d(3, 32, (4, 4), 2, 1)
        self.conv2 = nn.Conv2d(32, 32, (4, 4), 2, 1)
        self.conv3 = nn.Conv2d(32, 64, (4, 4), 2, 1)
        self.conv4 = nn.Conv2d(64, 64, (4, 4), 2, 1)
        self.conv5 = nn.Conv2d(64, 64, (4, 4), 2, 1)
        self.conv6 = nn.Conv2d(64, self.g_dim, (8, 8), 2, 1)
            
        self.upconv1 = nn.ConvTranspose2d(self.g_dim, 32, (8, 8), 2, 1)
        self.upconv2 = nn.ConvTranspose2d(64, 32, (4, 4), 2, 1)
        self.upconv3 = nn.ConvTranspose2d(64, 32, (4, 4), 2, 1)
        self.upconv4 = nn.ConvTranspose2d(64, 32, (4, 4), 2, 1)
        self.upconv5 = nn.ConvTranspose2d(64, 32, (4, 4), 2, 1)
        self.upconv6 = nn.ConvTranspose2d(64, 32, (4, 4), 2, 1)
        
        self.skip_conn1 = nn.Conv2d(32, 32, (1, 1), 1)
        self.skip_conn2 = nn.Conv2d(32, 32, (1, 1), 1)
        self.skip_conn3 = nn.Conv2d(64, 32, (1, 1), 1)
        self.skip_conn4 = nn.Conv2d(64, 32, (1, 1), 1)
        self.skip_conn5 = nn.Conv2d(64, 32, (1, 1), 1)

        self.decoder = nn.Conv2d(32, d_glo, (1, 1), 1)
        self.upsample = nn.Conv2d(32, d_loc, (1, 1), 1)
        self.l_relu = nn.LeakyReLU()
        
        self.dropout_2d = nn.Dropout2d(0.2)
        self.dropout = nn.Dropout(0.2)
        
        self.max_pool = nn.MaxPool2d(2, 2)
        self.multiMaxPooling = multiMaxPooling(2, 2, 2, 2)
        self.unwrapPrepare = unwrapPrepare()
        self.outChans = 128
    
    def forward(self, input, fdfe=True, upsample=False):            
        x = self.conv1(input)
        x = self.l_relu(x)
        _u = self.l_relu(self.skip_conn1(x))
        x = self.conv2(x)
        x = self.l_relu(x)
        _v = self.l_relu(self.skip_conn2(x))
        x = self.conv3(x)
        x = self.l_relu(x)
        _w = self.l_relu(self.skip_conn3(x))
        x = self.conv4(x)
        x = self.l_relu(x)
        _x = self.l_relu(self.skip_conn4(x))
        x = self.conv5(x)
        x = self.l_relu(x)
        _y = self.l_relu(self.skip_conn5(x))
        x = self.conv6(x)
        g = self.l_relu(x)

        x = torch.cat((self.upconv1(g), _y), 1)
        x = self.l_relu(x)    
        x = torch.cat((self.upconv2(x), _x), 1)
        x = self.l_relu(x)
        x = torch.cat((self.upconv3(x), _w), 1)
        x = self.l_relu(x)
        x = torch.cat((self.upconv4(x), _v), 1)
        x = self.l_relu(x)
        x = torch.cat((self.upconv5(x), _u), 1)
        x = self.l_relu(x)
        x = self.upconv6(x)
        x = self.l_relu(x)

        if (upsample) or (self.bool_upsample):
            output = self.l_relu(self.upsample(x))
        elif self.decode:
            output = self.l_relu(self.decoder(x))
        else:
            output = x        
      
        return output

model = GlobalBranchEncoder(upsample=True)
# model = GlobalBranchEncoder()

torchsummary.summary(model.cuda(), (3, 256, 256))
del model

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 128, 128]           1,568
         LeakyReLU-2         [-1, 32, 128, 128]               0
            Conv2d-3         [-1, 32, 128, 128]           1,056
         LeakyReLU-4         [-1, 32, 128, 128]               0
            Conv2d-5           [-1, 32, 64, 64]          16,416
         LeakyReLU-6           [-1, 32, 64, 64]               0
            Conv2d-7           [-1, 32, 64, 64]           1,056
         LeakyReLU-8           [-1, 32, 64, 64]               0
            Conv2d-9           [-1, 64, 32, 32]          32,832
        LeakyReLU-10           [-1, 64, 32, 32]               0
           Conv2d-11           [-1, 32, 32, 32]           2,080
        LeakyReLU-12           [-1, 32, 32, 32]               0
           Conv2d-13           [-1, 64, 16, 16]          65,600
        LeakyReLU-14           [-1, 64,

In [None]:
# Adapted from https://discuss.pytorch.org/t/unet-implementation/426

import torch
from torch import nn
import torch.nn.functional as F


class UNet(nn.Module):
    def __init__(
        self,
        in_channels=1,
        n_classes=2,
        depth=5,
        wf=6,
        padding=False,
        batch_norm=False,
        up_mode='upconv',
    ):
        super(UNet, self).__init__()
        assert up_mode in ('upconv', 'upsample')
        self.padding = padding
        self.depth = depth
        prev_channels = in_channels
        self.down_path = nn.ModuleList()
        for i in range(depth):
            self.down_path.append(
                UNetConvBlock(prev_channels, 2 ** (wf + i), padding, batch_norm)
            )
            prev_channels = 2 ** (wf + i)

        self.up_path = nn.ModuleList()
        for i in reversed(range(depth - 1)):
            self.up_path.append(
                UNetUpBlock(prev_channels, 2 ** (wf + i), up_mode, padding, batch_norm)
            )
            prev_channels = 2 ** (wf + i)

        self.last = nn.Conv2d(prev_channels, n_classes, kernel_size=1)

    def forward(self, x):
        blocks = []
        for i, down in enumerate(self.down_path):
            x = down(x)
            if i != len(self.down_path) - 1:
                blocks.append(x)
                x = F.max_pool2d(x, 2)

        for i, up in enumerate(self.up_path):
            x = up(x, blocks[-i - 1])

        return self.last(x)


class UNetConvBlock(nn.Module):
    def __init__(self, in_size, out_size, padding, batch_norm):
        super(UNetConvBlock, self).__init__()
        block = []

        block.append(nn.Conv2d(in_size, out_size, kernel_size=3, padding=int(padding)))
        block.append(nn.ReLU())
        if batch_norm:
            block.append(nn.BatchNorm2d(out_size))

        block.append(nn.Conv2d(out_size, out_size, kernel_size=3, padding=int(padding)))
        block.append(nn.ReLU())
        if batch_norm:
            block.append(nn.BatchNorm2d(out_size))

        self.block = nn.Sequential(*block)

    def forward(self, x):
        out = self.block(x)
        return out


class UNetUpBlock(nn.Module):
    def __init__(self, in_size, out_size, up_mode, padding, batch_norm):
        super(UNetUpBlock, self).__init__()
        if up_mode == 'upconv':
            self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2)
        elif up_mode == 'upsample':
            self.up = nn.Sequential(
                nn.Upsample(mode='bilinear', scale_factor=2),
                nn.Conv2d(in_size, out_size, kernel_size=1),
            )

        self.conv_block = UNetConvBlock(in_size, out_size, padding, batch_norm)

    def center_crop(self, layer, target_size):
        _, _, layer_height, layer_width = layer.size()
        diff_y = (layer_height - target_size[0]) // 2
        diff_x = (layer_width - target_size[1]) // 2
        return layer[
            :, :, diff_y : (diff_y + target_size[0]), diff_x : (diff_x + target_size[1])
        ]

    def forward(self, x, bridge):
        up = self.up(x)
        crop1 = self.center_crop(bridge, up.shape[2:])
        out = torch.cat([up, crop1], 1)
        out = self.conv_block(out)

        return out
        
model = UNet(in_channels=3, n_classes=d_loc, padding=True, up_mode='upsample').cuda()
torchsummary.summary(model, (3, 256, 256))
del model

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 256, 256]           1,792
              ReLU-2         [-1, 64, 256, 256]               0
            Conv2d-3         [-1, 64, 256, 256]          36,928
              ReLU-4         [-1, 64, 256, 256]               0
     UNetConvBlock-5         [-1, 64, 256, 256]               0
            Conv2d-6        [-1, 128, 128, 128]          73,856
              ReLU-7        [-1, 128, 128, 128]               0
            Conv2d-8        [-1, 128, 128, 128]         147,584
              ReLU-9        [-1, 128, 128, 128]               0
    UNetConvBlock-10        [-1, 128, 128, 128]               0
           Conv2d-11          [-1, 256, 64, 64]         295,168
             ReLU-12          [-1, 256, 64, 64]               0
           Conv2d-13          [-1, 256, 64, 64]         590,080
             ReLU-14          [-1, 256,

In [None]:
from PIL import Image

class AnomalyDataset(torch.utils.data.dataset.Dataset):
    def __init__(self, root_dir, transform=transforms.ToTensor(), gt_transform=transforms.ToTensor(), **constraint):
        super(AnomalyDataset, self).__init__()
        self.root_dir = root_dir
        self.transform = transform
        self.gt_transform = gt_transform
        self.img_dir = os.path.join(self.root_dir, 'img')
        self.gt_dir = os.path.join(self.root_dir, 'ground_truth')
        self.dataset = self.root_dir.split('/')[-1]
        self.csv_file =  os.path.join(self.root_dir, self.dataset + '.csv')
        self.frame_list = self._get_dataset(self.csv_file, constraint)
    
    def _get_dataset(self, csv_file, constraint):
        '''Apply filter based on the contraint dict on the dataset'''
        df = pd.read_csv(csv_file, keep_default_na=False)
        df = df.loc[(df[list(constraint)] == pd.Series(constraint)).all(axis=1)]
        return df
    
    def __len__(self):
        return len(self.frame_list)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        item = self.frame_list.iloc[idx]
        img_path = os.path.join(self.img_dir, item['image_name'])
        label = self.frame_list.iloc[idx]['label']
        image = Image.open(img_path)
 
        if item['gt_name']:
            if target_dataset in ['juice_bottle', 'pushpins', 'screw_bag', 'splicing_connectors', 'breakfast_box']:
                gt_path = os.path.join(self.gt_dir, item['gt_name'][:-9], '000.png')
            else:
                gt_path = os.path.join(self.gt_dir, item['gt_name'])
            gt = Image.open(gt_path)
        
        else:
            gt = Image.new('L', image.size, color=0)

        sample = {'label': label}

        if self.transform:
            sample['image'] = self.transform(image)

        if self.gt_transform:
            sample['gt'] = self.gt_transform(gt)

        return sample

In [None]:
def distillation_loss(output, target):
    err = torch.norm(output - target, dim=1)**2
    loss = torch.mean(err)
    return loss

def compactness_loss(output):
    _, n = output.size()
    avg = torch.mean(output, axis=1)
    std = torch.std(output, axis=1)
    zt = output.T - avg
    zt /= std
    corr = torch.matmul(zt.T, zt) / (n-1)
    loss = torch.sum(torch.triu(corr, diagonal=1)**2)
    return loss

def increment_mean_and_var(mu_N, var_N, N, batch):
    '''Increment value of mean and variance based on
       current mean, var and new batch
    '''
    # batch: (batch, h, w, vector)
    B = batch.size()[0] # batch size
    # we want a descriptor vector -> mean over batch and pixels
    mu_B = torch.mean(batch, dim=[0,1,2])
    S_B = B * torch.var(batch, dim=[0,1,2], unbiased=False) 
    S_N = N * var_N
    mu_NB = N/(N + B) * mu_N + B/(N + B) * mu_B
    S_NB = S_N + S_B + B * mu_B**2 + N * mu_N**2 - (N + B) * mu_NB**2
    var_NB = S_NB / (N+B)
    return mu_NB, var_NB, N + B

In [None]:
def student_loss(output, target):
    # dim: (batch, h, w, vector)
    err = reduce((output - target)**2, 'b h w vec -> b h w', 'sum')
    loss = torch.mean(err)
    return loss

class Conv1x1(nn.Module):
    def __init__(self, input_size):
        super(Conv1x1, self).__init__()
        self.conv = nn.Conv2d(input_size, d_loc, (1, 1), 1)
    def forward(self, x):
        return self.conv(x)

In [None]:
# Training Global Encoder with MVTec LOCO (target) dataset
import torchvision
from einops import reduce, rearrange
import pandas as pd

localEncoder = LocalBranchEncoder()
# trained local encoder
localEncoder.load_state_dict(torch.load(f"/content/drive/MyDrive/LOCO_AD/local_encoder_fixed.pt"))
localEncoder.eval().to(torch.device("cuda:0"))
localEncoder = nn.DataParallel(localEncoder, output_device=1)

globalEncoder = GlobalBranchEncoder()
localRegressor = LocalBranchEncoder()
globalRegressor = UNet(in_channels=3, n_classes=d_glo, padding=True, up_mode='upsample')
conv11 = nn.DataParallel(Conv1x1(512).to(torch.device("cuda:0")), output_device=1)


globalEncoder = nn.DataParallel(globalEncoder.cuda(), output_device=1)

localRegressor = nn.DataParallel(localRegressor.cuda(), output_device=1) 

globalRegressor = nn.DataParallel(globalRegressor.cuda(), output_device=1)

optimizer = torch.optim.Adam([{'params': globalEncoder.parameters()},
                              {'params': localRegressor.parameters()},
                              {'params': globalRegressor.parameters()}],
                             lr = 1e-4,
                             weight_decay=1e-5)

dataset = AnomalyDataset(root_dir=f"/content/drive/MyDrive/LOCO_AD/{target_dataset}",
                         transform = transforms.Compose([
                            transforms.Resize((256, 256)),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5), (0.225, 0.225, 0.225))]),
                        type='train',
                        label=0                        
                        )

dataloader = DataLoader(dataset, batch_size=8, shuffle=False, num_workers=0)
print("Preprocessing of training dataset")
with torch.no_grad():
    t_mu, t_var, N = 0, 0, 0
    for i, batch in tqdm(enumerate(dataloader)):
        inputs = batch['image'].cuda()
        t_out = localEncoder(inputs,fdfe=True)
        t_mu, t_var, N = increment_mean_and_var(t_mu, t_var, N, t_out)

Preprocessing of training dataset


42it [03:54,  5.59s/it]


In [None]:
def gcad_loss(loc_tch_output, loc_tch_output_patch, 
              glo_tch_output_up, glo_tch_output, 
              loc_stu_output, glo_stu_output):
    
    loc_tch_output = rearrange(loc_tch_output, 'b h w vec -> b vec h w')
    loc_tch_output = conv11(loc_tch_output)
    
    l_kd = torch.mean(torch.norm(loc_tch_output - glo_tch_output_up, dim=1)**2)
    l_loc = torch.mean(torch.norm(loc_tch_output_patch - loc_stu_output, dim=1)**2)
    l_glo = torch.mean(torch.norm(glo_tch_output - glo_stu_output, dim=1)**2)
    
    return (l_kd / d_loc) + (l_loc / d_loc) + (l_glo / d_glo)

In [None]:
%%time
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0)
random_crop = transforms.RandomCrop((33, 33))
min_running_loss = np.inf
stop_count = 0

for epoch in range(epochs):
    running_loss = 0.0
    length = len(dataloader)
    for i, batch in tqdm(enumerate(dataloader)):
        optimizer.zero_grad()
        inputs0 = batch['image'].cuda()
        inputs1 = batch['image'].cuda()
        cropped_inputs = random_crop(batch['image']).cuda()
        with torch.no_grad():
            targets = localEncoder(inputs0, fdfe=True) - t_mu / torch.sqrt(t_var)
        loc_tch_patch = localEncoder(cropped_inputs)    
        loc_stu = localRegressor(cropped_inputs)
        
        glo_tch = globalEncoder(inputs1)
        glo_tch_up = globalEncoder(inputs1, upsample=True)
        glo_stu = globalRegressor(inputs1)
        
        loss = gcad_loss(targets, loc_tch_patch, glo_tch_up, glo_tch, loc_stu, glo_stu)
        loss.backward()

        optimizer.step()
        running_loss += loss.item()
    if running_loss < min_running_loss and epoch > 0:
        torch.save(globalEncoder.module.state_dict(), 
                   f"/content/drive/MyDrive/LOCO_AD/gcad/{target_dataset}_{time_stamp}/global_encoder.pt")
        torch.save(localRegressor.module.state_dict(), 
                   f"/content/drive/MyDrive/LOCO_AD/gcad/{target_dataset}_{time_stamp}/local_regression.pt")
        torch.save(globalRegressor.module.state_dict(), 
                   f"/content/drive/MyDrive/LOCO_AD/gcad/{target_dataset}_{time_stamp}/global_regression.pt")
        print(f'Epoch #{epoch}  Loss decreased : {round(running_loss, 6)},  Model saved')
        min_running_loss = running_loss
        stop_count = 0
    else:
        stop_count += 1
    if stop_count > early_stopping:
        break
torch.cuda.empty_cache()

In [None]:
del dataloader

In [None]:
# Anomaly scoring

import matplotlib as mpl
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
from cv2 import GaussianBlur
import cv2


def get_error_map(students_pred, teacher_pred):
    # student: (batch, student_id, h, w, vector)
    # teacher: (batch, h, w, vector)
    mu_students = reduce(students_pred, 'b id h w vec -> b h w vec', 'mean') # no change
    err = reduce((mu_students - teacher_pred)**2, 'b h w vec -> b h w', 'sum') # teacher~student dist (norm)
    return err

def get_error_map_glo(students_pred, teacher_pred):
    # student: (batch, student_id, h, w, vector)
    # teacher: (batch, h, w, vector)
    mu_students = reduce(students_pred, 'b id h w vec -> b h w vec', 'mean') # no change
    err = reduce((mu_students - teacher_pred)**2, 'b c h w -> b h w', 'sum') # teacher~student dist (norm)
    return err


@torch.no_grad()
def  (teacher, students, dataloader, device):
    print('calibrating teacher on Student dataset.')
    t_mu, t_var, t_N = 0, 0, 0
    for _, batch in tqdm(enumerate(dataloader)):
        inputs = batch['image'].to(device)
        t_out = teacher(inputs, fdfe=True)
        t_mu, t_var, t_N = increment_mean_and_var(t_mu, t_var, t_N, t_out)
    
    print('calibrating scoring parameters on Student dataset.')
    max_err, max_var = 0, 0
    mu_err, var_err, N_err = 0, 0, 0

    for _, batch in tqdm(enumerate(dataloader)):
        inputs = batch['image'].to(device)

        t_out = (teacher(inputs, fdfe=True) - t_mu) / torch.sqrt(t_var)
        s_out = torch.stack([student(inputs, fdfe=True) for student in students], dim=1)

        s_err = get_error_map(s_out, t_out)
        mu_err, var_err, N_err = increment_mean_and_var(mu_err, var_err, N_err, s_err)

        max_err = max(max_err, torch.max(s_err))

    return {"teacher" : {"mu": t_mu, "var": t_var},
            "students": {"mu": mu_err, "var": var_err, "max": max_err}
           }

@torch.no_grad()
def calibrate_glo(teacher, students, dataloader, device):
    print('calibrating teacher on Student dataset.')
    t_mu, t_var, t_N = 0, 0, 0
    for _, batch in tqdm(enumerate(dataloader)):
        inputs = batch['image'].to(device)
        t_out = teacher(inputs)
        t_mu, t_var, t_N = increment_mean_and_var(t_mu, t_var, t_N, t_out)
    
    print('calibrating scoring parameters on Student dataset.')
    max_err, max_var = 0, 0
    mu_err, var_err, N_err = 0, 0, 0

    for _, batch in tqdm(enumerate(dataloader)):
        inputs = batch['image'].to(device)

        t_out = (teacher(inputs) - t_mu) / torch.sqrt(t_var)
        s_out = torch.stack([student(inputs) for student in students], dim=1)

        s_err = get_error_map(s_out, t_out)
        mu_err, var_err, N_err = increment_mean_and_var(mu_err, var_err, N_err, s_err)

        max_err = max(max_err, torch.max(s_err))

    return {"teacher" : {"mu": t_mu, "var": t_var},
            "students": {"mu": mu_err, "var": var_err, "max": max_err}
           }


@torch.no_grad()
def get_score_map(inputs, teacher, students, params):
    t_out = (teacher.fdfe(inputs) - params['teacher']['mu']) / torch.sqrt(params['teacher']['var'])
    s_out = torch.stack([student.fdfe(inputs) for student in students], dim=1)

    s_err = get_error_map(s_out, t_out)
    score_map = (s_err - params['students']['mu']) / torch.sqrt(params['students']['var'])\
    
    return score_map

@torch.no_grad()
def get_score_map_glo(inputs, teacher, students, params):
    t_out = (teacher(inputs) - params['teacher']['mu']) / torch.sqrt(params['teacher']['var'])
    s_out = torch.stack([student(inputs) for student in students], dim=1)

    s_err = get_error_map_glo(s_out, t_out)
    score_map = (s_err - params['students']['mu']) / torch.sqrt(params['students']['var'])\
    
    return score_map


def visualize(img, gt, score_map, max_score, i, mode):
    plt.figure(figsize=(13, 3))
    plt.subplot(1, 3, 1)
    plt.imshow(img)
    plt.title(f'Original image')

    plt.subplot(1, 3, 2)
    plt.imshow(torch.round(gt), cmap='gray')
    plt.title(f'Ground thuth anomaly')

    plt.subplot(1, 3, 3)
#     plt.imshow(score_map, cmap='jet')
     
    score_map = (score_map - score_map.min()) / (score_map.max() - score_map.min())
    score_map = score_map.numpy()
    score_map = GaussianBlur(score_map, (7, 7), cv2.BORDER_DEFAULT)
#     plt.imshow(img, cmap='gray', interpolation='none')
    plt.imshow(score_map, cmap='jet', alpha=0.5, interpolation='none')
    plt.colorbar(extend='both')
    plt.title('Anomaly map')

#     plt.clim(0, max_score)
    plt.savefig(f'/content/drive/MyDrive/LOCO_AD/gcad/{target_dataset}_{time_stamp}/results/{mode}/{target_dataset}_{i}.png')
    plt.show(block=True)

In [None]:
device = torch.device("cuda:0")

# Teacher network
localEncoder = LocalBranchEncoder()
localEncoder.load_state_dict(torch.load(f"/content/drive/MyDrive/LOCO_AD/local_encoder_fixed.pt"))
localEncoder.eval().cuda()

# Students networks
localRegressor = [LocalBranchEncoder()]
localRegressor[0].load_state_dict(torch.load(f"/content/drive/MyDrive/LOCO_AD/gcad/{target_dataset}_{time_stamp}/local_regression.pt"))
localRegressor[0].eval().cuda()

globalEncoder = GlobalBranchEncoder()
globalEncoder.load_state_dict(torch.load(f"/content/drive/MyDrive/LOCO_AD/gcad/{target_dataset}_{time_stamp}/global_encoder.pt"))
globalEncoder.eval().cuda()

globalRegressor = [UNet(in_channels=3, n_classes=d_glo, padding=True, up_mode='upsample')]
globalRegressor[0].load_state_dict(torch.load(f"C/content/drive/MyDrive/LOCO_AD/gcad/{target_dataset}_{time_stamp}/global_regression.pt"))
globalRegressor[0].cuda()

# calibration on anomaly-free dataset
calib_dataset = AnomalyDataset(
                root_dir=f'/content/drive/MyDrive/LOCO_AD/{target_dataset}',
                transform=transforms.Compose([
                    transforms.Resize((256, 256)),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.225, 0.225, 0.225))]),
                type='train',
                label=0)

calib_dataloader = DataLoader(calib_dataset, 
                               batch_size=8, 
                               shuffle=False)

params_loc = calibrate(localEncoder, localRegressor, calib_dataloader, device)
params_glo = calibrate_glo(globalEncoder, globalRegressor, calib_dataloader, device)


In [None]:
# Load testing data
test_dataset = AnomalyDataset(
                root_dir=f"/content/drive/MyDrive/LOCO_AD/{target_dataset}",
                transform=transforms.Compose([
                    transforms.Resize((256, 256)),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.225, 0.225, 0.225))]),
                gt_transform=transforms.Compose([
                    transforms.Resize((256, 256)),
                    transforms.ToTensor()]),
                type='test')

test_dataloader = DataLoader(test_dataset, 
                             batch_size=1, 
                             shuffle=False)


# Build anomaly map
y_score_loc = np.array([])
y_score_glo = np.array([])
y_true = np.array([])
test_iter = iter(test_dataloader)

for mode in ['both', 'local', 'global']:
    test_size = len(test_dataset)
    for i in range(test_size):
        batch = next(test_iter)
        inputs = batch['image'].to(device)
        gt = batch['gt'].cpu()

        score_map_loc = get_score_map(inputs, localEncoder, localRegressor, params_loc).cpu()
        cur_y_score_loc = rearrange(score_map_loc, 'b h w -> (b h w)').numpy()
        y_score_loc = np.concatenate((y_score_loc, cur_y_score_loc))

        score_map_glo = get_score_map_glo(inputs, globalEncoder, globalRegressor, params_glo).cpu()
        cur_y_score_glo = rearrange(score_map_glo, 'b h w -> (b h w)').numpy()
        y_score_glo = np.concatenate((y_score_glo, cur_y_score_glo))

        if mode == 'both':
            score_map = score_map_loc + score_map_glo
        elif mode == 'local':
            score_map = score_map_glo
        elif mode == 'global':
            score_map = score_map_loc
        y_true = np.concatenate((y_true, rearrange(gt, 'b c h w -> (b c h w)').numpy()))

        unorm = transforms.Normalize((-2.22, -2.22, -2.22), (4.44, 4.44, 4.44)) # get back to original image
        max_score = \
          (params_loc['students']['max'] - params_loc['students']['mu']) / torch.sqrt(params_loc['students']['var'])
        + (params_glo['students']['max'] - params_glo['students']['mu']) / torch.sqrt(params_glo['students']['var'])
        img_in = rearrange(unorm(inputs).cpu(), 'b c h w -> b h w c')
        gt_in = rearrange(gt, 'b c h w -> b h w c')

    #         for b in range(32): # batchsize
        b = 0
        visualize(img_in[b, :, :, :].squeeze(), 
                  gt_in[b, :, :, :].squeeze(), 
                  score_map[b, :, :].squeeze(), 
                  max_score, i, mode)

In [None]:
print("*********** loop done ***********")
y_score = y_score_loc + y_score_glo
y_score = (y_score - y_score.min()) / (y_score.max() - y_score.min()) # min max norm (0~1)

# AUC ROC
fpr, tpr, thresholds = roc_curve(y_true.astype(int), (y_score))
plt.figure(figsize=(13, 3))
plt.plot(fpr, tpr, 'r', label="ROC")
plt.plot(fpr, fpr, 'b', label="random")
plt.title(f'ROC AUC: {auc(fpr, tpr)}')
plt.xlabel('FPR')
plt.ylabel('TPR')
plt.legend()
plt.grid()
plt.show()
plt.savefig(f"/content/drive/MyDrive/LOCO_AD/gcad/{target_dataset}_{time_stamp}/results/{target_dataset}_roc.png")
#     plt.clf()
plt.show()

In [None]:
fpr, tpr, thresholds = roc_curve(y_true.astype(int), 1-y_score)
plt.figure(figsize=(13, 3))
plt.plot(fpr, tpr, 'r', label="ROC")
plt.plot(fpr, fpr, 'b', label="random")
plt.title(f'ROC AUC: {auc(fpr, tpr)}')
plt.xlabel('FPR')
plt.ylabel('TPR')
plt.legend()
plt.grid()
plt.show()

# local_branch_anomaly_score()