Skip to content

A module for making weights initialization easier in pytorch.

Notifications You must be signed in to change notification settings

o-tawab/Weights-Initializer-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 

Repository files navigation

Weights Initializer For pytorch Models

This is a class to make initializing the weights easier in pytorch.

How to use

First, few imports

import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

from weight_initializer import Initializer

Then, we can define a simple model

# Simple model
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

After that all what we need to do is to instantiate the model and call the weight initializer. You can pass whatever arguments you need to pass to the weight initializer.

net = Model()  # instantiate the model

# to apply xavier_uniform:
Initializer.initialize(model=net, initialization=init.xavier_uniform, gain=init.calculate_gain('relu'))

# or maybe normal distribution:
Initializer.initialize(model=net, initialization=init.normal, mean=0, std=0.2)

About

A module for making weights initialization easier in pytorch.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages