In [None]:
import torch
import numpy as np
import importlib

import os
os.chdir('..')
os.chdir('..')
import PyTorch_VR.prev_grads as prev_grads

## Linear Layer

In [2]:
# generates fake data
X = torch.randn(1000, 100)
beta = torch.randn(100)
bias = torch.randn(1)
y = (torch.sigmoid(X @ beta + bias) > 0.5).float()

# defines the model
class Logistic_regression(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(100, 1, bias=True)
        
        self.Z1 = None
        self.A1 = None

    def forward(self, X):
        self.A0 = X
        self.Z1 = self.lin(self.A0)
        self.Z1.retain_grad()
        self.A1 = torch.sigmoid(self.Z1.squeeze(1))
        return self.A1
    
# defines the loss and initializes the model
loss_func = torch.nn.BCELoss(reduction='sum')
model = Logistic_regression()

In [None]:
# initializes the prev_grads object
model.zero_grad()
loss = loss_func(model(X), y)
loss.backward()   
prev_grads_linear = prev_grads.layers.PrevGradLinearLayer(model.Z1.grad, model.A0, bias=True)

In [None]:
# checks batch_gradient functionality
# test 1
threshold = 1e-04
dW, db = prev_grads_linear.batch_gradient(torch.arange(1000))
max_error_dW = torch.max(dW - model.lin.weight.grad)
max_error_db = torch.max(db - model.lin.bias.grad)
if max(max_error_dW, max_error_db) < threshold:
    print('test 1: Success.')
else:
    print('test 1: Failed.')

# checks batch gradient and update functionality    
# test 2
model.zero_grad()
indices = np.random.choice(1000, 74, replace=False)
loss = loss_func(model(X[indices]), y[indices])
loss.backward()
prev_grads_linear.update(model.Z1.grad, model.A0, indices)

dW, db = prev_grads_linear.batch_gradient(indices)
max_error_dW = torch.max(dW - model.lin.weight.grad)
max_error_db = torch.max(db - model.lin.bias.grad)
if max(max_error_dW, max_error_db) < threshold:
    print('test 2: Success.')
else:
    print('test 2: Failed.')
    
# checks individual gradients functionality
# test 3
model.zero_grad()
ind_grads_dW = torch.zeros(len(indices), 1, 100)
ind_grads_db = torch.zeros(len(indices), 1)
for (i, j) in enumerate(indices):
    loss = loss_func(model(X[j:(j+1)]), y[j:(j+1)])
    loss.backward()
    ind_grads_dW[i, 0] = model.lin.weight.grad
    ind_grads_db[i] = model.lin.bias.grad
    model.zero_grad()

dW, db = prev_grads_linear.individual_gradients(indices)
max_error_dW = torch.max(dW - ind_grads_dW)
max_error_db = torch.max(db - ind_grads_db)
if max(max_error_dW, max_error_db) < threshold:
    print('test 3: Success.')
else:
    print('test 3: Failed.')
    
# checks weighted batch gradient and individual gradients functionality
# test 4
model.zero_grad()
weights = torch.randn(74)
weighted_batch_gradient_dW = 0
weighted_batch_gradient_db = 0
weighted_ind_grads_dW = ind_grads_dW
weighted_ind_grads_db = ind_grads_db
for i in range(74):
    weighted_batch_gradient_dW += weights[i] * ind_grads_dW[i]
    weighted_batch_gradient_db += weights[i] * ind_grads_db[i]
    weighted_ind_grads_dW[i] = weights[i] * weighted_ind_grads_dW[i]
    weighted_ind_grads_db[i] = weights[i] * weighted_ind_grads_db[i]

dW, db = prev_grads_linear.batch_gradient(indices, weights)
dWs, dbs = prev_grads_linear.individual_gradients(indices, weights)
max_error_dW = torch.max(dW - weighted_batch_gradient_dW)
max_error_db = torch.max(db - weighted_batch_gradient_db)
max_error_dWs = torch.max(dWs - weighted_ind_grads_dW)
max_error_dbs = torch.max(dbs - weighted_ind_grads_db)
if max(max_error_dW, max_error_db, max_error_dWs, max_error_dbs) < threshold:
    print('test 4: Success.')
else:
    print('test 4: Failed.')  

## Conv2D layer

In [9]:
# generates fake data
X = torch.randn(1000, 3, 16, 16)
y = (torch.rand(1000) > 0.5).float()

# defines the model
class Conv2D(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv2d = torch.nn.Conv2d(3, 5, (3, 3), bias=True)
        self.lin = torch.nn.Linear(980, 1, bias=True)
        
        self.Z1 = None
        self.A1 = None

    def forward(self, X):
        self.A0 = X
        self.Z1 = self.conv2d(self.A0)
        self.Z1.retain_grad()
        self.A1 = torch.nn.functional.relu(self.Z1).flatten(1, -1)
        self.Z2 = self.lin(self.A1)
        self.A2 = torch.sigmoid(self.Z2.squeeze(1))
        return self.A2
    
# defines the loss and initializes the model
loss_func = torch.nn.BCELoss(reduction='sum')
model = Conv2D()

### Uncompressed

In [None]:
# initializes the prev_grads object
model.zero_grad()
loss = loss_func(model(X), y)
loss.backward()   
prev_grads_conv2d = prev_grads.layers.PrevGradConv2DLayer(1000, (3,3), 5, 3, model.Z1.grad, model.A0)

In [None]:
# checks batch_gradient functionality
# test 1
threshold = 1e-04
dW, db = prev_grads_conv2d.batch_gradient(torch.arange(1000))
max_error_dW = torch.max(dW - model.conv2d.weight.grad)
max_error_db = torch.max(db - model.conv2d.bias.grad)
if max(max_error_dW, max_error_db) < threshold:
    print('test 1: Success.')
else:
    print('test 1: Failed.')
    
# checks batch gradient and update functionality    
# test 2
model.zero_grad()
indices = np.random.choice(1000, 74, replace=False)
loss = loss_func(model(X[indices]), y[indices])
loss.backward()
prev_grads_conv2d.update(model.Z1.grad, model.A0, indices)

dW, db = prev_grads_conv2d.batch_gradient(indices)
max_error_dW = torch.max(dW - model.conv2d.weight.grad)
max_error_db = torch.max(db - model.conv2d.bias.grad)
if max(max_error_dW, max_error_db) < threshold:
    print('test 2: Success.')
else:
    print('test 2: Failed.')
    
# checks individual gradients functionality
# test 3
model.zero_grad()
ind_grads_dW = torch.zeros(len(indices), 5, 3, 3, 3)
ind_grads_db = torch.zeros(len(indices), 5)
for (i, j) in enumerate(indices):
    loss = loss_func(model(X[j:(j+1)]), y[j:(j+1)])
    loss.backward()
    ind_grads_dW[i] = model.conv2d.weight.grad
    ind_grads_db[i] = model.conv2d.bias.grad
    model.zero_grad()

dW, db = prev_grads_conv2d.individual_gradients(indices)
max_error_dW = torch.max(dW - ind_grads_dW)
max_error_db = torch.max(db - ind_grads_db)
if max(max_error_dW, max_error_db) < threshold:
    print('test 3: Success.')
else:
    print('test 3: Failed.')
    
# checks weighted batch gradient and individual gradients functionality
# test 4
model.zero_grad()
weights = torch.randn(74)
weighted_batch_gradient_dW = 0
weighted_batch_gradient_db = 0
weighted_ind_grads_dW = ind_grads_dW
weighted_ind_grads_db = ind_grads_db
for i in range(74):
    weighted_batch_gradient_dW += weights[i] * ind_grads_dW[i]
    weighted_batch_gradient_db += weights[i] * ind_grads_db[i]
    weighted_ind_grads_dW[i] = weights[i] * weighted_ind_grads_dW[i]
    weighted_ind_grads_db[i] = weights[i] * weighted_ind_grads_db[i]

dW, db = prev_grads_conv2d.batch_gradient(indices, weights)
dWs, dbs = prev_grads_conv2d.individual_gradients(indices, weights)
max_error_dW = torch.max(dW - weighted_batch_gradient_dW)
max_error_db = torch.max(db - weighted_batch_gradient_db)
max_error_dWs = torch.max(dWs - weighted_ind_grads_dW)
max_error_dbs = torch.max(dbs - weighted_ind_grads_db)
if max(max_error_dW, max_error_db, max_error_dWs, max_error_dbs) < threshold:
    print('test 4: Success.')
else:
    print('test 4: Failed.')

### compressed

In [None]:
# initializes the prev_grads object
model.zero_grad()
loss = loss_func(model(X), y)
loss.backward()   
prev_grads_conv2d = prev_grads.layers.PrevGradConv2DLayer(1000, (3,3), 5, 3,
                                                          model.Z1.grad, model.A0,  rank=5)

In [None]:
# checks batch_gradient functionality
# test 1
threshold = 1e-04
dW, db = prev_grads_conv2d.batch_gradient(torch.arange(1000))
max_error_dW = torch.max(dW - model.conv2d.weight.grad)
max_error_db = torch.max(db - model.conv2d.bias.grad)
if max(max_error_dW, max_error_db) < threshold:
    print('test 1: Success.')
else:
    print('test 1: Failed.')
    
# checks batch gradient and update functionality    
# test 2
model.zero_grad()
indices = np.random.choice(1000, 74, replace=False)
loss = loss_func(model(X[indices]), y[indices])
loss.backward()
prev_grads_conv2d.update(model.Z1.grad, model.A0, indices)

dW, db = prev_grads_conv2d.batch_gradient(indices)
max_error_dW = torch.max(dW - model.conv2d.weight.grad)
max_error_db = torch.max(db - model.conv2d.bias.grad)
if max(max_error_dW, max_error_db) < threshold:
    print('test 2: Success.')
else:
    print('test 2: Failed.')
    
# checks individual gradients functionality
# test 3
model.zero_grad()
ind_grads_dW = torch.zeros(len(indices), 5, 3, 3, 3)
ind_grads_db = torch.zeros(len(indices), 5)
for (i, j) in enumerate(indices):
    loss = loss_func(model(X[j:(j+1)]), y[j:(j+1)])
    loss.backward()
    ind_grads_dW[i] = model.conv2d.weight.grad
    ind_grads_db[i] = model.conv2d.bias.grad
    model.zero_grad()

dW, db = prev_grads_conv2d.individual_gradients(indices)
max_error_dW = torch.max(dW - ind_grads_dW)
max_error_db = torch.max(db - ind_grads_db)
if max(max_error_dW, max_error_db) < threshold:
    print('test 3: Success.')
else:
    print('test 3: Failed.')
    
# checks weighted batch gradient and individual gradients functionality
# test 4
model.zero_grad()
weights = torch.randn(74)
weighted_batch_gradient_dW = 0
weighted_batch_gradient_db = 0
weighted_ind_grads_dW = ind_grads_dW
weighted_ind_grads_db = ind_grads_db
for i in range(74):
    weighted_batch_gradient_dW += weights[i] * ind_grads_dW[i]
    weighted_batch_gradient_db += weights[i] * ind_grads_db[i]
    weighted_ind_grads_dW[i] = weights[i] * weighted_ind_grads_dW[i]
    weighted_ind_grads_db[i] = weights[i] * weighted_ind_grads_db[i]

dW, db = prev_grads_conv2d.batch_gradient(indices, weights)
dWs, dbs = prev_grads_conv2d.individual_gradients(indices, weights)
max_error_dW = torch.max(dW - weighted_batch_gradient_dW)
max_error_db = torch.max(db - weighted_batch_gradient_db)
max_error_dWs = torch.max(dWs - weighted_ind_grads_dW)
max_error_dbs = torch.max(dbs - weighted_ind_grads_db)
if max(max_error_dW, max_error_db, max_error_dWs, max_error_dbs) < threshold:
    print('test 4: Success.')
else:
    print('test 4: Failed.')

## Full network

In [None]:
# generates fake data
X = torch.randn(1000, 3, 16, 16)
y = (torch.rand(1000) > 0.5).float()

# defines the model
class Conv2D(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv2d = torch.nn.Conv2d(3, 5, (3, 3), bias=True)
        self.lin = torch.nn.Linear(980, 1, bias=True)
        
        self.A0 = None
        
        self.Z1 = None
        self.A1 = None
        
        self.Z2 = None
        self.A2 = None

    def forward(self, X):
        self.A0 = X
        self.Z1 = self.conv2d(self.A0)
        self.Z1.retain_grad()
        self.A1 = torch.nn.functional.relu(self.Z1).flatten(1, -1)
        self.Z2 = self.lin(self.A1)
        self.Z2.retain_grad()
        self.A2 = torch.sigmoid(self.Z2.squeeze(1))
        return self.A2
    
# defines the loss and initializes the model
loss_func = torch.nn.BCELoss(reduction='sum')
model = Conv2D()

In [None]:
# initializes the prev_grads object
model.zero_grad()
loss = loss_func(model(X), y)
loss.backward()   
layers = [model.conv2d, model.lin]
prev_grads_full = prev_grads.prev_grads.PrevGrads(1000, layers, 
                                        [model.Z1.grad, model.Z2.grad], [model.A0, model.A1])

In [None]:
# checks batch_gradient functionality
# test 1
threshold = 1e-04
grads = prev_grads_full.batch_gradient(torch.arange(1000))
max_error_dW = max([torch.max(g[0] - l.weight.grad) for (g, l) in zip(grads, layers)])
max_error_db = max([torch.max(g[1] - l.bias.grad) for (g, l) in zip(grads, layers)])
if max(max_error_dW, max_error_db) < threshold:
    print('test 1: Success.')
else:
    print('test 1: Failed.')

# checks batch gradient and update functionality    
# test 2
model.zero_grad()
indices = np.random.choice(1000, 74, replace=False)
loss = loss_func(model(X[indices]), y[indices])
loss.backward()
prev_grads_full.update([model.Z1.grad, model.Z2.grad], [model.A0, model.A1], indices)

grads = prev_grads_full.batch_gradient(indices)
max_error_dW = max([torch.max(g[0] - l.weight.grad) for (g, l) in zip(grads, layers)])
max_error_db = max([torch.max(g[1] - l.bias.grad) for (g, l) in zip(grads, layers)])
if max(max_error_dW, max_error_db) < threshold:
    print('test 2: Success.')
else:
    print('test 2: Failed.')