In [None]:
import re
import os
import math
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.ticker import MaxNLocator

# Global plot configuration
DATASETS = ['PEMS08', 'ECL', 'Weather', 'ETTm1']
COLOR_LIST = ['#4C72B0', '#DD6B6B', '#55A868', '#C44E52', '#8172B2', '#CCB974', '#64B5CD']

FIG_SIZE = (12, 8)
BUBBLE_SCALE = 12000
LABEL_PADDING = 8
NAME_FONT = 12
MEM_FONT = 11
LINE_GAP_POINTS = 4
GRID_STYLE = {'alpha': 0.3, 'linestyle': '--', 'linewidth': 0.7, 'color': 'gray'}
SPINE_COLOR = '#E0E0E0'
SPINE_WIDTH = 1.2
LEGEND_FONT = 11
AXIS_LABEL_SIZE = 13
TITLE_SIZE = 16
COMPOSITE_WIDTH_MM = 178
COMPOSITE_DPI = 300

dataset_order = [('PEMS08', (0, 0)), ('ECL', (0, 1)), ('Weather', (1, 0)), ('ETTm1', (1, 1))]

# Composite-specific constants (scaled down to match original proportions + ~1px)
C_NAME_FONT = 4
C_MEM_FONT = 3.5
C_AXIS_LABEL_SIZE = 5
C_TITLE_SIZE = 6
C_LEGEND_FONT = 5
C_BUBBLE_SCALE = 1500
C_LABEL_PADDING = 1
C_LINE_GAP_POINTS = 1
C_SPINE_WIDTH = 0.5
C_GRID_STYLE = {'alpha': 0.3, 'linestyle': '--', 'linewidth': 0.4, 'color': 'gray'}


## 1. Helper Functions
Define the plotting function to visualize efficiency results. This function handles data reading, parsing, and dynamic visualization logic.

