In [1]:
import torch
import math
import os
os.chdir('..')
from torch_vr.prev_grads_layers import PrevGradConv2dLayer

In [2]:
# parameters of test
N = 1000

c_in = 3
h_in = 16
w_in = 12
c_out = 5
kernel_size = (3,3)
stride = 1
padding = 0
dilation = 1

h_out = math.floor((h_in + 2 * padding - dilation * (kernel_size[0] - 1) - 1)/stride + 1)
w_out = math.floor((w_in + 2 * padding - dilation * (kernel_size[1] - 1) - 1)/stride + 1)
d_in = h_out * w_out * c_out
d_out = 10

batch_size = 74
iterations = 10
torch.set_default_tensor_type(torch.DoubleTensor)
threshold = 1e-10

In [3]:
# generates fake data
X = torch.randn(N, c_in, h_in, w_in)
y = torch.multinomial(torch.ones(d_out), N, replacement=True)

# defines the model
class Conv2d(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv2d = torch.nn.Conv2d(c_in, c_out, kernel_size, stride, padding, dilation, bias=True)
        self.lin = torch.nn.Linear(d_in, d_out, bias=True)
        
        self.A0 = None
        self.Z1 = None
        self.A1 = None
        self.Z2 = None
        self.A2 = None

    def forward(self, X):
        self.A0 = X
        self.Z1 = self.conv2d(self.A0)
        self.Z1.retain_grad()
        self.A1 = torch.nn.functional.relu(self.Z1).flatten(1, -1)
        self.Z2 = self.lin(self.A1)
        self.A2 = torch.nn.functional.log_softmax(self.Z2, dim=1)
        return self.A2
    
# defines the loss and initializes the model
loss_func = torch.nn.NLLLoss(reduction='sum')
model = Conv2d()

In [4]:
# initializes PrevGradCon2dLayer object
prev_grads = PrevGradConv2dLayer(N, model.conv2d, rank=5)

In [5]:
# tests the batch_gradient functionality
for i in range(iterations):
    # generates the gradients
    model.zero_grad()
    indices = torch.multinomial(torch.ones(N), batch_size, replacement=False)
    loss = loss_func(model(X[indices]), y[indices])
    loss.backward()
    
    # stores generated gradients
    prev_grads.update(model.Z1.grad, model.A0, indices)
    
    # compares the stored gradients with the actual gradients
    dW, db = prev_grads.batch_gradient(indices)
    max_error_dW = torch.max(dW - model.conv2d.weight.grad)
    max_error_db = torch.max(db - model.conv2d.bias.grad)
    if max(max_error_dW, max_error_db) > threshold:
        print('Failed.')
        break
print('Succeeded.')

Succeeded.


In [6]:
# tests the individual gradients functionality
for i in range(iterations):
    # generates the gradients
    model.zero_grad()
    indices = torch.multinomial(torch.ones(N), batch_size, replacement=False)
    loss = loss_func(model(X[indices]), y[indices])
    loss.backward()
    model.zero_grad()
    
    # stores generated gradients
    prev_grads.update(model.Z1.grad, model.A0, indices)    
    
    # generates the individual gradients naively
    ind_grads_dW = torch.zeros(batch_size, c_out, c_in, kernel_size[0], kernel_size[1])
    ind_grads_db = torch.zeros(batch_size, c_out)
    for j in range(batch_size):
        loss = loss_func(model(X[indices[j]:(indices[j] + 1)]), y[indices[j]:(indices[j] + 1)])
        loss.backward()
        ind_grads_dW[j] = model.conv2d.weight.grad
        ind_grads_db[j] = model.conv2d.bias.grad
        model.zero_grad()
        
    # compares the stored individual gradients with the actual individual gradients
    dW, db = prev_grads.individual_gradients(indices)
    max_error_dW = torch.max(dW - ind_grads_dW)
    max_error_db = torch.max(db - ind_grads_db)
    if max(max_error_dW, max_error_db) > threshold:
        print('Failed.')
        break
print('Succeeded.')

Succeeded.


In [7]:
# tests the weighted version of batch_gradient and individual gradients functionalities.
for i in range(iterations):
    # generates the weights
    weights = 100*torch.randn(batch_size)
    
    # generates the gradients
    model.zero_grad()
    indices = torch.multinomial(torch.ones(N), batch_size, replacement=False)
    loss = loss_func(model(X[indices]), y[indices])
    loss.backward()
    model.zero_grad()
    
    # stores generated gradients
    prev_grads.update(model.Z1.grad, model.A0, indices)    
    
    # generates the individual gradients naively
    ind_grads_dW = torch.zeros(batch_size, c_out, c_in, kernel_size[0], kernel_size[1])
    ind_grads_db = torch.zeros(batch_size, c_out)
    for j in range(batch_size):
        loss = loss_func(model(X[indices[j]:(indices[j] + 1)]), y[indices[j]:(indices[j] + 1)])
        loss.backward()
        ind_grads_dW[j] = weights[j] * model.conv2d.weight.grad
        ind_grads_db[j] = weights[j] * model.conv2d.bias.grad
        model.zero_grad()
        
    # compares the stored individual gradients with the actual individual gradients
    dW, db = prev_grads.individual_gradients(indices, weights)
    max_error_dW = torch.max(dW - ind_grads_dW)
    max_error_db = torch.max(db - ind_grads_db)
    if max(max_error_dW, max_error_db) > threshold:
        print('Failed.')
        break
        
    # compares the stored batch gradients with the actual batch gradient
    dW, db = prev_grads.batch_gradient(indices, weights)
    max_error_dW = torch.max(dW - ind_grads_dW.sum(dim=0))
    max_error_db = torch.max(db - ind_grads_db.sum(dim=0))
    if max(max_error_dW, max_error_db) > threshold:
        print('Failed.')
        break   
print('Succeeded.')

Succeeded.
