In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import time
import tqdm
import torch
import warnings
import numpy as np
import torch.nn as nn
import chnet.cahn_hill as ch
import chnet.ch_tools as tools
import utilities as utils
import torch.nn.functional as F
from ipywidgets import interact
import matplotlib.pyplot as plt
from torchvision import transforms, utils
from toolz.curried import pipe, curry, compose
from torch.utils.data import Dataset, DataLoader

you can install PyFFTW for speed-up as - 
conda install -c conda-forge pyfftw


In [3]:
%matplotlib notebook
warnings.filterwarnings('ignore')

In [4]:
import matplotlib as mpl
from mpl_toolkits.mplot3d import Axes3D

mpl.rcParams['figure.figsize'] = [8.0, 6.0]
mpl.rcParams['figure.dpi'] = 80
mpl.rcParams['savefig.dpi'] = 100

mpl.rcParams['font.size'] = 12
mpl.rcParams['legend.fontsize'] = 'large'
mpl.rcParams['figure.titlesize'] = 'medium'

def draw_im(im, title=None):
    im = np.squeeze(im)
    plt.imshow(im)
    plt.colorbar()
    if title is not None:
        plt.title(title)
    plt.show()
    
@curry
def return_slice(x_data, cutoff):
    if cutoff is not None:
        return pipe(x_data,
                    lambda x_data: np.asarray(x_data.shape).astype(int) // 2,
                    lambda new_shape: [slice(new_shape[idim]-cutoff,
                                             new_shape[idim]+cutoff+1)
                                       for idim in range(x_data.ndim)],
                    lambda slices: x_data[slices])
    else:
        return x_data
    
cropper = return_slice(cutoff=5)

In [None]:
def init_unif(nsamples, dim_x, dim_y, seed=354875):
    np.random.seed(seed)
    return np.random.uniform(-0.95, 0.95, size=(nsamples, dim_x, dim_y))


def init_norm(nsamples, dim_x, dim_y, seed=354875):
    np.random.seed(seed)
    means  = np.random.uniform(-0.1, 0.1, size=nsamples)
    scales  = np.random.uniform(0.1, 0.5, size=nsamples)
    
    x_data = [np.random.normal(loc=m, scale=s, size = (1, dim_x, dim_y)) for m,s in zip(means, scales)]
    x_data = np.concatenate(x_data, axis=0)
    
    np.clip(x_data, -0.95, 0.95, out=x_data)
    
    return x_data

In [None]:
def mse_loss(y1, y2, scale=1.):
    """standard MSE definition"""
    assert y1.shape == y2.shape
    return ((y1 - y2) ** 2).sum() / y1.data.nelement() * scale

@curry
def rmse_loss(y1, y2, scale=1.):
    """standard RMSE definition"""
    assert y1.shape == y2.shape
    return ((((y1 - y2) ** 2).sum() / y1.data.nelement()).sqrt()) * scale


def mse_loss_npy(y1, y2):
    """standard MSE definition"""
    assert y1.shape == y2.shape
    return np.sum(((y1 - y2) ** 2)) / y1.size

## Training Data

In [None]:
nsamples = 1200 # no. of samples
dim_x = 101
dim_y = dim_x
sim_steps = 600 # simulation steps
dx = 0.25 # delta space_dim
dt = 0.01 # delta time
gamma = 1.0 # interface energy
device = torch.device("cuda:0") 

