In [None]:
import random
random.seed(0)
import numpy as np
np.random.seed(0)
import tensorflow as tf
tf.random.set_seed(0)

In [None]:
!pip install matplotlib-venn
!pip install tensorflow

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.animation as animation

In [None]:
from tqdm import tqdm_notebook as tqdm
from __future__ import print_function
import argparse
import pandas as pd
from PIL import Image
import glob
import imageio
import cv2
import pathlib
import sys
from skimage import io, transform
from IPython import display
from IPython.display import HTML

In [None]:
!pip install torchinfo

In [None]:
import os
import json
from zipfile import ZipFile
from __future__ import print_function
import time

In [None]:
import torch
from torch import nn, optim
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
from torch.autograd import Variable
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, ConcatDataset, TensorDataset
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

In [None]:
import torchvision
import torchvision.datasets as dset
from torchvision import datasets, transforms
import torchvision.transforms as transforms
from torchvision.utils import make_grid
import torchvision.utils as vutils
from torchvision.datasets import MNIST
from torchvision.utils import save_image

In [None]:
manualSeed = 999
random.seed(manualSeed)
torch.manual_seed(manualSeed)

In [None]:
def get_default_device():
    if torch.cuda.is_available():
        print("CUDA is available. Using GPU.")
        return torch.device('cuda')
    else:
        print("CUDA is not available. Using CPU.")
        return torch.device('cpu')

device = get_default_device()

def to_device(data, device):
    if isinstance(data, (list, tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device

    def __iter__(self):
        for b in self.dl:
            yield to_device(b, self.device)

    def __len__(self):
        return len(self.dl)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
kaggle_credentails= json.load(open("/content/drive/MyDrive/kaggle.json"))

In [None]:
os.environ['KAGGLE_USERNAME'] = kaggle_credentails['username']
os.environ['KAGGLE_KEY'] = kaggle_credentails['key']

In [None]:
!kaggle datasets download -d ashfakyeafi/brain-mri-images

In [None]:
from zipfile import ZipFile
with ZipFile("/content/brain-mri-images.zip",'r') as zip_ref:
  zip_ref.extractall()

In [None]:
print("Length of directory : ",len(os.listdir("/content/GAN-Traning Images")))

In [None]:
def preprocess_image(img_path):
    img = cv2.imread(img_path)

    if img is None:
        print(f"Error: Could not load image at {img_path}. Please check the file path.")
        return None

    # Converting to grayscale
    gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

    # Denoising
    denoised_img = cv2.fastNlMeansDenoising(gray_img, None, 20, 3, 21)

    # Sharpening using a custom kernel
    kernel_sharpen = np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]])
    sharpened_img = cv2.filter2D(denoised_img, -1, kernel_sharpen)

    return sharpened_img

In [None]:
def preprocess_dataset(image_dir, output_dir):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for img_name in os.listdir(image_dir):
        img_path = os.path.join(image_dir, img_name)
        processed_img = preprocess_image(img_path)

        if processed_img is not None:
            output_path = os.path.join(output_dir, img_name)
            cv2.imwrite(output_path, processed_img)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(9, 9))
img_path = '/content/GAN-Traning Images/OAS2_0001_MR2_y_slice_105.jpg'
img1=mpimg.imread(img_path)
ax[0].imshow(img1 , cmap='gray')
ax[0].set_title('Original MRI Image')
ax[0].axis('on')
preprocessed_img = preprocess_image(img_path)
ax[1].imshow(preprocessed_img, cmap='gray')
ax[1].set_title('Enhanced MRI Image')
ax[1].axis('on')

plt.show()
ax[1].imshow(preprocessed_img, cmap='gray')
ax[1].set_title('Enhanced MRI Image')
ax[1].axis('on')

plt.show()


In [None]:
fig, ax = plt.subplots(1, 2, figsize=(8,8))
img_path = '/content/GAN-Traning Images/OAS2_0023_MR2_z_slice_137.jpg'
img1=mpimg.imread(img_path)
ax[0].imshow(img1 , cmap='gray')
ax[0].set_title('Original MRI Image')
ax[0].axis('on')
preprocessed_img = preprocess_image(img_path)
ax[1].imshow(preprocessed_img, cmap='gray')
ax[1].set_title('Enhanced MRI Image')
ax[1].axis('on')

plt.show()
ax[1].imshow(preprocessed_img, cmap='gray')
ax[1].set_title('Enhanced MRI Image')
ax[1].axis('on')

plt.show()

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(9, 9))
img_path = '/content/GAN-Traning Images/OAS2_0008_MR1_x_slice_134.jpg'
img1=mpimg.imread(img_path)
ax[0].imshow(img1 , cmap='gray')
ax[0].set_title('Original MRI Image')
ax[0].axis('on')
preprocessed_img = preprocess_image(img_path)
ax[1].imshow(preprocessed_img, cmap='gray')
ax[1].set_title('Enhanced MRI Image')
ax[1].axis('on')

plt.show()
ax[1].imshow(preprocessed_img, cmap='gray')
ax[1].set_title('Enhanced MRI Image')
ax[1].axis('on')

plt.show()

In [None]:
image_dir = "/content/GAN-Traning Images"
output_dir = "/content/processed_images"

In [None]:
preprocess_dataset(image_dir, output_dir)

In [None]:
source_dir = '/content/processed_images'

In [None]:
destination_dir = '/content/drive/MyDrive/out/output_brain'

In [None]:
print("Length of destination directory : ",len(os.listdir(destination_dir)))

In [None]:
import shutil
shutil.copytree(source_dir, destination_dir)

In [None]:
def convert_to_tensor(img_path):
    img = Image.open(img_path).convert('L')
    transform = T.Compose([
        T.Resize((64, 64)),
        T.ToTensor(),
        T.Normalize((0.5,), (0.5,))
    ])

    img_tensor = transform(img)
    return img_tensor

In [None]:
def load_and_convert_images(image_dir):
    tensors = []

    for img_name in os.listdir(image_dir):
        img_path = os.path.join(image_dir, img_name)
        img_tensor = convert_to_tensor(img_path)
        tensors.append(img_tensor)

    return tensors

In [None]:
import torchvision.transforms as T
preprocessed_tensors = load_and_convert_images(destination_dir)

In [None]:
from torch.utils.data import DataLoader, TensorDataset

In [None]:
preprocessed_dataset = TensorDataset(torch.stack(preprocessed_tensors))

In [None]:
DATA_DIR = '/content/drive/MyDrive/out'
print(os.listdir(DATA_DIR))
device

In [None]:
image_size=64
batch_size=128
stats=(0.5,0.5,0.5),(0.5,0.5,0.5)
device

In [None]:
import torchvision.transforms as T

In [None]:
from torchvision.datasets import ImageFolder
train_ds= ImageFolder(DATA_DIR,transform=T.Compose([T.Resize(image_size),T.CenterCrop(image_size),T.ToTensor(),T.Normalize(*stats)]))
train_dl =DataLoader(train_ds,batch_size,shuffle=True,num_workers=3,pin_memory=True)
device

In [None]:
def denorm(img_tensors):
  return img_tensors * stats[1][0]+stats[0][0]
device

In [None]:
def show_images(images,nmax=64):
  fig,ax=plt.subplots(figsize=(8,8))
  ax.set_xticks([]);
  ax.set_yticks([])
  ax.imshow(make_grid(denorm(images.detach()[:nmax]),nrow=8).permute(1,2,0))

def show_batch(dl,nmax=64):
  for images,_ in dl:
    show_images(images,nmax)
    break

In [None]:
show_batch(train_dl)