In [None]:
def plot_efficiency(dataset_name, with_labels=True):
    file_path = f'efficiency_results_{dataset_name}.txt'
    
    if not os.path.exists(file_path):
        print(f"Notice: file '{file_path}' not found, skipping this dataset.")
        return

    try:
        with open(file_path, 'r') as f:
            content = f.read()
    except Exception as e:
        print(f"Error reading file: {e}")
        return

    # Parse each model block for key metrics
    models_data = []
    blocks = content.split('================================================================================')

    for block in blocks:
        if 'Model:' in block and 'Results:' in block:
            model_match = re.search(r'Model: (\w+)', block)
            mse_match = re.search(r'- MSE: ([\d\.]+)', block)
            time_match = re.search(r'- Training Time: ([\d\.]+)', block)
            memory_match = re.search(r'- Avg Allocated GPU Memory: ([\d\.]+)', block)
            
            if all([model_match, mse_match, time_match, memory_match]):
                memory_mb = float(memory_match.group(1))
                models_data.append({
                    'name': model_match.group(1),
                    'mse': float(mse_match.group(1)),
                    'training_time': float(time_match.group(1)),
                    'memory_gb': memory_mb / 1024  # convert MB to GB
                })

    if not models_data:
        print(f"No model data parsed in {dataset_name}.")
        return

    print(f"Plotting {dataset_name} with {len(models_data)} models...")

    # Prepare ranges
    times = [m['training_time'] for m in models_data]
    mses = [m['mse'] for m in models_data]
    min_time = min(times)
    max_time = max(times)
    min_mse = min(mses)
    max_mse = max(mses)
    
    time_range = max_time - min_time if max_time != min_time else 1
    mse_range = max_mse - min_mse if max_mse != min_mse else 1

    # Build plot
    fig, ax = plt.subplots(figsize=FIG_SIZE)
    fig.patch.set_facecolor('white')
    ax.set_facecolor('white')

    x_margin = time_range * 0.1
    x_start = max(0, min_time - x_margin)
    x_end = max_time + x_margin * 2  # extra room on the right for labels
    
    ax.set_xlim(x_start, x_end)
    ax.set_ylim(min_mse * 0.95, max_mse * 1.05)

    ax.set_xlabel('Training Time (ms/iter)', fontsize=AXIS_LABEL_SIZE, fontweight='bold')
    ax.set_ylabel('MSE', fontsize=AXIS_LABEL_SIZE, fontweight='bold')
    ax.set_title(f'{dataset_name}', fontsize=TITLE_SIZE, fontweight='bold', pad=15)
    
    ax.grid(**GRID_STYLE)
    
    for spine in ax.spines.values():
        spine.set_color(SPINE_COLOR)
        spine.set_linewidth(SPINE_WIDTH)

    # Assign colors per model
    sorted_names = sorted([m['name'] for m in models_data])
    color_map = {name: COLOR_LIST[i % len(COLOR_LIST)] 
                for i, name in enumerate(sorted_names)}

    # Sort by memory for consistent bubble sizes
    models_data.sort(key=lambda x: x['memory_gb'], reverse=True)

    # Scatter bubbles
    for model in models_data:
        size = model['memory_gb'] * BUBBLE_SCALE
        c = color_map[model['name']]
        
        ax.scatter(
            model['training_time'],
            model['mse'],
            s=size,
            c=c,
            alpha=0.3, 
            edgecolors=c,
            linewidth=0.8
        )

    if with_labels:
        # Dynamic label placement helpers
        fig.canvas.draw()
        axes_bbox = ax.get_window_extent()
        px_per_point = fig.dpi / 72.0
        placed_boxes = []

        def points_to_pixels(value):
            return value * px_per_point

        def compute_bbox(base_px, dx_pts, dy_pts, ha, va, width_px, height_px):
            anchor_x = base_px[0] + points_to_pixels(dx_pts)
            anchor_y = base_px[1] + points_to_pixels(dy_pts)

            if ha == 'left':
                x0 = anchor_x
                x1 = anchor_x + width_px
            elif ha == 'right':
                x0 = anchor_x - width_px
                x1 = anchor_x
            else:  # center
                x0 = anchor_x - width_px / 2
                x1 = anchor_x + width_px / 2

            if va == 'bottom':
                y0 = anchor_y
                y1 = anchor_y + height_px
            elif va == 'top':
                y0 = anchor_y - height_px
                y1 = anchor_y
            else:  # center
                y0 = anchor_y - height_px / 2
                y1 = anchor_y + height_px / 2

            return (x0, y0, x1, y1)

        def clamp_to_axes(base_px, dx_pts, dy_pts, ha, va, width_px, height_px):
            for _ in range(6):
                bbox = compute_bbox(base_px, dx_pts, dy_pts, ha, va, width_px, height_px)
                adjusted = False
                if bbox[0] < axes_bbox.x0:
                    dx_pts += (axes_bbox.x0 - bbox[0]) / px_per_point
                    adjusted = True
                if bbox[2] > axes_bbox.x1:
                    dx_pts -= (bbox[2] - axes_bbox.x1) / px_per_point
                    adjusted = True
                if bbox[1] < axes_bbox.y0:
                    dy_pts += (axes_bbox.y0 - bbox[1]) / px_per_point
                    adjusted = True
                if bbox[3] > axes_bbox.y1:
                    dy_pts -= (bbox[3] - axes_bbox.y1) / px_per_point
                    adjusted = True
                if not adjusted:
                    break
            final_bbox = compute_bbox(base_px, dx_pts, dy_pts, ha, va, width_px, height_px)
            return dx_pts, dy_pts, final_bbox

        def bbox_overlap_area(b1, b2):
            x0 = max(b1[0], b2[0])
            y0 = max(b1[1], b2[1])
            x1 = min(b1[2], b2[2])
            y1 = min(b1[3], b2[3])
            if x1 <= x0 or y1 <= y0:
                return 0
            return (x1 - x0) * (y1 - y0)

        def inside_axes(bbox):
            return (
                bbox[0] >= axes_bbox.x0 and
                bbox[2] <= axes_bbox.x1 and
                bbox[1] >= axes_bbox.y0 and
                bbox[3] <= axes_bbox.y1
            )

        def expand_bbox(bbox, padding_px):
            return (
                bbox[0] - padding_px,
                bbox[1] - padding_px,
                bbox[2] + padding_px,
                bbox[3] + padding_px
            )

        def build_candidate_sequence(x_rel, y_rel, offset_dist):
            directions = []

            def push(direction):
                if direction not in directions:
                    directions.append(direction)

            if x_rel <= 0.35:
                push('right_up')
                push('right_down')
            if x_rel >= 0.65:
                push('left_up')
                push('left_down')
            if y_rel <= 0.25:
                push('right_up')
                push('left_up')
            if y_rel >= 0.75:
                push('right_down')
                push('left_down')

            for base_dir in ['right_up', 'left_up', 'right_down', 'left_down', 'up', 'down']:
                push(base_dir)

            diag = offset_dist
            vertical = offset_dist + 10

            mapping = {
                'right_up':  {'dx': diag,  'dy': diag,   'ha': 'left',  'va': 'bottom'},
                'right_down':{'dx': diag,  'dy': -diag,  'ha': 'left',  'va': 'top'},
                'left_up':   {'dx': -diag, 'dy': diag,   'ha': 'right', 'va': 'bottom'},
                'left_down': {'dx': -diag, 'dy': -diag,  'ha': 'right', 'va': 'top'},
                'up':        {'dx': 0,     'dy': vertical,    'ha': 'center','va': 'bottom'},
                'down':      {'dx': 0,     'dy': -vertical,   'ha': 'center','va': 'top'}
            }

            return [mapping[d] for d in directions]

        for model in models_data:
            c = color_map[model['name']]
            size = model['memory_gb'] * BUBBLE_SCALE
            radius_pts = math.sqrt(size / math.pi)
            base_point = (model['training_time'], model['mse'])
            base_px = ax.transData.transform(base_point)

            mem_text = f"{model['memory_gb']:.3f}GB"
            block_width_pts = max(len(model['name']) * NAME_FONT * 0.55,
                                  len(mem_text) * MEM_FONT * 0.55) + 8
            block_height_pts = NAME_FONT + MEM_FONT + LINE_GAP_POINTS
            block_width_px = points_to_pixels(block_width_pts)
            block_height_px = points_to_pixels(block_height_pts)

            x_rel = (model['training_time'] - x_start) / (x_end - x_start)
            y_rel = (model['mse'] - (min_mse * 0.95)) / ((max_mse * 1.05) - (min_mse * 0.95) or 1)

            offset_dist = radius_pts + LABEL_PADDING
            candidates = build_candidate_sequence(x_rel, y_rel, offset_dist)

            best_choice = None
            best_penalty = float('inf')

            for idx, cand in enumerate(candidates):
                dx_adj, dy_adj, bbox = clamp_to_axes(
                    base_px,
                    cand['dx'],
                    cand['dy'],
                    cand['ha'],
                    cand['va'],
                    block_width_px,
                    block_height_px
                )

                overlap_area = sum(bbox_overlap_area(bbox, placed) for placed in placed_boxes)
                if not inside_axes(bbox):
                    continue

                penalty = overlap_area + idx * 1e-3  # slight preference for earlier candidates
                if penalty < best_penalty:
                    best_penalty = penalty
                    best_choice = {
                        'dx': dx_adj,
                        'dy': dy_adj,
                        'ha': cand['ha'],
                        'va': cand['va'],
                        'bbox': bbox
                    }

                if overlap_area == 0:
                    break

            if best_choice is None:
                cand = candidates[0]
                dx_adj, dy_adj, bbox = clamp_to_axes(
                    base_px,
                    cand['dx'],
                    cand['dy'],
                    cand['ha'],
                    cand['va'],
                    block_width_px,
                    block_height_px
                )
                best_choice = {
                    'dx': dx_adj,
                    'dy': dy_adj,
                    'ha': cand['ha'],
                    'va': cand['va'],
                    'bbox': bbox
                }

            placed_boxes.append(expand_bbox(best_choice['bbox'], points_to_pixels(3)))

            if best_choice['va'] == 'top':
                name_offset = best_choice['dy']
                mem_offset = best_choice['dy'] - (NAME_FONT + LINE_GAP_POINTS)
            else:  # treat center as bottom
                mem_offset = best_choice['dy']
                name_offset = best_choice['dy'] + MEM_FONT + LINE_GAP_POINTS

            ax.annotate(
                model['name'],
                base_point,
                xytext=(best_choice['dx'], name_offset),
                textcoords='offset points',
                ha=best_choice['ha'],
                va=best_choice['va'],
                fontsize=NAME_FONT,
                fontweight='bold',
                color=c
            )

            ax.annotate(
                mem_text,
                base_point,
                xytext=(best_choice['dx'], mem_offset),
                textcoords='offset points',
                ha=best_choice['ha'],
                va=best_choice['va'],
                fontsize=MEM_FONT,
                fontweight='normal',
                color='#606060'
            )

    legend_elements = [
        mpatches.Patch(facecolor=color_map[name], 
            edgecolor=color_map[name],
            linewidth=1,
            label=name,
            alpha=0.3) 
        for name in sorted_names
    ]
    ax.legend(handles=legend_elements, loc='best', fontsize=LEGEND_FONT)

    plt.tight_layout()
    plt.show()

