In [9]:
import utils 
import time 
import pickle

import torch as T
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

import matplotlib.pyplot as plt

import gym
from gym import wrappers

In [13]:
import torchvision.models as models
mobilenet = models.mobilenet_v2(pretrained=True)

# features is Sequence object, it is full of Conv and related layers. We don't touch that
# classifier is Sequence object with a nn.Dropout object at the first index
# and a nn.Linear object at the second/last index
num_ftrs = mobilenet.classifier[-1].in_features
# replace nn.Linear
mobilenet.classifier[-1] = nn.Linear(num_ftrs, 6)

t_in = T.randn(64, 3, 256, 256)
print("MobileNet v2:", mobilenet.classifier, mobilenet(t_in).shape, sep="\n", end="\n\n")

MobileNet v2:
Sequential(
  (0): Dropout(p=0.2, inplace=False)
  (1): Linear(in_features=1280, out_features=6, bias=True)
)
torch.Size([64, 6])



In [16]:
mobilenet.features[0]

ConvBNReLU(
  (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU6(inplace=True)
)

In [20]:
# mobilenet.features

In [39]:
import torch


def batched_dot_mul_sum(a, b):
    '''Computes batched dot by multiplying and summing'''
    return a.mul(b).sum(-1)


def batched_dot_bmm(a, b):
    '''Computes batched dot by reducing to bmm'''
    a = a.reshape(-1, 1, a.shape[-1])
    b = b.reshape(-1, b.shape[-1], 1)
    return torch.bmm(a, b).flatten(-3)


# Input for benchmarking
x = torch.randn(10000, 64)

# Ensure that both functions compute the same output
assert batched_dot_mul_sum(x, x).allclose(batched_dot_bmm(x, x))


import torch.utils.benchmark as benchmark
import timeit
x = torch.randn(100000, 2048, device='cuda')

t0 = timeit.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x})

t1 = timeit.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x})

# Ran each twice to show difference before/after warmup
print(f'mul_sum(x, x):  {t0.timeit(1000) / 1000 * 1e6:>5.1f} us')
print(f'mul_sum(x, x):  {t0.timeit(1000) / 1000 * 1e6:>5.1f} us')
print(f'bmm(x, x):      {t1.timeit(1000) / 1000 * 1e6:>5.1f} us')
print(f'bmm(x, x):      {t1.timeit(1000) / 1000 * 1e6:>5.1f} us')

mul_sum(x, x):   31.4 us
mul_sum(x, x):  3750.3 us
bmm(x, x):      12461.1 us
bmm(x, x):      5955.6 us


Collecting psutil
  Downloading psutil-5.8.0-cp37-cp37m-win_amd64.whl (244 kB)
Installing collected packages: psutil
Successfully installed psutil-5.8.0
