In [2]:
import os
from torchvision.transforms.functional import to_tensor
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GATConv
import numpy as np
from torchvision import models, transforms, datasets
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import os
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import networkx as nx
from PIL import Image



In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
import os
import time
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

def extract_patches_from_wsi_folder(folder_path, patch_size=224, stride=224, visualize=False, save_path="patches", break_interval=10, break_time=5):
    
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    
    processed_images = 0
    
    for root, _, files in os.walk(folder_path):
        for file in files:
            if file.lower().endswith(".png"):
                image_path = os.path.join(root, file)
                try:
                    image = Image.open(image_path).convert("RGB")
                    width, height = image.size
                    
                    for y in range(0, height - patch_size + 1, stride):
                        for x in range(0, width - patch_size + 1, stride):
                            patch = image.crop((x, y, x + patch_size, y + patch_size))
                            patch_filename = f"{os.path.splitext(file)[0]}_{x}_{y}.png"
                            patch.save(os.path.join(save_path, patch_filename))
                    
                    processed_images += 1
                    if processed_images % break_interval == 0:
                        print(f"Processed {processed_images} images. Taking a {break_time}-second break to prevent SSD overload...")
                        time.sleep(break_time)
                
                except Image.DecompressionBombError:
                    print(f"Skipping large image (possible decompression bomb): {image_path}")


# extract_patches_from_wsi_folder("data/", visualize=True)


In [10]:
class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = models.resnet152(pretrained=True)
        self.model.fc = nn.Linear(2048, 512)  # Reduce output size to 512

    def forward(self, x):
        return self.model(x)

feature_extractor = FeatureExtractor().to(device).eval()

In [11]:
def load_trained_gat_model(model_path, input_dim=512, hidden_dim=64, output_dim=3):
    from torch_geometric.nn import GATConv
    class GATClassifier(nn.Module):
        def __init__(self, input_dim, hidden_dim, output_dim, heads=4):
            super().__init__()
            self.gat1 = GATConv(input_dim, hidden_dim, heads=heads, concat=True)
            self.gat2 = GATConv(hidden_dim * heads, hidden_dim, heads=heads, concat=True)
            self.gat3 = GATConv(hidden_dim * heads, output_dim, heads=1, concat=False)
        
        def forward(self, data):
            x, edge_index = data.x.to(device), data.edge_index.to(device)
            x = torch.relu(self.gat1(x, edge_index))
            x = torch.relu(self.gat2(x, edge_index))
            x = self.gat3(x, edge_index)
            return x
    
    model = GATClassifier(input_dim, hidden_dim, output_dim).to(device)
    model.load_state_dict(torch.load(model_path))
    model.eval()
    return model

# Load trained model
model = load_trained_gat_model("best_model.pth")

  model.load_state_dict(torch.load(model_path))


In [13]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [17]:
def classify_patch(image):
    image = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        feature = feature_extractor(image).cpu()
    logits = model(Data(x=feature.to(device), edge_index=torch.tensor([[0], [0]], dtype=torch.long).to(device)))
    pred_label = logits.argmax().item()
    return pred_label

In [18]:
def classify_existing_patches(patch_folder, visualize=False):
    for root, _, files in os.walk(patch_folder):
        for file in files:
            if file.lower().endswith(".png"):
                patch_path = os.path.join(root, file)
                patch = Image.open(patch_path).convert("RGB")
                pred_label = classify_patch(patch)
                
                if pred_label == 1:  # Assuming 1 indicates cancer
                    print(f"Cancer detected in patch: {patch_path} (Label: {pred_label})")
                    if visualize:
                        plt.imshow(patch)
                        plt.title(f"Cancer Patch: {file} (Label: {pred_label})")
                        plt.axis("off")
                        plt.show()

In [None]:
classify_existing_patches("patches/")

Cancer detected in patch: patches/TCGA-44-3917-11A-01-BS1.f7097e3f-4b6d-48e5-b06e-c0cd4e9fcc56_10080_10752.png (Label: 1)
Cancer detected in patch: patches/TCGA-44-3917-11A-01-BS1.f7097e3f-4b6d-48e5-b06e-c0cd4e9fcc56_10080_12096.png (Label: 1)
Cancer detected in patch: patches/TCGA-44-3917-11A-01-BS1.f7097e3f-4b6d-48e5-b06e-c0cd4e9fcc56_10080_12768.png (Label: 1)
Cancer detected in patch: patches/TCGA-44-3917-11A-01-BS1.f7097e3f-4b6d-48e5-b06e-c0cd4e9fcc56_10080_13216.png (Label: 1)
Cancer detected in patch: patches/TCGA-44-3917-11A-01-BS1.f7097e3f-4b6d-48e5-b06e-c0cd4e9fcc56_10080_13664.png (Label: 1)
Cancer detected in patch: patches/TCGA-44-3917-11A-01-BS1.f7097e3f-4b6d-48e5-b06e-c0cd4e9fcc56_10080_2912.png (Label: 1)
Cancer detected in patch: patches/TCGA-44-3917-11A-01-BS1.f7097e3f-4b6d-48e5-b06e-c0cd4e9fcc56_10080_3584.png (Label: 1)
Cancer detected in patch: patches/TCGA-44-3917-11A-01-BS1.f7097e3f-4b6d-48e5-b06e-c0cd4e9fcc56_10080_5152.png (Label: 1)
Cancer detected in patch: p