Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Very Low Accuracy When Using Pretrained Model #8978

Closed
jwfromm opened this issue Dec 7, 2017 · 1 comment
Closed

Very Low Accuracy When Using Pretrained Model #8978

jwfromm opened this issue Dec 7, 2017 · 1 comment

Comments

@jwfromm
Copy link

jwfromm commented Dec 7, 2017

Description

Pretrained models dont seem to be working well with gluon, specifically datasets build with Dataloader and ImageRecords or ImageFolders.

As an example, here I load the ImageNet validation dataset and feed it into alexnet downloaded from gluon model zoo

ctx = mx.gpu()
batch_size = 64
def transformer(data, label):
    data = mx.image.imresize(data, 224, 224)
    data = mx.nd.transpose(data, (2,0,1))
    data = data.astype(np.float32)
    return data/255, label

test_data = gluon.data.DataLoader(gluon.data.vision.ImageFolderDataset(root="/data2/imagenet/val/", transform=transformer),
                                      batch_size, shuffle=False)

model = gluon.model_zoo.vision.alexnet(pretrained=True, ctx=ctx)

def evaluate_accuracy(data_iterator, net):
    acc = mx.metric.Accuracy()
    for d, l in data_iterator:
        data = d.as_in_context(ctx)
        label = l.as_in_context(ctx)
        output = net(data)
        predictions = nd.argmax(output, axis=1)
        acc.update(preds=predictions, labels=label)
    return acc.get()[1]

evaluate_accuracy(test_data, model, ctx=ctx)
[0.12393999999999999]

The 12% accuracy shows the issue is probably that the transforms used to train the model dont exactly align with the transforms presented in the Gluon tutorial. It would be nice if an example showing how to properly do this using the new gluon functions were added.

Environment info (Required)

Python 3.6

@jwfromm
Copy link
Author

jwfromm commented Dec 7, 2017

The issue turned out to be that both scaling and normalization are needed to match the pytorch transforms.

def transformer(data, label):
    data = mx.image.imresize(data, 256, 256)
    data, _ = mx.image.center_crop(data, (224, 224))
    data = data.astype(np.float32)
    data = data/255
    data = mx.image.color_normalize(data,
                                    mean=mx.nd.array([0.485, 0.456, 0.406]),
                                    std=mx.nd.array([0.229, 0.224, 0.225])) 
    data = mx.nd.transpose(data, (2,0,1))
    return data, label

@jwfromm jwfromm closed this as completed Dec 7, 2017
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant