<div align=justify dir=rtl>
<h2>مدل پیشنهادی </h2>
<p>
در این نوتبوک کد مدل پیشنهادی قرار دارد. 
تفاوت این مدل با مدل قبلی در دو بخش است . اول این که ورودی که میگیرد متقاوت است با داده هایی که توسط مدل DINO فیلتر شده کار میکند . البته قصد داشتیم این فرآیند را با مدل دوم به صورت موازی پیش ببریم که به محدودیت gpu خوردیم. 
تفاوت دوم در مدل تکست انکودر است که در این بخش از تکست انکودر pathologyBert استفاده میکنیم که بر روی بیش از ۳۴۰ هزار گزارش پاتولوژی آموزش دیده است .همچنین از توکنایز مخصوص آن استفاده میکنیم
</p>

## Load libraries

In [1]:
import os
import random
import pandas as pd
from sklearn.model_selection import train_test_split
import pandas as pd
import os
import gc
import multiprocessing
from PIL import Image
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

import torch
from torch.cuda.amp import autocast
from torch import nn
from transformers import CLIPModel, CLIPConfig, CLIPVisionModel
from transformers import AutoModel, AutoTokenizer
from transformers import TrainingArguments, Trainer
from transformers import default_data_collator
from transformers import pipeline



## Hyperparameters

In [2]:
IMAGE_MODEL = 'openai/clip-vit-base-patch32'
BATCH_SIZE = 64
IMAGE_SIZE = 224
MAX_LEN = 100
MEAN = torch.tensor([0.48145466, 0.4578275, 0.40821073])
STD = torch.tensor([0.26862954, 0.26130258, 0.27577711])

## load pathologyBert

In [3]:
language_model = pipeline('fill-mask', model='tsantos/PathologyBERT')

## Tokenizer and Vision vision_preprocessor

In [4]:
from transformers import TrainingArguments, AutoTokenizer, CLIPFeatureExtractor
vision_preprocessor = CLIPFeatureExtractor.from_pretrained(IMAGE_MODEL)
tokenizer = language_model.tokenizer



# Dataset & DataLoader

In [5]:
class CLIPDataset(Dataset):
    def __init__(self, image_paths: list, text: list, mode: str = 'train'):
        self.image_paths = image_paths
        self.tokens = tokenizer(text, padding='max_length',
                                max_length=MAX_LEN, truncation=True)

        if mode == 'train':
            self.augment = transforms.Compose([
                transforms.RandomResizedCrop((IMAGE_SIZE, IMAGE_SIZE)),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor(),
                transforms.Normalize(mean=MEAN, std=STD)
            ])
        elif mode == 'test':
            self.augment = transforms.Compose([
                transforms.RandomResizedCrop((IMAGE_SIZE, IMAGE_SIZE)),
                transforms.ToTensor(),
                transforms.Normalize(mean=MEAN, std=STD)
            ])

    def __getitem__(self, idx):
        token = self.tokens[idx]
        return {'input_ids': token.ids, 'attention_mask': token.attention_mask,
                'pixel_values': self.augment(Image.open(self.image_paths[idx]).convert('RGB'))}

    def __len__(self):
        return len(self.image_paths)





## Load data and set Dataloaders

In [6]:
df=pd.read_csv('E:/NLP/filter_data.csv')

train,test = train_test_split(df,test_size=0.2, random_state=42)
train_ds = CLIPDataset(image_paths=train.images.tolist(),
                        text=train.captions.tolist(), mode='train')
test_ds = CLIPDataset(image_paths=test.images.tolist(),
                        text=test.captions.tolist(), mode='test')

## Utils

In [7]:
def clear_gpu():
    torch.clear_autocast_cache()
    torch.cuda.ipc_collect()
    torch.cuda.empty_cache()
    gc.collect()


def optimal_workers():
    num_cpus = multiprocessing.cpu_count()
    num_gpus = torch.cuda.device_count()
    optimal_value = min(num_cpus, num_gpus*4) if num_gpus else num_cpus - 1
    return optimal_value

## image Encoder

In [8]:
from transformers import CLIPVisionConfig, CLIPVisionModel
configuration = CLIPVisionConfig()
vision_encoder = CLIPVisionModel(configuration)

## Text Encoder

