In [1]:
import os
import torch
import torchvision
import tarfile
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torchvision.datasets.utils import download_url
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import torchvision.transforms as tt
from torch.utils.data import random_split
from torchvision.utils import make_grid
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
#import h5py
import scipy.io

matplotlib.rcParams['figure.facecolor'] = '#ffffff'

In [2]:
print(torch.__version__)

1.8.1


In [3]:
#f = h5py.File('somefile.mat','r')
matdata=scipy.io.loadmat('dataForPythonBIMcmplx25dB.mat')

In [4]:
Hatm=matdata['Hatm']
Atm=matdata['Atm']
N=int((matdata['N']))
targetsR=matdata['d_epsaM']
targetsI=np.zeros_like(targetsR)
nMeas=int((matdata['nMeas']))
nRx=int((matdata['nRx']))
nTx=int((matdata['nTx']))
Np=int((matdata['Np']))
Emea=matdata['Emea']
Einc=matdata['Ez_inc']
#a1tm=matdata['a1tm']
#a1tmC=matdata['a1tmC']

In [5]:
targets=np.concatenate((targetsR,targetsI),axis=1)
targets.shape

(70000, 4608)

In [6]:
inputs=Emea
Emea.shape

(70000, 28, 16)

In [7]:
inputs=torch.tensor(inputs,dtype = torch.complex64)
targets=torch.tensor(targets,dtype = torch.float32)
Hatm=torch.tensor(Hatm,dtype = torch.complex64)
Atm=torch.tensor(Atm,dtype = torch.complex128)
Einc=torch.tensor(Einc,dtype = torch.complex128)
#a1tm=torch.tensor(a1tm,dtype = torch.complex64)
#a1tmC=torch.tensor(a1tmC,dtype = torch.complex64)

In [8]:
a1tm=Atm[0,:].reshape(Np,Np)
a1tm=torch.cat((a1tm[range(Np-1,-1,-1),:],a1tm[range(1,Np,1),:]),dim=0)
a1tm=torch.cat((a1tm[:,range(Np-1,-1,-1)],a1tm[:,range(1,Np,1)]),dim=1)
a1tmC=torch.fft.fft2(a1tm)

In [9]:
valSize=2000
testingSize=2000
trainingSize=len(inputs)-(valSize+testingSize)
train_in, val_in, test_in = torch.split(inputs, [trainingSize, valSize, testingSize])
train_tr, val_tr, test_tr = torch.split(targets, [trainingSize, valSize, testingSize])
train_ds = TensorDataset(train_in, train_tr)
val_ds = TensorDataset(val_in, val_tr)
test_ds = TensorDataset(test_in, test_tr)

In [10]:
#print(len(dataset[0:10]))
print(len(train_ds))
print(len(val_ds))
print(len(test_ds))

66000
2000
2000


In [11]:
batch_size = 128
train_dl = DataLoader(train_ds, batch_size, shuffle=True)
val_dl = DataLoader(val_ds, batch_size, shuffle=True)
test_dl = DataLoader(test_ds, batch_size, shuffle=True)

In [12]:
def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')
    
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)

In [13]:
device = get_default_device()
device

device(type='cpu')

In [14]:
train_dl = DeviceDataLoader(train_dl, device)
val_dl = DeviceDataLoader(val_dl, device)
test_dl = DeviceDataLoader(test_dl, device)
Hatm=to_device(Hatm, device)
Atm=to_device(Atm, device)
Einc=to_device(Einc, device)
a1tm=to_device(a1tm, device)
a1tmC=to_device(a1tmC, device)
inputs=to_device(inputs, device)
targets=to_device(targets, device)

In [15]:
class ImageRegressionBase(nn.Module):
    def training_step(self, batch):
        images, targets = batch 
        out = self(images)                  # Generate predictions
        loss = F.mse_loss(out, targets) # Calculate loss
        return loss
    
    def validation_step(self, batch):
        images, targets = batch 
        out = self(images)                    # Generate predictions
        loss = F.mse_loss(out, targets)   # Calculate loss
        return {'val_loss': loss.detach()}
        
    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()   # Combine losses
        return {'val_loss': epoch_loss.item()}
    
    def epoch_end(self, epoch, result):
        print("Epoch [{}], last_lr: {:.5f}, train_loss: {:.4f}, val_loss: {:.4f}".format(
            epoch, result['lrs'][-1], result['train_loss'], result['val_loss']))

In [16]:
negSlop=0.1
def conv_block(in_channels, out_channels):
    layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 
              nn.LeakyReLU(negSlop,inplace=True),
              nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 
              nn.LeakyReLU(negSlop,inplace=True)]
    return nn.Sequential(*layers)

def conv_block_downsampling(in_channels, out_channels):
    layers=nn.ModuleList()
    layers.append(nn.MaxPool2d(2, stride=2))
    layers.append(conv_block(in_channels, out_channels))
    return nn.Sequential(*layers)

def conv_block_upsampling(in_channels, out_channels):
    layers=nn.ModuleList()
    layers.append(conv_block(2*in_channels, in_channels))
    layers.append(nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, bias=False))
    layers.append(nn.LeakyReLU(negSlop,inplace=True))
    return nn.Sequential(*layers)

def conv_block_upsamplingF(in_channels, out_channels):
    layers=nn.ModuleList()
    layers.append(conv_block(2*in_channels, in_channels))
    layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
    layers.append(nn.LeakyReLU(negSlop,inplace=True))
    return nn.Sequential(*layers)

Nfltr=32
class Unet(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels=in_channels
        self.down1 = conv_block(in_channels,Nfltr) #conv_block_downsampling(in_channels, 64,pool=False)
        self.down2 = conv_block_downsampling(Nfltr, 2*Nfltr)
        self.down3 = conv_block_downsampling(2*Nfltr,4*Nfltr)
        self.down4 = nn.ModuleList()
        self.down4.append(conv_block_downsampling(4*Nfltr, 8*Nfltr))
        self.down4.append(nn.ConvTranspose2d(8*Nfltr, 4*Nfltr, kernel_size=2, stride=2, bias=False))
        self.down4.append(nn.LeakyReLU(negSlop,inplace=True))
        self.down4=nn.Sequential(*self.down4)
        
        self.up1=conv_block_upsampling(4*Nfltr, 2*Nfltr)
        self.up2=conv_block_upsampling(2*Nfltr, Nfltr)
        self.up3=conv_block_upsamplingF(Nfltr, in_channels)
        
    def forward(self, d_epsR,d_epsI):
        xb=torch.cat((d_epsR,d_epsI),dim=1)
        xb=torch.reshape(xb,(-1,self.in_channels,Np,Np))
        out1 = self.down1(xb)
        out2=self.down2(out1)
        out3=self.down3(out2)
        out4=self.down4(out3)
        out5=self.up1(torch.cat((out3,out4),dim=1))
        out6=self.up2(torch.cat((out2,out5),dim=1))
        out7=self.up3(torch.cat((out1,out6),dim=1))
        out=torch.reshape(out7,(-1,self.in_channels*N))
        return out[:,0:N], out[:,N:2*N]

In [20]:
def HmatMult(Etot,d_eps):
    nbatch=d_eps.shape[1]
    xout=to_device(torch.zeros((nRx,nbatch,nTx),dtype=torch.complex64), device)
    temp=to_device(torch.zeros((N,nbatch),dtype=torch.complex64), device)
    for tr in range(nTx):
        temp=Etot[:,:,tr]*d_eps
        xout[:,:,tr]=Hatm@temp
    return -xout

def HmatconjMult(Etot,xin):
    nbatch=xin.shape[1]
    xout=to_device(torch.zeros((N,nbatch)), device)
    for tr in range(nTx):
        xout=xout+Etot[:,:,tr].conj()*(Hatm.t().conj()@xin[:,:,tr])
    return -xout


def soft_thresholding(d_eps,delta):
    nbatch=d_eps.shape[1]
    maxi=torch.zeros_like(d_eps)
    maxi=torch.maximum(d_eps.abs()-delta,torch.zeros_like(d_eps.real))
    d_eps=(maxi*d_eps)/(maxi+delta);
    return d_eps

  

def BiCGFFTtm(d_eps,y,nmax):
    y=torch.reshape(y,(N,1))
    nbatch=d_eps.shape[1]
    xout=to_device(torch.zeros((N,nbatch),dtype= torch.complex64), device)
    if torch.sum(torch.abs(d_eps))<=0.0000000000001:
        temp=to_device(torch.zeros((N,nbatch),dtype= torch.complex64), device)
        xout=y-temp
        del temp
    else:
        rhoi=1.0*to_device(torch.ones((nbatch),dtype= torch.complex128), device)
        alpha=1.0*to_device(torch.ones((nbatch),dtype= torch.complex128), device)
        w=1.0*to_device(torch.ones((nbatch),dtype= torch.complex128), device)
        v=to_device(torch.zeros((N,nbatch),dtype= torch.complex128), device)
        p=to_device(torch.zeros((N,nbatch),dtype= torch.complex128), device)
        x=to_device(torch.zeros((N,nbatch),dtype= torch.complex128), device)
        temp=to_device(torch.zeros((N,nbatch)), device)
        temp=AmultFFT(d_eps,x)
        r=y-temp
        rhat0=r
        for cg_iter in range(nmax):
            rhoi_1=rhoi
            rhoi=torch.sum(rhat0*r,dim=0)
            beta=(rhoi/rhoi_1)*(alpha/w)
            p=r+beta*(p-w*v)
            v=AmultFFT(d_eps,p)
            alpha=rhoi/(torch.sum(rhat0*v,dim=0))
            s=r-alpha*v
            t=AmultFFT(d_eps,s)
            w=(torch.sum(t*s,dim=0))/(torch.sum(t*t,dim=0))
            x=x+alpha*p+w*s
            r=s-w*t
        xout=x
    return xout  



def AmultFFT(deps,xi):
    xi=xi.permute(1,0)
    deps=deps.permute(1,0)
    nbatch=deps.shape[0]
    x=to_device(torch.zeros((nbatch,2*Np-1,2*Np-1),dtype=torch.complex128),device)
    x[:,0:Np,0:Np]=torch.reshape(deps*xi,(-1,Np,Np))
    xfft=torch.fft.fft2(x)
    x=a1tmC*xfft
    out=torch.fft.ifft2(x)
    out=out[:,Np-1:2*Np-1,Np-1:2*Np-1]
    out=out.reshape(nbatch,N)
    out=out+xi
    return out.permute(1,0)

def BiCGtmFFTloop(d_eps,Einc,nmax):
    nbatch=d_eps.shape[1]
    Etot=to_device(torch.zeros((N,nbatch,nTx),dtype= torch.complex64), device)
    for tr in range(nTx):
        Etot[:,:,tr]=BiCGFFTtm(d_eps,Einc[:,tr],nmax)
    return Etot


def power(Etot,maxit):
    nbatch=Etot.shape[1]
    x=to_device(torch.rand((N,nbatch),dtype= torch.complex64), device)
    xn=to_device(torch.zeros((N,nbatch),dtype= torch.complex64), device)
    xm=to_device(torch.zeros((N,nbatch),dtype= torch.complex64), device)
    xm = HmatMult(Etot,x)
    xn = HmatconjMult(Etot,xm)
    lammda=torch.sum(x*xn,dim=0)/torch.sum(x*x,dim=0)
    x=x/torch.max(x.abs(),dim=0).values
    for itr in range(maxit):
        xm = HmatMult(Etot,x)
        x = HmatconjMult(Etot,xm)
        xm = HmatMult(Etot,x)
        xn = HmatconjMult(Etot,xm)
        lammda0=lammda
        lammda=torch.sum(x*xn,dim=0)/torch.sum(x*x,dim=0)
        x=x/torch.max(x.abs(),dim=0).values
        error=torch.abs(lammda-lammda0)/torch.abs(lammda)
    return 1/lammda.real

    
class ComputeBornSVD(ImageRegressionBase):
    def __init__(self,regNet1,regNet2,regNet3,nLW,nbim_iter,nmax,maxit):
        super().__init__()
        self.regNet1=regNet1
        self.regNet2=regNet2
        self.regNet3=regNet3
        self.nLW=nLW
        self.nbim_iter=nbim_iter
        self.nmax=nmax
        self.maxit=maxit
        Etot=to_device(torch.zeros((N,1,nTx),dtype= torch.complex64), device)
        Etot[:,0,:]=Einc[:,:]
        self.gamma0=power(Etot,maxit)
        print(self.gamma0)
        
    def forward(self, Emea):
        nbatch=Emea.shape[0]
        d_eps=to_device(torch.zeros((N,nbatch),dtype= torch.complex64), device)
        Etot=to_device(torch.zeros((N,nbatch,nTx),dtype= torch.complex64), device)

        for batch in range(nbatch):
            Etot[:,batch,:]=Einc[:,:]
        
        for bim_itr in range(self.nbim_iter):
            if bim_itr==0:
                gamma=self.gamma0
            else:
                gamma=power(Etot,self.maxit)  
        
            for lw_itr in range(self.nLW):
                xout = HmatMult(Etot,d_eps)
                misfit=Emea.permute(1,0,2)-xout
                xout = HmatconjMult(Etot,misfit)
                d_eps=d_eps+gamma*xout
                if bim_itr==0:
                    d_epsR,d_epsI=self.regNet1(d_eps.permute(1,0).real,d_eps.permute(1,0).imag)
                elif bim_itr==1:
                    d_epsR,d_epsI=self.regNet2(d_eps.permute(1,0).real,d_eps.permute(1,0).imag)
                elif bim_itr==2:
                    d_epsR,d_epsI=self.regNet3(d_eps.permute(1,0).real,d_eps.permute(1,0).imag)
                #d_eps=soft_thresholding(d_eps,0.001)
                 
                d_eps=torch.complex(d_epsR,d_epsI); d_eps=d_eps.permute(1,0)
            if bim_itr<(self.nbim_iter-1):
                Etot=BiCGtmFFTloop(d_eps,Einc, self.nmax)
                
        
        d_eps=d_eps.permute(1,0)
        return torch.cat((d_eps.real,d_eps.imag),dim=1)

In [None]:
# This initiate a model that runs over a single TBIM iteration, nbim_iter=1
regNet1=Unet(in_channels=2)
regNet2=Unet(in_channels=2)
regNet3=Unet(in_channels=2)
model = to_device(ComputeBornSVD(regNet1=regNet1,regNet2=regNet2,regNet3=regNet3,nLW=6,nbim_iter=1,nmax=4,maxit=6), device)

In [26]:
@torch.no_grad()
def evaluate(model, val_loader):
    model.eval()
    outputs = [model.validation_step(batch) for batch in val_loader]
    return model.validation_epoch_end(outputs)

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def fit_one_cycle(epochs, max_lr, model, train_loader, val_loader, 
                  weight_decay=0, grad_clip=None, opt_func=torch.optim.SGD):
    torch.cuda.empty_cache()
    history = []
    
    # Set up cutom optimizer with weight decay
    optimizer = opt_func(model.parameters(), max_lr, weight_decay=weight_decay)
    # Set up one-cycle learning rate scheduler
    sched = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, epochs=epochs, 
                                                steps_per_epoch=len(train_loader))
    
    for epoch in range(epochs):
        # Training Phase 
        model.train()
        train_losses = []
        lrs = []
        for batch in train_loader:
            loss = model.training_step(batch)
            train_losses.append(loss)
            loss.backward()
            
            # Gradient clipping
            if grad_clip: 
                nn.utils.clip_grad_value_(model.parameters(), grad_clip)
            
            optimizer.step()
            optimizer.zero_grad()
            
            # Record & update learning rate
            lrs.append(get_lr(optimizer))
            sched.step()
        
        # Validation phase
        result = evaluate(model, val_loader)
        result['train_loss'] = torch.stack(train_losses).mean().item()
        result['lrs'] = lrs
        model.epoch_end(epoch, result)
        history.append(result)
    return history

