In [2]:
import torch, torchvision

## Implements data parallelism at the module level. 
## module (Module) – module to be parallelized

In [None]:
net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])

In [None]:
## A state_dict is simply a Python dictionary object that maps each layer to its parameter tensor.
# Load the pretrained weights
weights = torch.load(pretrained_encoder)

encoder.load_state_dict(weights['state_dict'])

## encoder.eval() sets model in evaluation (inference) mode:

• normalisation layers use running statistics

• de-activates Dropout layers

## torch.no_grad() impacts the autograd engine and deactivate it. It will reduce memory usage and speed up 

# A Gentle Introduction to torch.autograd

## Training NN happens in two steps
## Step1: Forward propagation
## Step2: Backward propagation

- A pretrained ResNet18 is added from torchvision.
- A random single image with 3 channels is created with HxW of 64 and label is initialized to some random value.
- A random label is assigned to the image

In [3]:
model = torchvision.models.resnet18(pretrained=True)
data = torch.rand(1, 3, 64, 64)
labels = torch.rand(1, 1000)

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /home/manojkl/.cache/torch/checkpoints/resnet18-5c106cde.pth
100%|██████████| 44.7M/44.7M [00:05<00:00, 8.68MB/s]


Pass data through each layer i.e forward pass

In [4]:
prediction = model(data) # forward pass

- Backpropagate after computing the loss.
- Autograd then calculates and stores the gradients for each model parameter in the parameter’s .grad attribute.

In [6]:
loss = (prediction - labels).sum()
loss.backward() # backward pass

- Load SGD optimizer with learning rate 0.01 and momentum of 0.9

In [7]:
optim = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)

- .step() to initiate gradient descent.

In [8]:
optim.step() #gradient descent

- We create two tensors a and b with requires_grad=True. This signals to autograd that every operation on them should be tracked.

- tqdm is for Progress Bars

In [11]:
from tqdm import tqdm

In [12]:
for i in tqdm(range(1, 5)):
    print(i)

100%|██████████| 4/4 [00:00<00:00, 12122.27it/s]

1
2
3
4



