-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Q&A #14
Comments
Hi @mjack3,
I think that the paper wants us to build three different models and average their anomaly score. But how do we compute this anomaly score? This is the question that I can't solve. In the introduction we can find that:
But I can't find how this two-dimensional loss is defined. If you have an idea of good two-dimensional loss for this problem, I'm all ears. |
hummm yes you are right, definitevly we need to create 3 fastflow models..i will try. By the way, you can find my implementation here https://github.com/mjack3/EasyFastFlow feel you free to use what you want |
Have you try contacting to some of the main authors of the paper? I googled them but didn't find the email |
@mjack3 Hi, have you take a look about the CFLOW-AD? It also implemented by FLOW model, maybe it can help you to understand how 3 Fastflow module work. I'm trying to implement Fastflow by modify Cflow-AD. If you need any help or discuss, I would like to help (if I can). |
@Howeng98 you are welcome =) Yes I also looked the CSFLOW-AD code but I am not sure if here, we need to create 3 individual fastFLow model and training with 3 optimizers (one for FastFLow) or doing similar to CSFLOW-AD |
@mjack3 I tried to contact Yushuang Wu through a university e-mail I found, but I got no answer. I haven't found the e-mail of the other authors |
When did you contact them ?@AlessioGalluccio |
Hi @mjack3, |
Currently i am obliged to make the code in private because my job contract. I hope to open it soon. Anyway I will share information in this same thread if is needed :) |
@AlessioGalluccio just a small remark: For anomaly score calculation (global and pixelwise) you need to use p(z) and not z which you are currently using. you can estimate logp(z) (and therefore p(z)) analogous to the pytorch implementation of CFlow AD. |
Hi @maaft, did you manage to achieve a similar result as the claimed? I tried both the way of CFlow and DifferNet but still far below the performance in the paper. Another confusion for me is that I cannot get the same A.d param#:
Here's code I used to compute param# def count_params_per_flow_step(k, cin, ratio):
cout = 2 * cin
cmed = int(cin * ratio)
w1 = k * k * cin * cmed
b1 = cmed
w2 = k * k * cmed * cout
b2 = cout
return w1 + w2 + b1 + b2
def count_total_params(num_steps, conv3x3_only, feature_channels, ratio):
s = 0
for channels in feature_channels:
for i in range(num_steps):
k = 1 if (i % 2 == 1 and not conv3x3_only) else 3
s += count_params_per_flow_step(k, channels // 2, ratio)
return s
print("CaiT: ", count_total_params(20, False, [768], 0.16))
print("DeiT: ", count_total_params(20, False, [768], 0.16))
print("Resnet18: ", count_total_params(8, True, [64, 128, 256], 1.0))
print("WideResnet50: ", count_total_params(8, False, [256, 512, 1024], 1.0)) |
@gathierry no, I don't think that I can match the scores in the paper (didn't evaluate it yet, only visually). In particular, the transistor class (broken legs) does not learn at all. I'll evaluate auroc etc next week and report back. Also I tried different backbones than resnet18, that achieve higher accuracy on imagenet (e.g. EfficientNet) and noticed that the training has a very hard time to converge at all. No idea, why this is the case. |
My tests of this code (24 Epochs) shows acceptable results only for Resnet18 and only for three mvtec classes: <style> </style>
Other classes performed badly. |
Thats weird, right? Do the resnet18 features follow some kind of special/nice distribution that the other architectures don't have? Has anyone tried different feature extractors with CFLOW-AD or other flow-based approaches? |
I found out that resnet18 works well, because the extracted features have a low magnitude. When I use e.g. EfficientNet and just scale the features by 0.1, the NF-Head seems to learn quite well.
|
Another thing: According to the architecture image (fig 2) from the paper, I think we should use Furthermore, the Does anyone know if the permutation indeed needs to be fixed during training? Or do we need to use a different permutation at every training step? I'm asking because The Edit: Does it really matter though? I think the reason for alternating coupling networks for RealNVP was to also train the upper half of channels.. But when permuting randomly multiple times, we also train every channel. Hm, I'm a bit clueless here. Edit2: |
I think Based on our experiments,
|
Yes, I think you are right. To match parameters: Which layers are you using from the resnet? Per paper:
the only free variable to play with in this case is the number of mid-channels for both subnets. Unfortunately my GPU memory is too small to use first 3 image features. Please let me know if you can achieve any good results with above configuration. Count Parameters with: nf_params = sum(p.numel() for p in self.nf.parameters() if p.requires_grad) # self.nf is the flow head |
I opened a Q in the Freia github |
I tried to move Permute and ActNorm from the end to the beginning of the block, as you suggested, but I didn't see significant improvement. Maybe there are some other issues in my code. |
I am getting NaN when ActNorm is at the beginning of the block in a innSeq. Could you share an image? |
I guess I could share my model later. No idea why you get nans. Maybe your data is already bad and contains nans? Are you normalizing your images? |
Btw, for resNet the output of layer1, layer2 and layer3 are used. Currently, i got a model that achieve [0.98, 1.0] with 25 epochs instead 500 In Clasification, for every class. It needs some adjusts but hope to open the code soon for community participation. Note: the code of this repo is wrong (sorry) |
For that, i am emulating the process: x = torch.rand(16,3,256,256) o=model(x) But yes, I tested with the real normalized image in a pytorch-standard way |
The permutation of channels must be fixed during training. As @gathierry mentioned, it's necessary for normalizing flows.
For the anomaly score I apply |
Looking a CFlow AD, it does in utils.py
In this way he gets the anomaly score. So, it's basically the same. |
I'll share my model, loss function and anomaly map generation tomorrow |
@AlessioGalluccio In CFlow, there's an exponential converting |
And for
|
@mjack3 As for ActNorm, I simply moved the def forward(self, x, c=[], rev=False, jac=True):
'''See base class docstring'''
if self.householder:
self.w_perm = self._construct_householder_permutation()
if rev or self.reverse_pre_permute:
self.w_perm_inv = self.w_perm.transpose(0, 1).contiguous()
# ==== ActNorm ====
x0, global_scaling_jac = self._permute(x[0], rev=False)
# ==== ActNorm end ====
x1, x2 = torch.split(x0, self.splits, dim=1)
if self.conditional:
x1c = torch.cat([x1, *c], 1)
else:
x1c = x1
if not rev:
a1 = self.subnet(x1c)
x2, j2 = self._affine(x2, a1)
else:
a1 = self.subnet(x1c)
x2, j2 = self._affine(x2, a1, rev=True)
log_jac_det = j2
x_out = torch.cat((x1, x2), 1)
# add the global scaling Jacobian to the total.
# trick to get the total number of non-channel dimensions:
# number of elements of the first channel of the first batch member
n_pixels = x_out[0, :1].numel()
log_jac_det += (-1)**rev * n_pixels * global_scaling_jac
return (x_out,), log_jac_det |
My model: My fastflow head is based on AllInOne:
The loss function I am using in training is just the same as you (based on CS-FLOW (Not CFLOW-AD)
and then I am trying to build the anomalies scores directly from Z, similar as CS-FLOW does. Reading the paper, we can see that authors can build a estimate map of anomalies but because of NF is built with dense layers instead convolution, it's not great for anomaly location. Currrently, reading eq8 from CSFLOW-AD, you piece of code has sense but, we would be mixing different training logic because I think the loss funciont in CSFLOW-AD is different of what you are using (I didn't read the paper in detail yet) |
Thank you very much. Just some notes/questions:
|
|
Thanks for the other points! :) |
humm thanks for your point. I haven't so much experience in normalizing flow so, probably you are in true and I'm wrong using |
I don't think that it matters in the end. Your performance is great. By the way, I think that realNVP introduced the double sided blocks only to make sure that all channels are trained. They did this before |
I think there's conflict about the definition of "one step" in Eq(3) and Fig 2 |
I don't think so. The figure says "flow step" while eq.3 only says "step" and takes reference to Dhin et al. |
If authors only would publish their f** code. It should be mandatory lol. |
I dont get the result of the paper :( |
The code to my repo is opened. I welcome any contribution there. |
I managed to achieve comparable performance in a tricky way. I add LayerNorm layers before sending each feature map to NFs.
I opened my code as well https://github.com/gathierry/FastFlow |
@gathierry wow, LayerNorm really boosts the performance enormously. Thank you very much for that and for opening your code! :) |
Does anyone have a clue how we could measure model performance without having labeled ground-truth test data? |
@mjack3 w/o = without w = with |
@gathierry I 'm wondering about why you do this in lanes https://github.com/gathierry/FastFlow/blob/master/fastflow.py#L148-L151 |
@mjack3 why I'm doing exponential? Because the values in different levels should be converted to a probability before merged. I want to project the |
@gathierry is more about the -torch.mean and the interpolation of the negative |
@mjack3 oh about that.
|
@gathierry thanks for your answer. Why are you using that loss function instead something like this? Am I missing something? |
@mjack3 I am using the same loss function What you mentioned before |
@gathierry From your implementation, your AUROC is refer to seg performance right? |
@Howeng98 yes, the pixel level AUROC |
Hi @maaft , IMO, if you are trying to measure model performance in accuracy/percentage (%). Note: Very cool progress and results for you guys, considering how little-informed the paper is. |
This is great discussion. I also performed some experiments using wide_resnet_50_2.
|
y' is the output of the coupling flow, while y is the input. Coupling flows work by splitting the input and keeping a set unchanged ("a" in this example) and modifying the other ("b" in this example). This is needed to have invertibility and easy Jacobian determinant to compute, since the Jacobian matrix becomes triangular. |
I am very grateful for your GitHub source. |
maybe you can check there, |
Hello.
I would like to open this issue to talk about this project. I am also interested in developing this project and would be great to share information as the paper doesn't give deeply information about the implementations and offical code is no available.
If you are agree with this iniciative, firstly we could simplify the project to use Wide-ResNet50 in order to get comparative results with the previous researching. I would like to start from the begining of the paper when says:
This make me thing that in the implementation we need to use the features after the input layer, layer 1 and layer 2. In this way this table 6 makes sense
But can not to imagine how to concatenate this information for make it sense with the next
Depending of what part you read, it seems that just one feature map or 3 are taken
The text was updated successfully, but these errors were encountered: