In [1]:
import torch
from torch import nn
import torch.nn.functional as F

In [2]:
class Blk(nn.Module):
    
    def __init__(self, in_feats, out_feats):
        super().__init__()
        self.main = nn.Sequential(
            nn.Linear(in_feats, in_feats),
            nn.BatchNorm1d(in_feats),
            nn.ReLU(inplace=True),
            nn.Linear(in_feats, out_feats),
            nn.BatchNorm1d(out_feats),
        )
        self.skip = nn.Linear(in_feats, out_feats)
        
    def forward(self, x):
        identity = x
        x = self.main(x)
        skip = self.skip(identity)
        result = F.relu(x + skip)
        return result
        
        

K = 5000
net = nn.Sequential(
    Blk(K, K//2),
    Blk(K//2, K),
    Blk(K, K//2),
    Blk(K//2, K),
    Blk(K, K//2),
    Blk(K//2, K),
    Blk(K, K//2),
    Blk(K//2, K),
    
).cuda()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)

In [3]:
!nvidia-smi

Fri Jan 15 20:34:57 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.102.04   Driver Version: 450.102.04   CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 2070    Off  | 00000000:01:00.0 Off |                  N/A |
|  0%   47C    P2    48W / 175W |   2182MiB /  7982MiB |      4%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [4]:
x = torch.ones(2, K).cuda()
x.requires_grad_(True)

tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.]], device='cuda:0', requires_grad=True)

In [5]:
!nvidia-smi

Fri Jan 15 20:34:57 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.102.04   Driver Version: 450.102.04   CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 2070    Off  | 00000000:01:00.0 Off |                  N/A |
|  0%   47C    P2    48W / 175W |   2182MiB /  7982MiB |      4%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [6]:
!nvidia-smi

Fri Jan 15 20:34:58 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.102.04   Driver Version: 450.102.04   CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 2070    Off  | 00000000:01:00.0 Off |                  N/A |
|  0%   47C    P2    48W / 175W |   2182MiB /  7982MiB |     14%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [7]:
u = torch.ones(2, K).cuda()
z = net(u)

In [8]:
!nvidia-smi

Fri Jan 15 20:34:58 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.102.04   Driver Version: 450.102.04   CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 2070    Off  | 00000000:01:00.0 Off |                  N/A |
|  0%   47C    P2    48W / 175W |   2204MiB /  7982MiB |     14%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [9]:
l = z.sum().backward()

In [10]:
! nvidia-smi

Fri Jan 15 20:34:58 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.102.04   Driver Version: 450.102.04   CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 2070    Off  | 00000000:01:00.0 Off |                  N/A |
|  0%   47C    P2    48W / 175W |   3468MiB /  7982MiB |     14%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [12]:
optimizer.step()

In [13]:
! nvidia-smi

Fri Jan 15 20:35:16 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.102.04   Driver Version: 450.102.04   CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 2070    Off  | 00000000:01:00.0 Off |                  N/A |
|  0%   48C    P2    48W / 175W |   6158MiB /  7982MiB |      7%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces