In [None]:

import openslide as osl
import sys
import os
import numpy as np
from PIL import Image
from glob import glob
from tqdm import tqdm
import json
import matplotlib.pyplot as plt
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torchvision.utils import save_image
import random
import staintools
import torch.utils.data as data
from torchvision.utils import save_image
topilimage =transforms.ToPILImage()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def create_dir(path):
    import os
    if not os.path.exists(path):
        os.makedirs(path)
def create_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

In [None]:
root_dir='../../data/pdl1_wsi_anno/'
json_list=glob('../../data/pdl1_wsi_anno/PD-L1_tumor_area/*.json')
save_wsi_path='../../data/pdl1_annotation_patches/'
mpp=1.0

image_size=1000
image_resize_size=512
i=0 
json_path=json_list[i]
with open(json_path) as f:
    data=json.load(f)
slide_mpp=float(data['mpp_X'])
wsi_name=data['filename'].replace('.tiff',data['origin_extension']).replace( 'WSI-STCA_PDL1','CODIPAI-STBX')
wsi_path=os.path.join(root_dir,'PD-L1(22C3)',wsi_name)
slide_image=osl.open_slide(wsi_path)
wsi_w, wsi_h=slide_image.dimensions
thumbnail=slide_image.get_thumbnail((wsi_w//20, wsi_h//20))
annotation_mask=np.zeros((wsi_h//20, wsi_w//20), dtype=np.uint8)
thumbnail_mask=np.zeros((wsi_h//20, wsi_w//20), dtype=np.uint8)
for region in data['objects']:
    points=region['coordinate']
    scaled_points=[(int(x//20), int(y//20)) for x,y in points]
    cv2.fillPoly(annotation_mask, [np.array(scaled_points)], 1)
    
blur = cv2.GaussianBlur(np.array(thumbnail), (5,5), 0)
tre,thumbnail_mask=cv2.threshold(blur[:,:,0], 0, 1, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
thumbnail_mask=1-thumbnail_mask
common_mask=cv2.bitwise_and(thumbnail_mask, thumbnail_mask, mask=annotation_mask.astype(np.uint8))
slide_patch_size=int(image_size * (mpp/slide_mpp))
num_row=wsi_h//slide_patch_size
num_col=wsi_w//slide_patch_size
mask_patch_size=slide_patch_size//20
for col in tqdm(range(num_col)):
    for row in range(num_row):
        x1=col*slide_patch_size
        y1=row*slide_patch_size
        if common_mask[row*mask_patch_size:(row+1)*mask_patch_size, col*mask_patch_size:(col+1)*mask_patch_size].mean()>0.3:
            patch=slide_image.read_region((col*slide_patch_size, row*slide_patch_size), 0, (slide_patch_size, slide_patch_size))
            patch=patch.convert('RGB')
            patch=patch.resize((slide_patch_size//2, slide_patch_size//2), Image.LANCZOS)
            create_dir(os.path.join(save_wsi_path,'PD-L1'))    
            patch.save(os.path.join(save_wsi_path, 'PD-L1', f"{os.path.splitext(wsi_name)[0]}_patch{slide_patch_size}_{col*slide_patch_size}_{row*slide_patch_size}.tiff"))
            break

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, features):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(features, features, kernel_size=3, padding=1),
            nn.InstanceNorm2d(features),
            nn.ReLU(inplace=True),
            nn.Conv2d(features, features, kernel_size=3, padding=1),
            nn.InstanceNorm2d(features),
            nn.Dropout(0.5)  # Dropout 추가 (드롭아웃 확률 0.5)
        )

    def forward(self, x):
        return x + self.block(x)

# Generator Model
class Generator(nn.Module):
    def __init__(self, input_channels, output_channels, n_residual_blocks=9):
        super(Generator, self).__init__()
        # 초기 컨볼루션 블록
        model = [
            nn.Conv2d(input_channels, 64, kernel_size=7, padding=3),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        ]

        # 다운샘플링
        in_features = 64
        out_features = in_features * 2
        for _ in range(4):  # 기존 2에서 4로 변경
            model += [
                nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features * 2

        # 잔차 블록
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        # 업샘플링
        out_features = in_features // 2
        for _ in range(4):  # 기존 2에서 4로 변경
            model += [
                nn.ConvTranspose2d(in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features // 2

        # 출력 레이어
        model += [nn.Conv2d(64, output_channels, kernel_size=7, padding=3), nn.Tanh()]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

# Discriminator Model
class Discriminator(nn.Module):
    def __init__(self, input_channels):
        super(Discriminator, self).__init__()
        
        def discriminator_block(in_filters, out_filters, normalization=True):
            layers = [nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=2, padding=1)]
            if normalization:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(input_channels, 64, normalization=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, kernel_size=4, padding=1)
        )

    def forward(self, img):
        return self.model(img)
    
F = Generator(3, 3).to(device)  # 컬러에서 그레이스케일로



import os
model_dir = '../../model/HE_IHC_translation/membrane/'

F.load_state_dict(torch.load(os.path.join(model_dir, 'F_39.pth'),map_location=device))

In [None]:
IHC_img=torch.tensor(np.array(patch.resize((512, 512)))).permute(2, 0, 1).unsqueeze(0).float().to(device)/255.0*2-1
with torch.no_grad():
    fake_HE = F(IHC_img)

topilimage(fake_HE.squeeze(0).cpu()*0.5+0.5)