# Example of Graph Neural Network

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# System imports
import os
import sys
import yaml

# External imports
import matplotlib.pyplot as plt
import scipy as sp
from sklearn.decomposition import PCA
from sklearn.metrics import auc
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
import contextlib

sys.path.append("..")
device = "cuda" if torch.cuda.is_available() else "cpu"

from fairscale.experimental.nn.offload import OffloadModel

# GNN

### Roadmap

1. Make simple GNN (GCN)
2. Tweak to offloaded (checkpointing as HP)
3. Measure memory usage

## GCN

In [3]:
import torch
from torch.utils.data.dataloader import DataLoader
from torchvision.datasets import FakeData
from torchvision.transforms import ToTensor

from fairscale.experimental.nn.offload import OffloadModel

device = torch.device("cuda")

In [4]:
num_graphs = 10
num_nodes = 100000
num_input = 3
num_edges = 500000
num_outputs = 1

num_hidden = 64
num_layers = 6

batch_size = 1

nodes = torch.rand(num_graphs, num_nodes, num_input)
edges = torch.randint(num_nodes, (num_graphs, 2, num_edges))
truth = torch.round(torch.rand(num_graphs, num_edges))
graph_data = torch.utils.data.TensorDataset(nodes, edges, truth)
dataloader = DataLoader(graph_data, batch_size=batch_size)

In [5]:
def _get_fp16_context(use_fp16=False):
    if use_fp16:
        return torch.cuda.amp.autocast()
    else:
        return contextlib.nullcontext()

In [12]:
from torch_scatter import scatter_add
from LightningModules.GNN.utils import make_mlp


class GCN(torch.nn.Module):
    def __init__(
        self,
        num_input,
        num_hidden,
        num_layers,
        n_iters,
        num_slices=3,
        offload=False,
        checkpoint=False,
        fp16=False,
    ):
        super().__init__()
        """
        Initialise the Lightning Module that can scan over different GNN training regimes
        """

        self.node_network = make_mlp(
            (num_hidden) * 2, [num_hidden] * num_layers, layer_norm=False
        )

        self.input_network = make_mlp(
            num_input, [num_hidden] * num_layers, layer_norm=False
        )

        self.edge_network = make_mlp(
            (num_hidden) * 2,
            [num_hidden] * num_layers + [1],
            output_activation=None,
            layer_norm=False,
        )

        self.output_network = make_mlp(
            (num_hidden) * 2,
            [num_hidden] * num_layers + [1],
            output_activation=None,
            layer_norm=False,
        )

        self.n_iters = n_iters

        if offload:

            self.node_network = OffloadModel(
                model=self.node_network,
                device=torch.device("cuda"),
                offload_device=torch.device("cpu"),
                num_slices=num_slices,
                checkpoint_activation=True,
                num_microbatches=1,
            )

            self.input_network = OffloadModel(
                model=self.input_network,
                device=torch.device("cuda"),
                offload_device=torch.device("cpu"),
                num_slices=num_slices,
                checkpoint_activation=checkpoint,
                num_microbatches=1,
            )

            self.edge_network = OffloadModel(
                model=self.edge_network,
                device=torch.device("cuda"),
                offload_device=torch.device("cpu"),
                num_slices=num_slices,
                checkpoint_activation=False,
                num_microbatches=1,
            )

            self.output_network = OffloadModel(
                model=self.output_network,
                device=torch.device("cuda"),
                offload_device=torch.device("cpu"),
                num_slices=num_slices,
                checkpoint_activation=checkpoint,
                num_microbatches=1,
            )

    def forward(self, x, edge_index):
        start, end = edge_index

        input_x = self.input_network(x)

        # Loop over iterations of edge and node networks
        for i in range(self.n_iters):

            edge_inputs = torch.cat([input_x[start], input_x[end]], dim=-1)
            e = self.edge_network(edge_inputs)
            e = torch.sigmoid(e)

            #             print(e.shape, input_x.shape, start.shape, input_x[start].shape)

            # Apply node network
            messages = scatter_add(
                e * input_x[start], end, dim=0, dim_size=input_x.shape[0]
            ) + scatter_add(e * input_x[end], start, dim=0, dim_size=input_x.shape[0])

            node_inputs = torch.cat([messages, input_x], dim=1)
            input_x = self.node_network(node_inputs)

        edge_inputs = torch.cat([input_x[start], input_x[end]], dim=-1)
        output = self.output_network(edge_inputs)

        return output

In [7]:
from torch_scatter import scatter_add
from LightningModules.GNN.utils import make_mlp
from torch.utils.checkpoint import checkpoint


class CheckGCN(torch.nn.Module):
    def __init__(
        self,
        num_input,
        num_hidden,
        num_layers,
        n_iters,
        offload=False,
        checkpoint=False,
        fp16=False,
    ):
        super().__init__()
        """
        Initialise the Lightning Module that can scan over different GNN training regimes
        """

        self.node_network = make_mlp(
            (num_hidden) * 2, [num_hidden] * num_layers, layer_norm=False
        )

        self.input_network = make_mlp(
            num_input, [num_hidden] * num_layers, layer_norm=False
        )

        self.edge_network = make_mlp(
            (num_hidden) * 2,
            [num_hidden] * num_layers + [1],
            output_activation=None,
            layer_norm=False,
        )

        self.output_network = make_mlp(
            (num_hidden) * 2,
            [num_hidden] * num_layers + [1],
            output_activation=None,
            layer_norm=False,
        )

        self.n_iters = n_iters

        if offload:

            self.node_network = OffloadModel(
                model=self.node_network,
                device=torch.device("cuda"),
                offload_device=torch.device("cpu"),
                num_slices=3,
                checkpoint_activation=checkpoint,
                num_microbatches=1,
            )

            self.input_network = OffloadModel(
                model=self.input_network,
                device=torch.device("cuda"),
                offload_device=torch.device("cpu"),
                num_slices=3,
                checkpoint_activation=checkpoint,
                num_microbatches=1,
            )

            self.edge_network = OffloadModel(
                model=self.edge_network,
                device=torch.device("cuda"),
                offload_device=torch.device("cpu"),
                num_slices=3,
                checkpoint_activation=checkpoint,
                num_microbatches=1,
            )

            self.output_network = OffloadModel(
                model=self.output_network,
                device=torch.device("cuda"),
                offload_device=torch.device("cpu"),
                num_slices=3,
                checkpoint_activation=checkpoint,
                num_microbatches=1,
            )

    def forward(self, x, edge_index):
        start, end = edge_index

        input_x = self.input_network(x)

        # Loop over iterations of edge and node networks
        for i in range(self.n_iters):

            edge_inputs = torch.cat([input_x[start], input_x[end]], dim=-1)
            e = checkpoint(self.edge_network, edge_inputs)
            e = torch.sigmoid(e)

            print(e.shape, input_x.shape, start.shape, input_x[start].shape)

            # Apply node network
            messages = scatter_add(
                e * input_x[start], end, dim=0, dim_size=input_x.shape[0]
            ) + scatter_add(e * input_x[end], start, dim=0, dim_size=input_x.shape[0])

            node_inputs = torch.cat([messages, input_x], dim=1)
            input_x = checkpoint(self.node_network, node_inputs)

        edge_inputs = torch.cat([input_x[start], input_x[end]], dim=-1)
        output = self.output_network(edge_inputs)

        return output

## Un-offloaded Memory Usage

In [8]:
model = GCN(num_input, num_hidden, num_layers, 8).to(device)

torch.cuda.set_device(0)
device = torch.device("cuda")

In [9]:
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

In [10]:
torch.cuda.reset_max_memory_allocated()



In [15]:
%%time
for batch in dataloader:
    #     print(batch)
    optimizer.zero_grad()

    x, edges, y = (
        batch[0].squeeze().to(device),
        batch[1].squeeze().to(device),
        batch[2].squeeze().to(device),
    )

    with _get_fp16_context(use_fp16=False):
        output = model(x, edges)
        print(output.shape, y.shape)
        loss = criterion(output.float().squeeze(), target=y.float())
        loss.backward()
    optimizer.step()

    break

torch.Size([500000, 1]) torch.Size([500000])
CPU times: user 61.6 ms, sys: 18.2 ms, total: 79.8 ms
Wall time: 69.7 ms


Un-Checkpointed

In [16]:
print(torch.cuda.max_memory_allocated() / 1024**3, "Gb")

9.519188404083252 Gb


Checkpointed

In [11]:
print(torch.cuda.max_memory_allocated() / 1024**3, "Gb")

5.1315741539001465 Gb


Checkpointed & FP

In [12]:
print(torch.cuda.max_memory_allocated() / 1024**3, "Gb")

2.5914177894592285 Gb


## Offloaded Memory Usage

In [13]:
model = GCN(
    num_input, num_hidden, num_layers, 8, num_slices=3, offload=True, checkpoint=False
)

torch.cuda.set_device(0)
device = torch.device("cuda")

In [14]:
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

In [15]:
torch.cuda.reset_max_memory_allocated()

In [16]:
model.train()
for batch in dataloader:
    #     print(batch)
    optimizer.zero_grad()

    x, edges, y = (
        batch[0].squeeze().to(device),
        batch[1].squeeze().to(device),
        batch[2].squeeze().to(device),
    )

    with _get_fp16_context(use_fp16=False):
        output = model(x, edges)
        print(output.shape, y.shape)
        loss = criterion(output.float().squeeze(), target=y.float())
        loss.backward()
    optimizer.step()

    break

RuntimeError: CUDA out of memory. Tried to allocate 124.00 MiB (GPU 0; 15.78 GiB total capacity; 13.97 GiB already allocated; 109.75 MiB free; 14.32 GiB reserved in total by PyTorch)

Un-checkpointed, 3 slice

In [None]:
print(torch.cuda.max_memory_allocated() / 1024**3, "Gb")

Un-checkpointed, 6 slice

In [11]:
print(torch.cuda.max_memory_allocated() / 1024**3, "Gb")

4.9517316818237305 Gb


Un-checkpointed & FP

In [11]:
print(torch.cuda.max_memory_allocated() / 1024**3, "Gb")

1.2558321952819824 Gb


Checkpointed

In [11]:
print(torch.cuda.max_memory_allocated() / 1024**3, "Gb")

5.300224304199219 Gb


## FP16

In [7]:
model = GCN(num_input, num_hidden, num_layers, 8, offload=True, checkpoint=True)
scaler = torch.cuda.amp.GradScaler()

torch.cuda.set_device(0)
device = torch.device("cuda")

In [8]:
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

In [9]:
torch.cuda.reset_max_memory_allocated()



In [10]:
model.train()
for batch in dataloader:
    #     print(batch)
    optimizer.zero_grad()

    x, edges, y = (
        batch[0].squeeze().to(device),
        batch[1].squeeze().to(device),
        batch[2].squeeze().to(device),
    )

    print(x.shape, edges.shape, y.shape)

    with _get_fp16_context(use_fp16=True):
        output = model(x, edges)
        print(output.shape, y.shape)
        loss = criterion(output.float().squeeze(), target=y.float())

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    print(output)
    break

torch.Size([100000, 3]) torch.Size([2, 1000000]) torch.Size([1000000])
torch.Size([1000000, 1])
torch.Size([1000000, 1]) torch.Size([1000000])


NotImplementedError: Could not run 'aten::_amp_foreach_non_finite_check_and_unscale_' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::_amp_foreach_non_finite_check_and_unscale_' is only available for these backends: [CUDA, BackendSelect, Named, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, UNKNOWN_TENSOR_TYPE_ID, AutogradMLC, AutogradHPU, AutogradNestedTensor, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, Tracer, Autocast, Batched, VmapMode].

CUDA: registered at aten/src/ATen/RegisterCUDA.cpp:20674 [kernel]
BackendSelect: fallthrough registered at ../aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Named: registered at ../aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
ADInplaceOrView: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:60 [backend fallback]
AutogradOther: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:9848 [autograd kernel]
AutogradCPU: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:9848 [autograd kernel]
AutogradCUDA: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:9848 [autograd kernel]
AutogradXLA: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:9848 [autograd kernel]
UNKNOWN_TENSOR_TYPE_ID: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:9848 [autograd kernel]
AutogradMLC: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:9848 [autograd kernel]
AutogradHPU: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:9848 [autograd kernel]
AutogradNestedTensor: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:9848 [autograd kernel]
AutogradPrivateUse1: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:9848 [autograd kernel]
AutogradPrivateUse2: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:9848 [autograd kernel]
AutogradPrivateUse3: registered at ../torch/csrc/autograd/generated/VariableType_0.cpp:9848 [autograd kernel]
Tracer: registered at ../torch/csrc/autograd/generated/TraceType_0.cpp:9750 [kernel]
Autocast: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:255 [backend fallback]
Batched: registered at ../aten/src/ATen/BatchingRegistrations.cpp:1019 [backend fallback]
VmapMode: fallthrough registered at ../aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]


In [None]:
print(torch.cuda.max_memory_allocated() / 1024**3, "Gb")

## Attention Mechanism

In [3]:
from LightningModules.GNN.Models.agnn import ResAGNN
from LightningModules.GNN.Models.vanilla_agnn import VanillaResAGNN
from LightningModules.GNN.utils import make_mlp

In [4]:
with open("example_gnn.yaml") as f:
    hparams = yaml.load(f, Loader=yaml.FullLoader)

In [5]:
model = VanillaResAGNN(hparams)

## Explore

In [6]:
from torch_scatter import scatter_add


class VanillaResAGNN(torch.nn.Module):
    def __init__(self, hparams):
        super().__init__()
        """
        Initialise the Lightning Module that can scan over different GNN training regimes
        """

        self.hparams = hparams
        self.edge_network = make_mlp(
            (hparams["in_channels"] + hparams["hidden"]) * 2,
            [hparams["in_channels"] + hparams["hidden"]] * hparams["nb_edge_layer"]
            + [1],
            hidden_activation=hparams["hidden_activation"],
            output_activation=None,
            layer_norm=hparams["layernorm"],
        )

        self.node_network = make_mlp(
            (hparams["in_channels"] + hparams["hidden"]) * 2,
            [hparams["hidden"]] * hparams["nb_node_layer"],
            hidden_activation=hparams["hidden_activation"],
            output_activation=None,
            layer_norm=hparams["layernorm"],
        )

        self.input_network = make_mlp(
            hparams["in_channels"],
            [hparams["hidden"]] * hparams["nb_node_layer"],
            output_activation=hparams["hidden_activation"],
            layer_norm=hparams["layernorm"],
        )

        self.edge_offload_model = OffloadModel(
            model=self.edge_network,
            device=torch.device("cuda"),
            offload_device=torch.device("cpu"),
            num_slices=10,
            checkpoint_activation=False,
            num_microbatches=1,
        )

        self.node_offload_model = OffloadModel(
            model=self.node_network,
            device=torch.device("cuda"),
            offload_device=torch.device("cpu"),
            num_slices=10,
            checkpoint_activation=False,
            num_microbatches=1,
        )

        self.input_offload_model = OffloadModel(
            model=self.input_network,
            device=torch.device("cuda"),
            offload_device=torch.device("cpu"),
            num_slices=10,
            checkpoint_activation=False,
            num_microbatches=1,
        )

    def forward(self, x, edge_index):
        start, end = edge_index
        input_x = x

        x = self.input_network(x)

        # Shortcut connect the inputs onto the hidden representation
        x = torch.cat([x, input_x], dim=-1)

        # Loop over iterations of edge and node networks
        for i in range(self.hparams["n_graph_iters"]):
            x_inital = x

            # Apply edge network
            edge_inputs = torch.cat([x[start], x[end]], dim=1)
            e = torch.sigmoid(self.edge_network(edge_inputs))

            # Apply node network
            messages = scatter_add(
                e * x[start], end, dim=0, dim_size=x.shape[0]
            ) + scatter_add(e * x[end], start, dim=0, dim_size=x.shape[0])
            node_inputs = torch.cat([messages, x], dim=1)
            x = self.node_network(node_inputs)

            # Shortcut connect the inputs onto the hidden representation
            x = torch.cat([x, input_x], dim=-1)

            # Residual connection
            x = x_inital + x

        edge_inputs = torch.cat([x[start], x[end]], dim=1)
        return self.edge_network(edge_inputs)

In [7]:
vanilla_model = VanillaResAGNN(hparams)
vanilla_model = vanilla_model.to(device)

