<a href="https://colab.research.google.com/github/Keshav-Sundar-4/RD_ML_Model/blob/main/Reaction_Diffusion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#@title Imports and Notebook Utilities
import os
import io
import PIL.Image, PIL.ImageDraw
import base64
import zipfile
import json
import requests
import numpy as np
import matplotlib.pylab as plt
import glob
from scipy import ndimage
from tqdm import tnrange
from IPython.display import clear_output
from PIL import Image

#Pytorch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
from torchvision.transforms import ToTensor, Resize
print(torch.__version__)

#Video and Image imports
from IPython.display import Image, HTML, clear_output
import tqdm

os.environ['FFMPEG_BINARY'] = 'ffmpeg'
import moviepy.editor as mvp
from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter


def imread(url, max_size=None, mode=None):
  if url.startswith(('http:', 'https:')):
    r = requests.get(url)
    f = io.BytesIO(r.content)
  else:
    f = url
  img = PIL.Image.open(f)
  if max_size is not None:
    img.thumbnail((max_size, max_size), PIL.Image.ANTIALIAS)
  if mode is not None:
    img = img.convert(mode)
  img = np.float32(img)/255.0
  return img

def np2pil(a):
  if a.dtype in [np.float32, np.float64]:
    a = np.uint8(np.clip(a, 0, 1)*255)
  return PIL.Image.fromarray(a)

def imwrite(f, a, fmt=None):
  a = np.asarray(a)
  if isinstance(f, str):
    fmt = f.rsplit('.', 1)[-1].lower()
    if fmt == 'jpg':
      fmt = 'jpeg'
    f = open(f, 'wb')
  np2pil(a).save(f, fmt, quality=95)

def imencode(a, fmt='jpeg'):
  a = np.asarray(a)
  if len(a.shape) == 3 and a.shape[-1] == 4:
    fmt = 'png'
  f = io.BytesIO()
  imwrite(f, a, fmt)
  return f.getvalue()

def im2url(a, fmt='jpeg'):
  encoded = imencode(a, fmt)
  base64_byte_string = base64.b64encode(encoded).decode('ascii')
  return 'data:image/' + fmt.upper() + ';base64,' + base64_byte_string

def imshow(a, fmt='jpeg'):
  display(Image(data=imencode(a, fmt)))

def tile2d(a, w=None):
  a = np.asarray(a)
  if w is None:
    w = int(np.ceil(np.sqrt(len(a))))
  th, tw = a.shape[1:3]
  pad = (w-len(a))%w
  a = np.pad(a, [(0, pad)]+[(0, 0)]*(a.ndim-1), 'constant')
  h = len(a)//w
  a = a.reshape([h, w]+list(a.shape[1:]))
  a = np.rollaxis(a, 2, 1).reshape([th*h, tw*w]+list(a.shape[4:]))
  return a

def zoom(img, scale=4):
  img = np.repeat(img, scale, 0)
  img = np.repeat(img, scale, 1)
  return img

class VideoWriter:
  def __init__(self, filename='_tmp.mp4', fps=30.0, **kw):
    self.writer = None
    self.params = dict(filename=filename, fps=fps, **kw)

  def add(self, img):
    img = np.asarray(img)
    if self.writer is None:
      h, w = img.shape[:2]
      self.writer = FFMPEG_VideoWriter(size=(w, h), **self.params)
    if img.dtype in [np.float32, np.float64]:
      img = np.uint8(img.clip(0, 1)*255)
    if len(img.shape) == 2:
      img = np.repeat(img[..., None], 3, -1)
    self.writer.write_frame(img)

  def close(self):
    if self.writer:
      self.writer.close()

  def __enter__(self):
    return self

  def __exit__(self, *kw):
    self.close()
    if self.params['filename'] == '_tmp.mp4':
      self.show()

  def show(self, **kw):
      self.close()
      fn = self.params['filename']
      display(mvp.ipython_display(fn, **kw))


class LoopWriter(VideoWriter):
  def __init__(self, *a, **kw):
    super().__init__(*a, **kw)
    self._intro = []
    self._outro = []
    self.fade_len = int(kw.get('fade_len', 1.0)*self.params['fps'])

  def add(self, img):
    if len(self._intro) < self.fade_len:
      self._intro.append(img)
      return
    self._outro.append(img)
    if len(self._outro) > self.fade_len:
      super().add(self._outro.pop(0))

  def close(self):
    for t in np.linspace(0, 1, len(self._intro)):
      img = self._intro.pop(0)*t + self._outro.pop(0)*(1.0-t)
      super().add(img)
    super().close()





2.4.0+cu121


In [1]:
!wget https://raw.githubusercontent.com/Keshav-Sundar-4/RD_ML_Model/main/Turing_Pattern.png


--2024-09-07 21:24:03--  https://raw.githubusercontent.com/Keshav-Sundar-4/RD_ML_Model/main/Turing_Pattern.png
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1440934 (1.4M) [image/png]
Saving to: ‘Turing_Pattern.png’


2024-09-07 21:24:03 (51.7 MB/s) - ‘Turing_Pattern.png’ saved [1440934/1440934]



In [None]:
#@title RD Model definition {vertical-output: false}
from scipy.ndimage.filters import gaussian_filter
import torch.nn.functional as F
import torch.nn as nn


# Number of grid "chemical" channels
CHN = 32

# THIS FUNCTION CREATES A STARTING "SEED" STATE
def seed_f(n, sz=96, spot_prob=0.005, spread=3.0):
    '''Create seed states with scattered gaussian blobs in PyTorch'''
    # Generate uniform random numbers and threshold them to create initial spots
    x = torch.rand(n, sz, sz, 1) < spot_prob
    x = x.float()  # Convert boolean tensor to float

    # Apply Gaussian filter; note that PyTorch doesn't support wrapping mode directly
    # If wrapping is crucial, you may need to implement it manually or adjust padding strategies
    x = torch.from_numpy(gaussian_filter(x.numpy(), [0.0, spread, spread, 0.0], mode='wrap'))

    # Scaling by spread squared
    x *= spread**2

    # Repeat the tensor along the channel dimension
    x = x.repeat(1, 1, 1, 3)

    # Pad the tensor to match the desired channel size (assuming 'CHN' is defined)
    if x.shape[-1] < CHN:
        padding = (0, CHN - x.shape[-1])  # Only pad the channel dimension
        x = F.pad(x, padding)

    return x


#THIS FUNCTION PROVIDES A PADDING OF P TO THE ENTERED TENSOR
def pad_repeat(x, pad=1):
    # Pad along the second dimension (height)
    top = x[:, -pad:, :, :]  # Get the last 'pad' rows
    bottom = x[:, :pad, :, :]  # Get the first 'pad' rows
    x = torch.cat([top, x, bottom], dim=1)  # Concatenate along the height

    # Pad along the third dimension (width)
    left = x[:, :, -pad:, :]  # Get the last 'pad' columns
    right = x[:, :, :pad, :]  # Get the first 'pad' columns
    x = torch.cat([left, x, right], dim=2)  # Concatenate along the width

    return x

#THIS FUNCTION TAKES THE FIRST 3 CHANNEL VALUS AS RGB
def to_rgb(x):
  #Why is it adding 0.5?
  return x[...,:3]+0.5

#CREATES A LAPLACIAN KERNAL AND CONVOLVES IT
def laplacian(x):
    lap = torch.tensor([[1.0, 2.0, 1.0], [2.0, -12.0, 2.0], [1.0, 2.0, 1.0]]) / 16.0
    lap = lap.unsqueeze(0).unsqueeze(0)  # Shape: [1, 1, 3, 3]
    lap = lap.repeat(x.shape[1], 1, 1, 1)  # Repeat the kernel for each channel
    x = pad_repeat(x, 1)  # Apply padding as defined previously
    y = F.conv2d(x, lap, groups=x.shape[1], padding=0)  # Depthwise convolution
    return y

#CREATES A CELL AUTOMATON MODEL
class CA(nn.Module):
    def __init__(self, CHN):
        super(CA, self).__init__()
        self.w1 = nn.Conv2d(CHN, 128, 1)
        self.w2 = nn.Conv2d(128, CHN, 1, bias=False)
        nn.init.zeros_(self.w2.weight)  # Initialize weights to zero
        repeat_values = torch.tensor([0.125, 0.25, 0.5, 1.0])
        self.diff_coef = repeat_values.repeat_interleave(CHN // 4)

    def get_diff_coef(self):
        return self.diff_coef

    def forward(self, x, r=1.0, d=1.0):
        diff = laplacian(x) * self.get_diff_coef().view(1, -1, 1, 1)
        y = F.sigmoid(self.w1(x) * 5.0)
        react = self.w2(y)
        x = x + diff * d + react * r
        return x

ca = CA(CHN=32)  # Make sure to pass the required arguments if there are any

# Count parameters
param_n = sum(p.numel() for p in ca.parameters() if p.requires_grad)
print('Parameter count:', param_n)

# Print the seed state examples
print('Seed state examples:')

# Generate seeds
img = to_rgb(seed_f(4, 128))  # Generate and convert seeds to RGB

# Pad and combine images for display
img = np.pad(img, [(0, 0), (0, 0), (2, 2), (0, 0)], constant_values=1.0)  # Pad the images for visual separation
combined_img = np.hstack(img.numpy())  # Combine images into a single array for display

# Use imshow to display the images
imshow(combined_img)


In [None]:
#@title Image Dataset {vertical-output: false}

from IPython.display import Image, display

# Display the downloaded image
display(Image('Turing_Pattern.png'))

def load_image(image_path, size=(96, 96)):
    img = Image.open(image_path).convert('RGB')  # Ensure the image is in RGB format
    img = img.resize(size, Image.ANTIALIAS)  # Resize the image to the desired size
    tensor = ToTensor()(img)  # Convert the image to a PyTorch tensor
    return tensor.unsqueeze(0)  # Add a batch dimension

target = load_image(Image('Turing_Pattern.png'))  # Load the target image


In [None]:
#@title Training Module {vertical-output: false}

# Initialize the model and optimizer
model = CA(CHN=32)
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

# Prepare for loss tracking and video writing
losses = []
video_writer = VideoWriter(filename='training_process.mp4', fps=2)

try:
  # Training loop
  epochs = 500

    for epoch in range(epochs):
        optimizer.zero_grad()
        output = model(seed)  # Generates output from the seed
        loss = criterion(output, target)  # Compute loss against the target
        loss.backward()  # Backpropagation
        optimizer.step()  # Optimizer step

        # Record the loss
        losses.append(loss.item())

        # Periodically update the loss plot and output image
        if epoch % 10 == 0:
            clear_output(wait=True)
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
            # Show the loss plot
            ax1.set_title("Training Loss")
            ax1.plot(losses, label='Loss')
            ax1.set_xlabel('Epochs')
            ax1.set_ylabel('Loss')
            ax1.legend()

            # Convert tensor to image and show
            output_image = output.detach().cpu().numpy()[0].transpose(1, 2, 0)
            ax2.set_title("Current Output")
            ax2.imshow(np.clip(output_image, 0, 1))
            ax2.axis('off')

            plt.show()

            # Add current frame to the video
            video_writer.add(output_image)
finally:
    video_writer.close()  # Ensure the video is saved properly
