Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapting to Different Image Size #7

Closed
tjdurant opened this issue Feb 6, 2018 · 8 comments
Closed

Adapting to Different Image Size #7

tjdurant opened this issue Feb 6, 2018 · 8 comments

Comments

@tjdurant
Copy link

tjdurant commented Feb 6, 2018

Hello,

Our team is interested in testing an implementation of the mean-teacher Resnet in PyTorch for a few image classification problems we are working on.

However, we are having difficulty adapting the network to our image dimensions.

If I resize our images to 32x32 it runs without error. But, if I change to something else, I get:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.5/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/opt/conda/lib/python3.5/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/mnt/data/nihxr/emo/mean-teacher/pytorch/experiments/rbc_test.py", line 76, in <module>
    run(**run_params)
  File "/mnt/data/nihxr/emo/mean-teacher/pytorch/experiments/rbc_test.py", line 71, in run
    main.main(context)
  File "/mnt/data/nihxr/emo/mean-teacher/pytorch/main.py", line 97, in main
    train(train_loader, model, ema_model, optimizer, epoch, training_log)
  File "/mnt/data/nihxr/emo/mean-teacher/pytorch/main.py", line 225, in train
    ema_model_out = ema_model(ema_input_var)
  File "/opt/conda/lib/python3.5/site-packages/torch/nn/modules/module.py", line 325, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.5/site-packages/torch/nn/parallel/data_parallel.py", line 68, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/opt/conda/lib/python3.5/site-packages/torch/nn/parallel/data_parallel.py", line 78, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/opt/conda/lib/python3.5/site-packages/torch/nn/parallel/parallel_apply.py", line 67, in parallel_apply
    raise output
  File "/opt/conda/lib/python3.5/site-packages/torch/nn/parallel/parallel_apply.py", line 42, in _worker
    output = module(*input, **kwargs)
  File "/opt/conda/lib/python3.5/site-packages/torch/nn/modules/module.py", line 325, in __call__
    result = self.forward(*input, **kwargs)
  File "/mnt/data/nihxr/emo/mean-teacher/pytorch/mean_teacher/architectures.py", line 158, in forward
    x = self.layer3(x)
  File "/opt/conda/lib/python3.5/site-packages/torch/nn/modules/module.py", line 325, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.5/site-packages/torch/nn/modules/container.py", line 67, in forward
    input = module(input)
  File "/opt/conda/lib/python3.5/site-packages/torch/nn/modules/module.py", line 325, in __call__
    result = self.forward(*input, **kwargs)
  File "/mnt/data/nihxr/emo/mean-teacher/pytorch/mean_teacher/architectures.py", line 255, in forward
    residual = self.downsample(x)
  File "/opt/conda/lib/python3.5/site-packages/torch/nn/modules/module.py", line 325, in __call__
    result = self.forward(*input, **kwargs)
  File "/mnt/data/nihxr/emo/mean-teacher/pytorch/mean_teacher/architectures.py", line 302, in forward
    x[:, :, 1::2, 1::2]), dim=1)
RuntimeError: inconsistent tensor sizes at /opt/conda/conda-bld/pytorch_1512382878663/work/torch/lib/THC/generic/THCTensorMath.cu:157

Which makes sense. We're just a little unfamiliar with PyTorch and, speaking for myself, Resnet. So, I thought I would post this question while I was looking into this to see if someone might post an obvious hint that may not be obvious to find.

Thank you in advance,
Tommy

@tarvaina
Copy link
Contributor

tarvaina commented Feb 7, 2018

Hi,

The implementation contains two different architectures: cifar_shakeshake26 and resnext152. The former was used for CIFAR-10 experiments and the latter for ImageNet experiments in the paper. They expect different image sizes.

The architectures are defined in pytorch/mean_teacher/architectures.py and can be applied with the --arch option in the command line. E.g. to train on ImageNet you would use something like this:

python main.py \
    --dataset imagenet \
    --arch resnext152 \
    --consistency 10.0 \
    --consistency-rampup 5 \
    --epochs 60

See the experiments/imagenet_valid.py for a more complete set of good hyperparams.

If I recall correctly, the resnext152 architecture works with image sizes equal to or larger than 224x224. One way to tweak it to work with different image sizes is to change the first convolutional layer here to something that suits your need. You may also want to adapt the number of layers here at first, since the 152-layer one may take days or weeks to train.

