Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About Checkpoints #9

Closed
WY-2022 opened this issue Apr 12, 2022 · 1 comment
Closed

About Checkpoints #9

WY-2022 opened this issue Apr 12, 2022 · 1 comment

Comments

@WY-2022
Copy link

WY-2022 commented Apr 12, 2022

Hi! I have another question. If I just pip, and then :

class SWIN(nn.Module):
     def __init__(self, num_classes=4):
        super().__init__()
        self.num_classes = num_classes
        # self.pool = nn.MaxPool2d(2, 2)
        self.encoder: SwinTransformerV2 = swin_transformer_v2_t(in_channels=3,
                                                            window_size=8,
                                                            input_resolution=(1024, 1280),
                                                            sequential_self_attention=False,
                                                            use_checkpoint=True)
        self.p=self.encoder.patch_embedding
        self.encoder0 = self.encoder.stages[0]
        ... ...

How to use the checkpoint now? And Is there a pre-trained model for v2_base?
(And when I just run like above, a wired problem arises: 'warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")')

@ChristophReich1996
Copy link
Owner

For loading the provided checkpoints you need to initialize the network in the training configuration, load the state dict, and then change the resolution/window size for your need. Here an example for the CIFAR10 checkpoint:

import torch
from swin_transformer_v2 import swin_transformer_v2_b, SwinTransformerV2

swin_transformer: SwinTransformerV2 = swin_transformer_v2_t(input_resolution=(32, 32),
                                                                window_size=8,
                                                                sequential_self_attention=False,
                                                                use_checkpoint=True)
swin_transformer.load_state_dict(torch.load("path_to_weights/cifar10_swin_t_best_model_backbone.pt"))
swin_transformer.update_resolution(new_window_size=8, new_input_resolution=(1024, 1280))

Here an example for the Places365 dataset:

import torch
from swin_transformer_v2 import swin_transformer_v2_b, SwinTransformerV2

swin_transformer: SwinTransformerV2 = swin_transformer_v2_b(input_resolution=(256, 256),
                                                                window_size=8,
                                                                sequential_self_attention=False,
                                                                use_checkpoint=True)
swin_transformer.load_state_dict(torch.load("path_to_weights/places365_swin_b_best_model_backbone.pt"))
swin_transformer.update_resolution(new_window_size=8, new_input_resolution=(1024, 1280))

The CIFAR10 checkpoint is for the tiny model and the Places365 checkpoint is for the base model.

Please note that there are pre-trained weights on ImageNet1k available in the Timm library!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants