In [1]:
import os
import pandas as pd
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import cv2
import networkx as nx
import numpy as np
import torch
import torchvision.transforms as transforms
from torchvision.models import resnet50
from PIL import Image
import pickle
from torch.utils.data import Dataset, DataLoader
import gc
import psutil
import sys

In [2]:
MEMORY_LIMIT = 6 * 1024 * 1024 * 1024
total_memory = psutil.virtual_memory().total / (1024**3)
print(f"Total system memory: {total_memory:.2f}GB")

Total system memory: 15.63GB


In [3]:
def check_memory():
    """Check if memory usage is approaching limit"""
    process = psutil.Process(os.getpid())
    memory_use = process.memory_info().rss
    if memory_use > MEMORY_LIMIT:
        raise MemoryError(f"Memory usage ({memory_use / 1024**3:.2f}GB) exceeded limit of 7GB")
    return memory_use

def log_memory_usage(tag=""):
    """Log current memory usage"""
    process = psutil.Process(os.getpid())
    memory_use = process.memory_info().rss
    print(f"Memory usage {tag}: {memory_use / 1024**3:.2f}GB")


In [4]:
class ImageDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        try:
            img_path = self.df.iloc[idx]['FilePath']
            label = self.df.iloc[idx]['Label']
            image = Image.open(img_path).convert('RGB')
            
            if self.transform:
                image = self.transform(image)
            
            check_memory()  # Check memory usage
            return image, label, idx
        except Exception as e:
            print(f"Error loading image at index {idx}: {str(e)}")
            return None

In [5]:
def scan_folder_to_dataframe(base_folder):
    print("scan folder to dataframe")
    check_memory()
    data = [(os.path.join(root, file), os.path.basename(root))
            for root, _, files in os.walk(base_folder)
            for file in files]
    df = pd.DataFrame(data, columns=['FilePath', 'Label'])
    log_memory_usage("after DataFrame creation")
    return df



In [6]:
def save_features_increment(features_dict, output_file):
    print("save_features_increment")
    """Save features incrementally to avoid memory buildup"""
    check_memory()
    
    
    if os.path.exists(output_file):
        with open(output_file, 'rb') as f:
            existing_dict = pickle.load(f)
        existing_dict.update(features_dict)
        features_dict = existing_dict
    
    with open(output_file, 'wb') as f:
        pickle.dump(features_dict, f)

In [7]:
def extract_features_batch(model, dataloader, device, output_file, batch_size=32):
    print("extract_features_batch")
    """Extract features in batches and save incrementally to avoid memory buildup"""
    model.eval()
    features_dict = {}
    total_processed = 0
    
    with torch.no_grad():
        for batch_imgs, batch_labels, batch_indices in dataloader:
            check_memory()
            
            # Skip None values from failed loads
            if batch_imgs is None:
                continue
                
            batch_imgs = batch_imgs.to(device)
            
            # Extract features
            features = model(batch_imgs)
            features = features.cpu().numpy()
            
            # Store features and free memory
            for idx, (label, feature) in enumerate(zip(batch_labels, features)):
                original_idx = batch_indices[idx].item()
                features_dict[original_idx] = {
                    'label': label,
                    'features': feature
                }
            
            # Clear GPU memory
            del batch_imgs, features
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
            
            total_processed += batch_size
            
            # Save more frequently to manage memory
            if len(features_dict) >= 500:  # Reduced from 1000 to 500
                save_features_increment(features_dict, output_file)
                features_dict.clear()
                gc.collect()
            
            log_memory_usage(f"after processing {total_processed} images")
    
    # Save any remaining features
    if features_dict:
        save_features_increment(features_dict, output_file)


In [8]:
def build_graph_from_features(feature_file_path, graph_file_path, chunk_size=50, similarity_threshold=0.5):
    """
    Build graph with reduced memory usage by processing in small chunks
    and only storing significant edges.
    """
    print("graph building xxx")
    G = nx.Graph()
    
    # Load features in chunks
    with open(feature_file_path, 'rb') as f:
        features_dict = pickle.load(f)
    
    # Add nodes first
    nodes = list(features_dict.keys())
    labels = {node: features_dict[node]['label'] for node in nodes}
    G.add_nodes_from((node, {'label': label}) for node, label in labels.items())
    
    n_nodes = len(nodes)
    n_chunks = (n_nodes + chunk_size - 1) // chunk_size
    print(n_nodes)
    print(n_chunks)
    
    # Process chunks of nodes
    for i in range(n_chunks):
        start_i = i * chunk_size
        end_i = min((i + 1) * chunk_size, n_nodes)
        chunk_nodes_i = nodes[start_i:end_i]
        
        # Get features for current chunk
        chunk_features_i = np.array([features_dict[node]['features'] for node in chunk_nodes_i])
        norms_i = np.linalg.norm(chunk_features_i, axis=1)
        
        # Only process upper triangular part of similarity matrix
        for j in range(i, n_chunks):
            start_j = j * chunk_size
            end_j = min((j + 1) * chunk_size, n_nodes)
            chunk_nodes_j = nodes[start_j:end_j]
            
            # Get features for comparison chunk
            chunk_features_j = np.array([features_dict[node]['features'] for node in chunk_nodes_j])
            norms_j = np.linalg.norm(chunk_features_j, axis=1)
            
            # Calculate similarities for this small chunk
            similarities = chunk_features_i @ chunk_features_j.T
            similarities = similarities / np.outer(norms_i, norms_j)
            
            # Find significant similarities
            significant_pairs = np.where(similarities > similarity_threshold)
            
            # Add edges for significant similarities
            for idx1, idx2 in zip(*significant_pairs):
                if i != j or idx2 > idx1:  # Avoid duplicate edges
                    node1 = chunk_nodes_i[idx1]
                    node2 = chunk_nodes_j[idx2]
                    G.add_edge(node1, node2, weight=float(similarities[idx1, idx2]))
            
            # Clear memory
            del chunk_features_j
            del similarities
            gc.collect()
            
        # Clear memory after processing each main chunk
        del chunk_features_i
        gc.collect()
        
        # Save progress periodically
        if i % 10 == 0:
            print(f"Processed {end_i}/{n_nodes} nodes")
            print(f"Current graph size: Nodes={G.number_of_nodes()}, Edges={G.number_of_edges()}")
            log_memory_usage(f"after chunk {i}")
            
            # Intermediate save
            with open(graph_file_path + '.temp', 'wb') as f:
                pickle.dump(G, f)
    
    # Save final graph
    with open(graph_file_path, 'wb') as f:
        pickle.dump(G, f)
    
    # Remove temporary file if it exists
    if os.path.exists(graph_file_path + '.temp'):
        os.remove(graph_file_path + '.temp')
    
    return G

In [9]:
def main(base_folder, batch_size=16):  # Reduced default batch size
    print("main fn")
    try:
        # Initialize memory logging
        log_memory_usage("start")
        
        # Create dataset
        dataset = scan_folder_to_dataframe(base_folder)
        train_set, test_set = train_test_split(dataset, test_size=0.2, stratify=dataset['Label'], random_state=42)
        
        # Setup data loading
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        
        train_dataset = ImageDataset(train_set, transform=transform)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
        
        # Setup model
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = resnet50(pretrained=True).to(device)
        
        # Extract features
        feature_file = "features.pkl"
        extract_features_batch(model, train_loader, device, feature_file, batch_size)
        
        # Clear some memory before graph construction
        del model
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        gc.collect()
        
        # Build graph
        graph_file = "graph.pkl"
        G = build_graph_from_features(feature_file, graph_file)
        
        log_memory_usage("end")
        return G
        
    except MemoryError as e:
        print(f"Memory limit exceeded: {str(e)}")
        sys.exit(1)
    except Exception as e:
        print(f"An error occurred: {str(e)}")
        sys.exit(1)

In [10]:
if __name__ == "__main__":
    base_folder = "lung_image_sets"
    G = main(base_folder)

main fn
Memory usage start: 0.60GB
scan folder to dataframe
Memory usage after DataFrame creation: 0.60GB




extract_features_batch
Memory usage after processing 16 images: 1.17GB
Memory usage after processing 32 images: 1.17GB
Memory usage after processing 48 images: 0.99GB
Memory usage after processing 64 images: 0.99GB
Memory usage after processing 80 images: 0.99GB
Memory usage after processing 96 images: 0.99GB
Memory usage after processing 112 images: 0.99GB
Memory usage after processing 128 images: 0.99GB
Memory usage after processing 144 images: 0.99GB
Memory usage after processing 160 images: 0.99GB
Memory usage after processing 176 images: 0.99GB
Memory usage after processing 192 images: 0.99GB
Memory usage after processing 208 images: 0.99GB
Memory usage after processing 224 images: 0.99GB
Memory usage after processing 240 images: 0.99GB
Memory usage after processing 256 images: 0.99GB
Memory usage after processing 272 images: 0.99GB
Memory usage after processing 288 images: 0.99GB
Memory usage after processing 304 images: 0.99GB
Memory usage after processing 320 images: 0.99GB
Mem

KeyboardInterrupt: 