Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch integration: Remat #10

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions poet/architectures/graph_transformation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from poet.power_computation import (
LinearLayer,
ReLULayer,
Conv2dLayer,
FlattenLayer,
TanHLayer,
SigmoidLayer,
SkipAddLayer,
DropoutLayer,
GradientLayer,
InputLayer,
SkipAddLayer,
CrossEntropyLoss,
GradientLayer,
BatchNorm2d,
MaxPool2d,
AvgPool2d,
GlobalAvgPool,
get_net_costs,
)


# transforms input model's graph to output graph with POET layer nodes
def graph_transform(traced):
for n in traced.graph.nodes:
if "<built-in function" in str(n.target):
continue
elif "fc" in str(n.target):
with traced.graph.inserting_after(n):
new_node = traced.graph.call_function(LinearLayer, n.args, n.kwargs)
n.replace_all_uses_with(new_node)
traced.graph.erase_node(n)
elif "flatten" in str(n.target):
with traced.graph.inserting_after(n):
new_node = traced.graph.call_function(FlattenLayer, n.args, n.kwargs)
n.replace_all_uses_with(new_node)
traced.graph.erase_node(n)
elif "relu" in str(n.target):
with traced.graph.inserting_after(n):
new_node = traced.graph.call_function(ReLULayer, n.args, n.kwargs)
n.replace_all_uses_with(new_node)
traced.graph.erase_node(n)
elif "conv" in str(n.target):
with traced.graph.inserting_after(n):
new_node = traced.graph.call_function(Conv2dLayer, n.args, n.kwargs)
n.replace_all_uses_with(new_node)
traced.graph.erase_node(n)
elif "bn" in str(n.target):
with traced.graph.inserting_after(n):
new_node = traced.graph.call_function(BatchNorm2d, n.args, n.kwargs)
n.replace_all_uses_with(new_node)
traced.graph.erase_node(n)
elif "maxpool" in str(n.target):
with traced.graph.inserting_after(n):
new_node = traced.graph.call_function(MaxPool2d, n.args, n.kwargs)
n.replace_all_uses_with(new_node)
traced.graph.erase_node(n)
traced.recompile()
return traced
73 changes: 73 additions & 0 deletions poet/architectures/network_transform_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from poet.power_computation import (
LinearLayer,
ReLULayer,
Conv2dLayer,
FlattenLayer,
TanHLayer,
SigmoidLayer,
SkipAddLayer,
DropoutLayer,
GradientLayer,
InputLayer,
SkipAddLayer,
CrossEntropyLoss,
GradientLayer,
BatchNorm2d,
MaxPool2d,
AvgPool2d,
GlobalAvgPool,
get_net_costs,
)
from poet.power_computation_transformer import QueryKeyValueMatrix, QKTMatrix, QKTVMatrix
import torch.nn as nn
import torchvision.models
from torchvision.models.resnet import BasicBlock, Bottleneck


# transforms input model's network layers to output graph with POET layers for PyTorch models
def network_transform(net, layers, batch_size, num_classes, input_shape):
if isinstance(net, torchvision.models.resnet.ResNet):
modules = nn.Sequential(*list(net.children()))
elif isinstance(net, torchvision.models.vgg.VGG):
modules = list(net.children())
else:
modules = net
for module in modules:
if isinstance(module, nn.Sequential):
sequential_modules = [child for child in module]
input = network_transform(sequential_modules, [layers[-1]], batch_size, num_classes, input_shape)
layers.extend(input[1:])
if isinstance(module, BasicBlock) or isinstance(module, Bottleneck):
input = network_transform(nn.Sequential(*list(module.children())), [layers[-1]], batch_size, num_classes, input_shape)
layers.extend(input[1:])
if isinstance(module, nn.Linear):
lin_layer = LinearLayer(module.in_features, module.out_features, layers[-1])
act_layer = ReLULayer(lin_layer)
layers.extend([lin_layer, act_layer])
if isinstance(module, nn.ReLU):
relu_layer = ReLULayer(layers[-1])
layers.append(relu_layer)
if isinstance(module, nn.Conv2d):
conv_layer = Conv2dLayer(
module.in_channels, module.out_channels, module.kernel_size, module.stride[0], module.padding, layers[-1]
)
layers.append(conv_layer)
if isinstance(module, nn.BatchNorm2d):
layers.append(BatchNorm2d(layers[-1]))
if isinstance(module, nn.MaxPool2d):
layers.append(MaxPool2d((module.kernel_size, module.kernel_size), module.stride, layers[-1]))
if isinstance(module, nn.AvgPool2d):
layers.append(AvgPool2d(module.kernel_size, module.stride, layers[-1]))
if isinstance(module, nn.Tanh):
tanh_layer = TanHLayer(layers[-1])
layers.append(tanh_layer)
if isinstance(module, nn.Sigmoid):
sigmoid_layer = SigmoidLayer(layers[-1])
layers.append(sigmoid_layer)
if isinstance(module, nn.Flatten):
flatten_layer = FlattenLayer(layers[-1])
layers.append(flatten_layer)
if isinstance(module, nn.Dropout):
dropout_layer = DropoutLayer(layers[-1])
layers.append(dropout_layer)
return layers
62 changes: 62 additions & 0 deletions poet/architectures/network_transform_tensorflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from poet.power_computation import (
LinearLayer,
ReLULayer,
Conv2dLayer,
FlattenLayer,
TanHLayer,
SigmoidLayer,
SkipAddLayer,
DropoutLayer,
GradientLayer,
InputLayer,
SkipAddLayer,
CrossEntropyLoss,
GradientLayer,
BatchNorm2d,
MaxPool2d,
AvgPool2d,
GlobalAvgPool,
get_net_costs,
)
from poet.power_computation_transformer import QueryKeyValueMatrix, QKTMatrix, QKTVMatrix
from torchvision.models.resnet import BasicBlock, Bottleneck
import tensorflow as tf


##transforms input model's network layers to output graph with POET layers for TensorFlow models
def network_transform(net, layers, batch_size, num_classes, input_shape):
for module in net:
if isinstance(module, tf.keras.layers.Dense):
lin_layer = LinearLayer(module.units, module.units, layers[-1])
act_layer = ReLULayer(lin_layer)
layers.extend([lin_layer, act_layer])
if isinstance(module, tf.keras.layers.Activation) and module._name == "relu":
relu_layer = ReLULayer(layers[-1])
layers.append(relu_layer)
if isinstance(module, tf.keras.layers.Conv2D):
if module.padding == "valid":
padding = (0, 0)
elif module.padding == "same":
padding = (1, 1)
conv_layer = Conv2dLayer(1, module.filters, module.kernel_size, module.strides[0], padding, layers[-1])
layers.append(conv_layer)
if isinstance(module, tf.keras.layers.BatchNormalization):
layers.append(BatchNorm2d(layers[-1]))
if isinstance(module, tf.keras.layers.MaxPool2D):
layers.append(MaxPool2d(module.pool_size, module.strides[0], layers[-1]))
if isinstance(module, tf.keras.layers.GlobalAveragePooling2D):
if module.keepdims:
layers.append(GlobalAvgPool(layers[-1]))
if isinstance(module, tf.keras.layers.Activation) and module._name == "tanh":
tanh_layer = TanHLayer(layers[-1])
layers.append(tanh_layer)
if isinstance(module, tf.keras.layers.Activation) and module._name == "sigmoid":
sigmoid_layer = SigmoidLayer(layers[-1])
layers.append(sigmoid_layer)
if isinstance(module, tf.keras.layers.Flatten):
flatten_layer = FlattenLayer(layers[-1])
layers.append(flatten_layer)
if isinstance(module, tf.keras.layers.Dropout):
dropout_layer = DropoutLayer(layers[-1])
layers.append(dropout_layer)
return layers
45 changes: 45 additions & 0 deletions poet/architectures/remat_and_paging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch.nn as nn


# output layer traversal path in model for passed-in layer
def get_all_parent_layers(net, type):
layers = []
# iterates over all layers in model
for name in net.named_modules():
# check if curent module name matches specified type
if name == type:
layer = net
# extracts layer type from its name in the model which contains indices into sub-modules
attributes = name.strip().split(".")
for attr in attributes:
if not attr.isnumeric():
# retrieve layer attribute
layer = getattr(layer, attr)
else:
# index into sub-modules list traversing model's path of layers
layer = layer[int(attr)]
# append list of final layer and attribute name
layers.append([layer, attributes[-1]])
return layers


# implements rematerializaion technique on inputted model
# which saves passed-in node during forward pass for later recomputation
# during the backward pass of model
def memory_saving(model_indexer, node, remat_list):
# saves node and arguments for later recomputation
remat_list.append([node.target, getattr(model_indexer[0], model_indexer[1])])
# sets inputted node to Identity layer
setattr(model_indexer[0], model_indexer[1], nn.Identity())
return


# recomputes inputted node which was rematerialized and sets layer back into model
def reuse_layer(model_indexer, node, remat_list):
# iterates over rematerialized nodes to find matching layer
for layer in remat_list:
if layer[0] == node.target:
break
# sets inputted node back to its original state
setattr(model_indexer[0], model_indexer[1], layer[1])
return
2 changes: 1 addition & 1 deletion poet/poet_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _initialize_variables(self):

def _create_correctness_constraints(self):
# ensure all computations are possible
for (u, v) in self.g.edge_list:
for u, v in self.g.edge_list:
for t in range(self.T):
self.m += self.R[t][v] <= self.R[t][u] + self.SRam[t][u]
# ensure all checkpoints are in memory
Expand Down
3 changes: 2 additions & 1 deletion poet/poet_solver_gurobi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

# noinspection PyPackageRequirements


# POET ILP defined using Gurobi
class POETSolverGurobi:
def __init__(
Expand Down Expand Up @@ -124,7 +125,7 @@ def _disable_paging(self):

def _create_correctness_constraints(self):
# ensure all computations are possible
for (u, v) in self.g.edge_list:
for u, v in self.g.edge_list:
for t in range(self.T):
self.m.addLConstr(self.R[t, v], GRB.LESS_EQUAL, self.R[t, u] + self.SRam[t, u])
# ensure all checkpoints are in memory
Expand Down
19 changes: 19 additions & 0 deletions poet/transformation_testing/pytorch/test_resnet_graph_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from torch.fx import symbolic_trace
import torchvision
from poet.architectures.graph_transformation import graph_transform

# transforms ResNet Model graph into POET layers nodes

# #Resnet18 model transformation - https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py, commit: 7dc5e5bd60b55eb4e6ea5c1265d6dc7b17d2e917
# traced = symbolic_trace(torchvision.models.resnet18(pretrained=True))
# poet_traced = graph_transform(traced)
# for n in poet_traced.graph.nodes:
# print(n.target)
# print(n.name)

# #Resnet50 model transformation - https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py, commit: 7dc5e5bd60b55eb4e6ea5c1265d6dc7b17d2e917
# traced = symbolic_trace(torchvision.models.resnet50(pretrained=True))
# poet_traced = graph_transform(traced)
# for n in poet_traced.graph.nodes:
# print(n.target)
# print(n.name)
Loading