In [12]:
import torch
import torchvision
import torch.nn.functional as F

In [6]:
resnet18 = torchvision.models.resnet18(num_classes=128)

In [7]:
resnet18

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Co

In [9]:
# torchvision.models.__dict__

In [8]:
resnet18.fc

Linear(in_features=512, out_features=128, bias=True)

In [10]:
resnet18.fc.weight.shape

torch.Size([128, 512])

In [11]:
vec = torch.randn(4, 2)
vec

tensor([[-0.3001,  0.7760],
        [ 2.3347, -1.1546],
        [ 0.3472, -1.0327],
        [-0.1852, -0.0680]])

In [13]:
F.normalize(vec, dim=0)

tensor([[-0.1257,  0.4475],
        [ 0.9783, -0.6659],
        [ 0.1455, -0.5956],
        [-0.0776, -0.0392]])

In [15]:
vec[:, 0] / torch.sqrt(torch.sum(vec[:, 0] ** 2))

tensor([-0.1257,  0.9783,  0.1455, -0.0776])

In [22]:
idx_shuffle = torch.randperm(8)
idx_shuffle

tensor([6, 5, 0, 2, 3, 7, 4, 1])

In [23]:
idx_unshuffle = torch.argsort(idx_shuffle)
idx_unshuffle

tensor([2, 7, 3, 4, 6, 1, 0, 5])

In [24]:
idx_this = idx_shuffle.view(4, -1)[1]
idx_this

tensor([0, 2])

In [25]:
idx_shuffle.shape

torch.Size([8])

In [26]:
idx_shuffle.view(4, -1).shape

torch.Size([4, 2])

In [27]:
import torchvision.models as models
model_names = sorted(name for name in models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models.__dict__[name]))

In [28]:
model_names

['alexnet',
 'densenet121',
 'densenet161',
 'densenet169',
 'densenet201',
 'googlenet',
 'inception_v3',
 'mobilenet_v2',
 'resnet101',
 'resnet152',
 'resnet18',
 'resnet34',
 'resnet50',
 'resnext101_32x8d',
 'resnext50_32x4d',
 'shufflenet_v2_x0_5',
 'shufflenet_v2_x1_0',
 'shufflenet_v2_x1_5',
 'shufflenet_v2_x2_0',
 'squeezenet1_0',
 'squeezenet1_1',
 'vgg11',
 'vgg11_bn',
 'vgg13',
 'vgg13_bn',
 'vgg16',
 'vgg16_bn',
 'vgg19',
 'vgg19_bn']

In [29]:
import os
os.environ["WORLD_SIZE"]

KeyError: 'WORLD_SIZE'

In [32]:
_, pred = vec.topk(1, 1, True, True)
pred

tensor([[1],
        [0],
        [0],
        [1]])

In [33]:
pred.t()

tensor([[1, 0, 0, 1]])

In [34]:
a = torch.randn(8, 10)
a

tensor([[-0.1722,  0.4292, -0.0413, -1.8943,  0.1002,  0.9494, -2.7804,  0.5531,
         -0.4744, -0.3473],
        [ 2.5331,  0.4536, -0.7446, -0.7028, -0.7376,  0.4304, -0.9855, -0.3406,
          0.6078,  1.5183],
        [ 0.1459,  2.0748, -0.4021, -0.3458, -1.3079, -0.4880,  0.0301, -0.2488,
         -0.6806, -0.6541],
        [ 2.0599, -1.9154, -0.3832, -0.5650, -0.4352, -0.7920,  0.6852, -0.5222,
         -1.3149,  0.2537],
        [ 0.7563, -0.3223, -0.5883, -0.2502, -0.3496,  0.4060, -0.6947, -0.6795,
         -0.4522,  0.0400],
        [-1.4382, -0.6075, -0.5095,  0.2217,  0.5244,  1.4736, -0.6967,  0.2123,
          1.2263, -0.6189],
        [ 0.1255,  0.8646, -0.2279, -1.2988, -0.9159,  1.4411, -0.3452, -0.0187,
         -1.1880,  0.7530],
        [-0.8967, -0.2529,  0.1857, -0.1258, -0.5474, -1.4879,  0.8720, -0.2896,
         -0.9136, -1.2627]])

In [36]:
_, pred = a.topk(5, 1, True, True)
pred = pred.t()
pred

tensor([[5, 0, 1, 0, 0, 5, 5, 6],
        [7, 9, 0, 6, 5, 8, 1, 2],
        [1, 8, 6, 9, 9, 4, 9, 3],
        [4, 1, 7, 2, 3, 3, 0, 1],
        [2, 5, 3, 4, 1, 7, 7, 7]])