Alternatively, you can add your own architectures in the architectures.py. It should also be fairly easy to plug in architectures from torchvision.models.

I hope this helps.

@tarvaina tarvaina closed this as completed Feb 7, 2018
@tjdurant
Copy link
Author

tjdurant commented Feb 20, 2018

@tarvaina

Thank you very much for the quick reply.

I tried to add in other models. I am getting an error:

INFO:main:--- training epoch in 13.531290054321289 seconds ---
INFO:main:Evaluating the primary model:
Traceback (most recent call last):
  File "main.py", line 421, in <module>
    main(RunContext(__file__, 0))
  File "main.py", line 103, in main
    prec1 = validate(eval_loader, model, validation_log, global_step, epoch + 1)
  File "main.py", line 331, in validate
    output1, output2 = model(input_var)
ValueError: too many values to unpack (expected 2)

For more information, one of our classification problems involves images which are 70x70 for which there are 10 classes. For the other, the images are 224x224 and it is a binary classification problem. So, with this example, the mean-teacher implementation errors out when it attempts to calculate the Prec@5, which is understandable. But, I don't understand why there are the two variables expected to be returned for model(input_var) and why changing the model architecture causes it to fail?

@tarvaina
Copy link
Contributor

tarvaina commented Feb 21, 2018

Whoops, I think you stepped on a bug. Thanks for letting me know!

The code was meant to support cases where the model returns either one or two outputs. This is done on lines 225-235 of the train function but apparently I forgot to add support for one output in the validate function.

I'm not on my development machine, so I cannot fix it today. A quick fix for your case may be e.g. to change

output1, output2 = model(input_var)
softmax1, softmax2 = F.softmax(output1, dim=1), F.softmax(output2, dim=1)

to

output1 = model(input_var)
softmax1 = F.softmax(output1, dim=1)

on lines 327-328 of the validate function, since validate does not actually use the second output for anything. A more complete solution is to accept either one or two outputs.

Two outputs are needed when --logit-distance-cost parameter is used. (Figure 4(e) and section 3.4 in the paper.) It's a useful but not necessary trick to make learning more robust early in the training.

@tjdurant
Copy link
Author

@tarvaina

Thank you very much for the reply again. Thank you for being so responsive. It's looks like a great implementation and we are having fun playing with it.

Since it sounds like you are still interested in actively maintaining it, I thought I would pass along another thing I ran into in case you are interested.

One of our classification problems is a 224x224 image that is either positive or negative. So, when I put that dataset into resnext152, I get:

Traceback (most recent call last):
  File "main.py", line 420, in <module>
    main(RunContext(__file__, 0))
  File "main.py", line 97, in main
    train(train_loader, model, ema_model, optimizer, epoch, training_log)
  File "main.py", line 270, in train
    prec1, prec5 = accuracy(class_logit.data, target_var.data, topk=(1, 5))
  File "main.py", line 406, in accuracy
    _, pred = output.topk(maxk, 1, True, True)
RuntimeError: invalid argument 5: k not in range for dimension at /opt/conda/conda-bld/pytorch_1512382878663/work/torch/lib/THC/generic/THCTensorTopK.cu:21

Which looks like it is complaining about there not being 5 classes.

So, not an issue, but just wanted to pass it along since your response felt like you were still working on this. You might feel this needed handling, but may not be necessary for your use-case. But, just in case! : )

Thank you again for your help so far,
Tommy

@liuajian
Copy link

TyoeError:softmax() get an unexpected keyword argument 'dim'

@JohnnyRisk
Copy link

Thank you for the awesome implementation. I have been trying to use it for a binary classification task. However, I have come across something I find strange. It appears that the softmax1 or softmax2 are not actually used.

softmax1, softmax2 = F.softmax(output1, dim=1), F.softmax(output2, dim=1)
class_loss = class_criterion(output1, target_var) / minibatch_size

Should output1 in fact be softmax1? I am curious because it appears we are putting the raw logits class criterion instead of the softmax. Is this correct?

@tarvaina
Copy link
Contributor

tarvaina commented Aug 3, 2018

Hi Johnny,

Yeah, the first line is superfluous. I must have been using the softmaxes for debugging or something else and then forgotten them there. Thanks for pointing it out.

The class_criterion is actually a CrossEntropyLoss which takes logits as inputs. So the second line is correct.

@JohnnyRisk
Copy link

Great, that makes sense. Thank you for the quick reply! Again I appreciate the great implementation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants