In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

# Importing modules

In [2]:
import os
import zipfile
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
import albumentations as A
import cv2
import json

from skimage.io import imread
from skimage.transform import resize
from pathlib import Path
from tensorflow import keras
from tensorflow.keras import layers
from keras.utils import to_categorical
from mlxtend.plotting import plot_confusion_matrix
from sklearn.metrics import confusion_matrix
from PIL import Image

In [3]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

Num GPUs Available:  1


# Preparing the dataset

In [None]:
count_Test = 1164
count_Valid = 1164

In [None]:
# Define path to the data directory
#'/content/drive/Othercomputers/Ноутбук/3/Code/chest-xray-pneumonia/chest_xray'
data_dir = Path('chest-xray-pneumonia/chest_xray')

# Path to train directory (Fancy pathlib...no more os.path!!)
train_dir = data_dir / 'train'

# Path to validation directory
val_dir = data_dir / 'val'

# Path to test directory
test_dir = data_dir / 'test'

# Forming Train, Val and Test

In [None]:
train_data = [] # The first collection for Train, Val, Test

### Path to Train
normal_cases_dir = train_dir / 'NORMAL'
pneumonia_cases_dir = train_dir / 'PNEUMONIA'

normal_cases = normal_cases_dir.glob('*.jpeg')
pneumonia_cases = pneumonia_cases_dir.glob('*.jpeg')


for img in normal_cases:
    train_data.append((img, 0, 0)) # The label for these cases will be 0

for img in pneumonia_cases:
    if 'bacteria' in str(img).lower(): # The label for these cases will be 1 if it's bacteria or 2 if it's viral
        train_data.append((img, 1, 0))
    else:     
        train_data.append((img, 2, 0))

### Path to Val 
normal_cases_dir = val_dir / 'NORMAL'
pneumonia_cases_dir = val_dir / 'PNEUMONIA'

normal_cases = normal_cases_dir.glob('*.jpeg')
pneumonia_cases = pneumonia_cases_dir.glob('*.jpeg')

for img in normal_cases:
    train_data.append((img, 0, 0)) # The label for these cases will be 0

for img in pneumonia_cases:
    if 'bacteria' in str(img).lower(): # The label for these cases will be 1 if it's bacteria or 2 if it's viral
        train_data.append((img, 1, 0))
    else:     
        train_data.append((img, 2, 0))    
                     
### Path to Train 
normal_cases_dir = test_dir / 'NORMAL'
pneumonia_cases_dir = test_dir / 'PNEUMONIA'

normal_cases = normal_cases_dir.glob('*.jpeg')
pneumonia_cases = pneumonia_cases_dir.glob('*.jpeg')

for img in normal_cases:
    train_data.append((img, 0, 0)) # The label for these cases will be 0

for img in pneumonia_cases:
    if 'bacteria' in str(img).lower(): # The label for these cases will be 1 if it's bacteria or 2 if it's viral
        train_data.append((img, 1, 0))
    else:     
        train_data.append((img, 2, 0))   

# Delete dublicate      
train_del = []
#'/content/drive/Othercomputers/Ноутбук/3/Code/dublicate.json'
with open('dublicate.json', 'r') as F:
    check_list = json.load(F)
        
for i, j in check_list:
    for x, y, z in train_data:
        if i[i.rfind('/') + 1 :] in str(x):
            train_del.append(train_data.index((x, y, z)))

for i in train_del:
    train_data.pop(i)

# Get a pandas dataframe from the data we have in our list 
train_data = pd.DataFrame(train_data, columns=['image', 'label', 'add'], index=None)

# Shuffle the data 
train_data = train_data.sample(frac=1.).reset_index(drop=True)

# Devide into Train, Val and Test
test_data = train_data[:count_Test]
valid_data = train_data[count_Test : count_Test + count_Valid].reset_index(drop=True)
train_data = train_data[count_Test + count_Valid :].reset_index(drop=True)

train_data.head()

# Augmentation

In [None]:
add_labels= {0 : 0, 1 : 0, 2 : 0}
seq = A.Compose([
    A.HorizontalFlip(), # horizontal flips
    A.Affine(rotate=(-20, 20), p=0.50), # roatation
    A.RandomBrightnessContrast(p=0.20) #random brightness
                ]) 

# Visualization

In [None]:
# Get the counts for each class
cases_count = train_data['label'].value_counts()
print(cases_count)

# Plot the results 
plt.figure(figsize=(10,8))
sns.barplot(x = cases_count.index, y = cases_count.values)
plt.title('Число образцов', fontsize=14)
plt.xlabel('Тип образца', fontsize=12)
plt.ylabel('Число', fontsize=12)
plt.xticks(range(len(cases_count.index)), ['Нормальное состояние (0)', 'Бактериальная пневмония (1)', 'Вирусная пневмония (2)'])
plt.show()

In [None]:
# Get few samples for both the classes
normal_samples = (train_data[train_data['label']==0]['image'].iloc[:5]).tolist()
bacteria_samples = (train_data[train_data['label']==1]['image'].iloc[:5]).tolist()
viral_samples = (train_data[train_data['label']==2]['image'].iloc[:5]).tolist()

# Concat the data in a single list and del the above three list
samples = normal_samples + bacteria_samples + viral_samples
del bacteria_samples, normal_samples, viral_samples

# Plot the data 
f, ax = plt.subplots(3,5, figsize=(30,10))
for i in range(15):
    img = imread(samples[i])
    ax[i//5, i%5].imshow(img, cmap='gray')
    
    if i<5:
        ax[i//5, i%5].set_title("Нормальное состояние")
    elif 5 <= i < 10:
        ax[i//5, i%5].set_title("Бактериальная пневмония")
    else:
        ax[i//5, i%5].set_title("Вирусная пневмония")
    
    ax[i//5, i%5].axis('off')
    ax[i//5, i%5].set_aspect('auto')
plt.show()

# Uploading Val

In [None]:
# increase Val
valid_data = increase(valid_data, add_labels=add_labels)

# Preparing valid data
if 'seq' not in globals():
    valid_data = map(Conv, valid_data['image'].values, valid_data['label'].values)
else:
    valid_data_0 = map(Conv, valid_data[valid_data['add'] == 0]['image'].values, 
                     valid_data[valid_data['add'] == 0]['label'].values)
    
    valid_data_1 = map(Conv, valid_data[valid_data['add'] == 1]['image'].values, 
                     valid_data[valid_data['add'] == 1]['label'].values)
    
    valid_data = list(valid_data_0)
    for img, label in valid_data_1:
        valid_data.append((seq(image=img)['image'], label))
    
    #F = lambda img, label: (seq(imgage=img)['image'], label)
    #valid_data_1 = map(F, list(valid_data_1))
    
valid = []
valid_labels = []

for i, j in valid_data:
    valid.append(i.astype(np.float32)/255.)
    valid_labels.append(j)

# Convert the list into numpy arrays
valid_data = np.array(valid)
valid_labels = np.array(valid_labels)

def shuffle(x, y):
    p = np.random.permutation(len(y))
    return x[p], y[p]

valid_data, valid_labels = shuffle(valid_data, valid_labels)

print("Total number of validation examples: ", valid_data.shape)
print("Total number of labels:", valid_labels.shape)

del valid, valid_data_0, valid_data_1

# MODEL cGAN

In [None]:
batch_size = 32
num_channels = 3
num_classes = 3
image_size = 224
latent_dim = 128

generator_in_channels = latent_dim + num_classes
discriminator_in_channels = num_channels + num_classes

print(generator_in_channels, discriminator_in_channels)

# Models

In [None]:
Discriminator.summary()
Generator.summary()

## Fit

In [None]:
# Get a train data generator
dataset = data_gen(data=train_data, batch_size=batch_size, Aug=seq, add_labels=add_labels)

# Define the number of training steps
nb_train_steps = 0
for i in range(3):
    nb_train_steps += (add_labels[i] + 1) * cases_count[i]

nb_epochs = 300
nb_train_steps //= batch_size

print("Number of training and validation steps: {}".format(nb_train_steps))

In [None]:
Discriminator = keras.models.load_model('/content/drive/Othercomputers/Ноутбук/3/Code/train_models/Discriminator.hdf5')
Generator = keras.models.load_model('/content/drive/Othercomputers/Ноутбук/3/Code/train_models/Generator.hdf5')

In [None]:
cond_gan = ConditionalGAN(discriminator=Discriminator, generator=Generator, latent_dim=latent_dim)

cond_gan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5),
    g_optimizer=keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5),
    loss_fn=keras.losses.BinaryCrossentropy()
    )

data = tf.data.Dataset.from_generator(
            lambda: data_gen(data=train_data, batch_size=batch_size, Aug=seq, add_labels=add_labels),
            output_types= (tf.float32, tf.float32),
            output_shapes= ((batch_size,224, 224, 3),(batch_size,3))
            )

In [None]:
history_cGAN = cond_gan.fit(data, epochs=nb_epochs, steps_per_epoch=nb_train_steps, callbacks=[CustomCallback()])

In [None]:
loss_g = history_cGAN.history["g_loss"]
loss_d = history_cGAN.history["d_loss"]
epochs = range(1, len(loss_g) + 1)
plt.plot(epochs, loss_g, "b", label="Потери на этапе обучения генератора")
plt.plot(epochs, loss_d, "g", label="Потери на этапе обучения дискриминатора")
plt.title("Потери на этапах обучения")
plt.xlabel("Эпохи")
plt.ylabel("Потери")
plt.legend()
plt.show()

## Save GAN

In [None]:
GAN = cond_gan.generator

In [None]:
def interpolate(gen, examples, numbers):
    # Sample noise for the interpolation.
    interpolation_noise = tf.random.normal(shape=(numbers * len(examples), latent_dim))
    
    # One_hote coder
    repeats = [numbers, numbers, numbers]
    one_hot_labels = np.repeat(examples, repeats)
    one_hot_labels = keras.utils.to_categorical(one_hot_labels, num_classes)
    
    # Combine the noise and the labels and run inference with the generator.
    noise_and_labels = tf.concat([interpolation_noise, one_hot_labels], 1)
    fake = gen.predict(noise_and_labels)
    return fake


examples = [0, 1, 2]

fake_images = interpolate(GAN, examples, 5)

In [None]:
print(fake_images.shape)

# Plot the data 
f, ax = plt.subplots(3,5, figsize=(30,20))
for i in range(15):
    ax[i//5, i%5].imshow(fake_images[i], cmap='gray')

    if i<5:
        ax[i//5, i%5].set_title("Normal")
    elif 5 <= i < 10:
        ax[i//5, i%5].set_title("Bacteria")
    else:
        ax[i//5, i%5].set_title("Viral")
    
    ax[i//5, i%5].axis('off')
    ax[i//5, i%5].set_aspect('auto')
plt.show()