In [None]:
init_data1 = init_unif(nsamples//2, dim_x, dim_y, seed=354875)
init_data2 = init_norm(nsamples//2, dim_x, dim_y, seed=982632)
init_data = np.concatenate([init_data1, init_data2], axis=0)

In [None]:
%%time
x_data = ch.ch_run_torch(init_data, dt=dt, gamma=gamma, dx=dx, sim_step=sim_steps, device=device)
y_data = ch.ch_run_torch(x_data, dt=dt, gamma=gamma, dx=dx, sim_step=100, device=device)

In [None]:
%%time
init_data1 = init_unif(250, dim_x, dim_y, seed=438645)
init_data2 = init_norm(250, dim_x, dim_y, seed=234580)
init_data = np.concatenate([init_data1, init_data2], axis=0)
x_val = ch.ch_run_torch(init_data, dt=dt, gamma=gamma, dx=dx, sim_step=sim_steps, device=device)
y_val = ch.ch_run_torch(x_val, dt=dt, gamma=gamma, dx=dx, sim_step=100, device=device)

# CNN Model

In [None]:
from chnet.ch_net import CHnet
from chnet.ch_loader import CahnHillDataset

In [None]:
device = torch.device("cuda:0")

ks = 5 # kernel size
in_channels = 1 # no. of input channels
cw = 64 # channel width
model = CHnet(ks=ks, in_channels=in_channels, cw=cw).double().to(device)
lx = (ks // 2) * 5 
transformer_x = compose(lambda x: x[None], 
                        lambda x: np.pad(x, pad_width=[[lx,lx],[lx,lx]], mode='wrap'))

transformer_y = lambda x: x[None]

dataset = CahnHillDataset(x_data, y_data, transform_x=transformer_x, transform_y=transformer_y)

item = dataset[0]
x = item["x"][None].to(device)
y = item["y"][None].to(device)



item = dataset[0]
x = item["x"][None].to(device)
y = item["y"][None].to(device)
y_pred = model(x)

assert y.shape == y_pred.shape

print(x.shape, y.shape)
print(mse_loss(y, y_pred).data)

## Model Architecture

In [None]:
nprod = 0
for parameter in model.parameters():
    print(parameter.size())
    nprod += np.prod(parameter.size())
print("No. of Parameters: %d" % nprod)

## Model Parameters

In [None]:
@curry
def add_neighbors(x):
    dimx = x.shape[0]
    y = np.pad(x, pad_width=[[2,2],[2,2]], mode="wrap")
    out = [x[None]]
    for ix in [0, 1, 2, 3, 4]:
        for iy in [0, 1, 2, 3, 4]:
            out.append((y[ix:ix+dimx, iy:iy+dimx] * x)[None])
    return np.concatenate(out, axis=0)

In [None]:
dimx = 5
x_data = pipe(dimx, 
              lambda x: np.arange(1, x**2+1), 
              lambda x: np.reshape(x, (dimx, dimx)))
x_data

In [None]:
ks = 5 # kernel size
in_channels = 26 # no. of input channels
cw = 32 # channel width

train_batch_size = 2
val_batch_size = 2

transformer_x = compose(lambda x: add_neighbors(x), 
                        lambda x: np.pad(x, pad_width=[[lx,lx],[lx,lx]], mode='wrap'))

transformer_y = lambda x: x[None]

train_dataset = CahnHillDataset(x_data, y_data, transform_x=transformer_x, transform_y=transformer_y)
trainloader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=1)

val_dataset = CahnHillDataset(x_val, y_val, transform_x=transformer_x, transform_y=transformer_y)
valloader = DataLoader(val_dataset, batch_size=val_batch_size, shuffle=True, num_workers=1)

total_step = len(trainloader)
print("No. of training steps: %d" % total_step)
total_val_step = len(valloader)
print("No. of validation steps: %d" % total_val_step)

In [None]:
model = CHnet(ks=ks, in_channels=in_channels, cw=cw).double().to(device)

num_epochs = 10
criterion = mse_loss
learning_rate = 5e-5
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
train_losses = []
val_losses = []

In [None]:
# Loss and optimizer
for epoch in range(num_epochs):    
    if epoch % 5 == 0:
        torch.save(model.state_dict(), "weights/CH_trial_1_%d" % (epoch))
                   
    for i, item in enumerate(tqdm.tqdm_notebook(trainloader)):
        model.train()
        
        x = item['x'].to(device)
        target = item['y'].to(device)

        # Forward pass
        output = model(x)
        loss = criterion(output*100, target*100)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_losses.append(np.sqrt(loss.item()))
        
        if (i) % 100 == 0:
            for indx in np.random.permutation(np.arange(0, len(val_dataset)))[:5]:
                model.eval()
                item1 = val_dataset[indx]
                x1 = item1['x'][None].to(device)
                y1 = item1['y'][None].to(device)
                # Forward pass
                y2 = model(x1)
                val_losses.append(np.sqrt(criterion(y2, y1).item()))
                    
            print ('Epoch [{}/{}], Step [{}/{}], Training Loss: {:.11f}, Validation Loss: {:.11f}'.format(epoch+1, 
                                                                                                          num_epochs, 
                                                                                                          i+1, 
                                                                                                          total_step, 
                                                                                                          np.mean(train_losses[-50:]), 
                                                                                                          np.mean(val_losses[-5:])))

In [None]:
plt.plot(train_losses)
plt.title("training losses")
plt.xlabel("training steps")
plt.ylabel("mean squared error")
plt.show()

In [None]:
plt.plot(train_losses[100:])
plt.title("training losses")
plt.xlabel("training steps")
plt.ylabel("mean squared error")
plt.show()

In [None]:
plt.plot(val_losses[100:])
plt.title("Validation losses")
plt.xlabel("validation steps")
plt.ylabel("mean squared error")
plt.show()

In [None]:
model.eval()
item1 = val_dataset[indx]
x1 = item1['x'][None].to(device)
y1 = item1['y'][None].to(device)
# Forward pass
y2 = model(x1)

In [None]:
draw_im(y1.detach().cpu().numpy(), "Ground Truth")

In [None]:
draw_im(y2.detach().cpu().numpy(), "CNN output")

In [None]:
draw_im(y1.detach().cpu().numpy() - y2.detach().cpu().numpy(), "diff")