-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to train Propos on STL-10 dataset #7
Comments
Hi, plz refer to #5 (comment) There are little changes in my codebase, you can take the code below as a reference: def set_loader(self):
opt = self.opt
dataset_name = opt.dataset
if dataset_name not in dataset_dict.keys():
mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
self.logger.msg_str(f'Dataset does {dataset_name} not exist in dataset_dict,'
f' use default normalizations: mean {str(mean)}, std {str(std)}.')
else:
mean, std = dataset_dict[dataset_name]
normalize = transforms.Normalize(mean=mean, std=std, inplace=True)
train_transform = self.train_transform(normalize)
self.logger.msg_str('set transforms...')
self.logger.msg_str(train_transform)
self.logger.msg_str('set train and unlabeled dataloaders...')
train_loader, labels, train_sampler = self.build_dataloader(
transform=train_transform,
batch_size=opt.batch_size,
shuffle=True,
drop_last=True,
sampler=True,
train=True)
unlabeled_loader, _, unlabeled_sampler = self.build_dataloader(
transform=train_transform,
batch_size=opt.batch_size,
shuffle=True,
drop_last=True,
sampler=True,
train=False)
test_transform = []
if 'imagenet' in dataset_name:
test_transform += [transforms.Resize(256), transforms.CenterCrop(224)]
test_transform += [
transforms.Resize(opt.img_size),
transforms.ToTensor(),
normalize
]
test_transform = transforms.Compose(test_transform)
# self.logger.msg_str('set test dataloaders...')
# test_loader = self.build_dataloader(test_transform, train=False, batch_size=opt.batch_size)[0]
self.logger.msg_str('set memory dataloaders...')
memory_loader = self.build_dataloader(test_transform, train=True, batch_size=opt.batch_size, sampler=True)[0]
# self.test_loader = test_loader
self.train_loader = train_loader
self.memory_loader = memory_loader
self.unlabeled_loader = unlabeled_loader
self.train_sampler = train_sampler
self.unlabeled_sampler = unlabeled_sampler
self.iter_per_epoch = len(train_loader) + len(unlabeled_loader)
self.num_classes = len(np.unique(labels))
self.num_samples = len(labels)
self.indices = torch.zeros(len(self.train_sampler), dtype=torch.long).cuda()
self.num_cluster = self.num_classes if opt.num_cluster is None else opt.num_cluster
self.psedo_labels = torch.zeros((self.num_samples,)).long().cuda()
self.logger.msg_str('load {} images...'.format(self.num_samples))
def fit(self):
opt = self.opt
# training routine
self.progress_bar = tqdm.tqdm(total=self.iter_per_epoch * opt.epochs, disable=(self.rank != 0))
n_iter = self.iter_per_epoch * opt.resume_epoch + 1
self.progress_bar.update(n_iter)
max_iter = opt.epochs * self.iter_per_epoch
while True:
epoch = int(n_iter // self.iter_per_epoch + 1)
self.train_sampler.set_epoch(epoch)
self.unlabeled_sampler.set_epoch(epoch)
for inputs in self.unlabeled_loader:
inputs = convert_to_cuda(inputs)
self.train_unlabeled(inputs, n_iter)
self.progress_bar.refresh()
self.progress_bar.update()
n_iter += 1
apply_kmeans = epoch % opt.reassign == 0
if apply_kmeans:
self.psedo_labeling(n_iter)
self.indices.copy_(torch.Tensor(list(iter(self.train_sampler))))
for inputs in self.train_loader:
inputs = convert_to_cuda(inputs)
self.adjust_learning_rate(n_iter)
self.train(inputs, n_iter)
self.progress_bar.refresh()
self.progress_bar.update()
n_iter += 1
# if epoch % opt.save_freq == 0:
# self.logger.checkpoints(int(epcoch))
# self.test(n_iter)
if n_iter > max_iter:
break
def train(self, inputs, n_iter):
opt = self.opt
images, labels = inputs
self.byol.train()
im_q, im_k = images
_start = ((n_iter - 1) % self.iter_per_epoch - len(self.unlabeled_loader)) * opt.batch_size
indices = self.indices[_start: _start + opt.batch_size]
psedo_labels = self.psedo_labels[indices]
# compute loss
contrastive_loss, cluster_loss_batch, q = self.byol(
im_q, im_k, psedo_labels)
loss = contrastive_loss
if ((n_iter - 1) / self.iter_per_epoch) > opt.warmup_epochs:
loss += cluster_loss_batch * opt.cluster_loss_weight
self.optimizer.zero_grad()
# SGD
loss.backward()
self.optimizer.step()
with torch.no_grad():
q_std = torch.std(q.detach(), dim=0).mean()
outputs = [contrastive_loss, cluster_loss_batch, q_std]
self.logger.msg(outputs, n_iter)
def train_unlabeled(self, inputs, n_iter):
opt = self.opt
images, labels = inputs
self.byol.train()
im_q, im_k = images
# compute loss
unlabeled_contrastive_loss, _, q = self.byol(im_q, im_k, None)
self.optimizer.zero_grad()
# SGD
unlabeled_contrastive_loss.backward()
self.optimizer.step()
with torch.no_grad():
unlabeled_q_std = torch.std(q.detach(), dim=0).mean()
outputs = [unlabeled_contrastive_loss, unlabeled_q_std]
self.logger.msg(outputs, n_iter) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Propos is a charming work, and thanks to your code. I want to ask how to train Propos on STL-10 cos I wonder if the unlabeled images are used for training. Hope that you could show the code. Thanks.
The text was updated successfully, but these errors were encountered: