Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Declining test results #4

Closed
xiaobai-marker opened this issue Nov 15, 2021 · 9 comments
Closed

Declining test results #4

xiaobai-marker opened this issue Nov 15, 2021 · 9 comments

Comments

@xiaobai-marker
Copy link

xiaobai-marker commented Nov 15, 2021

Hello!Hi, the following code was used directly in my model:
transforms.Compose([transforms.TrivialAugmentWide(),transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(), transforms. ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
after I found that both validation loss and training loss increased on mobilev2 and resnet50 , am I using the wrong method?The training set is the flower dataset

@SamuelGabriel
Copy link
Contributor

What is the transform of the baseline? What is the complete diff to the baseline?

@xiaobai-marker
Copy link
Author

My training set is a flower classification training set, the model is mobilenetV2, using pytorch's pre-training model, before adding this line of code transforms.TrivialAugmentWide() the loss value is about 0.9, using TA's method the loss value becomes about 1.3

@xiaobai-marker
Copy link
Author

Is it wrong to use the following way?transforms.TrivialAugmentWide()

@SamuelGabriel
Copy link
Contributor

Are you talking about training or validation loss? Training loss is supposed to go up, as classifying augmented images is generally harder. You should turn off TrivialAugment during validation as well. For reference on how to use the torchvision version, see their references https://github.com/pytorch/vision/tree/main/references/classification which has TrivialAugment as an option. It is tested there and works for ImageNet. It generally is really hard for me to help you out with so little information, so if the above does not help, I could have quick glance over your codebase if you put it online somewhere.

@xiaobai-marker
Copy link
Author

xiaobai-marker commented Nov 18, 2021

`import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
from model_v2 import MobileNetV2
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
batch_size = 16
epochs = 5
data_transform = {
"train": transforms.Compose([transforms.TrivialAugmentWide(),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
"val": transforms.Compose([transforms.TrivialAugmentWide(),
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                     transform=data_transform["train"])
train_num = len(train_dataset)

# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
    json_file.write(json_str)

nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
print('Using {} dataloader workers every process'.format(nw))

train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size, shuffle=True,
                                           num_workers=nw)

validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                        transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                              batch_size=batch_size, shuffle=False,
                                              num_workers=nw)

print("using {} images for training, {} images for validation.".format(train_num,
                                                                       val_num))

# create model
net = MobileNetV2(num_classes=5)

# load pretrain weights
# download url: https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
model_weight_path = "./mobilenet_v2.pth"
assert os.path.exists(model_weight_path), "file {} dose not exist.".format(model_weight_path)
pre_weights = torch.load(model_weight_path, map_location=device)

# delete classifier weights
pre_dict = {k: v for k, v in pre_weights.items() if net.state_dict()[k].numel() == v.numel()}

# pre_dict = {k: v for k, v in pre_weights.items() if "classifier" not in k
missing_keys, unexpected_keys = net.load_state_dict(pre_dict, strict=False)

# freeze features weights
for param in net.features.parameters():
    param.requires_grad = False

net.to(device)

# define loss function
loss_function = nn.CrossEntropyLoss()

# construct an optimizer
params = [p for p in net.parameters() if p.requires_grad]
optimizer = optim.Adam(params, lr=0.0001)

best_acc = 0.0
save_path = './MobileNetV2.pth'
train_steps = len(train_loader)
for epoch in range(epochs):
    # train
    net.train()
    running_loss = 0.0
    train_bar = tqdm(train_loader)
    for step, data in enumerate(train_bar):
        images, labels = data
        optimizer.zero_grad()
        logits = net(images.to(device))
        loss = loss_function(logits, labels.to(device))
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()

        train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                 epochs,
                                                                 loss)

    # validate
    net.eval()
    acc = 0.0  # accumulate accurate number / epoch
    with torch.no_grad():
        val_bar = tqdm(validate_loader)
        for val_data in val_bar:
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))
            # loss = loss_function(outputs, test_labels)
            predict_y = torch.max(outputs, dim=1)[1]
            acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

            val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
                                                       epochs)
    val_accurate = acc / val_num
    print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
          (epoch + 1, running_loss / train_steps, val_accurate))

    if val_accurate > best_acc:
        best_acc = val_accurate
        torch.save(net.state_dict(), save_path)

print('Finished Training')

if name == 'main':
main()
`

@xiaobai-marker
Copy link
Author

This is my code, after using TA's method the loss value does go up compared to before.

@SamuelGabriel
Copy link
Contributor

Ok, it indeed is the problem I mentioned above. You use TrivialAugment during validation as well, but it is only to be used during training.

@xiaobai-marker
Copy link
Author

xiaobai-marker commented Nov 18, 2021

Yes, I've experimented with the method you mentioned, I just didn't remove the TA from the validation process when sending the code, but the loss value does go up after the validation process removes the TA, compared to the method without the TA

@SamuelGabriel
Copy link
Contributor

One more thing: You do fine-tuning. This was not part of our experiments, we always trained from scratch. Generally, fine-tuning is known to need different and usually less augmentations. If you still want to try out standard augmentation methods, it might be good to start from something like the references of torchvision and then try multiple augmentation methods there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants