|
| 1 | +import sys |
| 2 | +import json |
| 3 | +from collections import defaultdict, deque |
| 4 | +from form_blocks import form_blocks |
| 5 | +from util import flatten |
| 6 | + |
| 7 | +def construct_cfg(blocks): |
| 8 | + labels = [] |
| 9 | + block_map = {} |
| 10 | + for idx, block in enumerate(blocks): |
| 11 | + if 'label' in block[0]: |
| 12 | + label = block[0]['label'] |
| 13 | + else: |
| 14 | + label = f'<bb{idx}>' |
| 15 | + labels.append(label) |
| 16 | + block_map[label] = block |
| 17 | + |
| 18 | + cfg = {label: [] for label in labels} |
| 19 | + for idx, block in enumerate(blocks): |
| 20 | + label = labels[idx] |
| 21 | + last_instr = block[-1] if block else None |
| 22 | + if last_instr and 'op' in last_instr: |
| 23 | + op = last_instr['op'] |
| 24 | + if op == 'br': |
| 25 | + cfg[label].extend(last_instr['labels']) |
| 26 | + elif op == 'jmp': |
| 27 | + cfg[label].append(last_instr['labels'][0]) |
| 28 | + else: |
| 29 | + if idx + 1 < len(blocks): |
| 30 | + cfg[label].append(labels[idx + 1]) |
| 31 | + else: |
| 32 | + if idx + 1 < len(blocks): |
| 33 | + cfg[label].append(labels[idx + 1]) |
| 34 | + return cfg, labels |
| 35 | + |
| 36 | +def alias_analysis(blocks, cfg, labels, func_args, alloc_sites): |
| 37 | + block_map = dict(zip(labels, blocks)) |
| 38 | + in_state = {label: {} for label in labels} |
| 39 | + out_state = {label: {} for label in labels} |
| 40 | + |
| 41 | + # Initialize the analysis state |
| 42 | + init_state = {} |
| 43 | + for arg in func_args: |
| 44 | + init_state[arg] = set(['unknown']) |
| 45 | + |
| 46 | + worklist = deque(labels) |
| 47 | + while worklist: |
| 48 | + label = worklist.popleft() |
| 49 | + block = block_map[label] |
| 50 | + in_map = {} |
| 51 | + preds = [pred for pred in cfg if label in cfg[pred]] |
| 52 | + if preds: |
| 53 | + for pred in preds: |
| 54 | + pred_out = out_state[pred] |
| 55 | + for var in pred_out: |
| 56 | + in_map.setdefault(var, set()).update(pred_out[var]) |
| 57 | + else: |
| 58 | + in_map = {var: locs.copy() for var, locs in init_state.items()} |
| 59 | + |
| 60 | + # **Add this line to update in_state** |
| 61 | + in_state[label] = in_map.copy() |
| 62 | + |
| 63 | + old_out = out_state[label] |
| 64 | + out_map = analyze_block(block, in_map, alloc_sites) |
| 65 | + out_state[label] = out_map |
| 66 | + if out_map != old_out: |
| 67 | + for succ in cfg[label]: |
| 68 | + if succ not in worklist: |
| 69 | + worklist.append(succ) |
| 70 | + # print(f"[Alias Analysis] Block: {label}, In-State: {in_map}") |
| 71 | + # print(f"[Alias Analysis] Block: {label}, Out-State: {out_map}") |
| 72 | + return in_state, out_state |
| 73 | + |
| 74 | +def analyze_block(block, in_map, alloc_sites): |
| 75 | + state = {var: locs.copy() for var, locs in in_map.items()} |
| 76 | + for instr in block: |
| 77 | + if 'op' in instr: |
| 78 | + op = instr['op'] |
| 79 | + dest = instr.get('dest') |
| 80 | + args = instr.get('args', []) |
| 81 | + if op == 'alloc': |
| 82 | + # x = alloc n: x points to this allocation |
| 83 | + alloc_site = alloc_sites[id(instr)] |
| 84 | + state[dest] = set([alloc_site]) |
| 85 | + elif op == 'id': |
| 86 | + y = args[0] |
| 87 | + if y in state: |
| 88 | + # Only update if y is in state |
| 89 | + state[dest] = state[y].copy() |
| 90 | + elif op == 'ptradd': |
| 91 | + p = args[0] |
| 92 | + if p in state: |
| 93 | + # Only update if p is in state |
| 94 | + state[dest] = state[p].copy() |
| 95 | + elif op == 'load': |
| 96 | + # x = load p: x points to all memory locations |
| 97 | + state[dest] = set(['unknown']) |
| 98 | + elif op == 'call': |
| 99 | + # Function calls can change pointers conservatively |
| 100 | + state[dest] = set(['unknown']) |
| 101 | + # Do not update state for non-pointer variables |
| 102 | + return state |
| 103 | + |
| 104 | +def memory_liveness_analysis(blocks, cfg, labels, alias_info): |
| 105 | + block_map = dict(zip(labels, blocks)) |
| 106 | + live_in = {label: set() for label in labels} |
| 107 | + live_out = {label: set() for label in labels} |
| 108 | + changed = True |
| 109 | + while changed: |
| 110 | + changed = False |
| 111 | + for label in reversed(labels): |
| 112 | + block = block_map[label] |
| 113 | + alias_maps = alias_info[label] |
| 114 | + out_set = set() |
| 115 | + succs = cfg.get(label, []) |
| 116 | + for succ in succs: |
| 117 | + out_set.update(live_in[succ]) |
| 118 | + old_in = live_in[label].copy() |
| 119 | + live_out[label] = out_set.copy() |
| 120 | + in_set = analyze_memory_uses(block, live_out[label], alias_maps) |
| 121 | + live_in[label] = in_set |
| 122 | + if live_in[label] != old_in: |
| 123 | + changed = True |
| 124 | + # print(f"[Liveness] Block: {label}, Live-In: {live_in[label]}") |
| 125 | + # print(f"[Liveness] Block: {label}, Live-Out: {live_out[label]}") |
| 126 | + return live_in, live_out |
| 127 | + |
| 128 | +def analyze_memory_uses(block, live_out_set, alias_maps): |
| 129 | + live_set = live_out_set.copy() |
| 130 | + for idx in reversed(range(len(block))): |
| 131 | + instr = block[idx] |
| 132 | + alias_map = alias_maps[idx] |
| 133 | + if 'op' in instr: |
| 134 | + op = instr['op'] |
| 135 | + args = instr.get('args', []) |
| 136 | + if op == 'ret': |
| 137 | + if args: |
| 138 | + ret_var = args[0] |
| 139 | + pts = alias_map.get(ret_var, set()) |
| 140 | + live_set.update(pts) |
| 141 | + elif op == 'store': |
| 142 | + p = args[0] |
| 143 | + pts = alias_map.get(p, set()) |
| 144 | + # print(f"[Store Operation] Address: {p}, Points-To: {pts}, Live-Set Before: {live_set}") |
| 145 | + live_set.update(pts) |
| 146 | + elif op == 'load': |
| 147 | + p = args[0] |
| 148 | + pts = alias_map.get(p, set()) |
| 149 | + live_set.update(pts) |
| 150 | + elif op == 'free': |
| 151 | + p = args[0] |
| 152 | + pts = alias_map.get(p, set()) |
| 153 | + live_set -= pts |
| 154 | + elif op == 'call': |
| 155 | + for arg in args: |
| 156 | + pts = alias_map.get(arg, set()) |
| 157 | + live_set.update(pts) |
| 158 | + return live_set |
| 159 | + |
| 160 | +def remove_dead_stores(func, alias_info, live_out): |
| 161 | + blocks = list(form_blocks(func['instrs'])) |
| 162 | + cfg, labels = construct_cfg(blocks) |
| 163 | + block_map = dict(zip(labels, blocks)) |
| 164 | + for label in labels: |
| 165 | + block = block_map[label] |
| 166 | + alias_maps = alias_info[label] |
| 167 | + live_vars = live_out[label].copy() |
| 168 | + new_block = [] |
| 169 | + for idx in reversed(range(len(block))): |
| 170 | + instr = block[idx] |
| 171 | + alias_map = alias_maps[idx] |
| 172 | + if 'op' in instr and instr['op'] == 'store': |
| 173 | + p = instr['args'][0] |
| 174 | + pts = alias_map.get(p, set()) |
| 175 | + if pts.isdisjoint(live_vars) and 'unknown' not in pts and 'unknown' not in live_vars: |
| 176 | + # print(f"[Remove Decision] Store Instruction: {instr}, Points-To: {pts}, Live Vars: {live_vars}") |
| 177 | + # Store is dead, eliminate it |
| 178 | + continue |
| 179 | + # print(f"[Keep Decision] Store Instruction: {instr}, Points-To: {pts}, Live Vars: {live_vars}") |
| 180 | + if 'op' in instr: |
| 181 | + op = instr['op'] |
| 182 | + args = instr.get('args', []) |
| 183 | + if op == 'load': |
| 184 | + p = args[0] |
| 185 | + pts = alias_map.get(p, set()) |
| 186 | + live_vars.update(pts) |
| 187 | + elif op == 'free': |
| 188 | + p = args[0] |
| 189 | + pts = alias_map.get(p, set()) |
| 190 | + live_vars -= pts |
| 191 | + elif op == 'ret': |
| 192 | + if args: |
| 193 | + ret_var = args[0] |
| 194 | + pts = alias_map.get(ret_var, set()) |
| 195 | + live_vars.update(pts) |
| 196 | + elif op == 'call': |
| 197 | + for arg in args: |
| 198 | + pts = alias_map.get(arg, set()) |
| 199 | + live_vars.update(pts) |
| 200 | + new_block.append(instr) |
| 201 | + block[:] = reversed(new_block) |
| 202 | + func['instrs'] = flatten([block_map[label] for label in labels]) |
| 203 | + |
| 204 | + |
| 205 | +def collect_alloc_sites(func): |
| 206 | + alloc_sites = {} |
| 207 | + counter = 0 |
| 208 | + for instr in func['instrs']: |
| 209 | + if 'op' in instr and instr['op'] == 'alloc': |
| 210 | + alloc_sites[id(instr)] = f'alloc_{counter}' |
| 211 | + counter += 1 |
| 212 | + return alloc_sites |
| 213 | + |
| 214 | +def get_func_args(func): |
| 215 | + args = [] |
| 216 | + for arg in func.get('args', []): |
| 217 | + args.append(arg['name']) |
| 218 | + return args |
| 219 | + |
| 220 | +def optimize_function(func): |
| 221 | + blocks = list(form_blocks(func['instrs'])) |
| 222 | + cfg, labels = construct_cfg(blocks) |
| 223 | + alloc_sites = collect_alloc_sites(func) |
| 224 | + func_args = get_func_args(func) |
| 225 | + # Perform alias analysis |
| 226 | + in_alias, out_alias = alias_analysis(blocks, cfg, labels, func_args, alloc_sites) |
| 227 | + # Get alias info at each instruction |
| 228 | + alias_info = {} |
| 229 | + block_map = dict(zip(labels, blocks)) |
| 230 | + # print("labels:", labels) |
| 231 | + for label in labels: |
| 232 | + block = block_map[label] |
| 233 | + state = {var: locs.copy() for var, locs in in_alias.get(label, {}).items()} |
| 234 | + alias_maps = [] |
| 235 | + for instr in block: |
| 236 | + alias_maps.append(state.copy()) |
| 237 | + if 'op' in instr: |
| 238 | + op = instr['op'] |
| 239 | + dest = instr.get('dest') |
| 240 | + args = instr.get('args', []) |
| 241 | + if op == 'alloc': |
| 242 | + alloc_site = alloc_sites[id(instr)] |
| 243 | + state[dest] = set([alloc_site]) |
| 244 | + elif op == 'id': |
| 245 | + y = args[0] |
| 246 | + state[dest] = state.get(y, set()).copy() |
| 247 | + elif op == 'ptradd': |
| 248 | + p = args[0] |
| 249 | + state[dest] = state.get(p, set()).copy() |
| 250 | + elif op == 'load': |
| 251 | + state[dest] = set(['unknown']) |
| 252 | + elif op == 'call': |
| 253 | + state[dest] = set(['unknown']) |
| 254 | + alias_info[label] = alias_maps # Store the list of alias maps |
| 255 | + # Perform memory liveness analysis |
| 256 | + live_in, live_out = memory_liveness_analysis(blocks, cfg, labels, alias_info) |
| 257 | + # Remove dead stores |
| 258 | + remove_dead_stores(func, alias_info, live_out) |
| 259 | + |
| 260 | +def main(): |
| 261 | + program = json.load(sys.stdin) |
| 262 | + for func in program['functions']: |
| 263 | + optimize_function(func) |
| 264 | + # print(program) |
| 265 | + json.dump(program, sys.stdout, indent=2) |
| 266 | + |
| 267 | +if __name__ == '__main__': |
| 268 | + main() |
0 commit comments