# Pylops - torch operator

### Author: M.Ravasi

In this notebook I will show how to use the `TorchOperator` to mix and match pylops and pytorch operators into an AD-friendy chain of operations

In [4]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

#import warnings
#warnings.filterwarnings('ignore')

import numpy as np
import matplotlib.pyplot as plt
import scipy as sp
import torch
import torch.nn as nn

from torch.autograd import gradcheck
from pylops.torchoperator import TorchOperator
from pylops.basicoperators import *
from pylops.signalprocessing import Convolve2D

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


ModuleNotFoundError: No module named 'pylops.torchoperator'

## AD

Single batch

In [None]:
nx, ny = 10, 6
x0 = torch.arange(nx, dtype=torch.double, requires_grad=True)

# Forward
A = np.random.normal(0., 1., (ny, nx))
Aop = TorchOperator(MatrixMult(A))
y = Aop.apply(torch.sin(x0))

# AD
v = torch.ones(ny, dtype=torch.double)
y.backward(v, retain_graph=True)
adgrad = x0.grad

# Analytical
At = torch.from_numpy(A)
#J = (At * torch.cos(x0))
J = (At * torch.cos(x0))
print(J.shape)
anagrad = torch.matmul(J.T, v)

print('Input: ', x0)
print('AD gradient: ', adgrad)
print('Analytical gradient: ', anagrad)

# Grad check
input = (torch.arange(nx, dtype=torch.double, requires_grad=True),
         Aop.matvec, Aop.rmatvec, Aop.device, 'cpu')
test = gradcheck(Aop.Top, input, eps=1e-6, atol=1e-4)
print(test)

Multi batch, we should get here to sum of gradients

In [None]:
nbatch, nx, ny = 5, 3, 6
x0 = torch.arange(nbatch * nx, dtype=torch.float).reshape(nbatch, nx)
x0.requires_grad=True

# Forward
A = np.random.normal(0., 1., (ny, nx)).astype(np.float32)
Aop = TorchOperator(MatrixMult(A), batch=True)
y = Aop.apply(torch.sin(x0))

# AD
v = torch.ones((nbatch, ny), dtype=torch.float32)
y.backward(v, retain_graph=True)
adgrad = x0.grad
print('AD gradient: ', adgrad)

# Analytical
x0.grad.data.zero_()
At = torch.from_numpy(A)
Lin = nn.Linear(nx, ny, bias=False)
Lin.weight.data[:] = At.float()
y1 = Lin(torch.sin(x0))
y1.backward(v, retain_graph=True)
anagrad = x0.grad

print('Analytical gradient: ', anagrad)

In [None]:
nbatch, nx, ny = 5, 3, 6
x0 = torch.arange(nbatch*nx, dtype=torch.float).reshape(nbatch, nx).requires_grad_()

# Forward
A = np.random.normal(0., 1., (ny, nx)).astype(np.float32)
Aop = TorchOperator(MatrixMult(A), batch=True)
y = Aop.apply(torch.sin(x0))
l = torch.mean(y**2)
l.backward()
adgrad = x0.grad
print('AD gradient: ', adgrad)

# Analytical
x1 = torch.arange(nbatch*nx, dtype=torch.float).reshape(nbatch, nx).requires_grad_()
At = torch.from_numpy(A)
Lin = nn.Linear(nx, ny, bias=False)
Lin.weight.data[:] = At.float()
y1 = Lin(torch.sin(x1))
l1 = torch.mean(y1**2)
l1.backward()
anagrad = x1.grad

print('Analytical gradient: ', anagrad)

## Mixing NN and Physics

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
device.type

In [None]:
class Network(nn.Module):
    def __init__(self, input_channels):
        super(Network, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, input_channels // 2, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(input_channels // 2, input_channels // 4, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(input_channels // 4, input_channels // 8, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(input_channels // 8, input_channels // 32, kernel_size=3, padding=1)
        self.activation = nn.LeakyReLU(0.2)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.activation(x)
        x = self.conv2(x)
        x = self.activation(x)
        x = self.conv3(x)
        x = self.activation(x)
        x = self.conv4(x)
        x = self.activation(x)
        return x

In [None]:
net_cpu = Network(32)
net_gpu = Network(32)
net_gpu.to(device)

In [None]:
# CPU
h = np.ones((4, 4))
Pop = Convolve2D(dims=(128, 128), h=h)
Pop_torch_cpu = TorchOperator(Pop, device='cpu')

# forward
%timeit -n2 -r2 Pop_torch_cpu.apply(net_cpu(torch.ones((1, 32, 128, 128))).view(-1))

# backward
y = Pop_torch_cpu.apply(net_cpu(torch.ones((1, 32, 128, 128))).view(-1))
loss = y.sum()
%timeit -n1 -r1 loss.backward()

In [None]:
# GPU
h = np.ones((4, 4))
Pop = Convolve2D(dims=(128, 128), h=h)
Pop_torch_gpu = TorchOperator(Pop, device=device)
Pop_torch_gpu.apply(net_gpu(torch.ones((1, 32, 128, 128)).to(device)).view(-1)) # dry run

%timeit -n2 -r2 Pop_torch_gpu.apply(net_gpu(torch.ones((1, 32, 128, 128)).to(device)).view(-1))

# backward
y = Pop_torch_gpu.apply(net_gpu(torch.ones((1, 32, 128, 128)).to(device)).view(-1))
loss = y.sum()
%timeit -n1 -r1 loss.backward()

In [None]:
# Mixed (currently not allowed!)
h = np.ones((4, 4))
Pop = Convolve2D(dims=(128, 128), h=h)
Pop_torch_cpu = TorchOperator(Pop, device='cpu', devicetorch=device.type)

# forward
Pop_torch_cpu.apply(net_gpu(torch.ones((1, 32, 128, 128)).to(device)).view(-1)) # dry run
%timeit -n2 -r2 Pop_torch_cpu.apply(net_gpu(torch.ones((1, 32, 128, 128)).to(device)).view(-1))

# backward
y = Pop_torch_cpu.apply(net_gpu(torch.ones((1, 32, 128, 128)).to(device)).view(-1))
loss = y.sum()
%timeit -n1 -r1 loss.backward()