Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ViTs #3

Merged
merged 16 commits into from
May 30, 2023
Merged

Add ViTs #3

merged 16 commits into from
May 30, 2023

Conversation

nps1ngh
Copy link
Contributor

@nps1ngh nps1ngh commented May 27, 2023

@moboehle take a look please! :)

(You can ignore the pretrained.py and README.md files.)

I'll merge it if everything looks good and then finish it up as the v0.1.0 release

@nps1ngh nps1ngh requested a review from moboehle May 27, 2023 19:44
README.md Outdated Show resolved Hide resolved
Comment on lines 84 to 183
]

baseline = {
f"{name}": update_default(
dict(
data=dict(
batch_size=DEFAULT_BATCH_SIZE
if "_l_" not in name and "simple_vit_b" not in name
else DEFAULT_BATCH_SIZE // 2,
train_transform=ImageNetClassificationPresetTrain(
crop_size=DEFAULT_CROP_SIZE,
auto_augment_policy="ra",
ra_magnitude=10,
is_bcos=False,
),
test_transform=ImageNetClassificationPresetEval(
crop_size=DEFAULT_CROP_SIZE,
is_bcos=False,
),
),
model=dict(
is_bcos=False,
name=name,
args=dict(
# linear_layer and conv2d_layer set by model.py
norm_layer=nn.LayerNorm,
norm2d_layer=DetachableGNLayerNorm2d,
act_layer=nn.GELU,
channels=3,
),
),
criterion=nn.CrossEntropyLoss(),
test_criterion=nn.CrossEntropyLoss(),
optimizer=OptimizerFactory(
"AdamW",
lr=DEFAULT_LR,
weight_decay=0.0001,
),
use_agc=False,
lr_scheduler=DEFAULT_LR_SCHEDULE
if "_l_" not in name and "simple_vit_b" not in name
else LONG_WARM_SCHEDULE,
trainer=dict(
gradient_clip_val=1.0,
),
)
)
for name in SIMPLE_VIT_ARCHS
}


bcos = {
f"bcos_{name}": update_default(
dict(
data=dict(
batch_size=DEFAULT_BATCH_SIZE
if "_l_" not in name and "simple_vit_b" not in name
else DEFAULT_BATCH_SIZE // 2,
train_transform=ImageNetClassificationPresetTrain(
crop_size=DEFAULT_CROP_SIZE,
auto_augment_policy="ra",
ra_magnitude=10,
is_bcos=True,
),
test_transform=ImageNetClassificationPresetEval(
crop_size=DEFAULT_CROP_SIZE,
is_bcos=True,
),
num_workers=10,
),
model=dict(
is_bcos=True,
name=name,
args=dict(
# linear_layer and conv2d_layer set by model.py
norm_layer=norms.NoBias(norms.DetachableLayerNorm),
act_layer=nn.Identity,
channels=6,
norm2d_layer=norms.NoBias(DetachableGNLayerNorm2d),
),
bcos_args=dict(
b=2,
max_out=1,
),
logit_bias=math.log(1 / (NUM_CLASSES - 1)),
),
criterion=UniformOffLabelsBCEWithLogitsLoss(),
lr_scheduler=DEFAULT_LR_SCHEDULE
if "_l_" not in name and "simple_vit_b" not in name
else LONG_WARM_SCHEDULE,
test_criterion=BinaryCrossEntropyLoss(),
optimizer=OptimizerFactory(
"Adam",
lr=DEFAULT_LR,
),
)
)
for name in SIMPLE_VIT_ARCHS
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about using something like the following in the configs, might be a bit cleaner:

is_big_model = lambda model_name: "_l_" in model_name or "simple_vit_b" in model_name

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, much more readable. I'll push a change. Thanks

bcos/models/vit.py Outdated Show resolved Hide resolved
Copy link
Contributor

@moboehle moboehle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, minor questions / comments, what do you think?

nps1ngh and others added 3 commits May 30, 2023 17:55
@nps1ngh
Copy link
Contributor Author

nps1ngh commented May 30, 2023

Great suggestions and nice catch!

I'll merge and then finish the release. Thanks for taking a look! 😃

@nps1ngh nps1ngh merged commit 0f6389e into main May 30, 2023
@nps1ngh nps1ngh deleted the vits branch May 30, 2023 16:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants