In [2]:
import os
import random
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from PIL import Image
from concurrent.futures import ThreadPoolExecutor
from sklearn.model_selection import train_test_split
import shutil

In [36]:
num_tiles_per_patient = 595
files = os.listdir('plip_preprocess/')
train , test_val = train_test_split(files,test_size=0.4)
test,val = train_test_split(test_val,test_size=0.5)
for file in train[:]:
    fol_p = os.path.join('plip_preprocess',file)
    tiles = os.listdir(fol_p)
    selected_tiles = random.sample(tiles, min(num_tiles_per_patient, len(tiles)))
    for tile in selected_tiles:
        tile_p = os.path.join(fol_p,tile)
        new_p = os.path.join('Datasets/train_03/train',tile)
        shutil.copy(tile_p,new_p)

for file in test[:]:
    fol_p = os.path.join('plip_preprocess',file)
    tiles = os.listdir(fol_p)
    selected_tiles = random.sample(tiles, min(num_tiles_per_patient, len(tiles)))
    for tile in selected_tiles:
        tile_p = os.path.join(fol_p,tile)
        new_p = os.path.join('Datasets/train_03/test',tile)
        shutil.copy(tile_p,new_p)

for file in val[:]:
    fol_p = os.path.join('plip_preprocess',file)
    tiles = os.listdir(fol_p)
    selected_tiles = random.sample(tiles, min(num_tiles_per_patient, len(tiles)))
    for tile in selected_tiles:
        tile_p = os.path.join(fol_p,tile)
        new_p = os.path.join('Datasets/train_03/validate',tile)
        shutil.copy(tile_p,new_p)

        

In [11]:
# tiles


In [6]:
torch.load('/home/gp7/ml_pni/aug31/plip_preprocess/TCGA-CV-5966/TCGA-CV-5966_5_12.jpeg.pt').keys()

dict_keys(['tile_data', 'file_data'])

In [9]:
torch.load('/home/gp7/ml_pni/aug31/plip_preprocess/TCGA-CV-5966/TCGA-CV-5966_5_12.jpeg.pt')['tile_data'][0].shape

(3, 224, 224)

In [1]:
# torch.load('Datasets/train_03/train/TCGA-DQ-5625_2_10.jpeg.pt')['tile_data'][0]

In [39]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import random
import shutil

class FlatTileDataset(Dataset):
    def __init__(self, data_dir):
        super().__init__()
        self.data_dir = data_dir
        self.files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if os.path.isfile(os.path.join(data_dir, f))]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file_path = self.files[idx]
        data = torch.load(file_path)
        tile_data = torch.from_numpy(data['tile_data'][0])
        file_data = data['file_data']
        return tile_data, file_data


In [40]:
dataset = FlatTileDataset(data_dir='Datasets/train_03/train')
train_data_loader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=32)

validation_dataset = FlatTileDataset(data_dir='Datasets/train_03/validate')
validation_data_loader = DataLoader(validation_dataset, batch_size=128, shuffle=False, num_workers=32)


test_dataset = FlatTileDataset(data_dir='Datasets/train_03/test')
test_data_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)



In [41]:
from transformers import CLIPVisionModel

class CustomPLIPModel(torch.nn.Module):
    def __init__(self, original_model):
        super(CustomPLIPModel, self).__init__()
        self.vision_model = original_model.vision_model
        self.vision_projection = torch.nn.Linear(768, 512)
        self.fc_layer = torch.nn.Linear(4, 512)  # Fully connected layer for the 4D vector

    def forward(self, pixel_values, score_vector):
        vision_output = self.vision_model(pixel_values)
        pooled_output = vision_output.pooler_output
        vision_features = self.vision_projection(pooled_output)
        score_features = self.fc_layer(score_vector)
        
        return vision_features, score_features
    
model = CLIPVisionModel.from_pretrained("../plip/")
custom_model = CustomPLIPModel(model)

In [43]:
from torch import optim
import torch

criterion = torch.nn.CosineSimilarity(dim=1)
optimizer = optim.Adam(custom_model.parameters(), lr=0.00001)
num_epochs = 1

for epoch in range(num_epochs):
    # Training Phase
    custom_model.train()  # Set the model to training mode
    train_loss = 0.0
    for batch_images, batch_scores in train_data_loader:
        optimizer.zero_grad()

        batch_loss = 0
        for img, score in zip(batch_images, batch_scores):
            vision_features, score_features = custom_model(img.unsqueeze(0), score.unsqueeze(0))  # Adjust dimensions if necessary
            cos_sim = criterion(score_features, vision_features)
            loss = -cos_sim.mean()
            
            batch_loss += loss.item()
            
            loss.backward()
        optimizer.step()
        train_loss += batch_loss
        print(f"Batch part loss {batch_loss:.4f}")
        
    avg_train_loss = train_loss / len(train_data_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {avg_train_loss:.4f}")

    # Validation Phase
    custom_model.eval()  # Set the model to evaluation mode
    with torch.no_grad():
        validation_loss = 0.0
        for batch_images, batch_scores in validation_data_loader:
            batch_loss = 0
            for img, score in zip(batch_images, batch_scores):
                vision_features, score_features = custom_model(img.unsqueeze(0), score.unsqueeze(0))  # Adjust dimensions if necessary
                cos_sim = criterion(score_features, vision_features)
                loss = -cos_sim.mean()

                batch_loss += loss.item()
                
            validation_loss += batch_loss
            print(f"Validation Batch part loss {batch_loss:.4f}")
        
        avg_validation_loss = validation_loss / len(validation_data_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Validation Loss: {avg_validation_loss:.4f}")

# Save the model parameters
torch.save(custom_model.state_dict(), "kaal_model.pth")


Batch part loss -55.4341
Batch part loss -59.5735
Batch part loss -62.1469
Batch part loss -63.2711
Batch part loss -63.3499
Batch part loss -61.4691
Batch part loss -65.3415
Batch part loss -64.9783
Batch part loss -68.6788
Batch part loss -64.7548
Batch part loss -66.9452
Batch part loss -66.8607
Batch part loss -69.3504
Batch part loss -69.6781
Batch part loss -71.0662
Batch part loss -71.8540
Batch part loss -67.4748
Batch part loss -66.5620
Batch part loss -73.9682
Batch part loss -66.4608
Batch part loss -70.6438
Batch part loss -71.5281
Batch part loss -73.7284
Batch part loss -72.5080
Batch part loss -71.1007
Batch part loss -71.0023
Batch part loss -68.8267
Batch part loss -69.2588
Batch part loss -68.8446
Batch part loss -69.2159
Batch part loss -67.4212
Batch part loss -70.2703
Batch part loss -68.1071
Batch part loss -69.7616
Batch part loss -67.9216
Batch part loss -73.1598
Batch part loss -71.2504
Batch part loss -68.9539
Batch part loss -69.1610
Batch part loss -74.5227


In [53]:
import torch
from transformers import CLIPVisionModel

# Assuming CustomPLIPModel is defined as before
class CustomPLIPModel(torch.nn.Module):
    def __init__(self, original_model):
        super(CustomPLIPModel, self).__init__()
        self.vision_model = original_model.vision_model
        self.vision_projection = torch.nn.Linear(768, 512)
        self.fc_layer = torch.nn.Linear(4, 512)  # Fully connected layer for the 4D vector

    def forward(self, pixel_values, score_vector):
        vision_output = self.vision_model(pixel_values)
        pooled_output = vision_output.pooler_output
        vision_features = self.vision_projection(pooled_output)
        score_features = self.fc_layer(score_vector)
        
        return vision_features, score_features
    
# Reload the model
original_model = CLIPVisionModel.from_pretrained("../plip/")
custom_model = CustomPLIPModel(original_model)
custom_model.load_state_dict(torch.load("kaal_model.pth"))

# Ensure the model is in evaluation mode for feature extraction
custom_model.eval()



CustomPLIPModel(
  (vision_model): CLIPVisionTransformer(
    (embeddings): CLIPVisionEmbeddings(
      (patch_embedding): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
      (position_embedding): Embedding(50, 768)
    )
    (pre_layrnorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
        

In [54]:
import os
import torch
from torch.utils.data import Dataset
from pathlib import Path

class PatientTileDataset(Dataset):
    def __init__(self, data_dir, model, save_dir):
        super().__init__()
        self.data_dir = data_dir
        self.model = model
        self.save_dir = Path(save_dir)
        self.files = []
        for patient_id in os.listdir(data_dir):
            patient_dir = os.path.join(data_dir, patient_id)
            if os.path.isdir(patient_dir):
                for f in os.listdir(patient_dir):
                    if f.endswith('.pt'):
                        self.files.append((os.path.join(patient_dir, f), patient_id))
    
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        file_path, patient_id = self.files[idx]
        data = torch.load(file_path)
        tile_data = torch.from_numpy(data['tile_data'][0]).unsqueeze(0)  # Add batch dimension
        # Assuming the model takes a batch of images; if not, you might need to adjust this.
        with torch.no_grad():
            vision_features, _ = self.model(pixel_values=tile_data, score_vector=torch.zeros(1, 4))
        feature_path = self.save_dir / patient_id / os.path.basename(file_path)
        feature_path.parent.mkdir(parents=True, exist_ok=True)
        # Save vision features
        torch.save(vision_features, feature_path)
        return feature_path

# Assuming you've instantiated your model somewhere as custom_model
data_dir = 'plip_preprocess/'
save_dir = 'kaal_extract/'

# Initialize your dataset
dataset = PatientTileDataset(data_dir=data_dir, model=custom_model, save_dir=save_dir)

# Example of processing and saving features
for _ in dataset:
    pass


In [66]:
# torch.load('kaal_extract/TCGA-BA-4074/TCGA-BA-4074_10_1.jpeg.pt')

In [144]:
df23 =  pd.read_csv('g2_g3.csv')
df23_2 = df23.set_index('HNSC')
there = set(list(x[:] for x in df23_2.index))
wsi_there = os.listdir('kaal_extract/')
use = list(there.intersection(wsi_there))
df23_2 = df23_2.loc[use]
df23_2['cluster'] = df23_2['cluster'] -2

df23_3  = df23_2.sample(frac=1)

class1 = list(df23_3[df23_3['cluster']==1].index)
class0 = list(df23_3[df23_3['cluster']==0].index)

from sklearn.model_selection import train_test_split
C1_X_train, C1_X_test = train_test_split(class1, test_size=0.3)
C0_X_train, C0_X_test = train_test_split(class0, test_size=0.3)

C1_X_validate, C1_X_test = train_test_split(C1_X_test, test_size=0.4)
C0_X_validate, C0_X_test = train_test_split(C0_X_test, test_size=0.4)


X_train = [];X_train.extend(C1_X_train);X_train.extend(C0_X_train);
X_test = [];X_test.extend(C1_X_test);X_test.extend(C0_X_test)
X_validate = [];X_validate.extend(C1_X_validate);X_validate.extend(C0_X_validate)

random.shuffle(X_train);
random.shuffle(X_test)
random.shuffle(X_validate);

data_info = {};
data_info['train'] = X_train;data_info['test'] = X_test;data_info['validate'] = X_validate




In [165]:
with open('Datasets/data_info0315.pkl','wb') as f:
    pickle.dump(data_info,f)

In [145]:
# list(df23_2.index)

In [147]:
print(" C1 - Train : {} , Validate : {} , Test : {} ".format(len(C1_X_train),len(C1_X_test),len(C1_X_validate)))
print(" C0 - Train : {} , Validate : {} , Test : {} ".format(len(C0_X_train),len(C0_X_test),len(C0_X_validate)))

 C1 - Train : 70 , Validate : 12 , Test : 18 
 C0 - Train : 59 , Validate : 11 , Test : 15 


In [148]:
data = {};
data['train'] = {};data['test'] = {};data['validate'] = {};
data['train']['X'] = [];data['train']['Y'] = []
data['test']['X'] = [];data['test']['Y'] = []
data['validate']['X'] = [];data['validate']['Y'] = []

for i,pID in enumerate(X_train[:]):
    fol_p = os.path.join('kaal_extract/',pID) 
    tiles = os.listdir(fol_p)
    tile_data = []
    for tile in tiles:
        tile_p = os.path.join(fol_p,tile)
        
        np1 = torch.load(tile_p).numpy()
        # print(np1[0].shape)
        tile_data.append(np1)
        
    data['train']['X'].extend(np.array(tile_data))
    data['train']['Y'].extend(list(df23_3.loc[pID] for each in range(len(tile_data)) ))
    # except:
    #     print('not there {}'.format(pID))

data['train']['X'] = np.array(data['train']['X']);
data['train']['Y'] = np.array(data['train']['Y'])
data['train']['X'] = np.squeeze(data['train']['X'], axis=1)


for i, pID in enumerate(X_validate[:]):
    fol_p = os.path.join('kaal_extract/', pID)
    tiles = os.listdir(fol_p)
    tile_data = []
    for tile in tiles:
        tile_p = os.path.join(fol_p, tile)

        np1 = torch.load(tile_p).numpy()
        tile_data.append(np1)

    data['validate']['X'].extend(np.array(tile_data))
    data['validate']['Y'].extend([df23_3.loc[pID] for each in range(len(tile_data))])

data['validate']['X'] = np.array(data['validate']['X'])
data['validate']['Y'] = np.array(data['validate']['Y'])
data['validate']['X'] = np.squeeze(data['validate']['X'], axis=1)


for i, pID in enumerate(X_test[:]):
    fol_p = os.path.join('kaal_extract/', pID)
    tiles = os.listdir(fol_p)
    tile_data = []
    for tile in tiles:
        tile_p = os.path.join(fol_p, tile)

        np1 = torch.load(tile_p).numpy()
        tile_data.append(np1)

    data['test']['X'].extend(np.array(tile_data))
    data['test']['Y'].extend([df23_3.loc[pID] for each in range(len(tile_data))])

data['test']['X'] = np.array(data['test']['X'])
data['test']['Y'] = np.array(data['test']['Y'])
data['test']['X'] = np.squeeze(data['test']['X'], axis=1)

In [150]:
import pickle

In [151]:
with open('Datasets/data_031524_1.pkl','wb') as f:
    pickle.dump(data,f)

In [159]:
wsi_data = {};
for pID in df23_3.index:
    fol_p = os.path.join('kaal_extract/',pID)
    tiles = os.listdir(fol_p) ;
    tile_data = []
    for tile in tiles:
        tile_p = os.path.join(fol_p,tile);
        tile_data.append(torch.load(tile_p).numpy())
        
    np1 = np.array(tile_data)
    wsi_data[pID] = {} ;
    wsi_data[pID]['tiles'] = np.squeeze(np1,axis=1)
    wsi_data[pID]['class'] = df23_3.loc[pID][0]
    

In [161]:
wsi_data['TCGA-UF-A71E']['tiles'].shape

(833, 512)

In [163]:
# wsi_data.keys()

In [164]:
with open('Datasets/wsi_data_g2_g3.pkl','wb') as f:
    pickle.dump(wsi_data,f)