From b746be06a7ef59c52fb7911b985a1a73821d1a74 Mon Sep 17 00:00:00 2001 From: Ananya Harsh Jha Date: Mon, 16 Nov 2020 07:23:13 -0500 Subject: [PATCH] fix stl10 datamodule (#369) --- docs/source/self_supervised_models.rst | 2 +- pl_bolts/datamodules/stl10_datamodule.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/docs/source/self_supervised_models.rst b/docs/source/self_supervised_models.rst index 5f13caed4c..6856efdbf4 100644 --- a/docs/source/self_supervised_models.rst +++ b/docs/source/self_supervised_models.rst @@ -485,7 +485,7 @@ STL-10 pretrained model:: from pl_bolts.models.self_supervised import SwAV - weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/checkpoints/epoch%3D96.ckpt' + weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/checkpoints/swav_stl10.pth.tar' swav = SwAV.load_from_checkpoint(weight_path, strict=False) swav.freeze() diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index 2ca58b3268..b1ee3058a8 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -88,7 +88,6 @@ def __init__( self.batch_size = batch_size self.seed = seed self.num_unlabeled_samples = 100000 - unlabeled_val_split - self.labeled_val_split = 200 @property def num_classes(self): @@ -257,7 +256,7 @@ def train_dataloader_labeled(self): dataset = STL10(self.data_dir, split='train', download=False, transform=transforms) train_length = len(dataset) dataset_train, _ = random_split(dataset, - [train_length - self.labeled_val_split, self.labeled_val_split], + [train_length - self.train_val_split, self.train_val_split], generator=torch.Generator().manual_seed(self.seed)) loader = DataLoader( dataset_train, @@ -276,7 +275,7 @@ def val_dataloader_labeled(self): transform=transforms) labeled_length = len(dataset) _, labeled_val = random_split(dataset, - [labeled_length - self.labeled_val_split, self.labeled_val_split], + [labeled_length - self.train_val_split, self.train_val_split], generator=torch.Generator().manual_seed(self.seed)) loader = DataLoader(