In [15]:
import matplotlib.pyplot as plt
import matplotlib.font_manager as font_manager
from cycler import cycler

# Extract the first four colors from the Paired colormap
paired_colors = plt.cm.Paired(range(10))
selected_colors = paired_colors[:4]

# Set the color cycle with the selected colors
plt.rc("axes", prop_cycle=(cycler("color", selected_colors)))

# Step 1: Register custom font with Matplotlib
font_path = "/Users/henrismidt/Documents/Informatik/Fonts/libertinus/LibertinusSerif-Regular.otf"
font_manager.fontManager.addfont(font_path)  # Register the font with Matplotlib

# Step 2: Update Matplotlib's RC settings to use font by default
plt.rcParams["font.family"] = "Libertinus Serif"

In [2]:
import torch
import pandas as pd
import numpy as np
from transformers import MobileViTImageProcessor
import os
from dataset import MRIImageDataModule
from models import MobileViTLightning, EfficientNetBaseline
from utils import get_best_device

# Set device
device = get_best_device()

# Path to the CSV file
csv_path = "Data/metadata_for_preprocessed_files.csv"

# Define the models and slices
model_configs = {
    "MobileVit-s": {
        "model_ckpt": "apple/mobilevit-small",
        "slice_numbers": ["65", "86", "56", "95", "62", "35", "59", "74", "80", "134"],
    },
    "efficientnet-b2": {
        "model_ckpt": None,
        "slice_numbers": ["65", "86", "56", "95", "62", "35", "59", "74", "80", "134"],
    },
}

# Load and preprocess the MRI dataset
def get_transform(model_name, model_ckpt):
    if model_name == "MobileVit-s":
        processor = MobileViTImageProcessor.from_pretrained(model_ckpt)
        return lambda image: processor(image, return_tensors="pt")["pixel_values"].squeeze(0)
    elif model_name == "efficientnet-b2":
        return None

# Function to store feature maps into a dictionary
def extract_features_to_dict(model_name, model, data_loader, slice_number, feature_dict):
    with torch.no_grad():
        for batch in data_loader:
            inputs, labels, age, ids = batch
            inputs = inputs.to(device).float()
            features = model.extract_features(inputs)
            features = features.cpu().numpy()
            for i, id_ in enumerate(ids):
                if id_ not in feature_dict:
                    feature_dict[id_] = {"label": labels[i].item()}
                feature_dict[id_][f"slice_{slice_number}"] = features[i].tolist()

# Function to convert feature dictionary to DataFrame and save as CSV
def save_features_dict_to_csv(feature_dict, file_name):
    df = pd.DataFrame.from_dict(feature_dict, orient='index')
    df = df.reset_index().rename(columns={"index": "id"})
    df.to_csv(file_name, index=False)
    print(f"Saved features to {file_name}")

# Main loop to extract and save features
for stage in ['train', 'val', 'test']:
    for model_name, config in model_configs.items():
        feature_dict = {}
        model_ckpt = config["model_ckpt"]
        transform = get_transform(model_name, model_ckpt)

        for slice_number in config["slice_numbers"]:
            model_path = f"model_checkpoints/{model_name}/slice_number_{slice_number}_lr_1e-05.ckpt"
            if not os.path.exists(model_path):
                continue

            # Load the model
            if model_name == "MobileVit-s":
                model = MobileViTLightning.load_from_checkpoint(model_path, model_ckpt=model_ckpt, num_labels=4)
            elif model_name == "efficientnet-b2":
                model = EfficientNetBaseline.load_from_checkpoint(model_path, model_name="efficientnet-b2", num_classes=4)

            model = model.to(device)
            model.eval()

            # Initialize the data module
            data_module = MRIImageDataModule(
                csv_path,
                slice_number=int(slice_number),
                transform=transform,
                batch_size=48,
                num_workers=0,
                always_return_id=True,
            )
            data_module.setup()

            # Get data loaders
            if stage == 'train':
                data_loader = data_module.train_dataloader()
            elif stage == 'val':
                data_loader = data_module.val_dataloader()
            elif stage == 'test':
                data_loader = data_module.test_dataloader()

            # Extract and store features
            extract_features_to_dict(model_name, model, data_loader, slice_number, feature_dict)

        # Save feature dictionary to CSV
        save_features_dict_to_csv(feature_dict, f"extracted_features/{model_name}_{stage}_features.csv")


Some weights of MobileViTForImageClassification were not initialized from the model checkpoint at apple/mobilevit-small and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 640]) in the checkpoint and torch.Size([4, 640]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([4]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of MobileViTForImageClassification were not initialized from the model checkpoint at apple/mobilevit-small and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 640]) in the checkpoint and torch.Size([4, 640]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([4]) in the model instantiated
You should probably TRAIN this model on a down-st

Saved features to MobileVit-s_train_features.csv
Loaded pretrained weights for efficientnet-b2
Loaded pretrained weights for efficientnet-b2
Loaded pretrained weights for efficientnet-b2
Loaded pretrained weights for efficientnet-b2
Loaded pretrained weights for efficientnet-b2
Loaded pretrained weights for efficientnet-b2
Loaded pretrained weights for efficientnet-b2
Loaded pretrained weights for efficientnet-b2
Loaded pretrained weights for efficientnet-b2
Loaded pretrained weights for efficientnet-b2
Saved features to efficientnet-b2_train_features.csv


Some weights of MobileViTForImageClassification were not initialized from the model checkpoint at apple/mobilevit-small and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 640]) in the checkpoint and torch.Size([4, 640]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([4]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of MobileViTForImageClassification were not initialized from the model checkpoint at apple/mobilevit-small and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 640]) in the checkpoint and torch.Size([4, 640]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([4]) in the model instantiated
You should probably TRAIN this model on a down-st

Saved features to MobileVit-s_val_features.csv
Loaded pretrained weights for efficientnet-b2
Loaded pretrained weights for efficientnet-b2
Loaded pretrained weights for efficientnet-b2
Loaded pretrained weights for efficientnet-b2
Loaded pretrained weights for efficientnet-b2
Loaded pretrained weights for efficientnet-b2
Loaded pretrained weights for efficientnet-b2
Loaded pretrained weights for efficientnet-b2
Loaded pretrained weights for efficientnet-b2
Loaded pretrained weights for efficientnet-b2
Saved features to efficientnet-b2_val_features.csv


Some weights of MobileViTForImageClassification were not initialized from the model checkpoint at apple/mobilevit-small and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 640]) in the checkpoint and torch.Size([4, 640]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([4]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of MobileViTForImageClassification were not initialized from the model checkpoint at apple/mobilevit-small and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 640]) in the checkpoint and torch.Size([4, 640]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([4]) in the model instantiated
You should probably TRAIN this model on a down-st

Saved features to MobileVit-s_test_features.csv
Loaded pretrained weights for efficientnet-b2
Loaded pretrained weights for efficientnet-b2
Loaded pretrained weights for efficientnet-b2
Loaded pretrained weights for efficientnet-b2
Loaded pretrained weights for efficientnet-b2
Loaded pretrained weights for efficientnet-b2
Loaded pretrained weights for efficientnet-b2
Loaded pretrained weights for efficientnet-b2
Loaded pretrained weights for efficientnet-b2
Loaded pretrained weights for efficientnet-b2
Saved features to efficientnet-b2_test_features.csv


In [7]:
import pickle
import pandas as pd
import numpy as np
from tqdm import tqdm

csv_names = [
    'efficientnet-b2_train_features', 
    'efficientnet-b2_val_features', 
    'efficientnet-b2_test_features', 
    'MobileVit-s_train_features', 
    'MobileVit-s_val_features', 
    'MobileVit-s_test_features'
]

for csv_name in csv_names:
    # Read the CSV file
    df = pd.read_csv(f'{csv_name}.csv')  

    # Get slice columns
    slice_cols = [col for col in df.columns if col.startswith("slice_")]

    # Function to apply global average pooling to a single feature map
    def global_average_pooling(feature_map):
        feature_map_array = np.array(eval(feature_map))
        return np.mean(feature_map_array, axis=(1, 2))

    # Apply global average pooling to all slice columns with a progress bar
    for col in tqdm(slice_cols, desc=f"Processing {csv_name}"):
        df[col] = df[col].apply(global_average_pooling)

    # Save the resulting dataframe to a new CSV file
    with open(f"{csv_name}_pooled.pkl", 'wb') as f:
        pickle.dump(df, f)
    
    print(f"Global average pooling applied and saved to 'extracted_features/{csv_name}_pooled.pkl'")

    for col in df.columns:
        if col.startswith('slice_'):
            shapes = df[col].apply(lambda x: x.shape).unique()
            print(f"Unique shapes in column {col} of {csv_name}: {shapes}")


Processing efficientnet-b2_train_features: 100%|██████████| 10/10 [03:36<00:00, 21.64s/it]


Global average pooling applied and saved to 'efficientnet-b2_train_features_pooled.pkl'
Unique shapes in column slice_65 of efficientnet-b2_train_features: [(1408,)]
Unique shapes in column slice_86 of efficientnet-b2_train_features: [(1408,)]
Unique shapes in column slice_56 of efficientnet-b2_train_features: [(1408,)]
Unique shapes in column slice_95 of efficientnet-b2_train_features: [(1408,)]
Unique shapes in column slice_62 of efficientnet-b2_train_features: [(1408,)]
Unique shapes in column slice_35 of efficientnet-b2_train_features: [(1408,)]
Unique shapes in column slice_59 of efficientnet-b2_train_features: [(1408,)]
Unique shapes in column slice_74 of efficientnet-b2_train_features: [(1408,)]
Unique shapes in column slice_80 of efficientnet-b2_train_features: [(1408,)]
Unique shapes in column slice_134 of efficientnet-b2_train_features: [(1408,)]


Processing efficientnet-b2_val_features: 100%|██████████| 10/10 [00:35<00:00,  3.55s/it]


Global average pooling applied and saved to 'efficientnet-b2_val_features_pooled.pkl'
Unique shapes in column slice_65 of efficientnet-b2_val_features: [(1408,)]
Unique shapes in column slice_86 of efficientnet-b2_val_features: [(1408,)]
Unique shapes in column slice_56 of efficientnet-b2_val_features: [(1408,)]
Unique shapes in column slice_95 of efficientnet-b2_val_features: [(1408,)]
Unique shapes in column slice_62 of efficientnet-b2_val_features: [(1408,)]
Unique shapes in column slice_35 of efficientnet-b2_val_features: [(1408,)]
Unique shapes in column slice_59 of efficientnet-b2_val_features: [(1408,)]
Unique shapes in column slice_74 of efficientnet-b2_val_features: [(1408,)]
Unique shapes in column slice_80 of efficientnet-b2_val_features: [(1408,)]
Unique shapes in column slice_134 of efficientnet-b2_val_features: [(1408,)]


Processing efficientnet-b2_test_features: 100%|██████████| 10/10 [00:36<00:00,  3.65s/it]


Global average pooling applied and saved to 'efficientnet-b2_test_features_pooled.pkl'
Unique shapes in column slice_65 of efficientnet-b2_test_features: [(1408,)]
Unique shapes in column slice_86 of efficientnet-b2_test_features: [(1408,)]
Unique shapes in column slice_56 of efficientnet-b2_test_features: [(1408,)]
Unique shapes in column slice_95 of efficientnet-b2_test_features: [(1408,)]
Unique shapes in column slice_62 of efficientnet-b2_test_features: [(1408,)]
Unique shapes in column slice_35 of efficientnet-b2_test_features: [(1408,)]
Unique shapes in column slice_59 of efficientnet-b2_test_features: [(1408,)]
Unique shapes in column slice_74 of efficientnet-b2_test_features: [(1408,)]
Unique shapes in column slice_80 of efficientnet-b2_test_features: [(1408,)]
Unique shapes in column slice_134 of efficientnet-b2_test_features: [(1408,)]


Processing MobileVit-s_train_features: 100%|██████████| 10/10 [00:33<00:00,  3.35s/it]


Global average pooling applied and saved to 'MobileVit-s_train_features_pooled.pkl'
Unique shapes in column slice_65 of MobileVit-s_train_features: [(160,)]
Unique shapes in column slice_86 of MobileVit-s_train_features: [(160,)]
Unique shapes in column slice_56 of MobileVit-s_train_features: [(160,)]
Unique shapes in column slice_95 of MobileVit-s_train_features: [(160,)]
Unique shapes in column slice_62 of MobileVit-s_train_features: [(160,)]
Unique shapes in column slice_35 of MobileVit-s_train_features: [(160,)]
Unique shapes in column slice_59 of MobileVit-s_train_features: [(160,)]
Unique shapes in column slice_74 of MobileVit-s_train_features: [(160,)]
Unique shapes in column slice_80 of MobileVit-s_train_features: [(160,)]
Unique shapes in column slice_134 of MobileVit-s_train_features: [(160,)]


Processing MobileVit-s_val_features: 100%|██████████| 10/10 [00:05<00:00,  1.81it/s]


Global average pooling applied and saved to 'MobileVit-s_val_features_pooled.pkl'
Unique shapes in column slice_65 of MobileVit-s_val_features: [(160,)]
Unique shapes in column slice_86 of MobileVit-s_val_features: [(160,)]
Unique shapes in column slice_56 of MobileVit-s_val_features: [(160,)]
Unique shapes in column slice_95 of MobileVit-s_val_features: [(160,)]
Unique shapes in column slice_62 of MobileVit-s_val_features: [(160,)]
Unique shapes in column slice_35 of MobileVit-s_val_features: [(160,)]
Unique shapes in column slice_59 of MobileVit-s_val_features: [(160,)]
Unique shapes in column slice_74 of MobileVit-s_val_features: [(160,)]
Unique shapes in column slice_80 of MobileVit-s_val_features: [(160,)]
Unique shapes in column slice_134 of MobileVit-s_val_features: [(160,)]


Processing MobileVit-s_test_features: 100%|██████████| 10/10 [00:05<00:00,  1.80it/s]

Global average pooling applied and saved to 'MobileVit-s_test_features_pooled.pkl'
Unique shapes in column slice_65 of MobileVit-s_test_features: [(160,)]
Unique shapes in column slice_86 of MobileVit-s_test_features: [(160,)]
Unique shapes in column slice_56 of MobileVit-s_test_features: [(160,)]
Unique shapes in column slice_95 of MobileVit-s_test_features: [(160,)]
Unique shapes in column slice_62 of MobileVit-s_test_features: [(160,)]
Unique shapes in column slice_35 of MobileVit-s_test_features: [(160,)]
Unique shapes in column slice_59 of MobileVit-s_test_features: [(160,)]
Unique shapes in column slice_74 of MobileVit-s_test_features: [(160,)]
Unique shapes in column slice_80 of MobileVit-s_test_features: [(160,)]
Unique shapes in column slice_134 of MobileVit-s_test_features: [(160,)]



