In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from skimage.morphology import remove_small_objects
import matplotlib.pyplot as plt
import matplotlib
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score,
    f1_score, roc_auc_score
)
matplotlib.use('Agg')  # headless-safe

# --------------------- CTF LOADING --------------------- 
def load_ctf_file(path, mad_thresh, bands_thresh, min_region_size):
    """
    Load .ctf → (feature H×W×3, target H×W) using:
      target = (MAD < mad_thresh) & (Bands > bands_thresh),
      small regions < min_region_size removed.
    """
    df = pd.read_csv(path, sep=r"\s+|\t", comment='[', engine='python')
    expected = {'X','Y','Euler1','Euler2','Euler3','Bands','MAD'}
    missing = expected - set(df.columns)
    if missing:
        print(f"[{os.path.basename(path)}] Missing columns: {missing}")
        return None

    # Keep only indexed points
    df = df[df['Bands'] > 0]

    # Build raw mask
    df['recryst'] = ((df['MAD'] < mad_thresh) &
                     (df['Bands'] > bands_thresh)).astype(np.uint8)

    frac = df['recryst'].mean() * 100
    print(f"[{os.path.basename(path)}] {frac:.1f}% recrystallized "
          f"(MAD<{mad_thresh}, Bands>{bands_thresh})")


    # Construct grids
    x_vals = np.sort(df['X'].unique())
    y_vals = np.sort(df['Y'].unique())

    # Pivot each Euler angle
    eulers = {}
    for col in ['Euler1','Euler2','Euler3']:
        grid = df.pivot(index='Y', columns='X', values=col)
        grid = grid.reindex(index=y_vals, columns=x_vals)
        grid = grid.interpolate(axis=1, limit_direction='both')
        grid = grid.interpolate(axis=0, limit_direction='both')
        grid = grid.ffill(axis=1).bfill(axis=1)
        grid = grid.ffill(axis=0).bfill(axis=0)
        eulers[col] = grid.values

    # Pivot the mask and clean it with skimage
    mask_df = df.pivot(index='Y', columns='X', values='recryst')
    mask_df = mask_df.reindex(index=y_vals, columns=x_vals).fillna(0)
    mask_arr = mask_df.values.astype(bool)
    mask_clean = remove_small_objects(mask_arr, min_size=min_region_size).astype(np.uint8)

    # Stack features
    feature = np.stack([eulers['Euler1'], eulers['Euler2'], eulers['Euler3']], axis=2)
    target  = mask_clean

    return feature, target

def load_all_ctf_data(folder, mad_thresh=0.5, bands_thresh=8, min_region_size=100):
    maps = []
    for fname in os.listdir(folder):
        if fname.lower().endswith('.ctf'):
            data = load_ctf_file(
                os.path.join(folder, fname),
                mad_thresh, bands_thresh, min_region_size
            )
            if data is not None:
                maps.append(data)
    print(f"Loaded {len(maps)} maps from '{folder}'")
    return maps

# ------------------ PATCH DATASET -----------------------
class PatchDataset(Dataset):
    def __init__(self, maps, patch_size=64, patches_per_map=300):
        self.patches = []
        for feat, mask in maps:
            H, W, _ = feat.shape
            if H < patch_size or W < patch_size:
                continue
            for _ in range(patches_per_map):
                x = np.random.randint(0, W - patch_size + 1)
                y = np.random.randint(0, H - patch_size + 1)
                self.patches.append((
                    feat[y:y+patch_size, x:x+patch_size],
                    mask[y:y+patch_size, x:x+patch_size]
                ))

    def __len__(self):
        return len(self.patches)

    def __getitem__(self, idx):
        f, m = self.patches[idx]
        return (
            torch.tensor(f, dtype=torch.float32).permute(2,0,1),
            torch.tensor(m, dtype=torch.long)
        )

# ------------------- DOUBLE U-NET -----------------------
class UNetBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_c, out_c, 3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, 3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
        )
    def forward(self, x): return self.block(x)

class UNet(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.enc1 = UNetBlock(in_c,64)
        self.enc2 = UNetBlock(64,128)
        self.enc3 = UNetBlock(128,256)
        self.pool = nn.MaxPool2d(2)
        self.up2  = nn.ConvTranspose2d(256,128,2,stride=2)
        self.up1  = nn.ConvTranspose2d(128,64,2,stride=2)
        self.dec2 = UNetBlock(256,128)
        self.dec1 = UNetBlock(128,64)
        self.out  = nn.Conv2d(64,out_c,1)

    def _crop(self, enc, dec):
        _,_,h_e,w_e = enc.size()
        _,_,h_d,w_d = dec.size()
        ch, cw = (h_e-h_d)//2, (w_e-w_d)//2
        return enc[:,:,ch:ch+h_d,cw:cw+w_d]

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        u2 = self.up2(e3); c2 = self._crop(e2,u2)
        d2 = self.dec2(torch.cat([u2,c2], dim=1))
        u1 = self.up1(d2); c1 = self._crop(e1,u1)
        d1 = self.dec1(torch.cat([u1,c1], dim=1))
        return self.out(d1)

class DoubleUNet(nn.Module):
    def __init__(self, in_ch, out_c):
        super().__init__()
        self.u1 = UNet(in_ch, out_c)
        self.u2 = UNet(in_ch + out_c, out_c)
    def forward(self, x):
        o1 = self.u1(x)
        s  = torch.softmax(o1, dim=1)
        h  = min(x.size(2), s.size(2))
        w  = min(x.size(3), s.size(3))
        xc = x[:,:,:h,:w]
        sc = s[:,:,:h,:w]
        return self.u2(torch.cat([xc, sc],dim=1))

# ------------------ METRICS & PLOTTING --------------------
def evaluate_model(model, loader, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    all_p, all_t, all_pr = [], [], []
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            out = model(xb)
            prob = torch.softmax(out,1)[:,1,:,:].reshape(-1).cpu().numpy()
            p    = torch.argmax(out,1).reshape(-1).cpu().numpy()
            t    = yb.reshape(-1).cpu().numpy()
            all_pr.append(prob); all_p.append(p); all_t.append(t)

    y_prob = np.concatenate(all_pr)
    y_pred = np.concatenate(all_p)
    y_true = np.concatenate(all_t)

    return {
        'accuracy' : accuracy_score(y_true,y_pred),
        'precision': precision_score(y_true,y_pred,average='binary',zero_division=0),
        'recall'   : recall_score(y_true,y_pred,average='binary',zero_division=0),
        'f1'       : f1_score(y_true,y_pred,average='binary',zero_division=0),
        'roc_auc'  : roc_auc_score(y_true,y_prob)
    }

def plot_history(hist, save_path='training_metrics.png'):
    epochs = range(1,len(hist['loss'])+1)
    plt.figure(figsize=(12,10))
    keys = ['loss','accuracy','precision','recall','f1','roc_auc']
    for i,k in enumerate(keys,1):
        plt.subplot(3,2,i)
        plt.plot(epochs, hist[k], '-o')
        plt.title(k.replace('_',' ').title())
        plt.ylim(0,1 if k!='loss' else None)
        plt.grid()
    plt.tight_layout(); plt.savefig(save_path); plt.close()

# -------------------- TRAINING --------------------------
def train_model(model, train_loader, val_loader, epochs=10, lr=1e-3):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()

    hist = {k:[] for k in ('loss','accuracy','precision','recall','f1','roc_auc')}

    for ep in range(1, epochs+1):
        model.train(); running=0.0
        for xb,yb in train_loader:
            xb,yb = xb.to(device), yb.to(device)
            opt.zero_grad()
            out = model(xb)
            loss = loss_fn(out,yb)
            loss.backward(); opt.step()
            running += loss.item()
        train_loss = running/len(train_loader)
        hist['loss'].append(train_loss)

        metrics = evaluate_model(model, val_loader, device)
        for k,v in metrics.items():
            hist[k].append(v)

        print(f"Epoch {ep}/{epochs} — loss {train_loss:.4f}, "
              f"acc {metrics['accuracy']:.4f}, prec {metrics['precision']:.4f}, "
              f"rec {metrics['recall']:.4f}, f1 {metrics['f1']:.4f}, auc {metrics['roc_auc']:.4f}")

    plot_history(hist)
    return hist

# -------------------- INFERENCE VIS ----------------------
def visualize_predictions(model, maps, output_dir='output_images'):
    os.makedirs(output_dir, exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device).eval()
    with torch.no_grad():
        for i,(fm,_) in enumerate(maps):
            H,W,_ = fm.shape
            x = torch.tensor(fm, dtype=torch.float32).permute(2,0,1).unsqueeze(0).to(device)
            raw = model(x)
            h,w = min(H, raw.size(2)), min(W, raw.size(3))
            pred = torch.argmax(raw[:,:,:h,:w].squeeze(), dim=0).cpu().numpy()
            mask = remove_small_objects(pred.astype(bool), min_size=64).astype(np.uint8)
            color = np.zeros((h,w,3), dtype=np.uint8)
            color[mask==1] = [0,0,255]; color[mask==0] = [255,0,0]
            plt.imsave(os.path.join(output_dir, f'pred_{i}.png'), color)


# -------------------- MAIN ------------------------------
if __name__=='__main__':
    maps = load_all_ctf_data('CTF',
                             mad_thresh=0.5,
                             bands_thresh=8,
                             min_region_size=0)

    split = int(0.8*len(maps))
    train_maps, val_maps = maps[:split], maps[split:]

    train_loader = DataLoader(PatchDataset(train_maps), batch_size=16, shuffle=True)
    val_loader   = DataLoader(PatchDataset(val_maps, patches_per_map=100),
                              batch_size=16, shuffle=False)

    model   = DoubleUNet(3,2)
    history = train_model(model, train_loader, val_loader, epochs=10, lr=1e-3)

    print("Final Train Metrics:", evaluate_model(model, train_loader))
    print("Final  Val Metrics:", evaluate_model(model, val_loader))

    visualize_predictions(model, maps)


In [None]:
from PIL import Image
import numpy as np

# Load the prediction image
image_path = 'output_images/pred_0.png'
image = Image.open(image_path)
image_array = np.array(image)

# Define the colors for recrystallized (blue) and deformed (red)
recrystallized_color = [0, 0, 255]  # Blue
deformed_color = [255, 0, 0]        # Red

# Count the pixels for each category
total_pixels = image_array.shape[0] * image_array.shape[1]
# Ensure the comparison includes the alpha channel
recrystallized_pixels = np.sum(np.all(image_array[:, :, :3] == recrystallized_color, axis=-1))
deformed_pixels = np.sum(np.all(image_array[:, :, :3] == deformed_color, axis=-1))

# Calculate percentages
recrystallized_percentage = (recrystallized_pixels / total_pixels) * 100
deformed_percentage = (deformed_pixels / total_pixels) * 100

print(f"Recrystallized: {recrystallized_percentage:.2f}%")
print(f"Deformed: {deformed_percentage:.2f}%")