In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [19]:
x1 = 3*torch.ones((1, 2, 3, 3))
x2 = 2*torch.ones((1, 1, 3, 3))

In [20]:
class ReversibleConv2d(nn.Module):
    def __init__(self, in_channels, activation):
        super().__init__()
        self.in_channels = in_channels
        self.act = activation
        
        self.f = nn.Conv2d(in_channels, in_channels, 3)
        self.g = nn.Conv2d(in_channels, in_channels, 3)
        
    def forward(self, x, requires_grad=False):
        if requires_grad:
            return self._forward(x)
        else:
            with torch.no_grad():
                return self._forward(x)
            
    def _forward(self, x):
        n, c, h, w = x.shape
        padding = torch.zeros((n, 2*self.in_channels - c, h, w))
        x = torch.cat([x, padding], axis=1)

        x1, x2 = x[:, :self.in_channels], x[:, self.in_channels:]

        y2 = x2 + self.act(self.f(x1))
        y1 = x1 + self.act(self.g(y2))

        y = torch.cat([y1, y2], axis=1)
        return y
    
    def reverse(self, y):
        with torch.no_grad():
            n, c, h, w = y.shape
            y1, y2 = y[:, :self.in_channels], y[:, self.in_channels:]

            x1 = y1 - self.act(self.g(y2))
            x2 = y2 - self.act(self.f(x1))

            return torch.cat([x1, x2], axis=1)
    
rev = ReversibleConv2d(2, F.relu)

with torch.no_grad():
    y = rev(torch.cat([x1, x2], axis=1))
    x = rev.reverse(y)
    
x

tensor([[[[3., 3., 3.],
          [3., 3., 3.],
          [3., 3., 3.]],

         [[3., 3., 3.],
          [3., 3., 3.],
          [3., 3., 3.]],

         [[2., 2., 2.],
          [2., 2., 2.],
          [2., 2., 2.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]]])

In [21]:
y

tensor([[[[3.0535, 3.0535, 3.0535],
          [3.0535, 3.0535, 3.0535],
          [3.0535, 3.0535, 3.0535]],

         [[3.0000, 3.0000, 3.0000],
          [3.0000, 3.0000, 3.0000],
          [3.0000, 3.0000, 3.0000]],

         [[2.4886, 2.4886, 2.4886],
          [2.4886, 2.4886, 2.4886],
          [2.4886, 2.4886, 2.4886]],

         [[0.0408, 0.0408, 0.0408],
          [0.0408, 0.0408, 0.0408],
          [0.0408, 0.0408, 0.0408]]]])

In [22]:
y.shape

torch.Size([1, 4, 3, 3])

In [72]:
nn.View?

Object `nn.View` not found.


In [74]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.features = nn.Sequential(
            ReversibleConv2d(2, F.relu),
            ReversibleConv2d(2, F.relu),
            ReversibleConv2d(2, F.relu),
        )
        self.classifier = nn.Sequential(
            nn.Linear(36, 10),
            nn.Softmax(dim=1),
        )
        #self.rconv1 = ReversibleConv2d(2, F.relu)
        #self.rconv2 = ReversibleConv2d(2, F.relu)
        #self.rconv3 = ReversibleConv2d(2, F.relu)
        
        #self.fc = nn.Linear(36, 10)
        
    def forward(self, x):
        #x = self.rconv1(x)
        #x = self.rconv2(x)
        #x = self.rconv3(x)
        
        act = self.features(x)
        x = act.view(act.shape[0], -1)
        
        return act, self.classifier(x)
        
    
model = Model()
y, logits = model(torch.cat([x1, x2], axis=1))
    
y

tensor([[[[3.4771, 3.4771, 3.4771],
          [3.4771, 3.4771, 3.4771],
          [3.4771, 3.4771, 3.4771]],

         [[4.3360, 4.3360, 4.3360],
          [4.3360, 4.3360, 4.3360],
          [4.3360, 4.3360, 4.3360]],

         [[2.7324, 2.7324, 2.7324],
          [2.7324, 2.7324, 2.7324],
          [2.7324, 2.7324, 2.7324]],

         [[3.9817, 3.9817, 3.9817],
          [3.9817, 3.9817, 3.9817],
          [3.9817, 3.9817, 3.9817]]]])

In [49]:
model = Model()
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters())
model(torch.cat([x1, x2], axis=1))

(tensor([[[[5.9277, 5.9277, 5.9277],
           [5.9277, 5.9277, 5.9277],
           [5.9277, 5.9277, 5.9277]],
 
          [[3.5798, 3.5798, 3.5798],
           [3.5798, 3.5798, 3.5798],
           [3.5798, 3.5798, 3.5798]],
 
          [[5.4315, 5.4315, 5.4315],
           [5.4315, 5.4315, 5.4315],
           [5.4315, 5.4315, 5.4315]],
 
          [[0.2333, 0.2333, 0.2333],
           [0.2333, 0.2333, 0.2333],
           [0.2333, 0.2333, 0.2333]]]]),
 tensor([[9.3236e-01, 3.2724e-02, 2.4073e-03, 6.9962e-03, 1.5547e-03, 2.7069e-03,
          3.6154e-03, 1.6311e-02, 1.0564e-03, 2.6527e-04]],
        grad_fn=<SoftmaxBackward>))

In [50]:
for e in range(50):
    model.zero_grad()
    y, logits = model(torch.cat([x1, x2], axis=1))
    #loss = criterion(logits, torch.LongTensor([0]))
    #loss.backward()
    #optimizer.step()
    
    #print(model.fc.weight.grad.shape)
    x = model.rconv3.reverse(y)
    y_ = model.rconv3(x, requires_grad=True)
    #loss = criterion(logits, torch.LongTensor([0]))
    logits = F.softmax(model.fc(y_.view(x.shape[0], -1)), dim=1)
    #print(logits)
    loss = criterion(logits, torch.LongTensor([0]))
    loss.backward(retain_graph=True)
    #print(model.rconv3.f.data)
    
    for param in model.rconv3.f.parameters():
        print(param)
        print(param.grad)
    #print(model.rconv3.f.grad)
    #print(loss.grad)
    #print(y_.shape)
    #print(model.rconv3.g.weight.shape)
    #print(out_with_grad)
    #x = model.rconv2.reverse(x)
    #out_with_grad = model.rconv2(x, requires_grad=True)
    break
    #optimizer.step()

#model(torch.cat([x1, x2], axis=1))

AttributeError: 'Model' object has no attribute 'rconv3'

In [40]:
for param in model.rconv2.f.parameters():
    print(param)
    print(param.grad)

Parameter containing:
tensor([[[[ 0.0019, -0.0131,  0.0202],
          [ 0.1865, -0.0143,  0.0117],
          [ 0.0336, -0.2099, -0.1783]],

         [[ 0.1966,  0.1406,  0.0633],
          [ 0.1767, -0.0057,  0.0352],
          [ 0.0332, -0.1323, -0.0047]]],


        [[[ 0.0219,  0.0593,  0.0134],
          [-0.2146,  0.1347,  0.2296],
          [-0.0963,  0.0125, -0.1721]],

         [[-0.0638, -0.1680,  0.0737],
          [-0.0928, -0.1373,  0.0640],
          [-0.0246,  0.0871,  0.0221]]]], requires_grad=True)
None
Parameter containing:
tensor([-0.0235, -0.0354], requires_grad=True)
None


In [42]:
x = model.rconv3.reverse(y)
ygrad = model.rconv3(x, requires_grad=True)
x

tensor([[[[4.3842, 4.3842, 4.3842],
          [4.3842, 4.3842, 4.3842],
          [4.3842, 4.3842, 4.3842]],

         [[4.8098, 4.8098, 4.8098],
          [4.8098, 4.8098, 4.8098],
          [4.8098, 4.8098, 4.8098]],

         [[2.0000, 2.0000, 2.0000],
          [2.0000, 2.0000, 2.0000],
          [2.0000, 2.0000, 2.0000]],

         [[3.8432, 3.8432, 3.8432],
          [3.8432, 3.8432, 3.8432],
          [3.8432, 3.8432, 3.8432]]]])

In [43]:
ygrad