In [None]:
discriminator =nn.Sequential(
    nn.Conv2d(3,64,kernel_size=4,stride=2,padding=1,bias=False),
    nn.BatchNorm2d(64),
    nn.LeakyReLU(0.2,inplace=True),

    nn.Conv2d(64,128,kernel_size=4,stride=2,padding=1,bias=False),
    nn.BatchNorm2d(128),
    nn.LeakyReLU(0.2,inplace=True),

    nn.Conv2d(128,256,kernel_size=4,stride=2,padding=1,bias=False),
    nn.BatchNorm2d(256),
    nn.LeakyReLU(0.2,inplace=True),

    nn.Conv2d(256,512,kernel_size=4,stride=2,padding=1,bias=False),
    nn.BatchNorm2d(512),
    nn.LeakyReLU(0.2,inplace=True),

    nn.Conv2d(512,1,kernel_size=4,stride=1,padding=0,bias=False),

    nn.Flatten(),
    nn.Sigmoid()
)
device

In [None]:
discriminator = to_device(discriminator,device)
device

In [None]:
from torchsummary import summary

In [None]:
latent_size=128
device

In [None]:
generator=nn.Sequential(
    nn.ConvTranspose2d(latent_size,512,kernel_size=4,stride=1,padding=0,bias=False),
    nn.BatchNorm2d(512),
    nn.ReLU(True),

    nn.ConvTranspose2d(512,256,kernel_size=4,stride=2,padding=1,bias=False),
    nn.BatchNorm2d(256),
    nn.ReLU(True),

    nn.ConvTranspose2d(256,128,kernel_size=4,stride=2,padding=1,bias=False),
    nn.BatchNorm2d(128),
    nn.ReLU(True),

    nn.ConvTranspose2d(128,64,kernel_size=4,stride=2,padding=1,bias=False),
    nn.BatchNorm2d(64),
    nn.ReLU(True),

    nn.ConvTranspose2d(64,3,kernel_size=4,stride=2,padding=1,bias=False),
    nn.Tanh(),
)
device

In [None]:
xb=torch.randn(batch_size,latent_size,1,1)
print(xb.shape)
device

In [None]:
fake_images=generator(xb)
print(fake_images.shape)
show_images(fake_images)
#summary(generator, input_size=(batch_size, latent_size, 1, 1))
device

In [None]:
generator=to_device(generator,device)

In [None]:
from torchinfo import summary

In [None]:
def train_discriminator(real_images,opt_d):
  opt_d.zero_grad()
  real_images=real_images.to(device)
  real_preds= discriminator(real_images)
  real_targets=torch.ones(real_images.size(0),1,device=device)
  real_loss= F.binary_cross_entropy(real_preds,real_targets)
  real_score=torch.mean(real_preds).item()

  latent=torch.randn(batch_size,latent_size,1,1,device=device)
  fake_images=generator(latent)

  fake_targets= torch.zeros(fake_images.size(0),1,device=device)
  fake_preds= discriminator(fake_images)
  fake_loss= F.binary_cross_entropy(fake_preds,fake_targets)
  fake_score=torch.mean(fake_preds).item()

  loss=real_loss + fake_loss
  loss.backward()
  opt_d.step()
  return loss.item(),real_score,fake_score

In [None]:
def train_generator(opt_g):
  opt_g.zero_grad()

  latent=torch.randn(batch_size,latent_size,1,1,device=device)
  fake_images=generator(latent)
  preds= discriminator(fake_images)
  targets=torch.ones(batch_size,1,device=device)
  loss= F.binary_cross_entropy(preds,targets)
  loss.backward()
  opt_g.step()
  return loss.item()

In [None]:
sample_dir='generated'
os.makedirs(sample_dir,exist_ok=True)

In [None]:
def save_samples(index,latent_tensors,show=True):
  fake_images=generator(latent_tensors)
  fake_fname='generated-images-{0:0=4d}.png'.format(index)
  save_image(denorm(fake_images),os.path.join(sample_dir,fake_fname),nrow=8)
  print('Saving',fake_fname)
  if show:
    fig,ax=plt.subplots(figsize=(8,8))
    ax.set_xticks([]);ax.set_yticks([])
    ax.imshow(make_grid(fake_images.cpu().detach(),nrow=8).permute(1,2,0))