## 2. Generate Plots
Cell 5 renders annotated efficiency plots (model name and memory labels). Cell 6 repeats the same plots without on-chart text for a clean bubble view with only the legend.

In [None]:
# Generate annotated plots for all datasets
for dataset in DATASETS:
    plot_efficiency(dataset, with_labels=True)

In [None]:
# Generate plots without text annotations (clean bubbles + legend only)
for dataset in DATASETS:
    plot_efficiency(dataset, with_labels=False)

In [None]:
# Composite 2x2 figure (178mm wide, 300 ppi) with shared legend below

def load_dataset_data(dataset_name):
    file_path = f'efficiency_results_{dataset_name}.txt'
    if not os.path.exists(file_path):
        print(f"Notice: file '{file_path}' not found, skipping {dataset_name}.")
        return []
    try:
        with open(file_path, 'r') as f:
            content = f.read()
    except Exception as e:
        print(f"Error reading {file_path}: {e}")
        return []

    models_data = []
    blocks = content.split('================================================================================')
    for block in blocks:
        if 'Model:' in block and 'Results:' in block:
            model_match = re.search(r'Model: (\w+)', block)
            mse_match = re.search(r'- MSE: ([\d\.]+)', block)
            time_match = re.search(r'- Training Time: ([\d\.]+)', block)
            memory_match = re.search(r'- Avg Allocated GPU Memory: ([\d\.]+)', block)
            if all([model_match, mse_match, time_match, memory_match]):
                memory_mb = float(memory_match.group(1))
                models_data.append({
                    'name': model_match.group(1),
                    'mse': float(mse_match.group(1)),
                    'training_time': float(time_match.group(1)),
                    'memory_gb': memory_mb / 1024
                })
    return models_data

def draw_subplot(ax, dataset_name, models_data, color_map, with_labels=True):
    if not models_data:
        ax.set_visible(False)
        return

    times = [m['training_time'] for m in models_data]
    mses = [m['mse'] for m in models_data]
    min_time, max_time = min(times), max(times)
    min_mse, max_mse = min(mses), max(mses)
    time_range = max_time - min_time if max_time != min_time else 1
    x_margin = time_range * 0.1
    x_start = max(0, min_time - x_margin)
    x_end = max_time + x_margin * 2
    ax.set_xlim(x_start, x_end)
    ax.set_ylim(min_mse * 0.95, max_mse * 1.05)
    ax.set_xlabel('Training Time (ms/iter)', fontsize=C_AXIS_LABEL_SIZE, fontweight='bold')
    ax.set_ylabel('MSE', fontsize=C_AXIS_LABEL_SIZE, fontweight='bold')
    ax.set_title(dataset_name, fontsize=C_TITLE_SIZE, fontweight='bold', pad=4)
    ax.grid(**C_GRID_STYLE)
    
    # Consistent grid lines
    ax.xaxis.set_major_locator(MaxNLocator(nbins=5))
    ax.yaxis.set_major_locator(MaxNLocator(nbins=4))

    for spine in ax.spines.values():
        spine.set_color(SPINE_COLOR)
        spine.set_linewidth(C_SPINE_WIDTH)
    ax.tick_params(axis='both', which='major', labelsize=C_AXIS_LABEL_SIZE-1, width=C_SPINE_WIDTH)

    models_data.sort(key=lambda x: x['memory_gb'], reverse=True)
    for model in models_data:
        size = model['memory_gb'] * C_BUBBLE_SCALE
        c = color_map.get(model['name'], '#999999')
        ax.scatter(model['training_time'], model['mse'], s=size, c=c, alpha=0.3, edgecolors=c, linewidth=0.6)

    if with_labels:
        fig = ax.get_figure()
        fig.canvas.draw()
        axes_bbox = ax.get_window_extent()
        px_per_point = fig.dpi / 72.0
        placed_boxes = []

        def points_to_pixels(value):
            return value * px_per_point

        def compute_bbox(base_px, dx_pts, dy_pts, ha, va, width_px, height_px):
            anchor_x = base_px[0] + points_to_pixels(dx_pts)
            anchor_y = base_px[1] + points_to_pixels(dy_pts)
            if ha == 'left':
                x0, x1 = anchor_x, anchor_x + width_px
            elif ha == 'right':
                x0, x1 = anchor_x - width_px, anchor_x
            else:
                x0, x1 = anchor_x - width_px / 2, anchor_x + width_px / 2
            if va == 'bottom':
                y0, y1 = anchor_y, anchor_y + height_px
            elif va == 'top':
                y0, y1 = anchor_y - height_px, anchor_y
            else:
                y0, y1 = anchor_y - height_px / 2, anchor_y + height_px / 2
            return (x0, y0, x1, y1)

        def clamp_to_axes(base_px, dx_pts, dy_pts, ha, va, width_px, height_px):
            for _ in range(10):
                bbox = compute_bbox(base_px, dx_pts, dy_pts, ha, va, width_px, height_px)
                adjusted = False
                if bbox[0] < axes_bbox.x0:
                    dx_pts += (axes_bbox.x0 - bbox[0]) / px_per_point
                    adjusted = True
                if bbox[2] > axes_bbox.x1:
                    dx_pts -= (bbox[2] - axes_bbox.x1) / px_per_point
                    adjusted = True
                if bbox[1] < axes_bbox.y0:
                    dy_pts += (axes_bbox.y0 - bbox[1]) / px_per_point
                    adjusted = True
                if bbox[3] > axes_bbox.y1:
                    dy_pts -= (bbox[3] - axes_bbox.y1) / px_per_point
                    adjusted = True
                if not adjusted:
                    break
            return dx_pts, dy_pts, compute_bbox(base_px, dx_pts, dy_pts, ha, va, width_px, height_px)

        def bbox_overlap_area(b1, b2):
            x0 = max(b1[0], b2[0])
            y0 = max(b1[1], b2[1])
            x1 = min(b1[2], b2[2])
            y1 = min(b1[3], b2[3])
            if x1 <= x0 or y1 <= y0:
                return 0
            return (x1 - x0) * (y1 - y0)

        def inside_axes(bbox):
            return (bbox[0] >= axes_bbox.x0 and bbox[2] <= axes_bbox.x1 and bbox[1] >= axes_bbox.y0 and bbox[3] <= axes_bbox.y1)

        def expand_bbox(bbox, padding_px):
            return (bbox[0] - padding_px, bbox[1] - padding_px, bbox[2] + padding_px, bbox[3] + padding_px)

        def build_candidate_sequence(x_rel, y_rel, offset_dist):
            directions = []
            def push(direction):
                if direction not in directions:
                    directions.append(direction)
            if x_rel <= 0.35:
                push('right_up')
                push('right_down')
            if x_rel >= 0.65:
                push('left_up')
                push('left_down')
            if y_rel <= 0.25:
                push('right_up')
                push('left_up')
            if y_rel >= 0.75:
                push('right_down')
                push('left_down')
            for base_dir in ['right_up', 'left_up', 'right_down', 'left_down', 'up', 'down']:
                push(base_dir)
            diag = offset_dist
            vertical = offset_dist + 10
            mapping = {
                'right_up':  {'dx': diag,  'dy': diag,   'ha': 'left',  'va': 'bottom'},
                'right_down':{'dx': diag,  'dy': -diag,  'ha': 'left',  'va': 'top'},
                'left_up':   {'dx': -diag, 'dy': diag,   'ha': 'right', 'va': 'bottom'},
                'left_down': {'dx': -diag, 'dy': -diag,  'ha': 'right', 'va': 'top'},
                'up':        {'dx': 0,     'dy': vertical,    'ha': 'center','va': 'bottom'},
                'down':      {'dx': 0,     'dy': -vertical,   'ha': 'center','va': 'top'}
            }
            return [mapping[d] for d in directions]

        for model in models_data:
            c = color_map.get(model['name'], '#999999')
            size = model['memory_gb'] * C_BUBBLE_SCALE
            radius_pts = math.sqrt(size / math.pi)
            base_point = (model['training_time'], model['mse'])
            base_px = ax.transData.transform(base_point)
            mem_text = f"{model['memory_gb']:.3f}GB"
            block_width_pts = max(len(model['name']) * C_NAME_FONT * 0.55, len(mem_text) * C_MEM_FONT * 0.55) + 8
            block_height_pts = C_NAME_FONT + C_MEM_FONT + C_LINE_GAP_POINTS
            block_width_px = points_to_pixels(block_width_pts)
            block_height_px = points_to_pixels(block_height_pts)
            x_rel = (model['training_time'] - x_start) / (x_end - x_start)
            y_rel = (model['mse'] - (min_mse * 0.95)) / ((max_mse * 1.05) - (min_mse * 0.95) or 1)
            
            offset_dist = radius_pts + C_LABEL_PADDING
            if model['name'] in ['BiMamba4TS', 'PatchTST']:
                offset_dist = radius_pts + 0.1

            candidates = build_candidate_sequence(x_rel, y_rel, offset_dist)
            best_choice, best_penalty = None, float('inf')
            for idx, cand in enumerate(candidates):
                dx_adj, dy_adj, bbox = clamp_to_axes(base_px, cand['dx'], cand['dy'], cand['ha'], cand['va'], block_width_px, block_height_px)
                overlap_area = sum(bbox_overlap_area(bbox, placed) for placed in placed_boxes)
                if not inside_axes(bbox):
                    continue
                penalty = overlap_area + idx * 1e-3
                if penalty < best_penalty:
                    best_penalty = penalty
                    best_choice = {'dx': dx_adj, 'dy': dy_adj, 'ha': cand['ha'], 'va': cand['va'], 'bbox': bbox}
                if overlap_area == 0:
                    break
            if best_choice is None:
                cand = candidates[0]
                dx_adj, dy_adj, bbox = clamp_to_axes(base_px, cand['dx'], cand['dy'], cand['ha'], cand['va'], block_width_px, block_height_px)
                best_choice = {'dx': dx_adj, 'dy': dy_adj, 'ha': cand['ha'], 'va': cand['va'], 'bbox': bbox}
            
            if dataset_name in ['ECL', 'Weather'] and model['name'] == 'MODE':
                shift_pts = 6
                best_choice['dx'] += shift_pts
                shift_px = points_to_pixels(shift_pts)
                old_bbox = best_choice['bbox']
                best_choice['bbox'] = (old_bbox[0] + shift_px, old_bbox[1], old_bbox[2] + shift_px, old_bbox[3])

            placed_boxes.append(expand_bbox(best_choice['bbox'], points_to_pixels(3)))
            if best_choice['va'] == 'top':
                name_offset = best_choice['dy']
                mem_offset = best_choice['dy'] - (C_NAME_FONT + C_LINE_GAP_POINTS)
            else:
                mem_offset = best_choice['dy']
                name_offset = best_choice['dy'] + C_NAME_FONT + C_LINE_GAP_POINTS
            ax.annotate(model['name'], base_point, xytext=(best_choice['dx'], name_offset), textcoords='offset points', ha=best_choice['ha'], va=best_choice['va'], fontsize=C_NAME_FONT, fontweight='bold', color=c)
            ax.annotate(mem_text, base_point, xytext=(best_choice['dx'], mem_offset), textcoords='offset points', ha=best_choice['ha'], va=best_choice['va'], fontsize=C_MEM_FONT, fontweight='normal', color='#606060')

dataset_data = {name: load_dataset_data(name) for name, _ in dataset_order}
all_models = sorted({m['name'] for data in dataset_data.values() for m in data})
if not all_models:
    print('No model data found; composite figure not created.')
else:
    color_map = {name: COLOR_LIST[i % len(COLOR_LIST)] for i, name in enumerate(all_models)}
    fig_width_in = COMPOSITE_WIDTH_MM / 25.4
    base_height_in = fig_width_in / 1.5
    legend_pad_in = 0.25
    fig_height_in = base_height_in + legend_pad_in
    fig, axes = plt.subplots(2, 2, figsize=(fig_width_in, fig_height_in), dpi=COMPOSITE_DPI)
    for name, (r, c) in dataset_order:
        draw_subplot(axes[r][c], name, dataset_data.get(name, []), color_map, with_labels=True)
    legend_elements = [mpatches.Patch(facecolor=color_map[name], edgecolor=color_map[name], linewidth=1, label=name, alpha=0.3) for name in all_models]
    fig.subplots_adjust(left=0.05, right=0.98, top=0.95, bottom=0.12, hspace=0.25, wspace=0.15)
    fig.legend(handles=legend_elements, loc='lower center', bbox_to_anchor=(0.5, 0.003), ncol=min(len(legend_elements), 6), fontsize=C_LEGEND_FONT, frameon=False)
    plt.show()

In [None]:
# Custom plot for Weather and ETTm1 with Pred Len 96, excluding BiMamba4TS

# Define datasets and their prediction lengths
target_datasets = [('Weather', 96), ('ETTm1', 96)]
excluded_models = ['BiMamba4TS']

# Custom colors provided
# Purple: #544AFF, Blue: #00AAEE, Green: #3FB704
CUSTOM_COLORS = ['#544AFF', '#00AAEE', '#3FB704', '#DD6B6B', "#A85DB1", '#CCB974', '#64B5CD']
MODEL_PRIORITY = ['MODE', 'S_Mamba', 'iTransformer', 'PatchTST', 'DLinear', 'Flowformer']

def load_and_filter_data(dataset_name):
    file_path = f'efficiency_results_{dataset_name}.txt'
    if not os.path.exists(file_path):
        print(f"Notice: file '{file_path}' not found, skipping {dataset_name}.")
        return []
    try:
        with open(file_path, 'r') as f:
            content = f.read()
    except Exception as e:
        print(f"Error reading {file_path}: {e}")
        return []

    models_data = []
    blocks = content.split('================================================================================')
    for block in blocks:
        if 'Model:' in block and 'Results:' in block:
            model_match = re.search(r'Model: (\w+)', block)
            mse_match = re.search(r'- MSE: ([\d\.]+)', block)
            time_match = re.search(r'- Training Time: ([\d\.]+)', block)
            memory_match = re.search(r'- Avg Allocated GPU Memory: ([\d\.]+)', block)
            
            if all([model_match, mse_match, time_match, memory_match]):
                model_name = model_match.group(1)
                if model_name in excluded_models:
                    continue
                memory_mb = float(memory_match.group(1))
                models_data.append({
                    'name': model_name,
                    'mse': float(mse_match.group(1)),
                    'training_time': float(time_match.group(1)),
                    'memory_gb': memory_mb / 1024
                })
    return models_data

