In [None]:
from PIL import Image, ImageFilter
import tensorflow as tf
from tensorflow.keras import datasets, layers, models, optimizers
import matplotlib.pyplot as plt
from pathlib import Path
import random
import os
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix

In [None]:
# Get image file paths and shuffle
tumour_files = os.listdir("data/100/Invasive_Tumor/")
tumour_files = [os.path.join("data/100/Invasive_Tumor/", f) for f in tumour_files]
random.shuffle(tumour_files)

immune_files = os.listdir("data/100/CD8+_T_Cells/")
immune_files = [os.path.join("data/100/CD8+_T_Cells/", f) for f in immune_files]
random.shuffle(immune_files)

other_files = os.listdir("data/100/Myoepi_ACTA2+/")
other_files = [os.path.join("data/100/Myoepi_ACTA2+/", f) for f in other_files]
random.shuffle(other_files)

print(len(tumour_files))
print(len(immune_files))
print(len(other_files))

In [None]:
def load_resize(img_path, size=(50,50)):
    img = Image.open(img_path).convert('RGB')
    img = img.resize(size)
    return np.array(img)

In [None]:
tumour_imgs = [load_resize(f) for f in tumour_files]
immune_imgs = [load_resize(f) for f in immune_files]
other_imgs = [load_resize(f) for f in other_files]

imgs_train = immune_imgs[:5000] + tumour_imgs[:5000] + other_imgs[:5000]
imgs_test = immune_imgs[5000:6000] + tumour_imgs[5000:6000] + other_imgs[5000:6000]

Xmat_train = np.stack(imgs_train, axis=0)
Xmat_test = np.stack(imgs_test, axis=0)

y_train = ['Immune'] * 5000 + ['Tumour'] * 5000 + ['Other'] * 5000
y_test = ['Immune'] * 1000 + ['Tumour'] * 1000 + ['Other'] * 1000

In [None]:
# Define the input shape
input_shape = (50, 50, 3)

def model_function(learning_rate=0.00001):
    tf.keras.backend.clear_session()

    model = models.Sequential()
    model.add(layers.Conv2D(32, (5, 5), activation=None, input_shape=input_shape))
    model.add(layers.MaxPooling2D(pool_size=(2, 2)))
    model.add(layers.Conv2D(32, (5, 5), activation=None))
    model.add(layers.MaxPooling2D(pool_size=(2, 2)))
    model.add(layers.Dropout(0.25))
    model.add(layers.Flatten())
    model.add(layers.Dense(128, activation=None))
    model.add(layers.Dropout(0.5))
    model.add(layers.Dense(64, activation=None))
    model.add(layers.Dropout(0.5))
    model.add(layers.Dense(3, activation='softmax'))

    model.compile(
        loss='categorical_crossentropy',
        optimizer=optimizers.Adam(learning_rate=learning_rate),
        metrics=['accuracy']
    )
    
    return model

In [None]:
print("Begin Model Construction")

# Create the model
model = model_function()
model.summary()

batch_size = 64
epochs = 100
num_images = Xmat_train.shape[0]
yMat = pd.get_dummies(y_train).values

In [None]:
print("Begin Model Training")

# Fit the model
history = model.fit(
    x = Xmat_train,
    y = yMat,
    batch_size=batch_size,
    # steps_per_epoch = int(len(Xmat_train) / batch_size),
    epochs=epochs,
    validation_split=0.1,
    verbose =2
)

In [None]:
# Plot training & validation loss values
plt.figure(figsize=(12, 5))

# Loss plot
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Val Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# Accuracy plot (only if accuracy metric was included in model.compile)
if 'accuracy' in history.history:
    plt.subplot(1, 2, 2)
    plt.plot(history.history['accuracy'], label='Train Acc')
    plt.plot(history.history['val_accuracy'], label='Val Acc')
    plt.title('Model Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

plt.tight_layout()
plt.show()

In [None]:
# Predict class probabilities
pred_CNN_prob = model.predict(Xmat_test)

# Convert probabilities to predicted class labels
pred_CNN_indices = np.argmax(pred_CNN_prob, axis=1)
pred_CNN = np.array(['Immune', 'Tumour', 'Other'])[pred_CNN_indices]

# Create confusion matrix
tab = confusion_matrix(y_test, pred_CNN, labels=['Immune', 'Tumour', 'Other'])

print("Confusion Matrix:")
print(tab)

accuracy = np.trace(tab) / np.sum(tab)
print(f"Accuracy: {accuracy:.4f}")