diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index 0d78ea47d2..7c97e0b0c2 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -71,7 +71,7 @@ def __init__( self.batch_size = batch_size self.seed = seed self.num_unlabeled_samples = 100000 - unlabeled_val_split - self.num_labeled_samples = 5000 - train_val_split + self.labeled_val_split = 200 @property def num_classes(self): @@ -240,7 +240,7 @@ def train_dataloader_labeled(self): dataset = STL10(self.data_dir, split='train', download=False, transform=transforms) train_length = len(dataset) dataset_train, _ = random_split(dataset, - [train_length - self.num_labeled_samples, self.num_labeled_samples], + [train_length - self.labeled_val_split, self.labeled_val_split], generator=torch.Generator().manual_seed(self.seed)) loader = DataLoader( dataset_train, @@ -259,7 +259,7 @@ def val_dataloader_labeled(self): transform=transforms) labeled_length = len(dataset) _, labeled_val = random_split(dataset, - [labeled_length - self.num_labeled_samples, self.num_labeled_samples], + [labeled_length - self.labeled_val_split, self.labeled_val_split], generator=torch.Generator().manual_seed(self.seed)) loader = DataLoader( diff --git a/pl_bolts/models/self_supervised/cpc/cpc_finetuner.py b/pl_bolts/models/self_supervised/cpc/cpc_finetuner.py index 44c9baa9c3..1cb7c99468 100644 --- a/pl_bolts/models/self_supervised/cpc/cpc_finetuner.py +++ b/pl_bolts/models/self_supervised/cpc/cpc_finetuner.py @@ -31,8 +31,8 @@ def cli_main(): # pragma: no-cover elif args.dataset == 'stl10': dm = STL10DataModule.from_argparse_args(args) - dm.train_dataloader = dm.train_dataloader_mixed - dm.val_dataloader = dm.val_dataloader_mixed + dm.train_dataloader = dm.train_dataloader_labeled + dm.val_dataloader = dm.val_dataloader_labeled dm.train_transforms = CPCTrainTransformsSTL10() dm.val_transforms = CPCEvalTransformsSTL10() dm.test_transforms = CPCEvalTransformsSTL10()