In [None]:
!pip install pytorch-msssim
from pytorch_msssim import ssim
import glob
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader , random_split
from torchsummary import summary
# from torch.utils.tensorboard import SummaryWriter
from torch.optim import Adam , lr_scheduler
import matplotlib.pyplot as plt
from random import randint
from tqdm import tqdm
from tqdm.notebook import trange, tqdm
!pip install torchmetrics
from torchmetrics import StructuralSimilarityIndexMeasure
import time
import itertools
import pickle
from pathlib import Path
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision.utils import make_grid
from torchvision.transforms.functional import center_crop

In [None]:
from google.colab import drive  # mounting the drive to access the data stored on it
drive.mount('/content/drive')

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, dilation=1, groups=1, bias=False,
                 do_norm=True, norm = 'batch', do_activation = True): # bias default is True in Conv2d
        super(EncoderBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.leakyRelu = nn.LeakyReLU(0.2, True)
        self.do_norm = do_norm
        self.do_activation = do_activation
        if do_norm:
            if norm == 'batch':
                self.norm = nn.BatchNorm2d(out_channels)
            elif norm == 'instance':
                self.norm = nn.InstanceNorm2d(out_channels)
            elif norm == 'none':
                self.do_norm = False
            else:
                raise NotImplementedError("norm error")

    def forward(self, x):
        if self.do_activation:
            x = self.leakyRelu(x)

        x = self.conv(x)

        if self.do_norm:
            x = self.norm(x)

        return x

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False,
                 do_norm=True, norm = 'batch',do_activation = True, dropout_prob=0.2):
        super(DecoderBlock, self).__init__()

        self.convT = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
        self.relu = nn.ReLU()
        self.dropout_prob = dropout_prob
        self.drop = nn.Dropout2d(dropout_prob)
        self.do_norm = do_norm
        self.do_activation = do_activation
        if do_norm:
            if norm == 'batch':
                self.norm = nn.BatchNorm2d(out_channels)
            elif norm == 'instance':
                self.norm = nn.InstanceNorm2d(out_channels)
            elif norm == 'none':
                self.do_norm = False
            else:
                raise NotImplementedError("norm error")

    def forward(self, x):
        if self.do_activation:
            x = self.relu(x)

        x = self.convT(x)

        if self.do_norm:
           x = self.norm(x)

        if self.dropout_prob != 0:
            x= self.drop(x)

        return x

In [None]:
class Generator(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, bias = False, dropout_prob=0.2, norm = 'batch'):
        super(Generator, self).__init__()

        # 8-step encoder
        self.encoder1 = EncoderBlock(in_channels, 64, bias=bias, do_norm=False, do_activation=False)
        self.encoder2 = EncoderBlock(64, 128, bias=bias, norm=norm)
        self.encoder3 = EncoderBlock(128, 256, bias=bias, norm=norm)
        self.encoder4 = EncoderBlock(256, 512, bias=bias, norm=norm)
        self.encoder5 = EncoderBlock(512, 512, bias=bias, norm=norm)
        self.encoder6 = EncoderBlock(512, 512, bias=bias, norm=norm)
        self.encoder7 = EncoderBlock(512, 512, bias=bias, norm=norm)
        self.encoder8 = EncoderBlock(512, 512, bias=bias, do_norm=False)

        # 8-step UNet decoder
        self.decoder1 = DecoderBlock(512, 512, bias=bias, norm=norm)
        self.decoder2 = DecoderBlock(1024, 512, bias=bias, norm=norm, dropout_prob=dropout_prob)
        self.decoder3 = DecoderBlock(1024, 512, bias=bias, norm=norm, dropout_prob=dropout_prob)
        self.decoder4 = DecoderBlock(768, 256, bias=bias, norm=norm, dropout_prob=dropout_prob)
        self.decoder5 = DecoderBlock(256, 128, bias=bias, norm=norm)
        # self.decoder6 = DecoderBlock(128, 64, bias=bias, norm=norm)
        # self.decoder7 = DecoderBlock(64, 32, bias=bias, norm=norm)
        self.decoder8 = DecoderBlock(128, out_channels, bias=bias, do_norm=False)
        self.tanh = nn.Tanh()

    def forward(self, x):
        # 8-step encoder
        encode1 = self.encoder1(x)
        encode2 = self.encoder2(encode1)
        encode3 = self.encoder3(encode2)
        encode4 = self.encoder4(encode3)
        encode5 = self.encoder5(encode4)
        encode6 = self.encoder6(encode5)
        # encode7 = self.encoder7(encode6)
        # encode8 = self.encoder8(encode6)

        # 8-step UNet decoder
        decode1 = torch.cat([self.decoder1(encode6), encode5],1)
        decode2 = torch.cat([self.decoder2(decode1), encode4],1)
        decode3 = torch.cat([self.decoder3(decode2), encode3],1)
        decode4 = self.decoder4(decode3)
        decode5 = self.decoder5(decode4)
        # decode6 = self.decoder6(decode5)
        # decode7 = self.decoder7(decode6)
        decode8 = self.decoder8(decode5)
        final = self.tanh(decode8)
        return final