tensor([[[[4.8032, 4.8032, 4.8032],
          [4.8032, 4.8032, 4.8032],
          [4.8032, 4.8032, 4.8032]],

         [[8.0591, 8.0591, 8.0591],
          [8.0591, 8.0591, 8.0591],
          [8.0591, 8.0591, 8.0591]],

         [[4.4444, 4.4444, 4.4444],
          [4.4444, 4.4444, 4.4444],
          [4.4444, 4.4444, 4.4444]],

         [[4.5945, 4.5945, 4.5945],
          [4.5945, 4.5945, 4.5945],
          [4.5945, 4.5945, 4.5945]]]], grad_fn=<CatBackward>)

In [34]:
x = model.rconv2.reverse(x)
x

tensor([[[[3.0000, 3.0000, 3.0000],
          [3.0000, 3.0000, 3.0000],
          [3.0000, 3.0000, 3.0000]],

         [[3.0000, 3.0000, 3.0000],
          [3.0000, 3.0000, 3.0000],
          [3.0000, 3.0000, 3.0000]],

         [[2.0000, 2.0000, 2.0000],
          [2.0000, 2.0000, 2.0000],
          [2.0000, 2.0000, 2.0000]],

         [[0.7503, 0.7503, 0.7503],
          [0.7503, 0.7503, 0.7503],
          [0.7503, 0.7503, 0.7503]]]], grad_fn=<CatBackward>)

In [35]:
x = model.rconv1.reverse(x)
x

tensor([[[[ 3.0000e+00,  3.0000e+00,  3.0000e+00],
          [ 3.0000e+00,  3.0000e+00,  3.0000e+00],
          [ 3.0000e+00,  3.0000e+00,  3.0000e+00]],

         [[ 3.0000e+00,  3.0000e+00,  3.0000e+00],
          [ 3.0000e+00,  3.0000e+00,  3.0000e+00],
          [ 3.0000e+00,  3.0000e+00,  3.0000e+00]],

         [[ 2.0000e+00,  2.0000e+00,  2.0000e+00],
          [ 2.0000e+00,  2.0000e+00,  2.0000e+00],
          [ 2.0000e+00,  2.0000e+00,  2.0000e+00]],

         [[-5.9605e-08, -5.9605e-08, -5.9605e-08],
          [-5.9605e-08, -5.9605e-08, -5.9605e-08],
          [-5.9605e-08, -5.9605e-08, -5.9605e-08]]]], grad_fn=<CatBackward>)

In [12]:
y.shape

torch.Size([1, 6, 3, 3])

In [17]:
criterion = nn.NLLLoss()

f = nn.Conv2d(3, 3, (1, 1))
g = nn.Linear(27, 10)

x1 = torch.randn(1, 3, 3, 3, requires_grad=True)
x2 = f(x1)
x3 = x2.view(1, 27)
x4 = g(x3)
x5 = F.softmax(x4, dim=1)

loss = criterion(x5, torch.LongTensor([0]))
loss.backward(retain_graph=True)

lr = 0.001

for param in f.parameters():
    print(param)
    param = param - lr * param.grad
    print(param)
    print()
    
for param in g.parameters():
    param = param - lr * param.grad

Parameter containing:
tensor([[[[-0.1442]],

         [[ 0.4904]],

         [[-0.1310]]],


        [[[-0.0747]],

         [[-0.1755]],

         [[ 0.1917]]],


        [[[ 0.4958]],

         [[-0.0673]],

         [[ 0.1109]]]], requires_grad=True)
tensor([[[[-0.1442]],

         [[ 0.4903]],

         [[-0.1309]]],


        [[[-0.0747]],

         [[-0.1755]],

         [[ 0.1917]]],


        [[[ 0.4958]],

         [[-0.0673]],

         [[ 0.1109]]]], grad_fn=<SubBackward0>)

Parameter containing:
tensor([-0.1685, -0.5070,  0.4180], requires_grad=True)
tensor([-0.1684, -0.5070,  0.4180], grad_fn=<SubBackward0>)



In [15]:
for param in g.parameters():
    print('param', param)
    print('grad', param.grad)

param Parameter containing:
tensor([[-0.1507,  0.0685,  0.1473, -0.1457, -0.0130, -0.0834,  0.0090, -0.1046,
          0.0056,  0.1236,  0.0417,  0.0204, -0.0993,  0.0575,  0.0794,  0.1457,
         -0.1321, -0.1521,  0.1724,  0.0097,  0.1883, -0.0710, -0.1310,  0.0480,
          0.0698,  0.0086,  0.0032],
        [ 0.1253, -0.1064,  0.1887,  0.1920, -0.0440,  0.1443,  0.0686, -0.0352,
         -0.0421,  0.1082,  0.1216, -0.0529, -0.0940, -0.0402, -0.0366, -0.0019,
          0.1129, -0.1870,  0.0554,  0.1879, -0.0155,  0.0100, -0.0096, -0.0840,
         -0.1593, -0.1181, -0.1414],
        [-0.0529,  0.1347,  0.1749, -0.1731, -0.1388,  0.1180, -0.1436, -0.1313,
         -0.0232,  0.0972,  0.1711,  0.0590, -0.1325, -0.1796, -0.0495,  0.1057,
          0.1130,  0.1734, -0.1355,  0.0049, -0.1907, -0.1782,  0.0547, -0.1220,
         -0.1751, -0.1861, -0.0442],
        [-0.0795, -0.0977,  0.1117, -0.0756,  0.0644,  0.0334, -0.1375,  0.0481,
          0.0871, -0.0133, -0.1403, -0.1398, -0.189

In [57]:
for layer in model.features:
    for fn in [layer.f, layer.g]:
        for param in fn.parameters():
            print(param)
            print(param.grad)


Parameter containing:
tensor([[[[ 0.1631, -0.1311,  0.1842],
          [ 0.0904,  0.1047,  0.1307],
          [ 0.1567,  0.1591,  0.1095]],

         [[ 0.2218, -0.0840,  0.1019],
          [-0.1328,  0.0155,  0.0971],
          [-0.0611,  0.0068, -0.1498]]],


        [[[-0.0038, -0.2094,  0.1477],
          [ 0.1900,  0.0107, -0.1863],
          [-0.0667, -0.0774, -0.0917]],

         [[-0.0670, -0.1002, -0.1774],
          [-0.1627, -0.0823,  0.1866],
          [-0.0135, -0.1808,  0.0979]]]], requires_grad=True)
None
Parameter containing:
tensor([-0.2074,  0.1376], requires_grad=True)
None
Parameter containing:
tensor([[[[-0.1708,  0.1849, -0.0056],
          [-0.0337,  0.1618,  0.0254],
          [-0.0699, -0.1940, -0.0088]],

         [[ 0.1161,  0.0454, -0.0059],
          [-0.0762, -0.2145,  0.0576],
          [-0.1632,  0.0411, -0.1934]]],


        [[[-0.1760, -0.1316, -0.1121],
          [-0.1558,  0.0052,  0.0089],
          [-0.1978,  0.0790,  0.2126]],

         [[ 0.2056,

In [47]:
model

Model(
  (rconv1): ReversibleConv2d(
    (f): Conv2d(2, 2, kernel_size=(3, 3), stride=(1, 1))
    (g): Conv2d(2, 2, kernel_size=(3, 3), stride=(1, 1))
  )
  (rconv2): ReversibleConv2d(
    (f): Conv2d(2, 2, kernel_size=(3, 3), stride=(1, 1))
    (g): Conv2d(2, 2, kernel_size=(3, 3), stride=(1, 1))
  )
  (rconv3): ReversibleConv2d(
    (f): Conv2d(2, 2, kernel_size=(3, 3), stride=(1, 1))
    (g): Conv2d(2, 2, kernel_size=(3, 3), stride=(1, 1))
  )
  (fc): Linear(in_features=36, out_features=10, bias=True)
)

In [130]:
import numpy as np

In [149]:
torch.manual_seed(42)
x = torch.randn(2, 3, 3, 3)
y = torch.LongTensor([5, 7])

print(x)
model = Model()
criterion = nn.NLLLoss()
lr = 0.01
n = len(model.features)

epochs = 100
for epoch in range(epochs):
    act, logits = model(x)
    loss = criterion(logits, y)
    loss.backward()

    for param in model.classifier.parameters():
        param.data -= lr*param.grad
        
    for i, layer in enumerate(model.features[::-1]):
        model.zero_grad()
        act = layer.reverse(act)

        for j in range(n-i-1, n):
            if j == n-i-1:
                out = model.features[j].requires_grad_(True)(act, True)
            else:
                out = model.features[j].requires_grad_(False)(out, True)

        out = out.view(out.shape[0], -1)
        logits = model.classifier(out)

        loss = criterion(logits, y)
        loss.backward()

        for fn in [layer.f, layer.g]:
            for param in fn.parameters():
                param.data -= lr*param.grad


tensor([[[[ 1.9269,  1.4873,  0.9007],
          [-2.1055,  0.6784, -1.2345],
          [-0.0431, -1.6047, -0.7521]],

         [[ 1.6487, -0.3925, -1.4036],
          [-0.7279, -0.5594, -0.7688],
          [ 0.7624,  1.6423, -0.1596]],

         [[-0.4974,  0.4396, -0.7581],
          [ 1.0783,  0.8008,  1.6806],
          [ 1.2791,  1.2964,  0.6105]]],


        [[[ 1.3347, -0.2316,  0.0418],
          [-0.2516,  0.8599, -1.3847],
          [-0.8712, -0.2234,  1.7174]],

         [[ 0.3189, -0.4245, -0.8140],
          [-0.7360, -0.8371, -0.9224],
          [ 1.8113,  0.1606,  0.3672]],

         [[ 0.1754, -1.1845,  1.3835],
          [-1.2024,  0.7078, -1.0759],
          [ 0.5357,  1.1754,  0.5612]]]])


In [150]:
act

tensor([[[[ 1.9269,  1.4873,  0.9007],
          [-2.1055,  0.6784, -1.2345],
          [-0.0431, -1.6047, -0.7521]],

         [[ 1.6487, -0.3925, -1.4036],
          [-0.7279, -0.5594, -0.7688],
          [ 0.7624,  1.6423, -0.1596]],

         [[-0.4974,  0.4396, -0.7581],
          [ 1.0783,  0.8008,  1.6806],
          [ 1.2791,  1.2964,  0.6105]],

         [[ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]]],


        [[[ 1.3347, -0.2316,  0.0418],
          [-0.2516,  0.8599, -1.3847],
          [-0.8712, -0.2234,  1.7174]],

         [[ 0.3189, -0.4245, -0.8140],
          [-0.7360, -0.8371, -0.9224],
          [ 1.8113,  0.1606,  0.3672]],

         [[ 0.1754, -1.1845,  1.3835],
          [-1.2024,  0.7078, -1.0759],
          [ 0.5357,  1.1754,  0.5612]],

         [[ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]]]])

In [151]:
model(x)

(tensor([[[[ 4.0794,  3.6398,  3.0532],
           [ 0.0470,  2.8309,  0.9179],
           [ 2.1094,  0.5478,  1.4004]],
 
          [[ 2.4458,  0.4046, -0.6065],
           [ 0.0692,  0.2376,  0.0282],
           [ 1.5595,  2.4394,  0.6375]],
 
          [[ 2.8942,  3.8312,  2.6335],
           [ 4.4699,  4.1924,  5.0722],
           [ 4.6707,  4.6880,  4.0021]],
 
          [[ 0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000],
           [ 0.0000,  0.0000,  0.0000]]],
 
 
         [[[ 3.7319,  2.1655,  2.4389],
           [ 2.1456,  3.2570,  1.0125],
           [ 1.5259,  2.1738,  4.1145]],
 
          [[ 1.6706,  0.9272,  0.5377],
           [ 0.6157,  0.5146,  0.4293],
           [ 3.1630,  1.5123,  1.7190]],
 
          [[ 1.4297,  0.0698,  2.6378],
           [ 0.0519,  1.9621,  0.1784],
           [ 1.7900,  2.4297,  1.8155]],
 
          [[ 0.4925,  0.4925,  0.4925],
           [ 0.4925,  0.4925,  0.4925],
           [ 0.4925,  0.4925,  0.4925]]]]),
 tensor([[3.