In [None]:
fixed_latent= torch.randn(64,latent_size,1,1,device=device)

In [None]:
save_samples(0,fixed_latent)

In [None]:
save_samples(10,fixed_latent)

In [None]:
from tqdm.notebook import tqdm
import torch.nn.functional as F

In [None]:
def fit(epochs,lr,start_idx=1):
  torch.cuda.empty_cache()
  losses_g=[]
  losses_d=[]
  real_scores=[]
  fake_scores=[]

  opt_d=torch.optim.Adam(discriminator.parameters(),lr=lr,betas=(0.5,0.999))
  opt_g=torch.optim.Adam(generator.parameters(),lr=lr,betas=(0.5,0.999))

  for epoch in range(epochs):
    for real_images, _ in tqdm(train_dl):
      loss_d,real_score,fake_score= train_discriminator(real_images,opt_d)
      loss_g = train_generator(opt_g)

    losses_g.append(loss_g)
    losses_d.append(loss_d)
    real_scores.append(real_score)
    fake_scores.append(fake_score)
    print("Epoch[{}/{}], loss_g: {:.4f}, loss_d: {:.4f}, real_score: {:.4f}, fake_score: {:.4f}".format(
            epoch+1, epochs, loss_g, loss_d, real_score, fake_score))

    save_samples(epoch+start_idx,fixed_latent,show=False)
  return losses_g,losses_d,real_scores,fake_scores


In [None]:
lr =0.0002
epochs=25

In [None]:
history= fit(epochs,lr)

In [None]:
losses_g,losses_d,real_scores,fake_scores=history

In [None]:
torch.save(generator.state_dict(),'Generator.pth')
torch.save(discriminator.state_dict(),'Discriminator.pth')

In [None]:
from IPython.display import Image

In [None]:
Image('./generated/generated-images-0001.png')

In [None]:
import json

drive_dir = '/content/drive/MyDrive/GAN_metrics2/'
import os
os.makedirs(drive_dir, exist_ok=True)
with open(drive_dir + 'losses_g.json', 'w') as f:
    json.dump(losses_g, f)
with open(drive_dir + 'losses_d.json', 'w') as f:
    json.dump(losses_d, f)
with open(drive_dir + 'real_scores.json', 'w') as f:
    json.dump(real_scores, f)
with open(drive_dir + 'fake_scores.json', 'w') as f:
    json.dump(fake_scores, f)

In [None]:
import json
import matplotlib.pyplot as plt
drive_dir = '/content/drive/MyDrive/GAN_metrics2/'
with open(drive_dir + 'losses_g.json', 'r') as f:
    losses_g = json.load(f)
with open(drive_dir + 'losses_d.json', 'r') as f:
    losses_d = json.load(f)
with open(drive_dir + 'real_scores.json', 'r') as f:
    real_scores = json.load(f)
with open(drive_dir + 'fake_scores.json', 'r') as f:
    fake_scores = json.load(f)

epochs = len(losses_g)
def plot_training_metrics(losses_g, losses_d, real_scores, fake_scores, epochs):

    plt.figure(figsize=(10, 5))
    plt.title("GAN Training Metrics")
    plt.plot(losses_g, label="Generator Loss")
    plt.plot(losses_d, label="Discriminator Loss")
    plt.plot(real_scores, label="Real Image Scores")
    plt.plot(fake_scores, label="Fake Image Scores")
    plt.xlabel("Epoch")
    plt.ylabel("Score/Loss")
    plt.xlim(-1, 25)
    plt.ylim(-1.5, 17.4)
    plt.legend()
    plt.show()
plot_training_metrics(losses_g, losses_d, real_scores, fake_scores,epochs)

In [None]:
Image('./generated/generated-images-0025.png')