In [None]:
#############################################################
# patchGAN
#############################################################
class Discriminator(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, bias = False, norm = 'batch', sigmoid=True):
        super(Discriminator, self).__init__()
        self.sigmoid = nn.Sigmoid()

        # 70x70 discriminator
        self.disc1 = EncoderBlock(in_channels * 2, 64, bias=bias, do_norm=False, do_activation=False)
        self.disc2 = EncoderBlock(64, 128, bias=bias, norm=norm)
        self.disc3 = EncoderBlock(128, 256, bias=bias, norm=norm)
        self.disc4 = EncoderBlock(256, 512, bias=bias, norm=norm, stride=1)
        self.disc5 = EncoderBlock(512, out_channels, bias=bias, stride=1, do_norm=False)
        self.linear = nn.Linear(36,1)
        self.flat = nn.Flatten()

    def forward(self, x, ref):
        d1 = self.disc1(torch.cat([x, ref],1))
        d2 = self.disc2(d1)
        d3 = self.disc3(d2)
        d4 = self.disc4(d3)
        d5 = self.disc5(d4)
        d6  = self.flat(d5)
        d7 = self.linear(d6)
        final = self.sigmoid(d7)
        return final

In [None]:
device=None
if torch.cuda.is_available():
  device = 'cuda:0'
else:
  device = 'cpu'
print(device)

In [None]:
summary(Generator().to(device),(1,64,64))
summary(Discriminator().to(device),[(1,64,64),(1,64,64)])  #change this to discriminator

In [None]:
# This is the custom Data Loader class

class Speckle(Dataset):
    def __init__(self): # add additional parameters needed to load the dataset e.g dataset path

        self.data= np.load('speckle training data path', allow_pickle=True)
        self.label=np.load('MNIST label path' , allow_pickle=True)
        print(self.data.shape)

    def __len__(self):
        return self.data.shape[0]


    def __getitem__(self, idx):

        speckle=self.data[idx]   # picks the images based on the random index generated
        #speckle= cv2.cvtColor(speckle, cv2.COLOR_BGR2GRAY)    # read the image as grayscale instead of the cv2 default of BGR
        speckle_img=speckle/255.0
        speckle_img = cv2.resize(speckle_img, (64,64), interpolation= cv2.INTER_LANCZOS4)
        speckle_img = speckle_img.reshape(64,64,1)  # reshape the image from 256 x 256 to 256 x 256 x 1
        speckle_img=speckle_img.astype(np.float32)    # Asert the data type to be float
        speckle_img=speckle_img.T                     # transpose the image as this is the convention for pytorch where filters should come first
        speckle_img_tensor = torch.from_numpy(speckle_img)  # convert from numpy array to pytorch tensor

        #Follow the same steps for label preprocessig as well

        label_img=self.label[idx]
        label_img= cv2.cvtColor(label_img, cv2.COLOR_BGR2GRAY)
        label_img=label_img/255.0
        label_img = cv2.resize(label_img, (64,64), interpolation= cv2.INTER_LANCZOS4)
        #label_img= np.array([label_img], order='C')
        #datas.resize((1,28,28))
        label_img = label_img.reshape(64,64,1)
        label_img=label_img.astype(np.float32)
        label_img=label_img.T
        label_img_tensor = torch.from_numpy(label_img)


        return speckle_img_tensor, label_img_tensor  # return data and label pair as a tuple

In [None]:
speckle_test = Speckle()
print(len(speckle_test))
#randomly split the data into three, with 70% for training, 20% for validation and 10% for testing.


In [None]:
train,valid = random_split(speckle_test, lengths=[int(len(speckle_test)*0.9), int(len(speckle_test)*0.1)])
print(len(train),len(valid))

In [None]:
from torch.optim.lr_scheduler import CosineAnnealingLR
'''
We will be pasing the images in batches and can be set based on gpu memory. The images are also shuffled and picked randomly
'''
batch_size = 40
train_dataloader = DataLoader(train, batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(valid, batch_size=batch_size, shuffle=True)
#test_dataloader = DataLoader(test, batch_size=1, shuffle=True)
print(len(train_dataloader))

In [None]:
desc_loss_fn = torch.nn.BCELoss()
gen_loss_fn = torch.nn.MSELoss()
L1_loss = torch.nn.L1Loss()
def gan_loss( out, label, mode ):
  if mode == 'desc':
    return desc_loss_fn(out, torch.ones_like(out) if label else torch.zeros_like(out))
  if mode == 'gen':
    return gen_loss_fn(out, torch.ones_like(out) if label else torch.zeros_like(out))

In [None]:
lr=0.0001
D=Discriminator().to(device)
G=Generator().to(device)
optimizer_D=torch.optim.Adam(D.parameters(),lr=lr)
optimizer_G=torch.optim.Adam(G.parameters(),lr=lr)
scheduler_T_max = 5  # You can adjust this value based on your preference
scheduler_D = CosineAnnealingLR(optimizer_D, T_max=scheduler_T_max, eta_min=0)
scheduler_G = CosineAnnealingLR(optimizer_G, T_max=scheduler_T_max, eta_min=0)
lambd_d = 0.5
lambd = 100

In [None]:
last_loss_g = 99999999999999999999.0
last_loss_d = 99999999999999999999.0

ld_real =[]
ld_fake =[]
lg_gan =[]
lg_l1 = []
ld = []
lg =[]



for epoch in tqdm(range(120)):

  D.train()
  G.train()
  losses_G,losses_D=[],[]
  for i,(x,y) in enumerate(tqdm(train_dataloader)):

      x=x.to(device)
      y=y.to(device)

      ############################
      # D loss
      ############################
      optimizer_D.zero_grad()

      gen = G(x)
      # real y and x -> 1
      loss_D_real = gan_loss(D(y, x), 1, 'desc') * lambd_d
      # gen and x -> 0
      loss_D_fake = gan_loss(D(gen.detach(), x), 0, 'desc') * lambd_d
      # Combine
      loss_D = loss_D_real + loss_D_fake

      loss_D.backward()
      optimizer_D.step()
      # loss_D.backward()
      # optimizer_D.step()

    ##############
      # G loss
      ############################
      optimizer_G.zero_grad()

      # gen = G(x)
      # GAN loss of G
      loss_G_gan = gan_loss(D(gen, x), 1, 'gen')
      # L1 loss of G
      loss_G_L1 = L1_loss(gen, y) * lambd
      # Combine
      loss_G = loss_G_gan + loss_G_L1

      loss_G.backward()
      optimizer_G.step()

  scheduler_D.step()
  scheduler_G.step()

  if epoch%10==0:
    plt.figure()
    generated_imgs=G(x[:5])
    real_imgs=x[:5]
    imgs=torch.cat([generated_imgs,real_imgs,y[:5]],0).data.cpu()
    grid=make_grid(imgs,nrow=5).permute(1,2,0).numpy()
    plt.imshow(grid)
    plt.show()
  print(f'G {loss_G} G_gan {loss_G_gan} G_L1: {loss_G_L1} D: {loss_D} D_real: {loss_D_real} D_fake: {loss_D_fake}')
  ld_real.append(float(loss_D_real))
  ld_fake.append(float(loss_D_fake))
  lg_gan.append(float(loss_G_gan))
  lg_l1.append(float(loss_G_L1))
  ld.append(float(loss_D))
  lg.append(float(loss_G))

  torch.save(G.state_dict(),'/content/drive/MyDrive/weights/abc_gen.pt')
  torch.save(D.state_dict(), '/content/drive/MyDrive/weights/abc_dis.pt')


In [None]:
plt.plot(ld, 'k-', label="D")
plt.plot(lg, 'c-', label="G")
plt.legend()
plt.show()

In [None]:
plt.plot(ld_real, 'g-', label="D_real")
plt.plot(ld_fake, 'r-', label="D_fake")
plt.plot(lg_gan, 'y-', label="G_gan")
plt.plot(lg_l1, 'b-', label="G_L1")

In [None]:

class TestData(Dataset):
    def __init__(self): # add additional parameters needed to load the dataset e.g dataset path
        # your code here.

        self.data= np.load('speckle test data path', allow_pickle=True)
        self.label=np.load('MNIST label data path' , allow_pickle=True)[55000:60000]
        print(self.data.shape)
        print(self.label.shape)


    def __len__(self):
        return self.data.shape[0]
        # return 30000


    def __getitem__(self, idx):

        speckle=self.data[idx]   # picks the images based on the random index generated
        # speckle= cv2.cvtColor(speckle, cv2.COLOR_BGR2GRAY)    # read the image as grayscale instead of the cv2 default of BGR
        speckle_img=speckle/255.0
        speckle_img = cv2.resize(speckle_img, (64,64), interpolation= cv2.INTER_LANCZOS4)
        speckle_img = speckle_img.reshape(64,64,1)  # reshape the image from 256 x 256 to 256 x 256 x 1
        speckle_img=speckle_img.astype(np.float32)    # Asert the data type to be float
        speckle_img=speckle_img.T                     # transpose the image as this is the convention for pytorch where filters should come first
        speckle_img_tensor = torch.from_numpy(speckle_img)  # convert from numpy array to pytorch tensor

        #Follow the same steps for label preprocessig as well

        label_img=self.label[idx]
        label_img= cv2.cvtColor(label_img, cv2.COLOR_BGR2GRAY)
        label_img=label_img/255.0
        label_img = cv2.resize(label_img, (64,64), interpolation= cv2.INTER_LANCZOS4)
        label_img = label_img.reshape(64,64,1)
        label_img=label_img.astype(np.float32)
        label_img=label_img.T
        label_img_tensor = torch.from_numpy(label_img)



        return speckle_img_tensor, label_img_tensor  # return data and label pair as a tuple

In [None]:
test_dataloader2 = DataLoader(TestData(), batch_size=1, shuffle=False)

In [None]:
speckle=None
predicted=None
lab=None
print(device)
total = []
ssim = StructuralSimilarityIndexMeasure().to(device)
u_net = Generator()
u_net.load_state_dict(torch.load('/content/drive/MyDrive/weights/abc_gen.pt',map_location=torch.device(device)))
u_net.to(device)
u_net.eval()
count = 0
for data,label in test_dataloader2:
  data=data.to(device)
  predicted =u_net(data)
  lab=label.to(device)

  predicted=predicted.to(device)
  val = ssim(lab,predicted.reshape(data.shape[0],1,64,64)).item()
  total.append(val)
  if count <20 :
    count = count+1
    plt.subplot(1,2,1)
    plt.imshow(lab[0].reshape(64,64).T.cpu().detach(),cmap='gray')
    plt.subplot(1,2,2)
    plt.imshow(predicted[0].reshape(64,64).T.cpu().detach(),cmap='gray')

    print("SSIM VALUE = " , val)

    plt.show()

print("maximum SSIM: = ", max(total))
print("Average SSIM Value = ",sum(total)/len(total))