In [8]:
all_offload_params = (
    list(vanilla_model.input_offload_model.parameters())
    + list(vanilla_model.node_offload_model.parameters())
    + list(vanilla_model.edge_offload_model.parameters())
)
optimizer = torch.optim.AdamW(all_offload_params, lr=0.001)

In [12]:
optimizer.zero_grad()

torch.cuda.reset_peak_memory_stats()
output = vanilla_model(sample.x.to(device), sample.edge_index.to(device))
loss = torch.nn.functional.binary_cross_entropy_with_logits(
    output, torch.ones_like(output)
)
loss.backward()
optimizer.step()

In [13]:
print(torch.cuda.max_memory_allocated() / 1024**3, "Gb")

4.194556713104248 Gb


### Memory Test

In [10]:
%%time
model.setup(stage="fit")

CPU times: user 994 ms, sys: 28.9 ms, total: 1.02 s
Wall time: 129 ms


In [11]:
sample = model.trainset[0].to(device)

In [8]:
model = model.to(device)

In [9]:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
optimizer.zero_grad()

torch.cuda.reset_peak_memory_stats()
output = model(sample.x, sample.edge_index)
loss = torch.nn.functional.binary_cross_entropy_with_logits(
    output, torch.ones_like(output)
)
loss.backward()
optimizer.step()

In [10]:
print(torch.cuda.max_memory_allocated() / 1024**3, "Gb")

4.194556713104248 Gb


### Train GNN

