<a href="https://colab.research.google.com/github/NikNord174/Ceramic_grains_segmentation/blob/master/Ceramic_grain_segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Ceramic Microstructure Processing

### Python packages

In [None]:
import random
import numpy as np
import numba
from skimage import morphology
from skimage.color import label2rgb
import matplotlib.pyplot as plt
import h5py

import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose, ToTensor, Normalize, Grayscale, RandomRotation, Resize, ToPILImage, RandomGrayscale
import os
from PIL import Image
import torch.optim as optim
from itertools import product
from torchvision.utils import make_grid

### Pure Iron Grain Dataset

In [None]:
f=h5py.File("pure_iron_grain_data_sets.hdf5", "r")
images = f['image']
images=np.moveaxis(images,2,0)
labels = f['label']
labels=np.moveaxis(labels,2,0)
boundaries = f['boundary']
boundaries=np.moveaxis(boundaries,2,0)
print(images.shape,labels.shape,boundaries.shape)

(296, 1024, 1024) (296, 1024, 1024) (296, 1024, 1024)


In [None]:
plt.figure(figsize=(20, 20))
plt.subplot(131)
plt.imshow(images[0],cmap='gray')
plt.axis('off')
plt.title("SEM", fontsize=20)
plt.subplot(132)
plt.imshow(boundaries[0],cmap='gray')
plt.axis('off')
plt.title("detected boundaries", fontsize=20)
plt.subplot(133)
plt.imshow(labels[0],cmap='gray')
plt.axis('off')
plt.title("different grains (labels)", fontsize=20)
plt.show()

In [None]:
plt.imshow(images[0],cmap='gray')
#[plt.imsave(f'Im_Iron/Im_{i}.jpeg', images[i],cmap='gray') for i in range(len(images))]
print()

### Cutting source pictures

In [None]:
def crop(img, d):
    k=0
    w, h = img.shape
    im_matrix=np.zeros((int(h/d)*int(w/d),64,64))
    grid = product(range(0, h, d), range(0, w, d))
    for i, j in grid:
        im_matrix[k]=img[j:j+d,i:i+d]
        k+=1
    return im_matrix
        
for i,image in enumerate(images):
    if i==0:
        croped_images=crop(image, int(image.shape[1]/16))
    else:
        croped_images=np.concatenate((croped_images, crop(image, int(image.shape[1]/16))))

print(croped_images.shape)

In [None]:
plt.imshow(croped_images[75774],cmap='gray')

In [None]:
class SEM_Dataset(Dataset):
    def __init__(self, images):#, labels, boundaries):
        self.images = [torch.Tensor(i) for i in croped_images]
        #self.boundaries = [torch.Tensor(i) for i in boundaries]
        #self.labels = [torch.Tensor(i) for i in labels]
        #self.transform = Compose([RandomRotation(degrees=(-179, 179)),
        #                          Grayscale(num_output_channels=1)
        #                         ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image=self.images[idx]
        image=image.reshape(-1,image.shape[0],image.shape[1])
        #image=self.transform(image)
        #boundary=self.boundaries[idx]
        #label=self.labels[idx]
        return image#, boundary, label
    
SEM_images = SEM_Dataset(images)#, boundaries, labels)

### GAN Part

In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim=10, im_chan=1, hidden_dim=64):
        super(Generator, self).__init__()
        self.z_dim = z_dim
        self.gen = nn.Sequential(
            self.make_gen_block(z_dim, hidden_dim * 8),
            self.make_gen_block(hidden_dim * 8, hidden_dim * 4, kernel_size=4, stride=1),
            self.make_gen_block(hidden_dim * 4, hidden_dim*2),
            self.make_gen_block(hidden_dim * 2, hidden_dim),
            self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
        )

    def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(in_channels=input_channels,
                                   out_channels=output_channels,
                                   kernel_size=kernel_size,
                                   stride=stride,
                                   padding=0),
                nn.BatchNorm2d(num_features=output_channels),
                nn.ReLU()
            )
        else: # Final Layer
            return nn.Sequential(
                nn.ConvTranspose2d(in_channels=input_channels,
                                   out_channels=output_channels,
                                   kernel_size=kernel_size,
                                   stride=stride,
                                   padding=0),
                nn.Tanh()
            )

    def unsqueeze_noise(self, noise):
        return noise.view(len(noise), self.z_dim, 1, 1)

    def forward(self, noise):
        x = self.unsqueeze_noise(noise)
        return self.gen(x)

def get_noise(n_samples, z_dim, device='cpu'):
    return torch.randn(n_samples, z_dim, device=device)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, im_chan=1, hidden_dim=16):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            self.make_disc_block(im_chan, hidden_dim),
            self.make_disc_block(hidden_dim, hidden_dim * 2, final_layer=True)#,
            #self.make_disc_block(hidden_dim * 2, 1, final_layer=True),
        )

    def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(in_channels=input_channels,
                          out_channels=output_channels,
                          kernel_size=kernel_size,
                          stride=stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2)
            )
        else: # Final Layer
            return nn.Sequential(
                nn.Conv2d(in_channels=input_channels,
                          out_channels=output_channels,
                          kernel_size=kernel_size,
                          stride=stride)
            )

    def forward(self, image):
        disc_pred = self.disc(image)
        return disc_pred.view(len(disc_pred), -1)

In [None]:
criterion = nn.BCEWithLogitsLoss()
z_dim = 64
display_step = 500
lr = 0.0001

beta_1 = 0.5 
beta_2 = 0.999
#device = 'cuda'

batch_size=8
train_loader=DataLoader(dataset=SEM_images, batch_size=batch_size, shuffle=True)

In [None]:
gen = Generator(z_dim)#.to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))
disc = Discriminator()#.to(device) 
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr, betas=(beta_1, beta_2))

# You initialize the weights to the normal distribution
# with mean 0 and standard deviation 0.02
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)
gen = gen.apply(weights_init)
disc = disc.apply(weights_init)

In [None]:
n_epochs = 10
cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
gen_loss = []
disc_loss = []
for epoch in range(n_epochs):
    # Dataloader returns the batches
    for i, real in enumerate(train_loader):
        cur_batch_size = len(real)
        #real = real.to(device)

        ## Update discriminator ##
        disc_opt.zero_grad()
        fake_noise = get_noise(cur_batch_size, z_dim)#, device=device)
        fake = gen(fake_noise)
        disc_fake_pred = disc(fake.detach())
        disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
        disc_real_pred = disc(real)
        disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
        disc_loss = (disc_fake_loss + disc_real_loss) / 2

        # Keep track of the average discriminator loss
        mean_discriminator_loss += disc_loss.item() / display_step
        # Update gradients
        disc_loss.backward(retain_graph=True)
        # Update optimizer
        disc_opt.step()

        ## Update generator ##
        gen_opt.zero_grad()
        fake_noise_2 = get_noise(cur_batch_size, z_dim)#, device=device)
        fake_2 = gen(fake_noise_2)
        disc_fake_pred = disc(fake_2)
        gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
        gen_loss.backward()
        gen_opt.step()

        # Keep track of the average generator loss
        mean_generator_loss += gen_loss.item() / display_step

        ## Visualization code ##
        if cur_step % display_step == 0 and cur_step > 0:
            print(f"Epoch {epoch}, step {cur_step}:"+ 
                  f"Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")
            fix, axs = plt.subplots(ncols=2)
            axs[0].imshow(fake[0][0].detach().numpy(),cmap='gray')
            axs[0].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
            axs[1].imshow(real[0][0].detach().numpy(),cmap='gray')
            axs[1].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
            plt.show()
            gen_loss.append(mean_generator_loss)
            disc_loss.append(mean_discriminator_loss)
            mean_generator_loss = 0
            mean_discriminator_loss = 0
        cur_step += 1

### W-Net block

In [None]:
class W_Net(nn.Module):
    def __init__(self, input_channels = 1):
        super().__init__()
        
        #1st floor: channels: 1 -> 64 -> 64
        self.floor_1 = nn.Sequential(self.simple_conv2d_block(input_channels, input_channels*64),
                                     self.simple_conv2d_block(input_channels*64, input_channels*64))
        #1st donwscale: im_size: 64 -> 32
        self.downscale_1 = nn.Sequential(self.simple_maxpool_block(kernel_size = 2, stride = 2))
        #2nd floor: channels: 64 -> 128 -> 128
        self.floor_2 = nn.Sequential(self.simple_conv2d_block(input_channels*64, input_channels*128),
                                     self.simple_conv2d_block(input_channels*128, input_channels*128))
        #2nd donwscale: im_size: 32 -> 16
        self.downscale_2 = nn.Sequential(self.simple_maxpool_block(kernel_size = 2, stride = 2))
        #3d floor: channels: 128 -> 256 -> 256
        self.floor_3 = nn.Sequential(self.simple_conv2d_block(input_channels*128, input_channels*256),
                                     self.simple_conv2d_block(input_channels*256, input_channels*256))
        #2nd upscale: im_size: 16 -> 32; channels: 256 -> 128
        self.upscale_2 = nn.Sequential(self.simple_convtranspose2d_block(input_channels*256, input_channels*128))
        #2nd floor: channels: 256 -> 128 -> 128
        self.floor_2_up = nn.Sequential(self.simple_conv2d_block(input_channels*256, input_channels*128),
                                     self.simple_conv2d_block(input_channels*128, input_channels*128))
        #1st upscale: im_size: 32 -> 64; channels: 128 -> 64
        self.upscale_1 = nn.Sequential(self.simple_convtranspose2d_block(input_channels*128, input_channels*64))
        #1st floor: channels: 128 -> 64 -> 1
        self.floor_1_up = nn.Sequential(self.simple_conv2d_block(input_channels*128, input_channels*64),
                                     self.simple_conv2d_block(input_channels*64, input_channels))

    def simple_conv2d_block(self, input_channels, output_channels, kernel_size=3, padding = 1):
        return nn.Sequential(
            nn.Conv2d(in_channels=input_channels,
                      out_channels=output_channels,
                      kernel_size=kernel_size,
                      padding = 1),
            nn.BatchNorm2d(output_channels),
            nn.LeakyReLU(0.2))
            
    def simple_maxpool_block(self, kernel_size = 2, stride = 2):
        return nn.MaxPool2d(kernel_size, stride)
    
    def simple_convtranspose2d_block(self, input_channels, output_channels, kernel_size=2,
                                     stride = 2, padding = 1):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels = input_channels,
                               out_channels = output_channels,
                               kernel_size = kernel_size, 
                               stride = stride),
                               #padding = padding),
            nn.BatchNorm2d(int(input_channels/2)),
            nn.LeakyReLU(0.2))
        
    def concat(self, tensor1, tensor2):
        return torch.cat((tensor1, tensor2),1)
    
    def u_net(self, image):
        #downsampling
        tensor_1 = self.floor_1(image) #1st floor
        tensor_2 = self.downscale_1(tensor_1) 
        tensor_2 = self.floor_2(tensor_2) #2nd floor
        tensor_3 = self.downscale_2(tensor_2)
        #cellar
        tensor_3 = self.floor_3(tensor_3) #3d floor
        #upscale
        tensor_2_up = self.upscale_2(tensor_3)
        #concatenation
        concat_2 = self.concat(tensor_2, tensor_2_up) #2nd floor
        tensor_2_up = self.floor_2_up(concat_2)
        #upscale
        tensor_1_up = self.upscale_1(tensor_2_up)
        #concatenation 
        concat_1 = self.concat(tensor_1, tensor_1_up) #1st floor
        mask = self.floor_1_up(concat_1)
        return mask
    
    def w_net(self, image):
        mask = self.u_net(image)
        artificial_image = self.u_net(mask)
        return mask, artificial_image

wnet=W_Net()

In [None]:
input=torch.tensor(croped_images[75734],dtype=torch.float)
input=input[None,None,:]
print(input.size())
mask, output=wnet.w_net(input)
print(mask.size())
print(output.size())
plt.imshow(input[0][0].detach().numpy())
plt.show()
plt.imshow(mask[0][0].detach().numpy())
plt.show()
plt.imshow(output[0][0].detach().numpy())

In [None]:
class NCutLoss2D(nn.Module):
    r"""Implementation of the continuous N-Cut loss, as in:
    'W-Net: A Deep Model for Fully Unsupervised Image Segmentation', by Xia, Kulis (2017)"""

    def __init__(self, radius: int = 4, sigma_1: float = 5, sigma_2: float = 1):
        r"""
        :param radius: Radius of the spatial interaction term
        :param sigma_1: Standard deviation of the spatial Gaussian interaction
        :param sigma_2: Standard deviation of the pixel value Gaussian interaction
        """
        super(NCutLoss2D, self).__init__()
        self.radius = radius
        self.sigma_1 = sigma_1  # Spatial standard deviation
        self.sigma_2 = sigma_2  # Pixel value standard deviation

    def forward(self, labels: Tensor, inputs: Tensor) -> Tensor:
        r"""Computes the continuous N-Cut loss, given a set of class probabilities (labels) and raw images (inputs).
        Small modifications have been made here for efficiency -- specifically, we compute the pixel-wise weights
        relative to the class-wide average, rather than for every individual pixel.
        :param labels: Predicted class probabilities
        :param inputs: Raw images
        :return: Continuous N-Cut loss
        """
        num_classes = labels.shape[1]
        kernel = gaussian_kernel(radius=self.radius, sigma=self.sigma_1, device=labels.device.type)
        loss = 0

        for k in range(num_classes):
            # Compute the average pixel value for this class, and the difference from each pixel
            class_probs = labels[:, k].unsqueeze(1)
            class_mean = torch.mean(inputs * class_probs, dim=(2, 3), keepdim=True) / \
                torch.add(torch.mean(class_probs, dim=(2, 3), keepdim=True), 1e-5)
            diff = (inputs - class_mean).pow(2).sum(dim=1).unsqueeze(1)

            # Weight the loss by the difference from the class average.
            weights = torch.exp(diff.pow(2).mul(-1 / self.sigma_2 ** 2))

            # Compute N-cut loss, using the computed weights matrix, and a Gaussian spatial filter
            numerator = torch.sum(class_probs * F.conv2d(class_probs * weights, kernel, padding=self.radius))
            denominator = torch.sum(class_probs * F.conv2d(weights, kernel, padding=self.radius))
            loss += nn.L1Loss()(numerator / torch.add(denominator, 1e-6), torch.zeros_like(numerator))

        return num_classes - loss

### CNN block

In [None]:
class SEM(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.cnn = self.CNN_block(input_dim)
        self.linear = self.linear_block(input_dim)
        #self.get_simple_block=self.get_simple_block(input_dim)
        
    def get_simple_block(self, input_dim, kernel_size=4, stride=2, padding=1, last_layer=False):
        if not last_layer:
            return nn.Sequential(
                nn.Conv2d(input_dim, 
                          input_dim*2,
                          kernel_size,
                          stride,
                          padding=padding),
                nn.BatchNorm2d(input_dim*2),
                nn.LeakyReLU(0.2)
            )
        else:
            return nn.Conv2d(input_dim, 
                          input_dim*2,
                          kernel_size,
                          stride,
                          padding=padding)
    
    def CNN_block(self, input_dim, kernel_size=4, stride=2, padding=1):
        return nn.Sequential(self.get_simple_block(input_dim, kernel_size, stride, padding),
                             self.get_simple_block(input_dim*2, kernel_size, stride, padding),
                             self.get_simple_block(input_dim*4, kernel_size, stride, padding,last_layer=True))
    
    def linear_block(self, input_dim):
        return nn.Sequential(nn.Linear(input_dim,100),
                     nn.ReLU(),
                     nn.Linear(100,10),
                     nn.ReLU(),
                     nn.Linear(10,2))
    
    def forward(self,x):
        block=self.cnn(x.shape[1], kernel_size=4, stride=2, padding=1)
        x=block(x)
        x=torch.flatten(x, 1)
        lin_block=self.linear(x.shape[1])
        x=lin_block(x)
        return x


net=SEM(100)
print(net)#.parameters()

SEM(
  (cnn): Sequential(
    (0): Sequential(
      (0): Conv2d(100, 200, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (1): Sequential(
      (0): Conv2d(200, 400, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(400, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (2): Conv2d(400, 800, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  )
  (linear): Sequential(
    (0): Linear(in_features=100, out_features=100, bias=True)
    (1): ReLU()
    (2): Linear(in_features=100, out_features=10, bias=True)
    (3): ReLU()
    (4): Linear(in_features=10, out_features=2, bias=True)
  )
)


### Train block

In [None]:
learning_rate=0.01
def criterion(return_mean, true_mean, return_disp, true_disp):
    return torch.Tensor((true_mean-return_mean)**2+(true_disp-return_disp)**2)
optimizer=torch.optim.SGD(net.parameters(), lr=learning_rate)

In [None]:
def train(criterion, model=net, optimizer=optimizer, num_epochs=2):
    best_acc = 0.0
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)
        running_loss = 0.0
        running_corrects = 0
        for _, (image, mean, disp) in enumerate(train_loader):
            optimizer.zero_grad()
            outputs = model(image)
            loss = criterion(outputs[0], mean, outputs[1], disp)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
    print(f'Best val Acc: {best_acc:4f}')
    return

train(criterion)