In [None]:
##### -*- coding: utf-8 -*-
import os
import sys
import glob
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
from matplotlib.colors import Normalize
from scipy.interpolate import interp1d

# Ensure repo root on path (run from Data-preprocess or repo root)
for _root in [os.getcwd(), os.path.join(os.getcwd(), "..")]:
    if os.path.exists(os.path.join(_root, "src", "mt2d_inv")):
        if _root not in sys.path:
            sys.path.insert(0, _root)
        break

# ================= Configuration =================
edi_dir = r'./1china'              
freq_min, freq_max = 1, 10000    # Your measured frequency range

# --- Rotation parameters ---
# Negative = NW, positive = NE
MAG_DECLINATION = 16.0    # Magnetic declination (e.g. NW USA ~east)
PROFILE_STRIKE  = -26.2   # Profile strike (conjugate direction)

# --- Plot tuning ---
ELLIPSE_SIZE = 1.5       # Ellipse size
X_STRETCH = 1.0          # X stretch
Y_STRETCH = 4.0          # Y stretch
SKEW_THRESHOLD = 5       # Skew color threshold
SKIP_STATION = 1         # Skip stations for plot
SKIP_FREQ = 1            # Skip freqs for plot
# ===========================================

def parse_dms(dms_str):
    """Parse lat/lon, handle minus sign"""
    dms_str = dms_str.strip()
    try:
        parts = dms_str.split(':')
        d, m, s = float(parts[0]), float(parts[1]), float(parts[2])
        # Determine sign
        sign = -1 if (d < 0 or dms_str.startswith('-')) else 1
        return sign * (abs(d) + m/60.0 + s/3600.0)
    except: return 0.0

class CustomMT:
    def __init__(self, lat, lon, freqs, z_array):
        self.lat, self.lon = lat, lon
        self.freqs = np.array(freqs)
        self.Z = z_array 

    def rotate(self, angle):
        """Rotate Z tensor clockwise"""
        rad = np.deg2rad(angle)
        c, s = np.cos(rad), np.sin(rad)
        R = np.array([[c, s], [-s, c]])
        new_Z = np.zeros_like(self.Z, dtype=complex)
        for i in range(len(self.freqs)):
            new_Z[i] = R @ self.Z[i] @ R.T
        self.Z = new_Z

def read_custom_edi(file_path):
    with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
        lines = f.readlines()
    data_map = {}
    lat, lon = 0.0, 0.0
    current_key, current_vals = None, []
    
    for line in lines:
        line = line.strip()
        if not line: continue
        if line.startswith('LAT='): lat = parse_dms(line.split('=')[1]); continue
        if line.startswith('LONG='): lon = parse_dms(line.split('=')[1]); continue
        if line.startswith('>'):
            if current_key: data_map[current_key] = np.array(current_vals)
            temp = line[1:].split('//')[0].split()[0]
            current_key = temp.replace('.', '').upper()
            current_vals = []
        else:
            if current_key:
                try: current_vals.extend([float(x) for x in line.split()])
                except: pass
    if current_key: data_map[current_key] = np.array(current_vals)
    
    # --- Fix section ---
    freqs = data_map.get('FREQ')
    if freqs is None:
        freqs = data_map.get('FREQUENCIES')
    
    if freqs is None: return None
    # ---------------------
    
    n = len(freqs)
    Z = np.zeros((n, 2, 2), dtype=complex)
    for i, c1 in enumerate(['X','Y']):
        for j, c2 in enumerate(['X','Y']):
            comp = f'{c1}{c2}'
            if f'Z{comp}R' in data_map and f'Z{comp}I' in data_map:
                Z[:,i,j] = data_map[f'Z{comp}R'][:n] + 1j * data_map[f'Z{comp}I'][:n]
    
    return CustomMT(lat, lon, freqs, Z)
    
def project_stations(lats, lons, azimuth):
    # Project to profile direction
    lat0, lon0 = lats[0], lons[0]
    dy = (lats - lat0) * 111.0
    dx = (lons - lon0) * 111.0 * np.cos(np.deg2rad(lat0))
    rad = np.deg2rad(azimuth)
    dist = dx * np.sin(rad) + dy * np.cos(rad)
    return dist - dist.min()

# --- 2. Phase tensor calculation ---
def calc_phase_tensor(Z):
    X = Z.real
    Y = Z.imag
    try:
        X_inv = np.linalg.inv(X)
        Phi = np.matmul(X_inv, Y)
    except: return None 
        
    xx, xy = Phi[0,0], Phi[0,1]
    yx, yy = Phi[1,0], Phi[1,1]
    
    tr = xx + yy
    det = xx*yy - xy*yx
    sk = xy - yx
    
    # Skew (beta)
    beta = 0.5 * np.arctan2(sk, tr) * (180/np.pi)
    
    # Alpha (strike angle)
    # Alpha is relative to current X axis
    alpha = 0.5 * np.arctan2(xy + yx, xx - yy) * (180/np.pi)
    
    # Principal values
    term1 = np.sqrt( (xx-yy)**2 + (xy+yx)**2 )
    phi_max = (np.sqrt(tr**2 + sk**2) + term1) / 2.0
    phi_min = (np.sqrt(tr**2 + sk**2) - term1) / 2.0
    
    return phi_max, phi_min, alpha, beta

# --- 3. Plot main ---
def plot_rotated_pt():
    files = sorted(glob.glob(os.path.join(edi_dir, '*.edi')))
    print(f"Reading {len(files)} files...")
    
    mts = []
    lats, lons = [], []
    for f in files:
        m = read_custom_edi(f)
        if m:
            mts.append(m)
            lats.append(m.lat)
            lons.append(m.lon)
    
    if not mts: return

    # --- Core: rotate here ---
    print(f"Rotating data: declination {MAG_DECLINATION} deg + strike {PROFILE_STRIKE} deg ...")
    total_rotation = MAG_DECLINATION + PROFILE_STRIKE
    # To align profile with X, rotate (Strike + 90) or Strike
    # Rotate by Strike angle.
    # Expected: long axis -> 0 or 90 deg
    
    for m in mts:
        m.rotate(total_rotation)

    # Compute profile coordinates
    # Profile perpendicular to strike
    dists = project_stations(np.array(lats), np.array(lons), PROFILE_STRIKE + 90)
    
    # Plot
    fig, ax = plt.subplots(figsize=(14, 8))
    
    # Colors: Green=2D, Red/Blue=3D
    norm = Normalize(vmin=-SKEW_THRESHOLD, vmax=SKEW_THRESHOLD)
    cmap = plt.cm.nipy_spectral 

    print("Generating ellipses...")
    for idx, mt in enumerate(mts):
        if idx % SKIP_STATION != 0: continue
        x_pos = dists[idx]
        periods = 1.0 / mt.freqs
        
        for f_idx, freq in enumerate(mt.freqs):
            if f_idx % SKIP_FREQ != 0: continue
            if freq < freq_min or freq > freq_max: continue
            
            res = calc_phase_tensor(mt.Z[f_idx])
            if res is None: continue
            phmax, phmin, alpha, beta = res
            
            # Normalize ellipse size
            # Map to 0-90 deg range
            pmx_ang = np.degrees(np.arctan(phmax))
            pmn_ang = np.degrees(np.arctan(phmin))
            
            if pmx_ang <= 0: continue
            ratio = pmn_ang / pmx_ang
            
            base_size = 0.15 * ELLIPSE_SIZE
            width = base_size 
            height = base_size * ratio
            
            # Angle handling
            # Z is already rotated.
            # alpha ~0 or 90 if aligned
            # In plot coords, 0 deg = X axis (right).
            # MTpy alpha from X axis
            plot_angle = -alpha  # Adjust for display
            
            color = cmap(0.5 + 0.5 * (beta / SKEW_THRESHOLD))
            
            ell = Ellipse(
                xy=(x_pos, np.log10(periods[f_idx])),
                width=width * X_STRETCH,
                height=height * Y_STRETCH,
                angle=plot_angle,
                facecolor=color,
                edgecolor='k',
                linewidth=0.2
            )
            ax.add_patch(ell)

    ax.set_xlabel("Profile Distance (km)")
    ax.set_ylabel("Log10 Period (s)")
    ax.invert_yaxis()  # High freq on top
    
    # Set range
    all_periods = np.concatenate([1./m.freqs for m in mts])
    ymin, ymax = np.log10(all_periods.min()), np.log10(all_periods.max())
    ax.set_ylim(ymax+0.5, ymin-0.5)
    ax.set_xlim(dists.min()-2, dists.max()+2)
    
    # Colorbar
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    cbar = plt.colorbar(sm, ax=ax, aspect=30)
    cbar.set_label("Skew Angle (deg)")
    
    plt.title(f"Rotated Phase Tensor (Rotation: {total_rotation:.1f} deg)\nGreen=2D, Red/Blue=3D | Ideal 2D: Ellipses are Horizontal/Vertical")
    plt.tight_layout()
    plt.savefig("Rotated_PT_Section-usa.png", dpi=300)
    print("Done. Output: Rotated_PT_Section-usa.png")
    plt.show()

if __name__ == "__main__":
    plot_rotated_pt()

In [None]:
import os
# Print working directory
print("Working dir:", os.getcwd())

# Print output image path
print("Output image:", os.path.abspath("Rotated_PT_Section-china.png"))