In [1]:
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf

In [206]:
from tensorflow.keras.metrics import Recall, Precision

In [14]:
from tensorflow.keras import Sequential, layers, losses, optimizers

Import processed Data

In [3]:
data = pd.read_csv('data/cleaned_data.csv')

In [4]:
data.shape

(31011, 4010)

In [5]:
data.head(2)

Unnamed: 0,SampleID,CellType,IGLV3-19,IGHV4-34,IGKC,IGHA1,IGLC3,S100A2,SCGB3A1,IGHG1,...,BLM,STARD4,CCDC171,ITSN1,PPM1H,AHR,HPS5,MEI1,PNMA1,MAP3K9
0,2.0,2,-0.0159,-0.010017,-0.047302,-0.02565,-0.017487,-0.040863,-0.041718,-0.041687,...,-0.172607,-0.354004,-0.103394,-0.24585,-0.088196,-0.386719,-0.427734,-0.190186,-0.260254,-0.070862
1,2.0,2,-0.0159,-0.010017,-0.047302,-0.02565,-0.017487,-0.040863,-0.041718,-0.041687,...,-0.172607,-0.354004,-0.103394,-0.24585,-0.088196,-0.386719,-0.427734,-0.190186,-0.260254,-0.070862


Split Data

In [169]:
X = data.iloc[:, 2:].to_numpy().astype('float32')
y = data.iloc[:, 1].to_numpy().astype('int32')

In [170]:
y = y - 1

In [171]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

Data pipelining

In [9]:
train_ds = tf.data.Dataset.from_tensor_slices({'data':  X_train, 
                                               'label': y_train})

In [10]:
AUTOTUNE = tf.data.AUTOTUNE
batch_size = 32

In [78]:
train_ds = train_ds.cache() \
                   .shuffle(buffer_size=len(X_train)) \
                   .batch(batch_size) \
                   .prefetch(AUTOTUNE)

In [15]:
X_train.shape

(24808, 4008)

In [250]:
model = Sequential([
    layers.InputLayer(input_shape=(X_test.shape[1],)),
    layers.Dense(32, activation='relu'),
    layers.Dropout(0.2),
    layers.Dense(10, activation='softmax')
])



In [251]:
model.summary()

In [252]:
class CellType(tf.keras.Model):
    def __init__(self, model, *args, **kwargs):
        
        super().__init__(*args, **kwargs)

        self.model = model

    def compile(self, opt, loss, *args, **kwargs):
        super().compile(*args, **kwargs)
        self.opt = opt
        self.loss_fn = loss

        self.precision_metric = Precision()
        self.recall_metric = Recall()

    def train_step(self, batch):

        data, label = batch

        data = tf.convert_to_tensor(data, dtype=tf.float32)

        label = tf.convert_to_tensor(label, dtype=tf.int32) 

        with tf.GradientTape() as tape:

            predictions = self.model(data, training = True)

            total_loss = self.loss_fn(label, predictions)


        grad = tape.gradient(total_loss, self.model.trainable_variables)
    
        self.opt.apply_gradients(zip(grad, self.model.trainable_variables))


        #Calculate Precision 
        predicted_classes = tf.argmax(predictions, axis=-1, output_type=tf.int32)

        # Update the metric states
        self.precision_metric.update_state(label, predicted_classes)
        self.recall_metric.update_state(label, predicted_classes)

        precision_value = self.precision_metric.result()
        recall_value = self.recall_metric.result()

        epsilon = 1e-7  # Small constant to avoid division by zero
        f1_value = 2 * (precision_value * recall_value) / (precision_value + recall_value + epsilon)



        return {
            "loss": total_loss,
            "precision": precision_value,
            "recall": recall_value,
            "f1": f1_value
        }
    

    def test_step(self, batch):

        data, label = batch

        data = tf.convert_to_tensor(data, dtype=tf.float32)

        label = tf.convert_to_tensor(label, dtype=tf.int32) 


        predictions = self.model(data)

        total_loss = self.loss_fn(label, predictions)

        #Calculate Precision 
        predicted_classes = tf.argmax(predictions, axis=-1, output_type=tf.int32)

        # Update the metric states
        self.precision_metric.update_state(label, predicted_classes)
        self.recall_metric.update_state(label, predicted_classes)

        precision_value = self.precision_metric.result()
        recall_value = self.recall_metric.result()

        epsilon = 1e-7  # Small constant to avoid division by zero
        f1_value = 2 * (precision_value * recall_value) / (precision_value + recall_value + epsilon)



        return {
            "loss": total_loss,
            "precision": precision_value,
            "recall": recall_value,
            "f1": f1_value
        }

In [253]:
cell_type = CellType(model)

In [254]:
opt = optimizers.AdamW(learning_rate=1e-4)

loss = losses.SparseCategoricalCrossentropy(from_logits=False)

cell_type.compile(opt = opt, loss = loss)

In [255]:
cell_type.fit(X_train, y_train,
              validation_data = (X_test, y_test),
               batch_size=32, epochs=5)

Epoch 1/5
[1m776/776[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - f1: 0.9918 - loss: 0.5237 - precision: 0.9919 - recall: 0.9918 - val_f1: 0.9986 - val_loss: 0.1822 - val_precision: 0.9979 - val_recall: 0.9993
Epoch 2/5
[1m776/776[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1ms/step - f1: 0.9990 - loss: 0.1868 - precision: 0.9988 - recall: 0.9992 - val_f1: 0.9992 - val_loss: 0.1108 - val_precision: 0.9992 - val_recall: 0.9992
Epoch 3/5
[1m776/776[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1ms/step - f1: 0.9993 - loss: 0.1356 - precision: 0.9991 - recall: 0.9994 - val_f1: 0.9993 - val_loss: 0.1098 - val_precision: 0.9995 - val_recall: 0.9990
Epoch 4/5
[1m776/776[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1ms/step - f1: 0.9993 - loss: 0.1109 - precision: 0.9990 - recall: 0.9995 - val_f1: 0.9993 - val_loss: 0.0939 - val_precision: 0.9995 - val_recall: 0.9990
Epoch 5/5
[1m776/776[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 

<keras.src.callbacks.history.History at 0x3531176d0>

In [256]:
predictions = model(X_test)
predicted_classes = tf.argmax(predictions, axis=-1, output_type=tf.int32)

In [257]:
predicted_classes.numpy()

array([1, 6, 1, ..., 1, 6, 1], dtype=int32)

In [258]:
y_test

array([2, 6, 1, ..., 1, 6, 1], dtype=int32)