Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Q&A #14

Open
mjack3 opened this issue Feb 9, 2022 · 72 comments
Open

Q&A #14

mjack3 opened this issue Feb 9, 2022 · 72 comments

Comments

@mjack3
Copy link

mjack3 commented Feb 9, 2022

Hello.

I would like to open this issue to talk about this project. I am also interested in developing this project and would be great to share information as the paper doesn't give deeply information about the implementations and offical code is no available.

If you are agree with this iniciative, firstly we could simplify the project to use Wide-ResNet50 in order to get comparative results with the previous researching. I would like to start from the begining of the paper when says:

For ResNet, we directly use the features of the last layer in the first three blocks, and put these features into three corresponding FastFlow model.

This make me thing that in the implementation we need to use the features after the input layer, layer 1 and layer 2. In this way this table 6 makes sense

image

But can not to imagine how to concatenate this information for make it sense with the next

In the forward process, it takes the feature map from the backbone network as input
image

Depending of what part you read, it seems that just one feature map or 3 are taken

@AlessioGalluccio
Copy link
Owner

AlessioGalluccio commented Feb 10, 2022

Hi @mjack3,
I'm really glad to find some help in this project. Thank you very much for your proposal, I accept. This paper is quite obscure.
The problem you are addressing is explained in paragraph 4.7:

For ResNet18 and Wide-ResNet50-2, we directly use the features of the last layer in the first three blocks, put these features into the 2D flow model to obtain their respective anomaly detection and localization results, and finally take the average value as the final result.

I think that the paper wants us to build three different models and average their anomaly score. But how do we compute this anomaly score? This is the question that I can't solve. In the introduction we can find that:

We propose a 2D normalizing flow denoted as FastFlow for anomaly detection and localization with fully convolutional
networks and two-dimensional loss function to effectively model global and local distribution.

But I can't find how this two-dimensional loss is defined. If you have an idea of good two-dimensional loss for this problem, I'm all ears.
Best,
Alessio

@mjack3
Copy link
Author

mjack3 commented Feb 11, 2022

hummm yes you are right, definitevly we need to create 3 fastflow models..i will try. By the way, you can find my implementation here https://github.com/mjack3/EasyFastFlow feel you free to use what you want

@mjack3
Copy link
Author

mjack3 commented Feb 11, 2022

Have you try contacting to some of the main authors of the paper? I googled them but didn't find the email

@Howeng98
Copy link

@mjack3 Hi, have you take a look about the CFLOW-AD? It also implemented by FLOW model, maybe it can help you to understand how 3 Fastflow module work. I'm trying to implement Fastflow by modify Cflow-AD. If you need any help or discuss, I would like to help (if I can).

@mjack3
Copy link
Author

mjack3 commented Feb 11, 2022

@Howeng98 you are welcome =)

Yes I also looked the CSFLOW-AD code but I am not sure if here, we need to create 3 individual fastFLow model and training with 3 optimizers (one for FastFLow) or doing similar to CSFLOW-AD

@AlessioGalluccio
Copy link
Owner

@mjack3 I tried to contact Yushuang Wu through a university e-mail I found, but I got no answer. I haven't found the e-mail of the other authors

@mjack3
Copy link
Author

mjack3 commented Feb 11, 2022

When did you contact them ?@AlessioGalluccio

@rafalfirlejczyk
Copy link

Hi @mjack3,
Can you please share your implementation of FastFlow? The link seems to be deactivated.
Thanks

@mjack3
Copy link
Author

mjack3 commented Feb 14, 2022

Currently i am obliged to make the code in private because my job contract. I hope to open it soon. Anyway I will share information in this same thread if is needed :)

@maaft
Copy link

maaft commented Mar 4, 2022

@AlessioGalluccio just a small remark: For anomaly score calculation (global and pixelwise) you need to use p(z) and not z which you are currently using.

you can estimate logp(z) (and therefore p(z)) analogous to the pytorch implementation of CFlow AD.

