In [None]:
!pip install tensorflow==2.10.1
!pip install patchify

In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
import random
import itertools
import tensorflow as tf
from matplotlib import gridspec
from PIL import Image
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from patchify import patchify
from tensorflow.keras import layers
from tensorflow.keras.models import Model
from tensorflow.keras import callbacks
from sklearn.metrics import confusion_matrix

In [None]:
hp = {}
hp['image_size'] = 512
hp['num_channels'] = 3
hp['patch_size'] = 64
hp['num_patches'] = (hp['image_size']**2) // (hp["patch_size"]**2)
hp["flat_patches_shape"] = (hp["num_patches"], hp['patch_size']*hp['patch_size']*hp["num_channels"])
hp['batch_size'] = 32
hp['lr'] = 2e-5
hp["num_epochs"] = 30
hp['num_classes'] = 3
hp["num_layers"] = 12
hp["hidden_dim"] = 768
hp["mlp_dim"] = 3072
hp['num_heads'] = 12
hp['dropout_rate'] = 0.1
hp['class_names'] = ["lung_aca", "lung_n", "lung_scc"]

In [None]:
train_path = "/kaggle/input/lung-and-colon-cancer-dataset-splitted/lung/lung/Train/*"
valid_path = "/kaggle/input/lung-and-colon-cancer-dataset-splitted/lung/lung/Val/*"
test_path = "/kaggle/input/lung-and-colon-cancer-dataset-splitted/lung/lung/Test/*"

model_path = "/kaggle/working/ViT_for_lung_cancer_classification.h5"
csv_path = "/kaggle/working/ViT_for_lung_cancer_classification.csv"

# Helper Function

In [None]:
def create_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

In [None]:
def load_data(path, split=0.1):
    images = shuffle(glob(os.path.join(path, "*.jpeg")))
    return images
    

In [None]:
def process_image_label(path):
    path = path.decode()
    image = cv2.imread(path, cv2.IMREAD_COLOR)
    image = cv2.resize(image, (hp['image_size'], hp['image_size']))
    image = image / 255.0
    
    patch_shape = (hp['patch_size'], hp['patch_size'], hp['num_channels'])
    patches = patchify(image, patch_shape, hp['patch_size'])
    
    #patches = np.reshape(patches, (64, 64, 64, 3))
    #for i in range(64):
    #    cv2.imwrite(f'/kaggle/working/file/image_{i}.png', patches[i])
    patches = np.reshape(patches, hp['flat_patches_shape'])
    patches = patches.astype(np.float32)
    
    #class name
    class_name = path.split("/")[-2]
    class_idx = hp['class_names'].index(class_name)
    class_idx = np.array(class_idx, dtype=np.int32)
    
    return patches, class_idx

In [None]:
def parse(path):
    patches, labels = tf.numpy_function(process_image_label, [path], [tf.float32, tf.int32])
    labels = tf.one_hot(labels, hp['num_classes'])
    
    patches.set_shape(hp['flat_patches_shape'])
    labels.set_shape(hp['num_classes'])
    
    return patches, labels

In [None]:
def tf_dataset(images, batch_size=32):
    ds = tf.data.Dataset.from_tensor_slices((images))
    ds = ds.map(parse)
    ds = ds.batch(32).prefetch(8)
    return ds

# Data Preprocessing

In [None]:
#load the data
train_x = load_data(train_path)
valid_x = load_data(valid_path)
test_x = load_data(test_path)

In [None]:
print(f'Train:{len(train_x)}  Valid:{len(valid_x)}  Test:{len(test_x)}')

In [None]:
create_dir("/kaggle/working/file")

In [None]:
img = process_image_label(train_x[0])

## Visualize the patches

In [None]:
# Path to the folder containing your images
image_folder = "/kaggle/working/file"

# Get a list of image file names
image_files = os.listdir(image_folder)

# Create a 8x8 grid of subplots
fig = plt.figure(figsize=(10, 10))
gs = gridspec.GridSpec(8, 8)

for i, img_file in enumerate(image_files[:64]):
    # Load the image using PIL
    img_path = os.path.join(image_folder, img_file)
    img = Image.open(img_path)
    
    # Create a subplot
    ax = plt.subplot(gs[i])
    ax.imshow(img)
    ax.axis("off")

plt.tight_layout()
plt.show()


## Data pipeline

In [None]:
train_ds = tf_dataset(train_x, batch_size=hp['batch_size'])
valid_ds = tf_dataset(valid_x, batch_size=hp['batch_size'])
test_ds = tf_dataset(test_x, batch_size=hp['batch_size'])

In [None]:
train_ds

In [None]:
for i,j in train_ds:
    print(i.shape, j.shape)
    break

# Model | ViT

In [None]:
class ClassToken(layers.Layer):
    def __init__(self):
        super().__init__()

    def build(self, input_shape):
        w_init = tf.random_normal_initializer()
        self.w = tf.Variable(
            initial_value = w_init(shape=(1, 1, input_shape[-1]), dtype=tf.float32), 
            trainable = True
        )

    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        hidden_dim = self.w.shape[-1]

        #reshape
        cls = tf.broadcast_to(self.w, [batch_size, 1, hidden_dim])
        #change data type
        cls = tf.cast(cls, dtype=inputs.dtype)
        return cls   

In [None]:
def mlp(x, cf):
    x = layers.Dense(cf['mlp_dim'], activation='gelu')(x)
    x = layers.Dropout(cf['dropout_rate'])(x)
    x = layers.Dense(cf['hidden_dim'])(x)
    x = layers.Dropout(cf['dropout_rate'])(x)
    return x

In [None]:
def transformer_encoder(x, cf):
    skip_1 = x
    x = layers.LayerNormalization()(x)
    x = layers.MultiHeadAttention(num_heads=cf['num_heads'], key_dim=cf['hidden_dim'])(x,x)
    x = layers.Add()([x, skip_1])
    
    skip_2 = x
    x = layers.LayerNormalization()(x)
    x = mlp(x, cf)
    x = layers.Add()([x, skip_2])
    
    return x

In [None]:
def ViT(cf):
    input_shape = (cf['num_patches'], cf['patch_size']*cf['patch_size']*cf['num_channels'])
    inputs = layers.Input(input_shape) #(None, 256, 3072)
    
    #patch + Position embedding
    patch_embed = layers.Dense(cf['hidden_dim'])(inputs) #(None, 256, 768)
    
    positions = tf.range(start=0, limit=cf['num_patches'], delta=1)
    pos_emb = layers.Embedding(input_dim=cf['num_patches'], output_dim=cf['hidden_dim'])(positions) #(256, 768)
    
    embed = patch_embed + pos_emb #(None, 256, 768)
    
    token = ClassToken()(embed)
    x = layers.Concatenate(axis=1)([token, embed]) #(None, 257, 768)
    
    for _ in range(cf['num_layers']):
        x = transformer_encoder(x, cf)
        
    x = layers.LayerNormalization()(x)
    x = x[:, 0, :]
    x = layers.Dense(cf['num_classes'], activation='softmax')(x)
    
    model = Model(inputs, x)
    return model

In [None]:
model = ViT(hp)
model.summary()

In [None]:
model.compile(loss='categorical_crossentropy', 
              optimizer = tf.keras.optimizers.Adam(hp['lr'], clipvalue=1.0),
              metrics=['acc']
             )

In [None]:
cbacks = [
    tf.keras.callbacks.ModelCheckpoint(model_path, monitor='val_loss', verbose=1, save_best_only=True),
    tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=7, min_lr=2e-7),
    tf.keras.callbacks.CSVLogger(csv_path),
    tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=20, restore_best_weights=False)
]

# Training

In [None]:
model.fit(
        train_ds,
        epochs=20,
        validation_data=valid_ds,
        callbacks=cbacks
    )

# Testing the Model

In [None]:
saved_model = ViT(hp)
saved_model.load_weights("/kaggle/working/ViT_for_lung_cancer_classification.h5")
saved_model.compile(
        loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
        optimizer=tf.keras.optimizers.Adam(hp["lr"]),
        metrics=["acc"]
    )

In [None]:
saved_model.evaluate(test_ds)

In [None]:
#plot confusion matrix
def plt_confusion_matrix(cm, classes, normalize=False, title="Confusion Matrix", cmap=plt.cm.Blues):
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_mark = np.arange(len(classes))
    plt.xticks(tick_mark, classes, rotation=45)
    plt.yticks(tick_mark, classes)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.axis]
        print("normalized confusion matrix")

    else:
        print("confusion matrix without normalization")

    thresh = cm.max() / 2
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j], horizontalalignment="center", color="white" if cm[i, j] > thresh else "black")

        plt.tight_layout()
        plt.xlabel("predicted label")
        plt.ylabel("True label")

In [None]:
#prediction
prediction = saved_model.predict(test_ds, verbose=0)

In [None]:
#scale the predicted value
np.around(prediction)

In [None]:
#get the max value 
y_pred_classes = np.argmax(prediction, axis=1)

In [None]:
#function for get labels of test set
def get_test_data_class(test_path):
    names = []
    for i in test_path:
        name = i.split("/")[-2]
        name_idx = hp['class_names'].index(name)
        names.append(name_idx)
    names = np.array(names, dtype=np.int32)
    return names

In [None]:
classes = get_test_data_class(test_x)

In [None]:
#confusion matrix
cm = confusion_matrix(y_true=classes, y_pred=y_pred_classes)

In [None]:
plt_confusion_matrix(cm=cm, classes=hp['class_names'], title="confusion matrix", )