In [None]:
history1=[]

In [23]:
#To perform model training that runs over single TBIM iteration
%%time
epochs = 20
max_lr = 0.0005
grad_clip = 0.1
weight_decay = 1e-4
opt_func = torch.optim.Adam
history1 += fit_one_cycle(epochs, max_lr, model, train_dl, val_dl, 
                             opt_func=opt_func)


Epoch [0], last_lr: 0.00005, train_loss: 0.0072, val_loss: 0.0022
Epoch [1], last_lr: 0.00014, train_loss: 0.0017, val_loss: 0.0015
Epoch [2], last_lr: 0.00026, train_loss: 0.0012, val_loss: 0.0011
Epoch [3], last_lr: 0.00038, train_loss: 0.0010, val_loss: 0.0011
Epoch [4], last_lr: 0.00047, train_loss: 0.0008, val_loss: 0.0008
Epoch [5], last_lr: 0.00050, train_loss: 0.0007, val_loss: 0.0007
Epoch [6], last_lr: 0.00049, train_loss: 0.0006, val_loss: 0.0006
Epoch [7], last_lr: 0.00048, train_loss: 0.0005, val_loss: 0.0005
Epoch [8], last_lr: 0.00045, train_loss: 0.0005, val_loss: 0.0005
Epoch [9], last_lr: 0.00041, train_loss: 0.0004, val_loss: 0.0005
Epoch [10], last_lr: 0.00036, train_loss: 0.0004, val_loss: 0.0005
Epoch [11], last_lr: 0.00031, train_loss: 0.0004, val_loss: 0.0004
Epoch [12], last_lr: 0.00025, train_loss: 0.0004, val_loss: 0.0004
Epoch [13], last_lr: 0.00019, train_loss: 0.0003, val_loss: 0.0004
Epoch [14], last_lr: 0.00014, train_loss: 0.0003, val_loss: 0.0004
Epoch

In [36]:
# This initiate a model that runs over two TBIM iterations, nbim_iter=2, the regularization networks of the first iteration step is taken from the previous training
model = to_device(ComputeBornSVD(regNet1=model.regNet1,regNet2=model.regNet2,regNet3=model.regNet3,nLW=6,nbim_iter=2,nmax=4,maxit=5), device)

tensor([0.0001], device='cuda:0')


In [None]:
#Freeze the first TBIM regularization Network
for param in model.regNet1.parameters():
    param.requires_grad=False

In [None]:
history2=[]

In [27]:
# Train the regulrization network for only the second TBIM, and use the prevoius trained network for the first TBIM iteration 
%%time
epochs = 10
max_lr = 0.00025
grad_clip = 0.1
weight_decay = 1e-4
opt_func = torch.optim.Adam
history2 += fit_one_cycle(epochs, max_lr, model, train_dl, val_dl, 
                             opt_func=opt_func)

Epoch [0], last_lr: 0.00007, train_loss: 0.0006, val_loss: 0.0004
Epoch [1], last_lr: 0.00019, train_loss: 0.0003, val_loss: 0.0003
Epoch [2], last_lr: 0.00025, train_loss: 0.0003, val_loss: 0.0004
Epoch [3], last_lr: 0.00024, train_loss: 0.0003, val_loss: 0.0003
Epoch [4], last_lr: 0.00020, train_loss: 0.0002, val_loss: 0.0003
Epoch [5], last_lr: 0.00015, train_loss: 0.0002, val_loss: 0.0003
Epoch [6], last_lr: 0.00010, train_loss: 0.0002, val_loss: 0.0003
Epoch [7], last_lr: 0.00005, train_loss: 0.0002, val_loss: 0.0003
Epoch [8], last_lr: 0.00001, train_loss: 0.0002, val_loss: 0.0003
Epoch [9], last_lr: 0.00000, train_loss: 0.0002, val_loss: 0.0002
Wall time: 3h 4min 46s


In [37]:
# This initiate a model that runs over three TBIM iterations, nbim_iter=3, the regularization networks of the first and second iteration steps are taken from the previous training
model = to_device(ComputeBornSVD(regNet1=model.regNet1,regNet2=model.regNet2,regNet3=model.regNet3,nLW=6,nbim_iter=3,nmax=4,maxit=5), device)

tensor([0.0001], device='cuda:0')


In [None]:
#Freeze the first and second TBIM regularization Network
for param in model.regNet1.parameters():
    param.requires_grad=False
for param in model.regNet2.parameters():
    param.requires_grad=False

In [42]:
# Train the regulrization network for only the third TBIM, and use the prevoius trained networks for the first and second TBIM iteration 
%%time
epochs = 5
max_lr = 0.0002
grad_clip = 0.1
weight_decay = 1e-4
opt_func = torch.optim.Adam
history3=[]
history3 += fit_one_cycle(epochs, max_lr, model3, train_dl, val_dl, 
                             opt_func=opt_func)

Epoch [0], last_lr: 0.00015, train_loss: 0.0002, val_loss: 0.0003
Epoch [1], last_lr: 0.00019, train_loss: 0.0002, val_loss: 0.0003
Epoch [2], last_lr: 0.00012, train_loss: 0.0002, val_loss: 0.0003
Epoch [3], last_lr: 0.00004, train_loss: 0.0002, val_loss: 0.0003
Epoch [4], last_lr: 0.00000, train_loss: 0.0002, val_loss: 0.0003
Wall time: 2h 43min 30s


In [24]:
# Save the trained regularization network for the first,second, and third TBIM iteration regNet1, regNet2, and regNet3
torch.save(model.regNet1.state_dict(),"regNet1")
torch.save(model.regNet2.state_dict(),"regNet2")
torch.save(model.regNet3.state_dict(),"regNet3")

In [None]:
#To save histories
torch.save(history1,"history1")
torch.save(history2,"history2")
torch.save(history3,"history3")

In [None]:
#To plot histories
def plot_losses(history):
    train_losses = [x.get('train_loss') for x in history]
    val_losses = [x['val_loss'] for x in history]
    plt.plot(train_losses, '-bx')
    plt.plot(val_losses, '-rx')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.legend(['Training', 'Validation'])
    plt.title('Loss vs. No. of epochs');
plot_losses(history1)

In [None]:
plot_losses(history2)

In [None]:
plot_losses(history3)

In [None]:
# To run the RNEs model
%%time
def RNE_batch(model, inputs):
    with torch.no_grad():
        RNEm=[]
        for imageI,imageO in inputs:
            nbatch=imageI.shape[0]
            xLW=model(imageI)
            RNE=[100*torch.linalg.norm(xLW[indx:indx+1,:]-imageO[indx:indx+1,:], ord='fro')/torch.linalg.norm(imageO[indx:indx+1,:], ord='fro') for indx in range(nbatch)]
            RNE=[torch.stack(RNE).mean()]
            RNEm=RNEm+RNE
        RNEM=torch.stack(RNEm).mean()
    return RNEM

RNE=RNE_batch(model, test_dl)
print(RNE)

In [None]:
# To compute the MSE (loss function) of the testing examples
resTest = evaluate(model, test_dl)
resTest