In [None]:
logger = WandbLogger(project="ITk_1GeV_GNN", group="InitialTest")
trainer = Trainer(gpus=1, max_epochs=50, logger=logger)
trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[34m[1mwandb[0m: Currently logged in as: [33mmurnanedaniel[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.0 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Set SLURM handle signals.

  | Name                   | Type       | Params
------------------------------------------------------
0 | node_encoder           | Sequential | 7.1 K 
1 | edge_encoder           | Sequential | 19.3 K
2 | edge_network           | Sequential | 31.9 K
3 | node_network           | Sequential | 31.9 K
4 | output_edge_classifier | Sequential | 25.6 K
------------------------------------------------------
115 K     Trainable params
0         Non-trainable params
115 K     Total params
0.463     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  f'The dataloader, {name}, does not have many workers which may be a bottleneck.'
  eff = torch.tensor(edge_true_positive / edge_true)
  pur = torch.tensor(edge_true_positive / edge_positive)
  f'The dataloader, {name}, does not have many workers which may be a bottleneck.'


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

  "Relying on `self.log('val_loss', ...)` to set the ModelCheckpoint monitor is deprecated in v1.2"


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

# Toy Test

In [None]:
import torch
from torch.utils.data.dataloader import DataLoader
from torchvision.datasets import FakeData
from torchvision.transforms import ToTensor

from fairscale.experimental.nn.offload import OffloadModel

device = torch.device("cuda")

In [None]:
%%time
num_inputs = 100
num_outputs = 10
num_hidden = 1000
num_layers = 100
batch_size = 1000

transform = ToTensor()
dataloader = DataLoader(
    FakeData(
        image_size=(1, num_inputs, num_inputs),
        num_classes=num_outputs,
        transform=transform,
        size=batch_size,
    ),
    batch_size=batch_size,
)

model = torch.nn.Sequential(
    torch.nn.Linear(num_inputs * num_inputs, num_hidden),
    *([torch.nn.Linear(num_hidden, num_hidden) for _ in range(num_layers)]),
    torch.nn.Linear(num_hidden, num_outputs),
)

## Before Offload

In [None]:
%%time
torch.cuda.reset_peak_memory_stats()
model = model.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

for batch_inputs, batch_outputs in dataloader:
    batch_inputs, batch_outputs = batch_inputs.to("cuda"), batch_outputs.to("cuda")
    optimizer.zero_grad()

    inputs = batch_inputs.reshape(-1, num_inputs * num_inputs)
    output = model(inputs)
    loss = criterion(output, target=batch_outputs)
    loss.backward()
    optimizer.step()
    break

In [7]:
print(torch.cuda.max_memory_allocated() / 1024**3, "Gb")

0.7831840515136719 Gb


In [None]:
print(torch.cuda.max_memory_allocated() / 1024**3, "Gb")

In [7]:
print(torch.cuda.max_memory_allocated() / 1024**3, "Gb")

4.269891262054443 Gb


## After Offload

In [5]:
offload_model = OffloadModel(
    model=model,
    device=torch.device("cuda"),
    offload_device=torch.device("cpu"),
    num_slices=3,
    checkpoint_activation=False,
    num_microbatches=1,
)

device = torch.device("cuda")
torch.cuda.set_device(0)

In [6]:
torch.cuda.reset_peak_memory_stats()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001)

# train(model, offload_model)

In [7]:
%%time
# To train 1 epoch.
offload_model.train()
for batch_inputs, batch_outputs in dataloader:
    batch_inputs, batch_outputs = batch_inputs.to("cuda"), batch_outputs.to("cuda")

    optimizer.zero_grad()
    inputs = batch_inputs.reshape(-1, num_inputs * num_inputs)

    output = offload_model(inputs)
    loss = criterion(output, target=batch_outputs)
    loss.backward()
    optimizer.step()

    break

CPU times: user 305 ms, sys: 150 ms, total: 454 ms
Wall time: 429 ms


In [8]:
print(torch.cuda.max_memory_allocated() / 1024**3, "Gb")

0.9275264739990234 Gb


## From Benchmark Script

In [5]:
model = OffloadModel(
    model=model,
    device=torch.device("cuda"),
    offload_device=torch.device("cpu"),
    num_slices=3,
    checkpoint_activation=False,
    num_microbatches=1,
)

device = torch.device("cuda")
torch.cuda.set_device(0)

In [6]:
torch.cuda.reset_peak_memory_stats()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

# train(model, offload_model)

In [7]:
%%time
# To train 1 epoch.
model.train()
for batch_inputs, batch_outputs in dataloader:
    batch_inputs, batch_outputs = batch_inputs.to("cuda"), batch_outputs.to("cuda")

    optimizer.zero_grad()

    inputs = batch_inputs.reshape(-1, num_inputs * num_inputs)

    output = model(inputs)
    loss = criterion(output, target=batch_outputs)
    loss.backward()
    optimizer.step()

    break

CPU times: user 155 ms, sys: 167 ms, total: 321 ms
Wall time: 306 ms


In [8]:
print(torch.cuda.max_memory_allocated() / 1024**3, "Gb")

0.5458278656005859 Gb


In [9]:
print(torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] / 1024**3, "Gb")

0.5458278656005859 Gb


## Exact Reproduction

In [1]:
import torch
from torch.utils.data.dataloader import DataLoader
from torchvision.datasets import FakeData
from torchvision.transforms import ToTensor

from fairscale.experimental.nn.offload import OffloadModel


num_inputs = 8
num_outputs = 8
num_hidden = 4
num_layers = 2
batch_size = 8

transform = ToTensor()
dataloader = DataLoader(
    FakeData(
        image_size=(1, num_inputs, num_inputs),
        num_classes=num_outputs,
        transform=transform,
    ),
    batch_size=batch_size,
)

model = torch.nn.Sequential(
    torch.nn.Linear(num_inputs * num_inputs, num_hidden),
    *([torch.nn.Linear(num_hidden, num_hidden) for _ in range(num_layers)]),
    torch.nn.Linear(num_hidden, num_outputs),
)

In [2]:
offload_model = OffloadModel(
    model=model,
    device=torch.device("cuda"),
    offload_device=torch.device("cpu"),
    num_slices=3,
    checkpoint_activation=True,
    num_microbatches=1,
)

torch.cuda.set_device(0)
device = torch.device("cuda")

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001)

# To train 1 epoch.
offload_model.train()
for batch_inputs, batch_outputs in dataloader:
    batch_inputs, batch_outputs = batch_inputs.to("cuda"), batch_outputs.to("cuda")
    optimizer.zero_grad()
    inputs = batch_inputs.reshape(-1, num_inputs * num_inputs)
    with torch.cuda.amp.autocast():
        output = model(inputs)
        loss = criterion(output, target=batch_outputs)
        loss.backward()
    optimizer.step()

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking arugment for argument mat1 in method wrapper_addmm)