In [12]:
import os
import random
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from PIL import Image
from concurrent.futures import ThreadPoolExecutor
from sklearn.model_selection import train_test_split
import shutil

In [13]:
from transformers import CLIPVisionModel

class PLIPModel_Vision(torch.nn.Module):
    def __init__(self, original_model):
        super(PLIPModel_Vision, self).__init__()
        self.vision_model = original_model.vision_model

    def forward(self, pixel_values):
        vision_output = self.vision_model(pixel_values)
        pooled_output = vision_output.pooler_output
        return pooled_output

model = CLIPVisionModel.from_pretrained("../plip/")
custom_model = PLIPModel_Vision(model)
custom_model.eval()

PLIPModel_Vision(
  (vision_model): CLIPVisionTransformer(
    (embeddings): CLIPVisionEmbeddings(
      (patch_embedding): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
      (position_embedding): Embedding(50, 768)
    )
    (pre_layrnorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
       

In [14]:
import os
import torch
from torch.utils.data import Dataset
from pathlib import Path

class PatientTileDataset(Dataset):
    def __init__(self, data_dir, model, save_dir):
        super().__init__()
        self.data_dir = data_dir
        self.model = model
        self.save_dir = Path(save_dir)
        self.files = []
        for patient_id in os.listdir(data_dir):
            patient_dir = os.path.join(data_dir, patient_id)
            if os.path.isdir(patient_dir):
                for f in os.listdir(patient_dir):
                    if f.endswith('.pt'):
                        self.files.append((os.path.join(patient_dir, f), patient_id))
    
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        file_path, patient_id = self.files[idx]
        data = torch.load(file_path)
        tile_data = torch.from_numpy(data['tile_data'][0]).unsqueeze(0)  # Add batch dimension
        # Assuming the model takes a batch of images; if not, you might need to adjust this.
        with torch.no_grad():
            vision_features = self.model(pixel_values=tile_data)
        feature_path = self.save_dir / patient_id / os.path.basename(file_path)
        feature_path.parent.mkdir(parents=True, exist_ok=True)
        # Save vision features
        torch.save(vision_features, feature_path)
        return feature_path

# Assuming you've instantiated your model somewhere as custom_model
data_dir = 'plip_preprocess/'
save_dir = 'regular_plip_extracted/'

# Initialize your dataset
dataset = PatientTileDataset(data_dir=data_dir, model=custom_model, save_dir=save_dir)

# Example of processing and saving features
for _ in dataset:
    pass


In [15]:
torch.load('regular_plip_extracted/TCGA-CQ-5329/TCGA-CQ-5329_10_21.jpeg.pt').shape

torch.Size([1, 768])

In [66]:
# torch.load('kaal_extract/TCGA-BA-4074/TCGA-BA-4074_10_1.jpeg.pt')

In [144]:
df23 =  pd.read_csv('g2_g3.csv')
df23_2 = df23.set_index('HNSC')
there = set(list(x[:] for x in df23_2.index))
wsi_there = os.listdir('kaal_extract/')
use = list(there.intersection(wsi_there))
df23_2 = df23_2.loc[use]
df23_2['cluster'] = df23_2['cluster'] -2

df23_3  = df23_2.sample(frac=1)

class1 = list(df23_3[df23_3['cluster']==1].index)
class0 = list(df23_3[df23_3['cluster']==0].index)

from sklearn.model_selection import train_test_split
C1_X_train, C1_X_test = train_test_split(class1, test_size=0.3)
C0_X_train, C0_X_test = train_test_split(class0, test_size=0.3)

C1_X_validate, C1_X_test = train_test_split(C1_X_test, test_size=0.4)
C0_X_validate, C0_X_test = train_test_split(C0_X_test, test_size=0.4)


X_train = [];X_train.extend(C1_X_train);X_train.extend(C0_X_train);
X_test = [];X_test.extend(C1_X_test);X_test.extend(C0_X_test)
X_validate = [];X_validate.extend(C1_X_validate);X_validate.extend(C0_X_validate)

random.shuffle(X_train);
random.shuffle(X_test)
random.shuffle(X_validate);

data_info = {};
data_info['train'] = X_train;data_info['test'] = X_test;data_info['validate'] = X_validate




In [165]:
with open('Datasets/data_info0315.pkl','wb') as f:
    pickle.dump(data_info,f)

In [145]:
# list(df23_2.index)

In [147]:
print(" C1 - Train : {} , Validate : {} , Test : {} ".format(len(C1_X_train),len(C1_X_test),len(C1_X_validate)))
print(" C0 - Train : {} , Validate : {} , Test : {} ".format(len(C0_X_train),len(C0_X_test),len(C0_X_validate)))

 C1 - Train : 70 , Validate : 12 , Test : 18 
 C0 - Train : 59 , Validate : 11 , Test : 15 


In [148]:
data = {};
data['train'] = {};data['test'] = {};data['validate'] = {};
data['train']['X'] = [];data['train']['Y'] = []
data['test']['X'] = [];data['test']['Y'] = []
data['validate']['X'] = [];data['validate']['Y'] = []

for i,pID in enumerate(X_train[:]):
    fol_p = os.path.join('kaal_extract/',pID) 
    tiles = os.listdir(fol_p)
    tile_data = []
    for tile in tiles:
        tile_p = os.path.join(fol_p,tile)
        
        np1 = torch.load(tile_p).numpy()
        # print(np1[0].shape)
        tile_data.append(np1)
        
    data['train']['X'].extend(np.array(tile_data))
    data['train']['Y'].extend(list(df23_3.loc[pID] for each in range(len(tile_data)) ))
    # except:
    #     print('not there {}'.format(pID))

data['train']['X'] = np.array(data['train']['X']);
data['train']['Y'] = np.array(data['train']['Y'])
data['train']['X'] = np.squeeze(data['train']['X'], axis=1)


for i, pID in enumerate(X_validate[:]):
    fol_p = os.path.join('kaal_extract/', pID)
    tiles = os.listdir(fol_p)
    tile_data = []
    for tile in tiles:
        tile_p = os.path.join(fol_p, tile)

        np1 = torch.load(tile_p).numpy()
        tile_data.append(np1)

    data['validate']['X'].extend(np.array(tile_data))
    data['validate']['Y'].extend([df23_3.loc[pID] for each in range(len(tile_data))])

data['validate']['X'] = np.array(data['validate']['X'])
data['validate']['Y'] = np.array(data['validate']['Y'])
data['validate']['X'] = np.squeeze(data['validate']['X'], axis=1)


for i, pID in enumerate(X_test[:]):
    fol_p = os.path.join('kaal_extract/', pID)
    tiles = os.listdir(fol_p)
    tile_data = []
    for tile in tiles:
        tile_p = os.path.join(fol_p, tile)

        np1 = torch.load(tile_p).numpy()
        tile_data.append(np1)

    data['test']['X'].extend(np.array(tile_data))
    data['test']['Y'].extend([df23_3.loc[pID] for each in range(len(tile_data))])

data['test']['X'] = np.array(data['test']['X'])
data['test']['Y'] = np.array(data['test']['Y'])
data['test']['X'] = np.squeeze(data['test']['X'], axis=1)

In [150]:
import pickle

In [151]:
with open('Datasets/data_031524_1.pkl','wb') as f:
    pickle.dump(data,f)

In [159]:
wsi_data = {};
for pID in df23_3.index:
    fol_p = os.path.join('kaal_extract/',pID)
    tiles = os.listdir(fol_p) ;
    tile_data = []
    for tile in tiles:
        tile_p = os.path.join(fol_p,tile);
        tile_data.append(torch.load(tile_p).numpy())
        
    np1 = np.array(tile_data)
    wsi_data[pID] = {} ;
    wsi_data[pID]['tiles'] = np.squeeze(np1,axis=1)
    wsi_data[pID]['class'] = df23_3.loc[pID][0]
    

In [161]:
wsi_data['TCGA-UF-A71E']['tiles'].shape

(833, 512)

In [163]:
# wsi_data.keys()

In [164]:
with open('Datasets/wsi_data_g2_g3.pkl','wb') as f:
    pickle.dump(wsi_data,f)