@gathierry
Copy link

Hi @maaft, did you manage to achieve a similar result as the claimed? I tried both the way of CFlow and DifferNet but still far below the performance in the paper.

Another confusion for me is that I cannot get the same A.d param#:
I take each FlowStep as one AllInOneBlock from FrEIA, with 2 convolution layers
This is my counting result (and paper counting result in parentheses)

CaiT:  7,043,780 (14.8M)
DeiT:  7,043,780 (14.8M)
Resnet18:  4,650,240 (4.9M)
WideResnet50:  41,309,184 (41.3M) -> this one is matched

Here's code I used to compute param#

def count_params_per_flow_step(k, cin, ratio):
    cout = 2 * cin
    cmed = int(cin * ratio)
    w1 = k * k * cin * cmed
    b1 = cmed
    w2 = k * k * cmed * cout
    b2 = cout
    return w1 + w2 + b1 + b2
    
def count_total_params(num_steps, conv3x3_only, feature_channels, ratio):
    s = 0
    for channels in feature_channels:
        for i in range(num_steps):
            k = 1 if (i % 2 == 1 and not conv3x3_only) else 3
            s += count_params_per_flow_step(k, channels // 2, ratio)
    return s

print("CaiT: ", count_total_params(20, False, [768], 0.16))
print("DeiT: ", count_total_params(20, False, [768], 0.16))
print("Resnet18: ", count_total_params(8, True, [64, 128, 256], 1.0))
print("WideResnet50: ", count_total_params(8, False, [256, 512, 1024], 1.0))

@maaft
Copy link

maaft commented Mar 6, 2022

@gathierry no, I don't think that I can match the scores in the paper (didn't evaluate it yet, only visually). In particular, the transistor class (broken legs) does not learn at all.

I'll evaluate auroc etc next week and report back.

Also I tried different backbones than resnet18, that achieve higher accuracy on imagenet (e.g. EfficientNet) and noticed that the training has a very hard time to converge at all. No idea, why this is the case.

@brm738
Copy link

brm738 commented Mar 6, 2022

My tests of this code (24 Epochs) shows acceptable results only for Resnet18 and only for three mvtec classes:

<style> </style>
  AUROC-MAX AUCPR-MAX
Bottle 0.9849 0.9955
Screw 0.9859 0.9959
Wood 0.9956 0.9987

Other classes performed badly.
I did not test it for WideResnet50 yet.
Feature extractors based on Vision Transformers like Deit or Cait does not learn at all.

@maaft
Copy link

maaft commented Mar 6, 2022

Thats weird, right? Do the resnet18 features follow some kind of special/nice distribution that the other architectures don't have?

Has anyone tried different feature extractors with CFLOW-AD or other flow-based approaches?

@maaft
Copy link

maaft commented Mar 7, 2022

I found out that resnet18 works well, because the extracted features have a low magnitude. When I use e.g. EfficientNet and just scale the features by 0.1, the NF-Head seems to learn quite well.

I'll try to add a learnable scaling parameter to make my model backbone agnostic. Doesn't work - features will collapse to 0.

@maaft
Copy link

maaft commented Mar 8, 2022

Another thing: According to the architecture image (fig 2) from the paper, I think we should use RNVPCouplingBlock and not AllInOneBlock. The former includes two alternating coupling networks, while the latter is only single sided.

Furthermore, the AllInOneBlock applies ActNorm and PermuteRandom in the end of the coupling block and not in the beginning. We need to add those therefore manually before every RVNPCouplingBlock.

Does anyone know if the permutation indeed needs to be fixed during training? Or do we need to use a different permutation at every training step? I'm asking because The PermuteRandom Module from FrEIA is fixed during training.

Edit: Does it really matter though? I think the reason for alternating coupling networks for RealNVP was to also train the upper half of channels.. But when permuting randomly multiple times, we also train every channel. Hm, I'm a bit clueless here.

