# Custom Layers and Activations

In [1]:
# fully connected `nn.Linear` layer functional definition
import torch

def linear(input, weight, bias=None):
    
    if input.dim() == 2 and bias is not None:
        
        # fused op is marginally faster
        ret = torch.addmm(bias, input, weight.t())
    
    else:
        output = input.matmul(weight.t())
        if bias is not None:
            output += bias
            
        ret = output
    
    return ret

In [9]:
# we derive the `nn.Linear` class from `nn.Module`

import torch.nn as nn
from torch import Tensor
from torch.nn.parameter import Parameter
import torch.nn.functional as F

class Linear(nn.Module):
    
    # initialize input and output sizes, weights, and biases
    def __init__(self, in_features, out_features, bias):
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(
            torch.Tensor(out_features, in_features)
        )
        if bias:
            self.bias = Parameter(
                torch.Tensor(out_features)
            )
        else:
            self.register_parameter('bias', None)
            
    def reset_paramet(self):
        torch.nn.init.kaiming_uniform_(self.weight, a=torch.math.sqrt(5))
        
        if self.bias is not None:
            fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / torch.math.sqrt(fan_in)
            torch.nn.init.uniform_(self.bias, -bound, bound)
            
    def forward(self, input: Tensor) -> Tensor:
        # define the forward pass
        # use the functional definition of linear()
        return F.linear(input, self.weight, self.bias)

In [3]:
# torch.Tensor는 Tensor 객체를 받으며 메모리 주소값을 복사해 온다.
original_data = torch.Tensor([1])
new_data = torch.Tensor(original_data)
print(f"original : {original_data} new : {new_data}")

# original data를 수정
original_data[0] = 2
print(f"original : {original_data} new : {new_data}")

original : tensor([1.]) new : tensor([1.])
original : tensor([2.]) new : tensor([2.])


In [4]:
# torch.Tensor()는 list나 numpy를 받으면 값을 복사해온다.
original_data = [1]
new_data = torch.Tensor(original_data)
print(f"original : {original_data} new : {new_data}")

# original data 수정
original_data[0] = 2
print(f"original : {original_data} new : {new_data}")

original : [1] new : tensor([1.])
original : [2] new : tensor([1.])


In [8]:
# torch.torch의 경우 값을 복사해 Tensor 생성
original_data = torch.tensor([1])
new_data = torch.tensor(original_data)
print(f"original : {original_data} new : {new_data}")

# data 수정
original_data[0] = 2
print(f"original : {original_data} new : {new_data}")

original : tensor([1]) new : tensor([1])
original : tensor([2]) new : tensor([1])


  new_data = torch.tensor(original_data)


In [10]:
# we will create a functional version of our complex linear layer
def complex_linear(in_r, in_i, w_r, w_i, b_i, b_r):
    out_r = (in_r.matmul(w_r.t())
             - in_i.matmul(w_i.t()) + b_r)
    out_i = (in_r.matmul(w_i.t())
             + in_i.matmul(w_r.t()) + b_i)
    
    return out_r, out_i

In [11]:
# we create out class verion of ComplexLinear based on nn.Module
class ComplexLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weights_r = Parameter(torch.randn(out_features, in_features))
        self.weights_i = Parameter(torch.randn(out_features, in_features))
        self.bias_r = Parameter(torch.randn(out_features))
        self.bias_i = Parameter(torch.randn(out_features))
        
    def forward(self, in_r, in_i):
        return F.complex_linear(in_r, in_i, self.weights_r, self.weights_i, self.bias_r, self.bias_i)

In [12]:
# we can create also use PyTorch's existing `nn.Linear` layer

class ComplexLinearSimple(nn.Module):
    def __init__(self, in_features, out_features):
        super(Linear, self).__init__()
        self.fc_r = Linear(in_features, out_features)
        self.fc_i = Linear(in_features, out_features)
        
    def forward(self, in_r, in_i):
        return (self.fc_r(in_r) - self.fc_i(in_i), 
                self.fc_r(in_i) + self.fc_i(in_r))

In [14]:
def my_relu(input, thresh=0.0):
    return torch.where(
        input > thresh,
        input,
        torch.zeros_like(input)
    )

In [15]:
class MyReLU(nn.Module):
    def __init__(self, thresh = 0.0):
        super(MyReLU, self).__init__()
        self.thresh = thresh
        
    def foreward(self, input):
        return my_relu(input, self.thresh)

