Skip to content

Commit

Permalink
Use correct type and fix documentation (#36)
Browse files Browse the repository at this point in the history
* Use correct type and fix documentation

* Fix PEP8
  • Loading branch information
Frédéric Branchaud-Charron committed Jun 15, 2020
1 parent 76e54c1 commit 4642b32
Show file tree
Hide file tree
Showing 7 changed files with 13 additions and 12 deletions.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,4 @@ Indices and tables
api/pl_bolts.loggers
api/pl_bolts.optimizers
api/pl_bolts.transforms
api/pl_bolts.utils
1 change: 1 addition & 0 deletions docs/source/readme.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# PyTorchLightning Bolts


![CI testing](https://github.com/PyTorchLightning/pytorch-lightning-bolts/workflows/CI%20testing/badge.svg?branch=master)



Expand Down
8 changes: 4 additions & 4 deletions pl_bolts/losses/self_supervised_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,10 @@ class AmdimNceLoss(nn.Module):
def forward(self, anchor_representations, positive_representations, mask_mat):
"""
Compute the NCE scores for predicting r_src->r_trg.
Input:
r_src : (batch_size, emb_dim)
r_trg : (emb_dim, n_batch * w* h) (ie: nb_feat_vectors x embedding_dim)
mask_mat : (n_batch_gpu, n_batch)
Args:
anchor_representations : (batch_size, emb_dim)
positive_representations : (emb_dim, n_batch * w* h) (ie: nb_feat_vectors x embedding_dim)
mask_mat : (n_batch_gpu, n_batch)
Output:
raw_scores : (n_batch_gpu, n_locs)
Expand Down
2 changes: 0 additions & 2 deletions pl_bolts/metrics/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ def accuracy(preds, labels):

return acc

result.log('key', val)
result.to_pbar()

def precision_at_k(output, target, top_k=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/models/gans/basic/basic_gan_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def discriminator_loss(self, x):

# how well can it label as fake?
fake = torch.zeros(x.size(0), 1)
fake = fake.type_as(fake)
fake = fake.type_as(x)

fake_loss = self.adversarial_loss(
self.discriminator(self.generated_imgs.detach()), fake)
Expand Down
9 changes: 5 additions & 4 deletions pl_bolts/models/self_supervised/resnets.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def forward(self, x):

return out


class BottleneckBN(nn.Module):
expansion = 4

Expand Down Expand Up @@ -434,12 +435,12 @@ def __init__(self, expansion=1, num_classes=1000, zero_init_residual=False,
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(BottleneckBN, 64*expansion, layers[0])
self.layer2 = self._make_layer(BottleneckBN, 128*expansion, layers[1], stride=2,
self.layer1 = self._make_layer(BottleneckBN, 64 * expansion, layers[0])
self.layer2 = self._make_layer(BottleneckBN, 128 * expansion, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(BottleneckBN, 256*expansion, layers[2], stride=2,
self.layer3 = self._make_layer(BottleneckBN, 256 * expansion, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(BottleneckBN, 512*expansion, layers[3], stride=2,
self.layer4 = self._make_layer(BottleneckBN, 512 * expansion, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * BottleneckBN.expansion * expansion, num_classes)
Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/utils/ssl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ def torchvision_ssl_encoder(name, pretrained=False, return_all_feature_maps=Fals
pretrained_model = getattr(resnets, name)(pretrained=pretrained, return_all_feature_maps=return_all_feature_maps)

pretrained_model.fc = Identity()
return pretrained_model
return pretrained_model

0 comments on commit 4642b32

Please sign in to comment.