In [None]:
from __future__ import print_function, division

import torch
import torch.nn as nn
import numpy as np
import torchvision
from torchvision import transforms
import time
import os
from torch.utils.data import Dataset, DataLoader
from glob import glob
from PIL import Image
import pickle

In [None]:
# Device configuration
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

# Save path for extracted features
save_address_1024 = '/path/to/save/features/features_from_resnet50/'

# Transformation for input patches
trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Custom Dataset class
class My_dataloader(Dataset):
    def __init__(self, data_24, transform):
        """
        Args:
            data_24: Path to input data (slide directory).
        """
        self.data_24 = data_24
        self.pathes_24 = glob(self.data_24 + '/*')
        self.transform = transform

    def __len__(self):
        return len(self.pathes_24)

    def __getitem__(self, idx):
        img_24 = Image.open(self.pathes_24[idx]).convert('RGB')
        img_24_name = os.path.basename(self.pathes_24[idx])
        img_24_folder = os.path.basename(os.path.dirname(self.pathes_24[idx]))
        if self.transform:
            img_24 = self.transform(img_24)
        return img_24, img_24_name, img_24_folder

# Load ResNet50
model = torchvision.models.resnet50(pretrained=True)

# Freeze all layers initially
for param in model.parameters():
    param.requires_grad = False

# Unfreeze last 50 layers only
for param in list(model.parameters())[-30:]:
    param.requires_grad = True

# Modify feature extractor to include AdaptiveAvgPool2d
# model.features = nn.Sequential(model.features, nn.AdaptiveAvgPool2d(output_size=(1, 1)))

# Modify the average pooling layer
model.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))  # Correct modification

# Reduce features from 2048 → 512 using a linear projection
model.fc = nn.Sequential(
    nn.Linear(2048, 512),  # ResNet50 outputs 2048-dimensional features
    nn.ReLU(inplace=True),
)

# Set num_ftrs to 512 (new feature size)
num_ftrs = 512


# Define Custom Fully Connected Model
class fully_connected(nn.Module):
    def __init__(self, model, num_ftrs, num_classes):
        print(f"Input features: {num_ftrs}, Output classes: {num_classes}")
        super(fully_connected, self).__init__()
        self.model = model
        self.fc_4 = nn.Linear(num_ftrs, num_classes)  # Now input is 512 instead of 1280

    def forward(self, x):
        x = self.model(x)
        out_1 = x  # 512-dimensional features
        out_3 = self.fc_4(x)  # Classification output
        return out_1, out_3


# Define Final Model
model_final = fully_connected(model, num_ftrs, 4)

# Move model to device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_final = model_final.to(device)
model_final = nn.DataParallel(model_final)

# Define loss function
criterion = nn.CrossEntropyLoss()