def draw_custom_subplot(ax, dataset_name, pred_len, models_data, color_map):
    if not models_data:
        ax.set_visible(False)
        return

    times = [m['training_time'] for m in models_data]
    mses = [m['mse'] for m in models_data]
    min_time, max_time = min(times), max(times)
    min_mse, max_mse = min(mses), max(mses)
    
    time_range = max_time - min_time if max_time != min_time else 1
    x_margin = time_range * 0.1
    x_start = max(0, min_time - x_margin)
    x_end = max_time + x_margin * 2
    
    ax.set_xlim(x_start, x_end)
    ax.set_ylim(min_mse * 0.95, max_mse * 1.05)
    
    ax.set_xlabel('Training Time (ms/iter)', fontsize=C_AXIS_LABEL_SIZE, fontweight='bold')
    ax.set_ylabel('MSE', fontsize=C_AXIS_LABEL_SIZE, fontweight='bold')
    # Increased padding for title
    ax.set_title(f"{dataset_name} (Pred Len: {pred_len})", fontsize=C_TITLE_SIZE, fontweight='bold', pad=12)
    
    ax.grid(**C_GRID_STYLE)
    ax.xaxis.set_major_locator(MaxNLocator(nbins=5))
    ax.yaxis.set_major_locator(MaxNLocator(nbins=4))

    for spine in ax.spines.values():
        spine.set_color(SPINE_COLOR)
        spine.set_linewidth(C_SPINE_WIDTH)
    ax.tick_params(axis='both', which='major', labelsize=C_AXIS_LABEL_SIZE-1, width=C_SPINE_WIDTH)

    models_data.sort(key=lambda x: x['memory_gb'], reverse=True)
    for model in models_data:
        size = model['memory_gb'] * C_BUBBLE_SCALE
        c = color_map.get(model['name'], '#999999')
        
        # Adjust zorder: iTransformer lower (background), others higher
        zorder = 10
        if model['name'] == 'iTransformer':
            zorder = 1
            
        ax.scatter(model['training_time'], model['mse'], s=size, c=c, alpha=0.3, edgecolors=c, linewidth=0.6, zorder=zorder)

    # Label placement logic
    fig = ax.get_figure()
    fig.canvas.draw()
    axes_bbox = ax.get_window_extent()
    px_per_point = fig.dpi / 72.0
    placed_boxes = []

    def points_to_pixels(value):
        return value * px_per_point

    def compute_bbox(base_px, dx_pts, dy_pts, ha, va, width_px, height_px):
        anchor_x = base_px[0] + points_to_pixels(dx_pts)
        anchor_y = base_px[1] + points_to_pixels(dy_pts)
        if ha == 'left':
            x0, x1 = anchor_x, anchor_x + width_px
        elif ha == 'right':
            x0, x1 = anchor_x - width_px, anchor_x
        else:
            x0, x1 = anchor_x - width_px / 2, anchor_x + width_px / 2
        if va == 'bottom':
            y0, y1 = anchor_y, anchor_y + height_px
        elif va == 'top':
            y0, y1 = anchor_y - height_px, anchor_y
        else:
            y0, y1 = anchor_y - height_px / 2, anchor_y + height_px / 2
        return (x0, y0, x1, y1)

    def clamp_to_axes(base_px, dx_pts, dy_pts, ha, va, width_px, height_px):
        for _ in range(10):
            bbox = compute_bbox(base_px, dx_pts, dy_pts, ha, va, width_px, height_px)
            adjusted = False
            if bbox[0] < axes_bbox.x0:
                dx_pts += (axes_bbox.x0 - bbox[0]) / px_per_point
                adjusted = True
            if bbox[2] > axes_bbox.x1:
                dx_pts -= (bbox[2] - axes_bbox.x1) / px_per_point
                adjusted = True
            if bbox[1] < axes_bbox.y0:
                dy_pts += (axes_bbox.y0 - bbox[1]) / px_per_point
                adjusted = True
            if bbox[3] > axes_bbox.y1:
                dy_pts -= (bbox[3] - axes_bbox.y1) / px_per_point
                adjusted = True
            if not adjusted:
                break
        return dx_pts, dy_pts, compute_bbox(base_px, dx_pts, dy_pts, ha, va, width_px, height_px)

    def bbox_overlap_area(b1, b2):
        x0 = max(b1[0], b2[0])
        y0 = max(b1[1], b2[1])
        x1 = min(b1[2], b2[2])
        y1 = min(b1[3], b2[3])
        if x1 <= x0 or y1 <= y0:
            return 0
        return (x1 - x0) * (y1 - y0)

    def inside_axes(bbox):
        return (bbox[0] >= axes_bbox.x0 and bbox[2] <= axes_bbox.x1 and bbox[1] >= axes_bbox.y0 and bbox[3] <= axes_bbox.y1)

    def expand_bbox(bbox, padding_px):
        return (bbox[0] - padding_px, bbox[1] - padding_px, bbox[2] + padding_px, bbox[3] + padding_px)

    def build_candidate_sequence(x_rel, y_rel, offset_dist):
        directions = []
        def push(direction):
            if direction not in directions:
                directions.append(direction)
        if x_rel <= 0.35:
            push('right_up')
            push('right_down')
        if x_rel >= 0.65:
            push('left_up')
            push('left_down')
        if y_rel <= 0.25:
            push('right_up')
            push('left_up')
        if y_rel >= 0.75:
            push('right_down')
            push('left_down')
        for base_dir in ['right_up', 'left_up', 'right_down', 'left_down', 'up', 'down']:
            push(base_dir)
        diag = offset_dist
        vertical = offset_dist + 10
        mapping = {
            'right_up':  {'dx': diag,  'dy': diag,   'ha': 'left',  'va': 'bottom'},
            'right_down':{'dx': diag,  'dy': -diag,  'ha': 'left',  'va': 'top'},
            'left_up':   {'dx': -diag, 'dy': diag,   'ha': 'right', 'va': 'bottom'},
            'left_down': {'dx': -diag, 'dy': -diag,  'ha': 'right', 'va': 'top'},
            'up':        {'dx': 0,     'dy': vertical,    'ha': 'center','va': 'bottom'},
            'down':      {'dx': 0,     'dy': -vertical,   'ha': 'center','va': 'top'}
        }
        return [mapping[d] for d in directions]

    for model in models_data:
        c = color_map.get(model['name'], '#999999')
        size = model['memory_gb'] * C_BUBBLE_SCALE
        radius_pts = math.sqrt(size / math.pi)
        base_point = (model['training_time'], model['mse'])
        base_px = ax.transData.transform(base_point)
        mem_text = f"{model['memory_gb']:.3f}GB"
        block_width_pts = max(len(model['name']) * C_NAME_FONT * 0.55, len(mem_text) * C_MEM_FONT * 0.55) + 8
        block_height_pts = C_NAME_FONT + C_MEM_FONT + C_LINE_GAP_POINTS
        block_width_px = points_to_pixels(block_width_pts)
        block_height_px = points_to_pixels(block_height_pts)
        x_rel = (model['training_time'] - x_start) / (x_end - x_start)
        y_rel = (model['mse'] - (min_mse * 0.95)) / ((max_mse * 1.05) - (min_mse * 0.95) or 1)
        
        offset_dist = radius_pts + C_LABEL_PADDING
        if model['name'] in ['PatchTST']:
            offset_dist = radius_pts + 0.1

        candidates = build_candidate_sequence(x_rel, y_rel, offset_dist)
        best_choice, best_penalty = None, float('inf')
        for idx, cand in enumerate(candidates):
            dx_adj, dy_adj, bbox = clamp_to_axes(base_px, cand['dx'], cand['dy'], cand['ha'], cand['va'], block_width_px, block_height_px)
            overlap_area = sum(bbox_overlap_area(bbox, placed) for placed in placed_boxes)
            if not inside_axes(bbox):
                continue
            # Increased penalty for overlap
            penalty = overlap_area * 10 + idx * 1e-3
            if penalty < best_penalty:
                best_penalty = penalty
                best_choice = {'dx': dx_adj, 'dy': dy_adj, 'ha': cand['ha'], 'va': cand['va'], 'bbox': bbox}
            if overlap_area == 0:
                break
        if best_choice is None:
            cand = candidates[0]
            dx_adj, dy_adj, bbox = clamp_to_axes(base_px, cand['dx'], cand['dy'], cand['ha'], cand['va'], block_width_px, block_height_px)
            best_choice = {'dx': dx_adj, 'dy': dy_adj, 'ha': cand['ha'], 'va': cand['va'], 'bbox': bbox}
        
        if dataset_name in ['ECL', 'Weather'] and model['name'] == 'MODE':
            shift_pts = 6
            best_choice['dx'] += shift_pts
            shift_px = points_to_pixels(shift_pts)
            old_bbox = best_choice['bbox']
            best_choice['bbox'] = (old_bbox[0] + shift_px, old_bbox[1], old_bbox[2] + shift_px, old_bbox[3])

        placed_boxes.append(expand_bbox(best_choice['bbox'], points_to_pixels(3)))
        if best_choice['va'] == 'top':
            name_offset = best_choice['dy']
            mem_offset = best_choice['dy'] - (C_NAME_FONT + C_LINE_GAP_POINTS)
        else:
            mem_offset = best_choice['dy']
            name_offset = best_choice['dy'] + C_NAME_FONT + C_LINE_GAP_POINTS
        ax.annotate(model['name'], base_point, xytext=(best_choice['dx'], name_offset), textcoords='offset points', ha=best_choice['ha'], va=best_choice['va'], fontsize=C_NAME_FONT, fontweight='bold', color=c)
        ax.annotate(mem_text, base_point, xytext=(best_choice['dx'], mem_offset), textcoords='offset points', ha=best_choice['ha'], va=best_choice['va'], fontsize=C_MEM_FONT, fontweight='normal', color='#606060')

# Prepare data
dataset_data = {}
all_models = set()
for name, _ in target_datasets:
    data = load_and_filter_data(name)
    dataset_data[name] = data
    for m in data:
        all_models.add(m['name'])

if not all_models:
    print('No model data found; figure not created.')
else:
    # Sort models by priority
    def get_priority(name):
        if name in MODEL_PRIORITY:
            return MODEL_PRIORITY.index(name)
        return len(MODEL_PRIORITY) + (1 if name > '' else 0)
    
    sorted_models = sorted(list(all_models), key=get_priority)
    
    # Assign colors
    color_map = {}
    # Pre-assign priority colors
    priority_colors = {
        'MODE': '#544AFF',      # Purple
        'S_Mamba': '#00AAEE',   # Blue
        'DLinear': '#3FB704',   # Green (reassigned from iTransformer)
        'iTransformer': '#999999', # Gray
        'Flowformer': '#1E7636', # Dark Green
    }
    
    # Assign colors to sorted models
    used_colors = set()
    for name in sorted_models:
        if name in priority_colors:
            color_map[name] = priority_colors[name]
            used_colors.add(priority_colors[name])
    
    # Assign remaining colors
    available_colors = [c for c in CUSTOM_COLORS if c not in used_colors]
    extra_pool = CUSTOM_COLORS # Fallback
    
    for name in sorted_models:
        if name not in color_map:
            if available_colors:
                c = available_colors.pop(0)
            else:
                c = extra_pool[len(color_map) % len(extra_pool)]
            color_map[name] = c

    # Create figure
    fig_width_in = COMPOSITE_WIDTH_MM / 25.4
    fig_height_in = fig_width_in / 2.5 # More elongated (smaller height)
    fig, axes = plt.subplots(1, 2, figsize=(fig_width_in, fig_height_in), dpi=COMPOSITE_DPI)

    for i, (name, pred_len) in enumerate(target_datasets):
        draw_custom_subplot(axes[i], name, pred_len, dataset_data[name], color_map)

    legend_elements = [mpatches.Patch(facecolor=color_map[name], edgecolor=color_map[name], linewidth=1, label=name, alpha=0.3) for name in sorted_models]
    
    # Adjust layout: 
    # - wspace reduced to 0.15 (closer subplots)
    # - bottom reduced to 0.20 (closer legend)
    fig.subplots_adjust(left=0.08, right=0.98, top=0.80, bottom=0.20, wspace=0.15)
    
    # Legend in one row (ncol=len), closer to plots (bbox y adjusted)
    fig.legend(handles=legend_elements, loc='lower center', bbox_to_anchor=(0.5, 0.02), ncol=len(legend_elements), fontsize=C_LEGEND_FONT, frameon=False)
    plt.show()