In [1]:
import torch
import torch.nn as nn
from torch import optim
import torchvision.datasets as dset
import torchvision.transforms as transforms

from utils import device_setting, seed_torch, image_check, Anomaly_score
from data_manager import AnomalyDataManager
from model import Generator, Discriminator
# from resnet import ResNet18
# from trainer import Trainer

import matplotlib.pyplot as plt
import torchvision.utils as v_utils

from sklearn.manifold import TSNE
import numpy as np
import matplotlib

import copy
from statistics import mean

In [2]:
data_dir = './data'
dataset = 'mnist'
batch_size = 25
lr = 0.001
momentum = 0.9
max_epoch = 200
criterion = nn.MSELoss(reduction='none')
feature_dim = 10
gpu = 0
device = device_setting(gpu=gpu)
anomaly_setting = 1
seed = 0

In [3]:
data_manager = AnomalyDataManager(dataset=dataset, data_dir=data_dir, trans=None, seed=seed, only_normal=False, anomaly_setting=anomaly_setting, data_num=2000)
num_classes = data_manager.get_num_classes()
dataloader_dict = data_manager.build_dataloader(batch_size)
seed_torch(0)
generator = nn.DataParallel(Generator(), device_ids=[0])
discriminator = nn.DataParallel(Discriminator(),device_ids=[0])
optimizer = optim.Adam(generator.parameters(), lr=lr)
criterion = nn.MSELoss(reduction='none')

# generator.eval()
# discriminator.eval()

train data: 1600


In [4]:
seed_torch(0)
load_epoch = 999
# model_path = './saved_model/anomaly_setting_{}/generator_epoch{}.pkl'.format(anomaly_setting, load_epoch)
gen_path = './saved_model/generator_epoch{}.pkl'.format(load_epoch)
generator.load_state_dict(torch.load(gen_path))

dis_path = './saved_model/discriminator_epoch{}.pkl'.format(load_epoch)
discriminator.load_state_dict(torch.load(dis_path))

<All keys matched successfully>

In [5]:
z_batch = 25
z = torch.randn(z_batch, 100, requires_grad=True, device=device)
z_optimizer = torch.optim.Adam([z],lr=0.1)


In [6]:
z.size()

torch.Size([25, 100])

In [7]:
gen_fake = generator(z)

In [8]:
v_utils.save_image(gen_fake.data[0:25],"./images/anomaly_setting_{}/gen_epoch{}_update0.png".format(anomaly_setting, 
        load_epoch+1), nrow=5)

In [9]:
def Anomaly_score(x, G_z, discriminator, Lambda=0.1):
    _,x_feature = discriminator(x)
    _,G_z_feature = discriminator(G_z)
    
    residual_loss = torch.sum(torch.abs(x - G_z))
    discrimination_loss = torch.sum(torch.abs(x_feature - G_z_feature))
    # residual_loss = torch.abs(x - G_z)
    # discrimination_loss = torch.abs(x_feature - G_z_feature)
    
    total_loss = (1 - Lambda) * residual_loss + Lambda*discrimination_loss
    # print('residual loss: ', residual_loss.item(), ' disloss: ', discrimination_loss.item())
    return total_loss

In [21]:
update_num = 3000

In [22]:
save_update = [1, 10, 30, 50]
s = [100 * i for i in range(1, 31)]
save_update = save_update + s

In [12]:
seed_torch(0)
for i, (data, label) in enumerate(dataloader_dict['train']):
    v_utils.save_image(data.cpu().reshape(-1, 1, 28, 28)[0:25],"./images/anomaly_setting_{}/normal_data.png".format(anomaly_setting, load_epoch+1), nrow=5)
    data = data.to(device)
    for updt in range(update_num):
        gen_fake = generator(z)
        loss = Anomaly_score(data, gen_fake, discriminator, Lambda=0.01)
        print("data_iter: {}, update_count: {}, loss: {}".format(i, updt, loss))
        loss.backward()
        z_optimizer.step()
        if updt+1 in save_update:
            print("save: {}".format(updt))
            v_utils.save_image(gen_fake.data[0:25],"./images/anomaly_setting_{}/normal_gen_epoch{}_update{}.png".format(anomaly_setting, 
                load_epoch+1, updt+1), nrow=5)

    break

data_iter: 0, update_count: 0, loss: 5757.28759765625
save: 0
data_iter: 0, update_count: 1, loss: 5145.619140625
data_iter: 0, update_count: 2, loss: 4731.65234375
data_iter: 0, update_count: 3, loss: 4574.86376953125
data_iter: 0, update_count: 4, loss: 4584.587890625
data_iter: 0, update_count: 5, loss: 4588.17822265625
data_iter: 0, update_count: 6, loss: 4550.76708984375
data_iter: 0, update_count: 7, loss: 4497.7470703125
data_iter: 0, update_count: 8, loss: 4431.73095703125
data_iter: 0, update_count: 9, loss: 4355.3701171875
save: 9
data_iter: 0, update_count: 10, loss: 4282.6025390625
data_iter: 0, update_count: 11, loss: 4217.294921875
data_iter: 0, update_count: 12, loss: 4168.609375
data_iter: 0, update_count: 13, loss: 4132.1240234375
data_iter: 0, update_count: 14, loss: 4109.54150390625
data_iter: 0, update_count: 15, loss: 4098.41064453125
data_iter: 0, update_count: 16, loss: 4089.138671875
data_iter: 0, update_count: 17, loss: 4081.162353515625
data_iter: 0, update_co

KeyboardInterrupt: 

In [13]:
# 異常データの生成
label = dataloader_dict['test'].dataset.targets
if anomaly_setting == 0:
    normal_mask = (label == 4)
    anomaly_mask = (label != 4)
elif anomaly_setting == 1:
    normal_mask = (label % 2 == 0)
    anomaly_mask = (label % 2 != 0)
elif anomaly_setting == 2:
    normal_mask = (label != 9)
    anomaly_mask = (label == 9)
else:
    normal_mask = (label >= 0)
    anomaly_mask = (label < 0)

In [14]:
label.unique()

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [15]:
label[anomaly_mask].unique()

tensor([0, 1, 2, 3, 5, 6, 7, 8, 9])

In [16]:
dataloader_dict['test'].dataset.data = dataloader_dict['test'].dataset.data[anomaly_mask]
dataloader_dict['test'].dataset.targets = dataloader_dict['test'].dataset.targets[anomaly_mask]

In [17]:
dataloader_dict['test'].dataset.targets.unique()

tensor([0, 1, 2, 3, 5, 6, 7, 8, 9])

In [19]:
z_batch = 25
z = torch.randn(z_batch, 100, requires_grad=True, device=device)
z_optimizer = torch.optim.Adam([z],lr=0.1)

In [20]:
seed_torch(0)
for i, (data, label) in enumerate(dataloader_dict['test']):
    v_utils.save_image(data.cpu().reshape(-1, 1, 28, 28)[0:25],"./images/anomaly_setting_{}/anomaly_data.png".format(anomaly_setting, load_epoch+1), nrow=5)
    data = data.to(device)
    for updt in range(update_num):
        gen_fake = generator(z)
        loss = Anomaly_score(data, gen_fake, discriminator, Lambda=0.01)
        print("data_iter: {}, update_count: {}, loss: {}".format(i, updt, loss))
        loss.backward()
        z_optimizer.step()
        if updt+1 in save_update:
            print("save: {}".format(updt))
            v_utils.save_image(gen_fake.data[0:25],"./images/anomaly_setting_{}/anomaly_gen_epoch{}_update{}.png".format(anomaly_setting, 
                load_epoch+1, updt+1), nrow=5)

    break

data_iter: 0, update_count: 0, loss: 6906.74560546875
save: 0
data_iter: 0, update_count: 1, loss: 6451.9560546875
data_iter: 0, update_count: 2, loss: 6233.271484375
data_iter: 0, update_count: 3, loss: 6013.10546875
data_iter: 0, update_count: 4, loss: 5900.2392578125
data_iter: 0, update_count: 5, loss: 5875.53759765625
data_iter: 0, update_count: 6, loss: 5825.6796875
data_iter: 0, update_count: 7, loss: 5779.04736328125
data_iter: 0, update_count: 8, loss: 5728.05517578125
data_iter: 0, update_count: 9, loss: 5656.86767578125
save: 9
data_iter: 0, update_count: 10, loss: 5586.87255859375
data_iter: 0, update_count: 11, loss: 5529.8974609375
data_iter: 0, update_count: 12, loss: 5486.5947265625
data_iter: 0, update_count: 13, loss: 5453.2080078125
data_iter: 0, update_count: 14, loss: 5421.69970703125
data_iter: 0, update_count: 15, loss: 5399.78759765625
data_iter: 0, update_count: 16, loss: 5388.17822265625
data_iter: 0, update_count: 17, loss: 5374.25439453125
data_iter: 0, upda

KeyboardInterrupt: 

In [None]:
gen = y_hat.cpu().reshape(-1, 1, 28, 28)
gen.shape

In [None]:
v_utils.save_image(data.cpu().reshape(-1, 1, 28, 28)[0:25],"../results/anomaly_setting_{}/emb_size_{}/anomaly_data.png".format(anomaly_setting, feature_dim, load_epoch+1), nrow=5)
v_utils.save_image(gen.data[0:25],"../results/anomaly_setting_{}/emb_size_{}/anomaly_gen_epoch{}.png".format(anomaly_setting, feature_dim, load_epoch+1), nrow=5)

In [None]:
# しきい値を変えた性能テスト
miss_data_dict = {}
miss_loss_dict = {}
for th in [0.01, 0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40]:
    normal_acc = 0
    losses = []
    miss_data = []
    miss_loss = []
    model.threshold = th
    for data, label in dataloader_dict['test']:
        data = data.reshape(-1, data_manager.input_dim)
        data = data.to(device)
        label = label.to(device)
        if anomaly_setting == 0:
            normal_mask = (label == 4)
            anomaly_mask = (label != 4)
        elif anomaly_setting == 1:
            normal_mask = (label == 0)
            anomaly_mask = (label != 0)
        elif anomaly_setting == 2:
            normal_mask = (label != 9)
            anomaly_mask = (label == 9)
        else:
            normal_mask = (label >= 0)
            anomaly_mask = (label < 0)
        
        label[normal_mask] = 0
        label[anomaly_mask] = 1
        
        # label = label.to(device)
        pred_label, loss = model.inference(data, criterion, device)
        losses.append(mean(loss.cpu().detach().tolist()))
        normal_acc += torch.sum(pred_label == label).item()
        for i, (pl, lbl) in enumerate(zip(pred_label, label)):
            if pl != lbl:
                miss_data.append(data[i].cpu())
                miss_loss.append(loss[i].cpu().detach().numpy())
                

    normal_acc = normal_acc / len(dataloader_dict['test'].dataset)
    # plot, print result
    print('Th.:{:.4f} Test Normal accuracy: {:4f} Test Loss: {:4f} '.format(th, normal_acc, mean(losses)))
    miss_loss_dict[th] = copy.deepcopy(miss_loss)
    miss_data_dict[th] = copy.deepcopy(miss_data)