def test_model(model, criterion, num_epochs=25):
    since = time.time()

    model.eval()   # Setting in evaluate mode

    running_loss = 0.0
    running_corrects = 0

    # Processing pipeline
    class_dirs = glob("/path/to/patches/...")  

    for class_dir in class_dirs:
        class_name = os.path.basename(class_dir)  
        slide_paths = glob(os.path.join(class_dir, '*')) 

        # Create directory for class in save path
        class_save_path = os.path.join(save_address_1024, class_name)
        os.makedirs(class_save_path, exist_ok=True)

        for slide_path in slide_paths:
            slide_name = os.path.basename(slide_path) 
            print(f"Processing Slide: {slide_name} in Class: {class_name}")

            # Create dataset and dataloader for the slide
            test_imagedataset = My_dataloader(slide_path, trans)
            dataloader_test = torch.utils.data.DataLoader(test_imagedataset, batch_size=600, shuffle=False, num_workers=16)

            # Initialize dictionary to store features
            slide_patches_dict_1024 = {}

            # Extract features for all patches in the slide
            for ii, (inputs, img_name, folder_name) in enumerate(dataloader_test):
                print("Batch count:", ii)
                print("Input shape:", inputs.shape)
                inputs = inputs.to(device)
                output1, outputs = model(inputs)
                output_1024 = output1.cpu().detach().numpy()

                # Save features in the dictionary
                for j in range(len(outputs)):
                    slide_patches_dict_1024[img_name[j]] = output_1024[j]

            # Save features to a pickle file in the class-specific directory
            output_file_name = f"{slide_name}_resnet50Features_dict.pickle"
            output_file_path = os.path.join(class_save_path, output_file_name)
            with open(output_file_path, 'wb') as outfile:
                pickle.dump(slide_patches_dict_1024, outfile)
                print(f"Saved features to: {output_file_path}")

    time_elapsed = time.time() - since
    print('Evaluation completed in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    return model

# Call the test_model function
test_model(model_final, criterion)


## Get combined hd5 file for all classes features

In [None]:
import os
import pickle
import h5py
import numpy as np

# Directory containing the .pickle files
pickle_dir = '/path/to/features/features_from_resnet50'
h5_output_file = '/path/to/save/features/features_from_resnet50/features_resnet50.hdf5'

# Label mapping for classes
class_labels = {"TCIA-normal": 0, "osmf_oscc": 1, "OSMF": 1, "wd": 1, "tcia-wd": 1, "md": 2, "tcia-md": 2, "pd": 3, "tcia-pd": 3, 'Normal_aug_all_TCIA_inhouse': 0} 

# Function to extract (row, col) from patch name
def extract_coords_from_patch_name(patch_name):
    print("patch_name", patch_name)
    try:
        parts = patch_name.replace('.png', '').split('_')  # Remove .png and split by _
        
        if 'tile' in parts:
            tile_idx = parts.index('tile')
            row = int(parts[tile_idx + 1])  # Extract row number
            col = int(parts[tile_idx + 2])  # Extract column number
            return (row, col)
        else:
            raise ValueError("Unexpected format in patch name")

    except Exception as e:
        print(f"Error parsing patch name {patch_name}: {e}")
        return (0, 0)  
 

# Create HDF5 file
with h5py.File(h5_output_file, 'w') as h5_file:
    # Loop through class directories
    for class_dir in os.listdir(pickle_dir):
        class_path = os.path.join(pickle_dir, class_dir)
        if os.path.isdir(class_path):
            # Process each slide in the class directory
            for pickle_file in os.listdir(class_path):
                if pickle_file.endswith('.pickle'):
                    slide_name = os.path.splitext(pickle_file)[0]
                    slide_id = slide_name.split('_')[0]  # Extract slide_id
                    slide_path = os.path.join(class_path, pickle_file)
                    print(f"Processing slide: {slide_name} (Slide ID: {slide_id})")

                    # Load .pickle file
                    with open(slide_path, 'rb') as f:
                        patch_features = pickle.load(f)

                    # Extract embeddings and coordinates
                    embeddings = []
                    coords = []
                    for patch_name, feature in sorted(patch_features.items()):  
                        embeddings.append(feature)
                        coords.append(extract_coords_from_patch_name(patch_name))

                    embeddings = np.array(embeddings)
                    coords = np.array(coords)

                    # Create HDF5 group for the slide
                    group = h5_file.create_group(slide_name)
                    group.create_dataset("embeddings", data=embeddings)
                    group.create_dataset("coords", data=coords)
                    group.create_dataset("label", data=class_labels[class_dir])
                    group.attrs["slide_id"] = slide_id  
                    group.attrs["path"] = slide_name   
                    # group.attrs["slide_id"] = slide_id 

                    print(f"Saved {slide_name} (Slide ID: {slide_id}) to HDF5.")


MODIFIED HDF5 CREATION SCRIPT (for Multi-class OSCC Grading

In [None]:
import os
import pickle
import h5py
import numpy as np

# Paths
pickle_dir = '/path/to/features/features_organized_by_class'
h5_output_file = '/path/to/save/features/oscc_grading_features_resnet50.hdf5'

# Label mapping
class_labels = {
    "2": 2,  # WD
    "3": 3,  # MD
    "4": 4   # PD
}

# Extract patch coords from name
def extract_coords_from_patch_name(patch_name):
    try:
        parts = patch_name.replace('.png', '').split('_')
        if 'tile' in parts:
            tile_idx = parts.index('tile')
            return int(parts[tile_idx + 1]), int(parts[tile_idx + 2])
        return (0, 0)
    except Exception:
        return (0, 0)

# Start HDF5 creation
with h5py.File(h5_output_file, 'w') as h5_file:
    total_saved = 0
    skipped = []

    for class_folder in os.listdir(pickle_dir):
        class_path = os.path.join(pickle_dir, class_folder)
        if not os.path.isdir(class_path) or class_folder not in class_labels:
            continue

        label = class_labels[class_folder]

        for pickle_file in os.listdir(class_path):
            if not pickle_file.endswith(".pickle"):
                continue

            slide_name = pickle_file.replace("_resnet50Features_dict.pickle", "")
            slide_path = os.path.join(class_path, pickle_file)

            try:
                with open(pickle_path := slide_path, 'rb') as f:
                    patch_features = pickle.load(f)

                if not patch_features:
                    print(f" Skipped (empty features): {slide_name}")
                    skipped.append(slide_name)
                    continue

                embeddings = []
                coords = []

                for patch_name, feature in sorted(patch_features.items()):
                    if feature is not None:
                        embeddings.append(feature)
                        coords.append(extract_coords_from_patch_name(patch_name))

                if len(embeddings) == 0:
                    print(f"⚠️ Skipped (no valid patches): {slide_name}")
                    skipped.append(slide_name)
                    continue

                embeddings = np.array(embeddings)
                coords = np.array(coords)

                group = h5_file.create_group(slide_name)
                group.create_dataset("embeddings", data=embeddings)
                group.create_dataset("coords", data=coords)
                group.create_dataset("label", data=label)
                group.attrs["slide_name"] = slide_name
                group.attrs["label"] = label

                print(f" Saved slide {slide_name} with {embeddings.shape[0]} patches.")
                total_saved += 1

            except Exception as e:
                print(f"Error in slide {slide_name}: {e}")
                skipped.append(slide_name)

print(f"\n✅ HDF5 creation complete. Total saved: {total_saved}")
print(f"⚠️ Slides skipped: {len(skipped)}")
if skipped:
    print("   Examples:", skipped[:5])
