# Data Loading

In [2]:
import warnings
warnings.filterwarnings('ignore')

import tensorflow as tf
from tensorflow import keras
tf.random.set_seed(10)
import matplotlib.pyplot as plt

# Load the dataset
import tensorflow_datasets as tfds
dataset, dataset_info = tfds.load(name='malaria', shuffle_files=True, with_info=True, as_supervised=True, split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'])

# Callbacks

In [4]:
from tensorflow.keras.callbacks import Callback,CSVLogger,EarlyStopping
csv_logger = CSVLogger("renetLogs.csv",separator=',',append=False)

In [5]:
es_callback = EarlyStopping(restore_best_weights=True,patience=2)

In [6]:
from tensorflow.keras.callbacks import LearningRateScheduler

def scheduler(epoch, lr):
    if epoch < 3:
        return lr
    else:
        return float(lr * tf.math.exp(-0.1))

sched = LearningRateScheduler(scheduler, verbose=1)

In [7]:
# Preprocessing function
def preprocess(image, label):
    image = tf.image.resize(image, [224, 224])  # Resize images to 224x224
    image = tf.cast(image, tf.float32)  # Convert images to float32
    return image, label

# Apply preprocessing
train_dataset, valid_dataset, test_dataset = dataset

train_dataset = train_dataset.map(preprocess)
valid_dataset = valid_dataset.map(preprocess)
test_dataset = test_dataset.map(preprocess)

# Normalization function
def normalise(image, label):
    return image / 255.0, label

# Apply normalization
train_dataset = train_dataset.map(normalise)
valid_dataset = valid_dataset.map(normalise)
test_dataset = test_dataset.map(normalise)

# Apply shuffling, batching, and prefetching
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(100).prefetch(tf.data.AUTOTUNE)
valid_dataset = valid_dataset.batch(100).prefetch(tf.data.AUTOTUNE)
test_dataset = test_dataset.batch(100).prefetch(tf.data.AUTOTUNE)

# ResNet34

In [None]:
from keras.layers import Conv2D, MaxPooling2D, Dense, InputLayer, Flatten, BatchNormalization, GlobalAveragePooling2D
from keras.models import Model
class ResNet34(Model):
    def __init__(self):
        super(ResNet34,self).__init__(name="resent_34")
        self.conv_1 = CustomConv2D(64,7,2,padding='same')
        self.max_pool = MaxPooling2D(3,2)

        self.conv_2_1 = ResidualBlock(64)
        self.conv_2_2 = ResidualBlock(64)
        self.conv_2_3 = ResidualBlock(64)
        
        self.conv_3_1 = ResidualBlock(128,2)
        self.conv_3_2 = ResidualBlock(128)
        self.conv_3_3 = ResidualBlock(128)
        self.conv_3_3 = ResidualBlock(128)
        
        self.conv_3_1 = ResidualBlock(256,2)
        self.conv_3_2 = ResidualBlock(256)
        self.conv_3_3 = ResidualBlock(256)
        self.conv_3_3 = ResidualBlock(256)
        
        self.conv_3_1 = ResidualBlock(512,2)
        self.conv_3_2 = ResidualBlock(512)
        self.conv_3_3 = ResidualBlock(512)
        self.conv_3_3 = ResidualBlock(512)

        self.global_pool = GlobalAveragePooling2D()
        self.fc_3 = Dense(2,activation='softmax')