In [9]:
modeltext = language_model.model.bert
class BertPooler(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = torch.nn.Tanh()

    def forward(self, hidden_states):
        # We take the hidden state corresponding to the first token
        pooled_output = hidden_states[:, 0]
        pooled_output = self.dense(pooled_output)
        pooled_output = self.activation(pooled_output)
        return pooled_output

# Replace the original pooler with our custom pooler
modeltext.pooler = BertPooler(modeltext.config)
text_encoder= modeltext

## CLIPModel

In [10]:
def clip_wraper_creator():
    """create a dummy CLIPModel to wrap text and vision encoders in order to use CLIPTrainer"""
    config = {'num_hidden_layers': 0,
              'max_position_embeddings': 0,
              'vocab_size': 0,
              'hidden_size': 1,
              'patch_size': 1,
              }
    DUMMY_CONFIG = CLIPConfig(text_config_dict=config,
                              vision_config_dict=config)
    clip = CLIPModel(config=DUMMY_CONFIG)
    # convert projectors to Identity
    clip.text_projection = nn.Identity()
    clip.visual_projection = nn.Identity()
    return clip

In [11]:
clip = clip_wraper_creator()
clip.text_model = text_encoder
clip.vision_model = vision_encoder

clip=clip.cuda()

## Trainer

In [12]:
class CLIPTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(**inputs, return_loss=True)
        return outputs["loss"]

    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys):
        inputs = self._prepare_inputs(inputs)
        with torch.no_grad():
            if 1:
                with autocast():
                    loss = self.compute_loss(model, inputs)
            else:
                loss = self.compute_loss(model, inputs)
        return (loss, None, None)

## Set Args

In [13]:
args = TrainingArguments(
        "clip-fa",
        evaluation_strategy="steps",
        save_strategy="steps",
        eval_steps=5000,
        save_steps=5000,
        logging_steps=5000,
        learning_rate=3e-6,
        weight_decay=0.003,
        warmup_steps=100,
        fp16=False,
        prediction_loss_only=True,
        gradient_accumulation_steps=1,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE,
        num_train_epochs=300,
        report_to='tensorboard'
    )

## Start Training

In [14]:
args.dataloader_num_workers = optimal_workers()
trainer = CLIPTrainer(clip, args,
                        train_dataset=train_ds,
                        eval_dataset=test_ds)

trainer.train()

***** Running training *****
  Num examples = 67068
  Num Epochs = 300
  Instantaneous batch size per device = 64
  Total train batch size (w. parallel, distributed & accumulation) = 64
  Gradient Accumulation steps = 1
  Total optimization steps = 314400
  Number of trainable parameters = 183481345


  0%|          | 0/314400 [00:00<?, ?it/s]

## Test

## Loade Model

In [None]:
clip.load_state_dict(torch.load('E:/NLP/clip-fa_base_line/checkpoint-50000/pytorch_model.bin'))
clip.eval()


In [None]:


counter=0

for i in range(len(test)):
  print(i)
  image=Image.open(test.iloc[i]['images']).convert("RGB")
  editImage = vision_preprocessor(image)
  proccess_image = editImage['pixel_values'][0].transpose(0, 1, 2)
  image_input = torch.tensor(np.stack([proccess_image])).cuda()
  
  text_descriptions=[test.iloc[i]['captions']]

  for txt in test.captions.tolist():
    if txt not in text_descriptions:
       text_descriptions.append(txt)

  
  tokens=tokenizer(text_descriptions,padding=True, truncation=True)

  with torch.no_grad():
    image_features=clip.vision_model(image_input).pooler_output.float()
    text_features=clip.text_model(torch.tensor(tokens.input_ids).cuda(),torch.tensor(tokens.attention_mask).cuda()).pooler_output.float()

  image_features /= image_features.norm(dim=-1, keepdim=True)
  text_features /= text_features.norm(dim=-1, keepdim=True)

  text_probs = (1.0 * image_features @ text_features.T).softmax(dim=-1)
  top_probs, top_labels = text_probs.cpu().topk(len(text_descriptions), dim=-1)
  if top_labels[0][0].item() in [0]:
    counter+=1


print("Accuracy on Testset= ",counter/len(test))


## Plot Example

In [None]:
!pip install arabic-reshaper
!pip install python-bidi
import matplotlib.pyplot as plt
from bidi.algorithm import get_display
from arabic_reshaper import reshape

In [None]:
import random
clip.eval()

# Open the image file and convert it to RGB format

image = Image.open('E:/NLP/content/normalized/TCGA-EW-A1J2-01Z-00-DX1/3072/15860_43508.png').convert("RGB")



# Create a CLIPFeatureExtractor object and use it to pre-process the image
vision_preprocessor = CLIPFeatureExtractor.from_pretrained('openai/clip-vit-base-patch32')
editImage = vision_preprocessor(image)

# Transpose the tensor to the correct shape
proccess_image = editImage['pixel_values'][0].transpose(0, 1, 2)

# Convert the image to a tensor and move it to the GPU
image_input = torch.tensor(np.stack([proccess_image])).cuda()
text_descriptions=[        

                   'Invasive lobular carcinoma in greatest linear dimension.'.lower(),
                   'IVASIVE DUCTAL CARCINOMA, DUCTAL CARCINOMA IN SITU'.lower(),
                   'LOBULAR CARCINOMA IN SITU. INFILTRATING LOBULAR CARCINOMA. LOBULAR CARCINOMA IN SITU'.lower()

                   ]
tokens=tokenizer(text_descriptions,padding=True, truncation=True)

with torch.no_grad():
  image_features=clip.vision_model(image_input).pooler_output.float()
  text_features=clip.text_model(torch.tensor(tokens.input_ids).cuda(),torch.tensor(tokens.attention_mask).cuda()).pooler_output.float()

image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)

text_probs = (1.0 * image_features @ text_features.T).softmax(dim=-1)
top_probs, top_labels = text_probs.cpu().topk(len(text_descriptions), dim=-1)
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
# Create a figure and an axis
fig, (ax1, ax2) = plt.subplots(1, 2)
fig.set_size_inches(12, 8)

persian_descriptions = [get_display(reshape(description)) for description in text_descriptions]
# Use the `barh` method to create a horizontal bar plot
ax1.barh(range(len(persian_descriptions)), top_probs[0])
ax1.set_yticks(range(len(persian_descriptions)))
ax1.set_yticklabels([persian_descriptions[i] for i in top_labels[0].numpy()])
ax1.grid()
ax2.imshow(image)
ax2.axis("off")
# Show the plot
plt.show()

## Plot More example

In [None]:
counter=0

for i in range(len(test)):
  counter+=1
  k=random.randint(0, len(test)-1)
  image=Image.open(test.iloc[k]['images']).convert("RGB")
  editImage = vision_preprocessor(image)
  proccess_image = editImage['pixel_values'][0].transpose(0, 1, 2)
  image_input = torch.tensor(np.stack([proccess_image])).cuda()
  text_descriptions=[        
                   test.iloc[k]['captions'].lower(),
                   test.iloc[random.randint(0, len(test)-1)]['captions'].lower(),
                   test.iloc[random.randint(0, len(test)-1)]['captions'].lower()
                   ]
  tokens=tokenizer(text_descriptions,padding=True, truncation=True)

  with torch.no_grad():
    image_features=clip.vision_model(image_input).pooler_output.float()
    text_features=clip.text_model(torch.tensor(tokens.input_ids).cuda(),torch.tensor(tokens.attention_mask).cuda()).pooler_output.float()

  image_features /= image_features.norm(dim=-1, keepdim=True)
  text_features /= text_features.norm(dim=-1, keepdim=True)

  text_probs = (1.0 * image_features @ text_features.T).softmax(dim=-1)
  top_probs, top_labels = text_probs.cpu().topk(len(text_descriptions), dim=-1)
  if top_labels[0][0].item() ==0:
    counter+=1
    import matplotlib.pyplot as plt
    import matplotlib.image as mpimg

    fig, (ax1, ax2) = plt.subplots(1, 2)
    fig.set_size_inches(12, 8)
    persian_descriptions = [get_display(reshape(description)) for description in text_descriptions]

    ax1.barh(range(len(persian_descriptions)), top_probs[0])
    ax1.set_yticks(range(len(persian_descriptions)))
    ax1.set_yticklabels([persian_descriptions[i] for i in top_labels[0].numpy()])
    ax1.grid()
    ax2.imshow(image)
    ax2.axis("off")

    plt.show()

    plt.close()

    if counter==10:
      break
    


