<a href="https://colab.research.google.com/github/Renan-Domingues/PyTorchRecipes/blob/main/WarmstartingModelWithDifferentParameters.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Warmstarting model using parameters from a different model in PyTorch

Partial loading a model or loading a partial model are common scenarios when transfer learning or training a new complex model.

Levearing trained parameters will help to warmstart the training process and hopefully help the model to converge much faster than training from scratch

# Introduction
Whether we are loading from partial state_dict (missing some keys) or loading a state_dict with more keys than the model that we are looking into, we can set the strict argumento to False in the load_state_dict() to ignore non-maching keys.

We will experiment with warmstarting a model using parameters of a diferent model.

# Steps

1. Import all necessary libraries for loading our data
2. Define and initialize the neural network A and B
3. Save model A
4. Load into model B


### 1. Import necessary libraries for loading our data

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

### 2. Define and initialize the neural network A and B

We will create two neural networks for sake of loading one parameter of type A into type B.

In [2]:
class NetA(nn.Module):
  def __init__(self):
    super(NetA, self).__init__()
    self.conv1 = nn.Conv2d(3, 6, 5)
    self.pool = nn.MaxPool2d(2, 2)
    self.conv2 = nn.Conv2d(6, 16, 5)
    self.fc1 = nn.Linear(16 * 5 * 5, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

  def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

netA = NetA()

class NetB(nn.Module):
  def __init__(self):
    super(NetB, self).__init__()
    self.conv1 = nn.Conv2d(3, 6, 5)
    self.pool = nn.MaxPool2d(2, 2)
    self.conv2 = nn.Conv2d(6, 16, 5)
    self.fc1 = nn.Linear(16 * 5 * 5, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

  def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

netB = NetB()

# 3. Save model A


In [3]:
# especify a path to save to
PATH = "model.pt"

torch.save(netA.state_dict(), PATH)

### 4. Load into model B

If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.

In [5]:
netB.load_state_dict(torch.load(PATH), strict=False)

<All keys matched successfully>