In [16]:
# the functional version
# a common way to import the functional package
import torch.nn.functional as F 

class SimpleNet(nn.Module):
    def __init__(self, D_in, H, D_out):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(D_in, H)
        self.fc2 = nn.Linear(H, D_out)
        
    def forward(self, x):
        # the functional version of ReLU is used tere
        x = F.relu(self.fc1(x))
        return self.fc2(x)

In [17]:
# the class version
class SimpleNet(nn.Module):
    def __init__(self, D_in, H, D_out):
        super(SimpleNet, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(D_in, H),
            nn.ReLU(),
            nn.Linear(H, D_out)
        )
    
    def forward(self, x):
        return self.net(x)

## Custom Activation Example

In [18]:
# Custom activation example - complex ReLU

# functional version
def complex_relu(in_r, in_i):
    return (F.relu(in_r), F.relu(in_i))

# class version
class ComplexReLU(nn.Module):
    def __init__(self):
        super(ComplexReLU, self).__init__()
        
    def forward(self, in_r, in_i):
        return complex_relu(in_r, in_i)
    

# Custom Model Architectures

In [19]:
# `tocchvision.model.alexnet()`

class AlexNet(nn.Module):
    
    def __init__(self, num_classes=10000):
        super(AlexNet, self).__init__()
        
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11,
                      stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        
        self.avgpool = nn.AdaptiveAvgPool1d((6, 6))
        
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )
        
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x)
        x = self.classifier(x)
        return x
        

In [20]:
from torch.hub import load_state_dict_from_url
model_urls = {
    'alexnet': 'https://pytorch.tips/alexnet-download',
}

def alexnet(pretrained=True, progress=True, **kwargs):
    model = AlexNet(**kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls['alexnet'], progress=progress)
    
        model.load_state_dict(state_dict)

# Custom Loss Function

In [21]:
# Create our own custom function.

def complex_mse_loss(input_r, input_i, target_r, target_i):
    return (((input_r-target_r)**2).mean(),
            ((input_i-target_i)**2).mean())

class ComplexMSELoss(nn.Module):
    def __init__(self, real_only=False):
        super(ComplexMSELoss, self).__init__()
        self.real_only = real_only
        
def forward(self, input_r, input_i, target_r, target_i):
    if (self.real_only):
        return F.mse_loss(input_r, target_r)
    else:
        return complex_mse_loss(input_r, input_i, target_r, target_i)

In [None]:
for epoch in range(n_epochs):
    
    # Training
    for data in train_dataloader:
        input, targets = data
        optimizer.zero_grad()
        output = model(input)
        train_loss = criterion(output, targets)
        train_loss.backward()
        optimizer.step()
        
    # validation
    with torch.no_grad():
        for input, targets in val_dataloader:
            output = model(input)
            val_loss = criterion(output, targets)
            
    # testing
    with torch.no_grad():
        for input, targets in test_dataloader:
            output = model(input)
            test_loss = criterion(output, targets)

In [None]:
# add some additional capabilities to our loops
# this exmaple will demonstrate some simple tasks like printing information, reconfiguring a model, and adjusting a hyperparameter in the middle of training.

for epoch in range(n_epochs):
    
    # pringing epoch, training, and validation loss
    total_train_loss = 0.0
    total_val_loss = 0.0
    
    # reconfiguring a model (best practice) -> fine tuning parameters updates after training on half of the epochs
    if (epoch == epoch // 2):
        optimizer = optim.SGD(model.parameters(), lr=0.001)
        
    # training
    model.train()
    for data in train_dataloader:
        input, targets = data
        optimizer.zero_grad()
        output = model(input)
        train_loss = criterion(output, targets)
        train_loss.backward()
        optimizer.step()
        total_train_loss += train_loss
        
    # validation
    # modifying a hyperparameter during trainig
    model.eval()
    with torch.no_grad():
        for input, targets in val_dataloader:
            output = model(input)
            val_loss = criterion(output, targets)
            total_val_loss += val_loss
    print("""Epoch: {}
          Train Loss: {}
          Val Loss: {}""".format(
              epoch, total_train_loss, total_val_loss
          ))
    
    # testing
    model.eval()
    with torch.no_grad():
        for input, targets in test_dataloader:
            output = model(input)
            test_loss = criterion(output, targets)