# Transfer Learning - bacteria & virus (chest_xray)

## denseNet(tensorflow)


In [None]:
import os
import torch
import pandas as pd
from sklearn.model_selection import train_test_split
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import DenseNet121
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# Set the path to your images
image_dir = './data/chest_xray/mix/pneumonia'

In [None]:
torch.manual_seed(42)

In [None]:
# Step 1: Label images based on their filenames
image_paths = []
labels = []

for filename in os.listdir(image_dir):
    if filename.endswith(".jpeg"):
        label = "bacteria" if "bacteria" in filename else "virus"
        image_paths.append(os.path.join(image_dir, filename))
        labels.append(label)

# Create a DataFrame to facilitate splitting
data_df = pd.DataFrame({"filename": image_paths, "label": labels})

# Step 2: Split the dataset into train, validation, and test sets (8:1:1)
train_df, temp_df = train_test_split(data_df, test_size=0.2, stratify=data_df["label"], random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, stratify=temp_df["label"], random_state=42)

# Step 3: Use ImageDataGenerator to load images without moving them
# Normalization values
mean = [0.485, 0.456, 0.406]  # Mean for normalization
std = [0.229, 0.224, 0.225]   # Standard deviation for normalization

# Custom preprocessing function to normalize images
def custom_preprocessing_function(x):
    # Normalize image data by subtracting the mean and dividing by the std
    return (x - mean) / std

# Create ImageDataGenerators with custom normalization
train_datagen = ImageDataGenerator(
    rescale=1./255,  # Rescale to [0, 1]
    preprocessing_function=custom_preprocessing_function  # Apply custom normalization
)

val_test_datagen = ImageDataGenerator(
    rescale=1./255,  # Rescale to [0, 1]
    preprocessing_function=custom_preprocessing_function  # Apply custom normalization
)

train_generator = train_datagen.flow_from_dataframe(
    train_df, x_col="filename", y_col="label", target_size=(224, 224),
    class_mode="binary", batch_size=32, shuffle=True
) # shuffle on train data

val_generator = val_test_datagen.flow_from_dataframe(
    val_df, x_col="filename", y_col="label", target_size=(224, 224),
    class_mode="binary", batch_size=32, shuffle=False
)

test_generator = val_test_datagen.flow_from_dataframe(
    test_df, x_col="filename", y_col="label", target_size=(224, 224),
    class_mode="binary", batch_size=32, shuffle=False
)

# Step 4: Load and fine-tune DenseNet for binary classification
base_model = DenseNet121(weights="imagenet", include_top=False, input_shape=(224, 224, 3))

# Freeze the base model layers
for layer in base_model.layers:
    layer.trainable = False

# Add custom layers on top
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation="relu")(x)
output = Dense(1, activation="sigmoid")(x)  # Binary classification (0 or 1)

model = Model(inputs=base_model.input, outputs=output)

# Compile the model
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])

# Train the model
history = model.fit(
    train_generator,
    epochs=10,
    validation_data=val_generator
)

# Evaluate on the test set
test_loss, test_acc = model.evaluate(test_generator)
print(f"Test Accuracy: {test_acc:.2f}")

In [None]:
# Step 5 Visualization

# Plotting the training and validation loss over epochs
plt.figure(figsize=(10, 5))

# Plot training & validation loss values
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

# Optional: Plot training & validation accuracy values if needed
plt.figure(figsize=(10, 5))
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# Step 1: Predict on the test set
# Since `test_generator` is used here, make sure that `shuffle=False` was set in test data generation.
test_generator.reset()  # Ensures that predictions are in the correct order
predictions = model.predict(test_generator)
predicted_classes = (predictions > 0.5).astype("int32")  # Convert probabilities to binary (0 or 1)

# Step 2: Get the true labels
true_classes = test_generator.classes  # Ground truth labels from the test generator
class_labels = list(test_generator.class_indices.keys())  # Label names (e.g., ["bacteria", "virus"])

# Step 3: Create the confusion matrix
conf_matrix = confusion_matrix(true_classes, predicted_classes)

# Step 4: Plot the confusion matrix
plt.figure(figsize=(6, 6))
disp = ConfusionMatrixDisplay(confusion_matrix=conf_matrix, display_labels=class_labels)
disp.plot(cmap="Blues", values_format="d")
plt.title("Confusion Matrix on Test Data")
plt.show()


## denseNet(pytorch)

In [None]:
import os
import shutil
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import numpy as np
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay


In [None]:
torch.manual_seed(42)

### 1. Data Preprocessing

In [None]:
# source_dir = r'./data/chest_xray/mix/pneumonia'
# train_dir = r'./data/chest_xray/tfl/train'
# val_dir = r'./data/chest_xray/tfl/val'
# test_dir = r'./data/chest_xray/tfl/test'

In [1]:
# for category in ['bacteria', 'virus']:
#     os.makedirs(os.path.join(train_dir, category), exist_ok=True)
#     os.makedirs(os.path.join(val_dir, category), exist_ok=True)
#     os.makedirs(os.path.join(test_dir, category), exist_ok=True)

# images = os.listdir(source_dir)
# bacteria_images = [img for img in images if 'bacteria' in img]
# virus_images = [img for img in images if 'virus' in img]

# train_bacteria, test_bacteria = train_test_split(bacteria_images, test_size=0.2, random_state=42)
# val_bacteria, test_bacteria = train_test_split(test_bacteria, test_size=0.5, random_state=42)

# train_virus, test_virus = train_test_split(virus_images, test_size=0.2, random_state=42)
# val_virus, test_virus = train_test_split(test_virus, test_size=0.5, random_state=42)

# def move_files(file_list, dest_dir):
#     for file_name in file_list:
#         shutil.move(os.path.join(source_dir, file_name), os.path.join(dest_dir, file_name))


# move_files(train_bacteria, os.path.join(train_dir, 'bacteria'))
# move_files(val_bacteria, os.path.join(val_dir, 'bacteria'))
# move_files(test_bacteria, os.path.join(test_dir, 'bacteria'))

# move_files(train_virus, os.path.join(train_dir, 'virus'))
# move_files(val_virus, os.path.join(val_dir, 'virus'))
# move_files(test_virus, os.path.join(test_dir, 'virus'))

### 2. Model Definition (DenseNet)

In [None]:
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224), # cut the image into a 224*224 pixel subimage
        transforms.RandomHorizontalFlip(), # randomly horizontally flip the image in p=0.5
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224), # cut a 224*224 subimage in the center of the original image
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

data_dir = './data/chest_xray/tfl'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val', 'test']}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=32, shuffle=True, num_workers=4) for x in ['train', 'val', 'test']}

# load pretrained DenseNet
model = models.densenet121(weights=True)
num_features = model.classifier.in_features

model.classifier = nn.Linear(num_features, 2) 

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

### 3. Loss Function and Optimizer

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

### 4. Training Loop

In [None]:
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

def train_model(model, criterion, optimizer, num_epochs=25):
    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # set training
            else:
                model.eval()   # set validation

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # forwards
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backwards + optimizationï¼Œonly during training
                    if phase == 'train':
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(image_datasets[phase])
            epoch_acc = running_corrects.double() / len(image_datasets[phase])

            # recording loss and accuracy
            if phase == 'train':
                history['train_loss'].append(epoch_loss)
                history['train_acc'].append(epoch_acc.cpu().item()) # train_acc and val_acc are Pytorch tensor, needed to be converted into item in cpu
            else:
                history['val_loss'].append(epoch_loss)
                history['val_acc'].append(epoch_acc.cpu().item())

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

    return model

# training
model = train_model(model, criterion, optimizer, num_epochs=10)

# Visualization
def plot_training_history(history):

    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.title('Loss over Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(history['train_acc'], label='Train Accuracy')
    plt.plot(history['val_acc'], label='Validation Accuracy')
    plt.title('Accuracy over Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.show()

plot_training_history(history)

### 5. Evaluate the Model (Confusion Matrix)

In [None]:
# confusion matrix
def evaluate_and_plot_confusion_matrix(model):
    model.eval() 
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in dataloaders['test']:
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    cm = confusion_matrix(all_labels, all_preds)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=image_datasets['test'].classes)

    disp.plot(cmap='Blues')
    plt.title('Confusion Matrix')
    plt.show()

evaluate_and_plot_confusion_matrix(model)
