# Load image data and text data

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image
from transformers import AutoImageProcessor, ViTModel

from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
from torch.nn import functional as F
from transformers import BertModel, BertTokenizer
device = 'cuda' if torch.cuda.is_available() else 'cpu'

torch.manual_seed(13)

class_ind = {'CC':0, 'EC':1, 'LGSC':2, 'HGSC':3, 'MC':4}

class OvarianDataset(Dataset):
    def __init__(self, annotations_file, img_dir, text_dir):
        self.img_metadata = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.texts = {}
        for c in ['CC', 'EC', 'LGSC', 'HGSC', 'MC']:
            self.texts[c] = pd.read_table(text_dir+c+'.txt', header=None, sep='.').iloc[:,1].to_list()
       
    def __len__(self):
        return len(self.img_metadata)

    def __getitem__(self, idx):
        sample_idx = idx
        sample = self.img_metadata.iloc[sample_idx, 0]
        group = self.img_metadata.iloc[sample_idx, 1]

        label = torch.tensor(class_ind[group]).to(device)  

        image_patches = []
        for i in range(100):
            img_path = self.img_dir + f'sample_{sample}' + f'/{sample}_{i}.png'
            patch = torch.tensor(np.asarray(Image.open(img_path))[13:237, 13:237].T, dtype=torch.float32)
            image_patches.append(patch)
        image_patches = torch.stack(image_patches, dim=0).to(device)

        text = self.texts[group][idx % 100]
        # text = torch.tensor(self.tokenizer.encode(text, max_length=seq_max_length, padding="max_length")).to(device)

        return image_patches, text, label

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Create Dataset
metadata = "/scratch1/yuqiuwan/CSCI567/train.csv"
image_dir = "/scratch1/yuqiuwan/CSCI567/preprocess_images_threshold/"
text_dir = "/scratch1/yuqiuwan/CSCI567/textLabel/"

wholedataset = OvarianDataset(metadata, image_dir, text_dir)
train_set, test_set = torch.utils.data.random_split(wholedataset, [0.8, 0.2])

# Build CLIP model

In [3]:
class CLIP(nn.Module):
    def __init__(self, ImageEncoder, TextEncoder, d_embed=768, n_classes=5):
        super().__init__()
        self.image_encoder = ImageEncoder
        self.text_encoder = TextEncoder
        self.image_proj = nn.Linear(d_embed, n_classes)
        self.text_proj = nn.Linear(d_embed, n_classes)

        self.n_classes = n_classes
        self.d_embed = d_embed

    def forward(self, img, text):
        img_outputs = self.image_encoder(img)
        img_features = img_outputs.last_hidden_state[:, 0, :].view(-1, 100, self.d_embed)
        img_features = torch.mean(img_features, 1)
        img_embed = self.image_proj(img_features)
        img_embed = img_embed / torch.norm(img_embed, dim=-1, keepdim=True)

        text_outputs = self.text_encoder(**text)
        text_embed = text_outputs.pooler_output
        text_embed = self.text_proj(text_embed)
        text_embed = text_embed / torch.norm(text_embed, dim=-1, keepdim=True)

        logits = img_embed @ text_embed.T
        return logits

# Set up training

In [4]:
seq_max_length = 50
tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")

def contrastive_loss(logits, labels):
    image_loss = F.cross_entropy(logits, labels, reduction="mean")
    text_loss = F.cross_entropy(logits.transpose(0, 1), labels, reduction="mean")
    loss = (image_loss + text_loss) / 2

    return loss

def train_loop(dataloader, model, loss_fn, optimizer, image_processor=None):
    size = len(dataloader.dataset)
    model.train()
    best_loss = np.inf
    for batch, (imgs, texts, _) in enumerate(dataloader):
        labels = torch.tensor(range(imgs.shape[0])).to(device)
        imgs = imgs.view(-1, 3, 224, 224)
        texts = tokenizer(texts, padding='max_length', max_length=seq_max_length, return_tensors='pt').to(device)


        if image_processor:
            imgs = image_processor(imgs, return_tensors="pt").to(device)
            # Compute prediction and loss
            logits = model(imgs['pixel_values'], texts)
        else:
            # Compute prediction and loss
            logits = model(imgs, texts)
        loss = loss_fn(logits, labels)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_acc = torch.sum(torch.argmax(logits, axis=1) == labels) / logits.shape[0]

        loss, current = loss.item(), (batch + 1) * batch_size
        if train_acc == 1:
            if loss < best_loss:
                torch.save(model.state_dict(), '/scratch1/yuqiuwan/CSCI567/bert_phikon_model_state_dict_current_best.pt')
            
      
        print(f"loss: {loss:>7f}; train_acc: {train_acc:>7f}  [{current:>5d}/{size:>5d}]")
            
######################## Train CLIP ########################
batch_size = 3
dropout = 0.0
learning_rate = 10**(-3)

train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True)

# load phikon
image_processor = AutoImageProcessor.from_pretrained("owkin/phikon")
img_encoder = ViTModel.from_pretrained("owkin/phikon", add_pooling_layer=False)
img_encoder.eval()

# load bert
text_encoder = BertModel.from_pretrained('bert-base-uncased')
text_encoder.eval()

clip_model = CLIP(img_encoder, text_encoder, d_embed=768, n_classes=5).to(device)
# clip_model.load_state_dict(torch.load('/scratch1/yuqiuwan/CSCI567/trainedModels/bert_phikon_model_state_dict_current_best.pt'))

# Freeze the pre-trained model
for p in clip_model.image_encoder.parameters():
    p.requires_grad = False
for p in clip_model.text_encoder.parameters():
    p.requires_grad = False

optimizer = torch.optim.Adam(clip_model.parameters(), lr=learning_rate)

Some weights of the model checkpoint at owkin/phikon were not used when initializing ViTModel: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
epochs = 19
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, clip_model, contrastive_loss, optimizer, image_processor)
print("Done!")

In [101]:
torch.save(clip_model.state_dict(), '/scratch1/yuqiuwan/CSCI567/bert_phikon_model_state_dict_last_epoch_output.pt')

# Tests

In [8]:
clip_model = CLIP(img_encoder, text_encoder, d_embed=768, n_classes=5).to(device)
clip_model.load_state_dict(torch.load('/scratch1/yuqiuwan/CSCI567/trainedModels/bert_phikon_model_state_dict_last_epoch_output.pt'))

<All keys matched successfully>

In [6]:
class_ind = {'CC':0, 'EC':1, 'LGSC':2, 'HGSC':3, 'MC':4}

def test_loop(dataloader, model, image_processor=None):
    size = len(dataloader.dataset)
    model.eval()
    # texts = ['Cells are clear and transparent', 
    #          'Cells exhibit a back-to-back glandular arrangement', 
    #          'This cell type includes cells resembling normal healthy cells, cells with single nuclei, and living cells', 
    #          'Numerous cells frequently exhibit deformed shapes, several cells contain multiple nuclei, and many tissues typically contain dead cells.',
    #          'Cells have goblet-like appearance']
    
    texts = ['This type of cells have cell cytoplasm that are see through, and often have clear cell boundaries', 
            'Cells exhibit a back-to-back glandular pattern', 
            'This type of cells have cells close to normal healthy cells, cells containing single nuclei, and alive cell', 
            'There are many cells that are often deformed in shape, and many cells with multiple nucleus, and tissues often present many dead cells',
            'This type of cells often have goblet cells, they are often goblet-like or cell-like']
    
    texts = tokenizer(texts, padding='max_length', max_length=30, return_tensors='pt').to(device)
    
    correct_num = 0
    for batch, (imgs, _, labels) in enumerate(dataloader):
        imgs = imgs.view(-1, 3, 224, 224)
        
        if image_processor:
            imgs = image_processor(imgs, return_tensors="pt").to(device)
            # Compute prediction and loss
            logits = model(imgs['pixel_values'], texts)
        else:
            # Compute prediction and loss
            logits = model(imgs, texts)
        print(logits, 'True label:', labels)

        correct_num += torch.sum(torch.argmax(logits, axis=1) == labels)
        print(batch, 'correct_num:', correct_num)

    test_acc = correct_num  / size
    print('Test_accuracy:', test_acc)

test_dataloader = DataLoader(test_set, batch_size=1, shuffle=True)
test_loop(test_dataloader, clip_model, image_processor)

tensor([[ 0.9743, -0.3738,  0.3041, -0.3364,  0.5164]], device='cuda:0',
       grad_fn=<MmBackward0>) True label: tensor([0], device='cuda:0')
0 correct_num: tensor(1, device='cuda:0')
tensor([[-0.6804, -0.2192,  0.5631,  0.3523, -0.0291]], device='cuda:0',
       grad_fn=<MmBackward0>) True label: tensor([2], device='cuda:0')
1 correct_num: tensor(2, device='cuda:0')
tensor([[-0.3674, -0.5154,  0.1742,  0.9608, -0.7074]], device='cuda:0',
       grad_fn=<MmBackward0>) True label: tensor([3], device='cuda:0')
2 correct_num: tensor(3, device='cuda:0')
tensor([[-0.5799, -0.4739,  0.4927,  0.8359, -0.2638]], device='cuda:0',
       grad_fn=<MmBackward0>) True label: tensor([3], device='cuda:0')
3 correct_num: tensor(4, device='cuda:0')
tensor([[-0.4328,  0.9029, -0.4565, -0.7340,  0.2061]], device='cuda:0',
       grad_fn=<MmBackward0>) True label: tensor([1], device='cuda:0')
4 correct_num: tensor(5, device='cuda:0')
tensor([[-0.4934, -0.4036,  0.0249,  0.9065, -0.6038]], device='cuda:0