Skip to content

Commit

Permalink
fix stl10 datamodule (#369)
Browse files Browse the repository at this point in the history
  • Loading branch information
ananyahjha93 committed Nov 16, 2020
1 parent 6409e84 commit b746be0
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
2 changes: 1 addition & 1 deletion docs/source/self_supervised_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ STL-10 pretrained model::

from pl_bolts.models.self_supervised import SwAV

weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/checkpoints/epoch%3D96.ckpt'
weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/checkpoints/swav_stl10.pth.tar'
swav = SwAV.load_from_checkpoint(weight_path, strict=False)

swav.freeze()
Expand Down
5 changes: 2 additions & 3 deletions pl_bolts/datamodules/stl10_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def __init__(
self.batch_size = batch_size
self.seed = seed
self.num_unlabeled_samples = 100000 - unlabeled_val_split
self.labeled_val_split = 200

@property
def num_classes(self):
Expand Down Expand Up @@ -257,7 +256,7 @@ def train_dataloader_labeled(self):
dataset = STL10(self.data_dir, split='train', download=False, transform=transforms)
train_length = len(dataset)
dataset_train, _ = random_split(dataset,
[train_length - self.labeled_val_split, self.labeled_val_split],
[train_length - self.train_val_split, self.train_val_split],
generator=torch.Generator().manual_seed(self.seed))
loader = DataLoader(
dataset_train,
Expand All @@ -276,7 +275,7 @@ def val_dataloader_labeled(self):
transform=transforms)
labeled_length = len(dataset)
_, labeled_val = random_split(dataset,
[labeled_length - self.labeled_val_split, self.labeled_val_split],
[labeled_length - self.train_val_split, self.train_val_split],
generator=torch.Generator().manual_seed(self.seed))

loader = DataLoader(
Expand Down

0 comments on commit b746be0

Please sign in to comment.