Edit2: ActNorm on beginning is paramount. When you do this, all backbones will work like magic. No manual scaling needed.

@gathierry
Copy link

gathierry commented Mar 10, 2022

I think PermuteRandom is actually a more flexible lower-half/upper-half alternating so essentially, I don't feel big difference.
I was also trying to figure out if it's a coupling block or AllInOneBlock from the A.D params in Table. 1. But as mentioned earlier, I can never match all of them.

Based on our experiments, PermuteRandom must be fixed since initialization. Otherwise, the NF cannot learn anything useful.

Another thing: According to the architecture image (fig 2) from the paper, I think we should use RNVPCouplingBlock and not AllInOneBlock. The former includes two alternating coupling networks, while the latter is only single sided.

Furthermore, the AllInOneBlock applies ActNorm and PermuteRandom in the end of the coupling block and not in the beginning. We need to add those therefore manually before every RVNPCouplingBlock.

Does anyone know if the permutation indeed needs to be fixed during training? Or do we need to use a different permutation at every training step? I'm asking because The PermuteRandom Module from FrEIA is fixed during training.

Edit: Does it really matter though? I think the reason for alternating coupling networks for RealNVP was to also train the upper half of channels.. But when permuting randomly multiple times, we also train every channel. Hm, I'm a bit clueless here.

Edit2: ActNorm on beginning is paramount. When you do this, all backbones will work like magic. No manual scaling needed.

@maaft
Copy link

maaft commented Mar 10, 2022

Yes, I think you are right.

To match parameters: Which layers are you using from the resnet?

Per paper:

  • use first three block outputs (64, 64, 128) channels for resnet18
  • use RNVPCouplingBlock (or Glow - I think parameterwise it shouldn't matter)
  • use ActNorm followed by PermuteRandom before every block
  • use a total of 8 coupling blocks per layer output (3x3 and 1x1 alternating)

the only free variable to play with in this case is the number of mid-channels for both subnets.

Unfortunately my GPU memory is too small to use first 3 image features. Please let me know if you can achieve any good results with above configuration.

Count Parameters with:

nf_params = sum(p.numel() for p in self.nf.parameters() if p.requires_grad) # self.nf is the flow head

@mjack3
Copy link
Author

mjack3 commented Mar 10, 2022

I opened a Q in the Freia github

vislearn/FrEIA#113

@gathierry
Copy link

@maaft

  • I think the "first 3 blocks" for resnet18 means stride4x, 8x, and 16x, so the channel numbers should be (64, 128, 256). See table 6.
  • In fact, in section 6.1 and caption of Table7, the paper indicates the mid-channel numbers in subnets

I tried to move Permute and ActNorm from the end to the beginning of the block, as you suggested, but I didn't see significant improvement. Maybe there are some other issues in my code.

@mjack3
Copy link
Author

mjack3 commented Mar 10, 2022

@maaft

  • I think the "first 3 blocks" for resnet18 means stride4x, 8x, and 16x, so the channel numbers should be (64, 128, 256). See table 6.

  • In fact, in section 6.1 and caption of Table7, the paper indicates the mid-channel numbers in subnets

I tried to move Permute and ActNorm from the end to the beginning of the block, as you suggested, but I didn't see significant improvement. Maybe there are some other issues in my code.

I am getting NaN when ActNorm is at the beginning of the block in a innSeq. Could you share an image?

@maaft
Copy link

maaft commented Mar 10, 2022

I guess I could share my model later. No idea why you get nans.

Maybe your data is already bad and contains nans? Are you normalizing your images?

@mjack3
Copy link
Author

mjack3 commented Mar 10, 2022

Btw, for resNet the output of layer1, layer2 and layer3 are used.

Currently, i got a model that achieve [0.98, 1.0] with 25 epochs instead 500 In Clasification, for every class.

It needs some adjusts but hope to open the code soon for community participation.

Note: the code of this repo is wrong (sorry)

@mjack3
Copy link
Author

mjack3 commented Mar 10, 2022

I guess I could share my model later. No idea why you get nans.

Maybe your data is already bad and contains nans? Are you normalizing your images?

For that, i am emulating the process:

x = torch.rand(16,3,256,256)

o=model(x)

But yes, I tested with the real normalized image in a pytorch-standard way

@AlessioGalluccio
Copy link
Owner

The permutation of channels must be fixed during training. As @gathierry mentioned, it's necessary for normalizing flows.

@AlessioGalluccio just a small remark: For anomaly score calculation (global and pixelwise) you need to use p(z) and not z which you are currently using.

you can estimate logp(z) (and therefore p(z)) analogous to the pytorch implementation of CFlow AD.

For the anomaly score I apply
anomaly_score.append(t2np(torch.mean(z_grouped_temp ** 2, dim=(-2, -1))))
As it is used in DifferNet. Do you mean that I should add a /2 to it to be the same as the negative loglikelihood of a normal function?

@AlessioGalluccio
Copy link
Owner

AlessioGalluccio commented Mar 10, 2022

Looking a CFlow AD, it does in utils.py
logp = C * _GCONST_ - 0.5*torch.sum(z**2, 1) + logdet_J
He computes the positive likelihood instead of the negative one. In fact, he calculates the score, not the anomaly score. Then he computes in train.py

# invert probs to anomaly scores
        super_mask = score_mask.max() - score_mask

In this way he gets the anomaly score. So, it's basically the same.
I think that adding the jacobian in the anomaly score is useless, since it is the same for every output. The jacobian depends on the weights of the net, not on the input image

@maaft
Copy link

maaft commented Mar 10, 2022

I'll share my model, loss function and anomaly map generation tomorrow

@gathierry
Copy link

gathierry commented Mar 11, 2022

@AlessioGalluccio In CFlow, there's an exponential converting logp to p as well. It's the same if there's only one feature level (such as DeiT and CaiT). But if there are 3 feature levels (resnet), it would be different since exp is performed before sum of three score maps in three levels. logp is in (-inf, 0] but p is in [0, 1], sum(log_p) and sum(p) can result in totally different values

@gathierry
Copy link

And for logp = C * _GCONST_ - 0.5*torch.sum(z**2, 1) + logdet_J. Does it make sense if we only reduce dim=1 when doing sum on logdet_J?
I subclassed AllInOneBlock to keep the axes of H and W

class AllInOneBlock2D(Fm.AllInOneBlock):
    def __init__(self, dims_in, **kwargs):
        super().__init__(dims_in, **kwargs)
        self.sum_dims = (1,)

@gathierry
Copy link

gathierry commented Mar 11, 2022

@mjack3 As for ActNorm, I simply moved the _permute of AllInOneBlock to the beginning of forward and removed the original ones. I don't think this is the root cause of NaN but it might somehow amplify your gradient.

def forward(self, x, c=[], rev=False, jac=True):
        '''See base class docstring'''
        if self.householder:
            self.w_perm = self._construct_householder_permutation()
            if rev or self.reverse_pre_permute:
                self.w_perm_inv = self.w_perm.transpose(0, 1).contiguous()
        # ==== ActNorm ====
        x0, global_scaling_jac = self._permute(x[0], rev=False)
        # ==== ActNorm end ====
        x1, x2 = torch.split(x0, self.splits, dim=1)

        if self.conditional:
            x1c = torch.cat([x1, *c], 1)
        else:
            x1c = x1

        if not rev:
            a1 = self.subnet(x1c)
            x2, j2 = self._affine(x2, a1)
        else:
            a1 = self.subnet(x1c)
            x2, j2 = self._affine(x2, a1, rev=True)

        log_jac_det = j2
        x_out = torch.cat((x1, x2), 1)

        # add the global scaling Jacobian to the total.
        # trick to get the total number of non-channel dimensions:
        # number of elements of the first channel of the first batch member
        n_pixels = x_out[0, :1].numel()
        log_jac_det += (-1)**rev * n_pixels * global_scaling_jac

        return (x_out,), log_jac_det

@mjack3
Copy link
Author

mjack3 commented Mar 15, 2022

My model:

My fastflow head is based on AllInOne:

def fastflow_head(dims: tuple) -> Ff.SequenceINN:

    inn = Ff.SequenceINN(*dims)
    for k in range(4):
        inn.append(Fm.AllInOneBlock, subnet_constructor=subnet_conv_3x3, permute_soft=True)
        inn.append(Fm.AllInOneBlock, subnet_constructor=subnet_conv_1x1, permute_soft=True)

    return inn

The loss function I am using in training is just the same as you (based on CS-FLOW (Not CFLOW-AD)

loss = torch.mean(0.5 * torch.sum(z ** 2, dim=(1, 2, 3)) - log_j) / z.shape[1]

and then I am trying to build the anomalies scores directly from Z, similar as CS-FLOW does. Reading the paper, we can see that authors can build a estimate map of anomalies but because of NF is built with dense layers instead convolution, it's not great for anomaly location.

Currrently, reading eq8 from CSFLOW-AD, you piece of code has sense but, we would be mixing different training logic because I think the loss funciont in CSFLOW-AD is different of what you are using (I didn't read the paper in detail yet)

@maaft
Copy link

maaft commented Mar 15, 2022

Thank you very much.

Just some notes/questions:

  • when using 8x AllInOneBlock you're only using half the amount of convolution-layers as in they state in the paper (shouldn't matter because you get nice results anyway)
  • I tried the loss functtion from CFLOW-AD but it didn't work at all
  • could you elaborate on your subnet config? How many mid-channels are you using?
  • are you using all 3 layers from resnet18?
  • what optimizer config are you using?

@mjack3
Copy link
Author

mjack3 commented Mar 15, 2022

  • Maybe i misunderstood you. The figure 2.b says one flow step for our FastFlow so in the iteration 1 of this loop you will 2 step flow (conv3+conv1). At the end they are 8 step flows. I execute this loop one for each resnet layer
  • I think is normal becausee the loss of CSFLOW is the equivalent for minimizing.
  • I'm using x2 input channel in the middle.
  • I'm using the 3 layers from wide_resnet50.
  • I 'm using the same optimizer than fastflow paper

@maaft
Copy link

maaft commented Mar 15, 2022

  • according to fig. 2 "one flow step" is represented by double-sided blocks (which can be GLOW or RVNP). These blocks have two subnets and work on both halfs of the input. Because the paper says "s(y_a) and b(y_a) are outputs of two neural networks", it must be the RVNP-Block, because the GLOWBlock uses one network to predict both. In practice it doesn't matter which you use and parameter-wise they are the same. On the other hand, the AllInOneBlock that you are using is only one-sided. That means it only represents the first half of fig.2. Therefore you need to double your number of AllInOneBlocks to match the number of conv-layers used in the paper.

Thanks for the other points! :)

@mjack3
Copy link
Author

mjack3 commented Mar 15, 2022

humm thanks for your point. I haven't so much experience in normalizing flow so, probably you are in true and I'm wrong using AllInOneBlock

@maaft
Copy link

maaft commented Mar 15, 2022

I don't think that it matters in the end. Your performance is great.

By the way, I think that realNVP introduced the double sided blocks only to make sure that all channels are trained. They did this before RandomPermute was becoming a thing. Both are equivalent IMHO. Just the number of blocks (when using FrEIA) needs to be adapted to match a certain number of conv layers.

@gathierry
Copy link

I think there's conflict about the definition of "one step" in Eq(3) and Fig 2

@maaft
Copy link

maaft commented Mar 15, 2022

I don't think so. The figure says "flow step" while eq.3 only says "step" and takes reference to Dhin et al.

@maaft
Copy link

maaft commented Mar 15, 2022

If authors only would publish their f** code. It should be mandatory lol.

@mjack3
Copy link
Author

mjack3 commented Mar 15, 2022

I dont get the result of the paper :(

@mjack3
Copy link
Author

mjack3 commented Mar 17, 2022

The code to my repo is opened. I welcome any contribution there.

https://github.com/mjack3/FastFlow-AD

@gathierry
Copy link

gathierry commented Mar 18, 2022

I managed to achieve comparable performance in a tricky way. I add LayerNorm layers before sending each feature map to NFs.
It is "tricky" since the usage is different for different backbones but that's the only way works for me:

  • resnet18 and wide-resnet-50: use trainable LayerNorm
  • CaiT and DeiT: use the final norm from the pre-trained model and fix it's affine parameters

I opened my code as well https://github.com/gathierry/FastFlow

@maaft
Copy link

maaft commented Mar 18, 2022

@gathierry wow, LayerNorm really boosts the performance enormously. Thank you very much for that and for opening your code! :)

@maaft
Copy link

maaft commented Mar 18, 2022

Does anyone have a clue how we could measure model performance without having labeled ground-truth test data?

@mjack3
Copy link
Author

mjack3 commented Mar 24, 2022

image

Hello guys. Someone know what does mean "w/o" and "w" in the first column of this table?

@Howeng98
Copy link

@mjack3 w/o = without w = with

@mjack3
Copy link
Author

mjack3 commented Mar 25, 2022

@gathierry I 'm wondering about why you do this in lanes https://github.com/gathierry/FastFlow/blob/master/fastflow.py#L148-L151

@gathierry
Copy link

gathierry commented Mar 25, 2022

@mjack3 why I'm doing exponential? Because the values in different levels should be converted to a probability before merged. I want to project the output to its probability in gaussian distribution.
And I guess this can also normalize values in different levels to the same range. Otherwise there can be one level dominating the result.

@mjack3
Copy link
Author

mjack3 commented Mar 25, 2022

@gathierry is more about the -torch.mean and the interpolation of the negative

@gathierry
Copy link

@mjack3 oh about that.

  • -torch.mean: a gaussian should be like torch.exp(-0.5*torch.sum(output**2)). However, sum make the value tooooo small and it's not comparable between different levels as their channel numbers are different. I see CFlow divided the module of output**2 by channel number. So I do the same, which make me replace sum with mean

  • the interpolation of the negative: the result is the probability and anomaly score should be the opposite. Resolution of different levels are not identical so we have to resize them before merging.

@mjack3
Copy link
Author

mjack3 commented Mar 29, 2022

@gathierry thanks for your answer.

Why are you using that loss function instead something like this?
image
link

Am I missing something?

@gathierry
Copy link

@mjack3 I am using the same loss function
https://github.com/gathierry/FastFlow/blob/d275b79d47d6e272115d45fd7fc0f29cca0f5107/fastflow.py#L139

What you mentioned before
https://github.com/gathierry/FastFlow/blob/master/fastflow.py#L148-L151
was just for inference.

@Howeng98
Copy link

@gathierry From your implementation, your AUROC is refer to seg performance right?

@gathierry
Copy link

@Howeng98 yes, the pixel level AUROC

@briliantnugraha
Copy link

Does anyone have a clue how we could measure model performance without having labeled ground-truth test data?

Hi @maaft , IMO, if you are trying to measure model performance in accuracy/percentage (%).
However, if you are using distance measurements (kNN/aNN/euclidean etc), then it should be possible, with the caveat that you will still need to verify the measurements (manual label by human eyes) in the end in order to map distance mesurement -> accuracy.
Hope this helps :).

Note: Very cool progress and results for you guys, considering how little-informed the paper is.

@questionstorer
Copy link

questionstorer commented Apr 18, 2022

This is great discussion. I also performed some experiments using wide_resnet_50_2.

  1. I wrote from scratch the various modules in fastflow myself such as actnorm, affine coupling and channel permute, split and merge and combine them in a module fastflow_head (following the framework here). I didn't use the FrEIA framework just because I think the framework I follow is more transparent for me.

  2. Model details. I largely follow the paper's setting by using a flow step(called fastflow_head in my code) consisting of actnorm -> channel random shuffling -> two affine couplings acting on different halves of channels with Conv-RELU-Conv subnets. There are things that are not clearly specified in the paper and let me try to make it explicit here.

    • I use 3 feature maps from 3 layers of wide_resnet_50_2. On MvTec, they have sizes (256, 64, 64), (512,32,32) and (1024, 16,16) respectively. I followed the paper to build 3 fastflow models using these 3 feature maps
    • hidden channels in the Conv-RELU-Conv block: I use the same number as the number of input channels to the first Conv, that is, half the total channels in the original input from feature extractor. For wide_resnet_50_2, the hidden channels are respectively 128, 256, 512. I think there is some ambuiguity in the original paper, they say the number of input and output channels of the subnet are the same, but that's not the case.
    • For actnorm, I initialize the log-scale and bias with all zeros. This is different from the intialization in the GLOW paper. But I think there is no reason to initialize with mean and std of the activation in our case. In experiment, I found the trivial initialization makes it easier to train.
    • For optimizer, I follow the paper.
    • For loss, I use the same loss as the CFlow paper
    • For the number of flow steps, I have 4 flow steps with alternating 3x3 convs and 1x1 convs instead of 8 flow steps as specified in the paper. I found this makes a very big difference. I observed very unstable training for 8 flow steps and sometimes NaN losses or gradients. 4 flow steps is easier to train and it matches the parameter number specified in the paper.
    • For number of epochs, I use 500 epochs. I also observed that the model stabilize or reach the best performance in a few steps. But in my own project(where there is no labelled groundtruth for test), I use some good object as valid dataset and observe that the loss on valid data reach minimum at around 250 epochs.
    • For final anomaly map, I average on anomaly map from 3 fastflow models. For each fastflow model, I took the same method as in gathierry's approach There can be understood as taking product of the probability along channels.
    • Under the above specifications, the number of parameters is 41.3M. Using 8 flow steps doubles this number to 82.6M. Doubling the number of hidden channels also doubles this number to 82.6.
  3. I did not experiment on all category of the MvTec dataset. I only experiment on bottle category and see that it's pixel-wise AUROC matches the result from paper in just a few steps. Maybe other categories require more training epochs? That I'm not sure

