In [None]:
import os
import cv2
import torch
import numpy as np
import pandas as pd
from PIL import Image
from glob import glob
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data
from scipy.spatial import Delaunay
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.neighbors import NearestNeighbors
import torchvision.transforms as T
import torchvision.transforms.functional as TF

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:
label_to_int = {'Benign': 0, 'InSitu': 1, 'Invasive': 2, 'Normal': 3}
# ['Benign', 'InSitu', 'Invasive', 'Normal']
root_dir = '../data/Photos/'
patches = 'Patches'
graph_dir_delaunay = './graphs_delaunay/'
graph_dir_knn = './graphs_knn/'

os.makedirs(graph_dir_delaunay, exist_ok=True)
os.makedirs(graph_dir_knn, exist_ok=True)


In [None]:
from torchvision.models import vit_b_16, ViT_B_16_Weights

# Load ViT-B/16 with ImageNet pretrained weights
model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)

# Replace the classification head to output 4 classes
model.heads = torch.nn.Linear(model.heads.head.in_features, 4)

model.eval().to(device)


In [None]:
transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

def blur_except_subpatch(img_tensor, subpatch_idx, patch_size=56):
    img_pil = TF.to_pil_image(img_tensor.squeeze().cpu())
    img_cv2 = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
    blurred = cv2.GaussianBlur(img_cv2, (25, 25), 0)

    row, col = divmod(subpatch_idx, 4)
    r_start, c_start = row * patch_size, col * patch_size
    blurred[r_start:r_start+patch_size, c_start:c_start+patch_size, :] = \
        img_cv2[r_start:r_start+patch_size, c_start:c_start+patch_size, :]

    img_blurred = Image.fromarray(cv2.cvtColor(blurred, cv2.COLOR_BGR2RGB))
    return transform(img_blurred).unsqueeze(0).to(device)


In [None]:
def build_graph_delaunay(node_feats, centroids):
    pos = np.array(centroids)
    x = torch.tensor(node_feats, dtype=torch.float)
    tri = Delaunay(pos)

    edges = set()
    for simplex in tri.simplices:
        for i in range(3):
            for j in range(i+1, 3):
                u, v = simplex[i], simplex[j]
                edges.add(tuple(sorted((u, v))))

    edge_index, edge_attr = [], []
    cos_sim = cosine_similarity(x)

    for u, v in edges:
        dist = np.linalg.norm(pos[u] - pos[v])
        sim = (1.0 / (dist + 1e-6) + cos_sim[u][v]) / 2
        edge_index += [[u, v], [v, u]]
        edge_attr += [sim, sim]

    return Data(
        x=x,
        edge_index=torch.tensor(edge_index).T,
        edge_attr=torch.tensor(edge_attr),
        pos=torch.tensor(pos, dtype=torch.float)
    )

def build_graph_knn(node_feats, centroids, k=4):
    pos = np.array(centroids)
    x = torch.tensor(node_feats, dtype=torch.float)
    nbrs = NearestNeighbors(n_neighbors=k+1).fit(pos)
    _, indices = nbrs.kneighbors(pos)

    edge_index, edge_attr = [], []
    cos_sim = cosine_similarity(x)

    for i in range(len(pos)):
        for j in indices[i][1:]:
            dist = np.linalg.norm(pos[i] - pos[j])
            sim = (1.0 / (dist + 1e-6) + cos_sim[i][j]) / 2
            edge_index += [[i, j], [j, i]]
            edge_attr += [sim, sim]

    return Data(
        x=x,
        edge_index=torch.tensor(edge_index).T,
        edge_attr=torch.tensor(edge_attr),
        pos=torch.tensor(pos, dtype=torch.float)
    )


In [None]:
class HistopathologyPatchDataset(Dataset):
    def __init__(self, root_dir, transform, label_map):
        self.samples = []
        self.transform = transform
        self.label_map = label_map

        for class_name in os.listdir(root_dir):
            class_path = os.path.join(root_dir, class_name)
            print(class_name, class_path)
            if not os.path.isdir(class_path): continue
            for path in glob(os.path.join(class_path, "*.tif")):
                # print(path)
                self.samples.append((path, class_name))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label_name = self.samples[idx]
        label = self.label_map[label_name]
        return img_path, label


In [None]:
import os
import torch
import pandas as pd
from tqdm import tqdm
from PIL import Image
from torch.utils.data import DataLoader

# Your dataset, model, transform, device, etc. must already be defined

print(root_dir)
dataset = HistopathologyPatchDataset(root_dir, transform, label_to_int)
print(dataset.samples)
loader = DataLoader(dataset, batch_size=1, shuffle=False)

csv_delaunay = []
csv_knn = []

for img_path, label in tqdm(loader):
    img_path = img_path[0]
    label = label.item()

    try:
        image = Image.open(img_path).convert("RGB")
        W, H = image.size
        feats, coords = []

        correct_preds = 0
        total_patches = 0

        for y in range(0, H - 224 + 1, 224):
            for x in range(0, W - 224 + 1, 224):
                patch = image.crop((x, y, x+224, y+224))
                patch_tensor = transform(patch).unsqueeze(0).to(device)

                for i in range(16):
                    sub_tensor = blur_except_subpatch(patch_tensor, i)
                    with torch.no_grad():
                        _ = model(sub_tensor)
                        pred = torch.argmax(torch.softmax(_, dim=1)).item()
                        print(f"Patch {i} prediction: {pred}, label: {label}")
                    total_patches += 1
                    if pred == label:
                        correct_preds += 1
                        feats.append(features_dict['feat'].squeeze().cpu().numpy())
                        row, col = divmod(i, 4)
                        cx, cy = x + col*56 + 28, y + row*56 + 28
                        coords.append((cx, cy))

#         if len(feats) < 3:
#             continue

#         graph_d = build_graph_delaunay(feats, coords)
#         graph_k = build_graph_knn(feats, coords, k=4)

#         fname = os.path.splitext(os.path.basename(img_path))[0]
#         path_d = os.path.join(graph_dir_delaunay, f"{fname}.pt")
#         path_k = os.path.join(graph_dir_knn, f"{fname}.pt")

#         torch.save(graph_d, path_d)
#         torch.save(graph_k, path_k)

#         accuracy = correct_preds / total_patches if total_patches > 0 else 0.0

#         csv_delaunay.append({
#             'graph_path': path_d,
#             'label': label,
#             'accuracy': round(accuracy, 4)
#         })
#         csv_knn.append({
#             'graph_path': path_k,
#             'label': label,
#             'accuracy': round(accuracy, 4)
#         })

    except Exception as e:
        print(f"❌ Error: {img_path} — {e}")

# # Save CSV metadata
# pd.DataFrame(csv_delaunay).to_csv("d_meta.csv", index=False)
# pd.DataFrame(csv_knn).to_csv("k_meta.csv", index=False)

# print("✅ Saved metadata to d_meta.csv and k_meta.csv")