@mjack3
Copy link
Author

mjack3 commented Apr 18, 2022

Guys, what's mean this in the figure 3?
What is y' ?
image

@AlessioGalluccio
Copy link
Owner

Guys, what's mean this in the figure 3? What is y' ? image

y' is the output of the coupling flow, while y is the input. Coupling flows work by splitting the input and keeping a set unchanged ("a" in this example) and modifying the other ("b" in this example). This is needed to have invertibility and easy Jacobian determinant to compute, since the Jacobian matrix becomes triangular.

@HanMinung
Copy link

I am very grateful for your GitHub source.
However, I have a few questions regarding the use of the model. As the epochs progress, the loss value appears as follows, and I wonder if this is correct.
Secondly, I placed the mvtec-ad data file in the path where the source files are gathered, but the following error occurs. Could you tell me what the problem is?

image

@lil-wayne-0319
Copy link

I am very grateful for your GitHub source. However, I have a few questions regarding the use of the model. As the epochs progress, the loss value appears as follows, and I wonder if this is correct. Secondly, I placed the mvtec-ad data file in the path where the source files are gathered, but the following error occurs. Could you tell me what the problem is?

image

maybe you can check there,
'dataset.py' Line 93:
target = Image.open(
# image_file.replace("/test/", "/ground_truth/").replace(
image_file.replace("test", "ground_truth").replace(
".png", "_mask.png"
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests