# Load image data and text data

In [21]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image
from transformers import AutoImageProcessor, ViTModel

from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
from torch.nn import functional as F
from transformers import BertModel, BertTokenizer
device = 'cuda' if torch.cuda.is_available() else 'cpu'

torch.manual_seed(13)

class_ind = {'CC':0, 'EC':1, 'LGSC':2, 'HGSC':3, 'MC':4}

class OvarianDataset(Dataset):
    def __init__(self, annotations_file, img_dir, text_dir):
        self.img_metadata = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.texts = {}
        for c in ['CC', 'EC', 'LGSC', 'HGSC', 'MC']:
            self.texts[c] = pd.read_table(text_dir+c+'.txt', header=None, sep='.').iloc[:,1].to_list()
       
    def __len__(self):
        return len(self.img_metadata)

    def __getitem__(self, idx):
        sample_idx = idx
        sample = self.img_metadata.iloc[sample_idx, 0]
        group = self.img_metadata.iloc[sample_idx, 1]

        label = torch.tensor(class_ind[group]).to(device)  

        image_patches = []
        for i in range(100):
            img_path = self.img_dir + f'sample_{sample}' + f'/{sample}_{i}.png'
            patch = torch.tensor(np.asarray(Image.open(img_path))[13:237, 13:237].T, dtype=torch.float32)
            image_patches.append(patch)
        image_patches = torch.stack(image_patches, dim=0).to(device)

        text = self.texts[group][idx % 100]
        # text = torch.tensor(self.tokenizer.encode(text, max_length=seq_max_length, padding="max_length")).to(device)

        return image_patches, text, label

In [22]:
# Create Dataset
metadata = "/scratch1/yuqiuwan/CSCI567/train.csv"
image_dir = "/scratch1/yuqiuwan/CSCI567/preprocess_images_threshold/"
text_dir = "/scratch1/yuqiuwan/CSCI567/textLabel/"

wholedataset = OvarianDataset(metadata, image_dir, text_dir)
train_set, test_set = torch.utils.data.random_split(wholedataset, [0.8, 0.2])

# Build CLIP model

In [23]:
class CLIP(nn.Module):
    def __init__(self, ImageEncoder, TextEncoder, d_embed=[384,768], n_classes=5):
        super().__init__()
        self.image_encoder = ImageEncoder
        self.text_encoder = TextEncoder
        self.image_proj = nn.Linear(d_embed[0], n_classes)
        self.text_proj = nn.Linear(d_embed[1], n_classes)
        self.n_classes = n_classes

    def forward(self, img, text):
        img_features = self.image_encoder(img)
        img_embed = self.image_proj(img_features).view(-1, 100, self.n_classes)
        img_embed = torch.mean(img_embed, 1)
        img_embed = img_embed / torch.norm(img_embed, dim=-1, keepdim=True)

        text_outputs = self.text_encoder(**text)
        text_embed = text_outputs.pooler_output
        text_embed = self.text_proj(text_embed)
        text_embed = text_embed / torch.norm(text_embed, dim=-1, keepdim=True)

        logits = img_embed @ text_embed.T
        return logits

# Set up training

In [24]:
from timm.models.vision_transformer import VisionTransformer

def get_pretrained_url(key):
    URL_PREFIX = "https://github.com/lunit-io/benchmark-ssl-pathology/releases/download/pretrained-weights"
    model_zoo_registry = {
        "DINO_p16": "dino_vit_small_patch16_ep200.torch",
        "DINO_p8": "dino_vit_small_patch8_ep200.torch",
    }
    pretrained_url = f"{URL_PREFIX}/{model_zoo_registry.get(key)}"
    return pretrained_url


def vit_small(pretrained, progress, key, **kwargs):
    patch_size = kwargs.get("patch_size", 16)
    model = VisionTransformer(
        img_size=224, patch_size=patch_size, embed_dim=384, num_heads=6, num_classes=0
    )
    if pretrained:
        pretrained_url = get_pretrained_url(key)
        verbose = model.load_state_dict(
            torch.hub.load_state_dict_from_url(pretrained_url, progress=progress)
        )
        print(verbose)
    return model

In [25]:
seq_max_length = 50
tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")

def contrastive_loss(logits, labels):
    image_loss = F.cross_entropy(logits, labels, reduction="mean")
    text_loss = F.cross_entropy(logits.transpose(0, 1), labels, reduction="mean")
    loss = (image_loss + text_loss) / 2

    return loss

def train_loop(dataloader, model, loss_fn, optimizer, image_processor=None):
    size = len(dataloader.dataset)
    model.train()
    best_loss = np.inf
    for batch, (imgs, texts, _) in enumerate(dataloader):
        labels = torch.tensor(range(batch_size)).to(device)
        imgs = imgs.view(-1, 3, 224, 224)
        texts = tokenizer(texts, padding='max_length', max_length=seq_max_length, return_tensors='pt').to(device)

        if image_processor:
            imgs = image_processor(imgs, return_tensors="pt").to(device)
            # Compute prediction and loss
            logits = model(imgs['pixel_values'], texts)
        else:
            # Compute prediction and loss
            logits = model(imgs, texts)
        loss = loss_fn(logits, labels)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_acc = torch.sum(torch.argmax(logits, axis=1) == labels) / logits.shape[0]

        loss, current = loss.item(), (batch + 1) * batch_size
        if train_acc == 1:
            if loss < best_loss:
                torch.save(model.state_dict(), '/scratch1/yuqiuwan/CSCI567/bert_lunit_model_state_dict_current_best.pt')
            
      
        print(f"loss: {loss:>7f}; train_acc: {train_acc:>7f}  [{current:>5d}/{size:>5d}]")
            
######################## Train CLIP ########################
batch_size = 2
dropout = 0.0
learning_rate = 10**(-3)

train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True)

# load phikon
image_processor = AutoImageProcessor.from_pretrained("owkin/phikon")
img_encoder = vit_small(pretrained=True, progress=False, key="DINO_p16", patch_size=16)
img_encoder.eval()

# load bert
text_encoder = BertModel.from_pretrained('bert-base-uncased')
text_encoder.eval()

clip_model = CLIP(img_encoder, text_encoder, d_embed=[384, 768], n_classes=5).to(device)

# Freeze the pre-trained model
for p in clip_model.image_encoder.parameters():
    p.requires_grad = False
for p in clip_model.text_encoder.parameters():
    p.requires_grad = False

optimizer = torch.optim.Adam(clip_model.parameters(), lr=learning_rate)

<All keys matched successfully>


In [14]:
epochs = 1
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, clip_model, contrastive_loss, optimizer, image_processor)
print("Done!")

Epoch 1
-------------------------------


In [None]:
torch.save(model.state_dict(), '/scratch1/yuqiuwan/CSCI567/bert_lunit_model_state_dict_last_epoch_output.pt')

# Tests

In [26]:
image_processor = AutoImageProcessor.from_pretrained("owkin/phikon")
img_encoder = vit_small(pretrained=True, progress=False, key="DINO_p16", patch_size=16)
img_encoder.eval()
# load bert
text_encoder = BertModel.from_pretrained('bert-base-uncased')
text_encoder.eval()

clip_model = CLIP(img_encoder, text_encoder, d_embed=[384, 768], n_classes=5).to(device)
clip_model.load_state_dict(torch.load('/scratch1/yuqiuwan/CSCI567/bert_lunit_model_state_dict_last_epoch_output.pt'))

<All keys matched successfully>


<All keys matched successfully>

In [27]:
class_ind = {'CC':0, 'EC':1, 'LGSC':2, 'HGSC':3, 'MC':4}

def test_loop(dataloader, model, image_processor=None):
    size = len(dataloader.dataset)
    model.eval()
    texts = ['This type of cells have cell cytoplasm that are see through, and often have clear cell boundaries', 
             'Cells exhibit a back-to-back glandular pattern', 
             'This type of cells have cells close to normal healthy cells, cells containing single nuclei, and alive cell', 
             'There are many cells that are often deformed in shape, and many cells with multiple nucleus, and tissues often present many dead cells',
             'This type of cells often have goblet cells, they are often goblet-like or cell-like']
    texts = tokenizer(texts, padding='max_length', max_length=30, return_tensors='pt').to(device)
    
    correct_num = 0
    for batch, (imgs, _, labels) in enumerate(dataloader):
        imgs = imgs.view(-1, 3, 224, 224)
        
        if image_processor:
            imgs = image_processor(imgs, return_tensors="pt").to(device)
            # Compute prediction and loss
            logits = model(imgs['pixel_values'], texts)
        else:
            # Compute prediction and loss
            logits = model(imgs, texts)
        print(logits, 'True label:', labels)

        correct_num += torch.sum(torch.argmax(logits, axis=1) == labels)
        print(batch, 'correct_num:', correct_num)

    test_acc = correct_num  / size
    print('Test_accuracy:', test_acc)

test_dataloader = DataLoader(test_set, batch_size=1, shuffle=True)
test_loop(test_dataloader, clip_model, image_processor)

tensor([[-0.4119, -0.7330, -0.6756,  0.5110, -0.8785]], device='cuda:0',
       grad_fn=<MmBackward0>) True label: tensor([3], device='cuda:0')
0 correct_num: tensor(1, device='cuda:0')
tensor([[ 0.9085, -0.3696,  0.3149, -0.0672,  0.3685]], device='cuda:0',
       grad_fn=<MmBackward0>) True label: tensor([0], device='cuda:0')
1 correct_num: tensor(2, device='cuda:0')
tensor([[-0.9527, -0.2206, -0.5433,  0.4799, -0.7716]], device='cuda:0',
       grad_fn=<MmBackward0>) True label: tensor([3], device='cuda:0')
2 correct_num: tensor(3, device='cuda:0')
tensor([[-0.8024,  0.0017, -0.0387,  0.4935, -0.2982]], device='cuda:0',
       grad_fn=<MmBackward0>) True label: tensor([3], device='cuda:0')
3 correct_num: tensor(4, device='cuda:0')
tensor([[-0.6433,  0.7872, -0.2268, -0.4326, -0.0730]], device='cuda:0',
       grad_fn=<MmBackward0>) True label: tensor([1], device='cuda:0')
4 correct_num: tensor(5, device='cuda:0')
tensor([[-0.7546,  0.7108, -0.4456, -0.4166, -0.2887]], device='cuda:0

# Examples and our old Transformer

In [None]:
#################### Bert Example ############################################
from transformers import BertTokenizer

# Define the path to your file
file_path = '/scratch1/yuqiuwan/CSCI567/textLabel/CC.txt'
output_CC = []

# Open the file using the 'with' statement to ensure it gets closed after reading
with open(file_path, 'r') as file:
    # Iterate over each line in the file
    for line in file:
        # Each 'line' includes a newline character at the end, you can strip it using strip()
        clean_line = line.strip()
        parts = clean_line.split('. ', 1)
        tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
        sample = parts[1]
        encoding = tokenizer.encode(sample, max_length=512, padding="max_length")
        output_CC.append(encoding)

output_CC = np.array(output_CC)

In [None]:
dir_path = '/scratch1/yuqiuwan/CSCI567/preprocess_images/sample_10077/10077_0.png'
img = np.asarray(Image.open(dir_path))
img = img[13:237, 13:237]
img = torch.tensor(img.T[None, :, :, :], dtype=torch.float32)

In [None]:
#################### Phikon Example ############################################
from PIL import Image
import torch
from transformers import AutoImageProcessor, ViTModel

# load an image
image = np.stack([img[0].reshape(250,250)]*3, axis=-1) # Image.open("assets/example.tif")

# load phikon
image_processor = AutoImageProcessor.from_pretrained("owkin/phikon")
model = ViTModel.from_pretrained("owkin/phikon", add_pooling_layer=False)

# process the image
inputs = image_processor(image, return_tensors="pt")

# get the features
with torch.no_grad():
    outputs = model(**inputs)
    features = outputs.last_hidden_state[:, 0, :] 

In [None]:
#################### Lunit Example ############################################
import torch
from timm.models.vision_transformer import VisionTransformer

def get_pretrained_url(key):
    URL_PREFIX = "https://github.com/lunit-io/benchmark-ssl-pathology/releases/download/pretrained-weights"
    model_zoo_registry = {
        "DINO_p16": "dino_vit_small_patch16_ep200.torch",
        "DINO_p8": "dino_vit_small_patch8_ep200.torch",
    }
    pretrained_url = f"{URL_PREFIX}/{model_zoo_registry.get(key)}"
    return pretrained_url


def vit_small(pretrained, progress, key, **kwargs):
    patch_size = kwargs.get("patch_size", 16)
    model = VisionTransformer(
        img_size=224, patch_size=patch_size, embed_dim=384, num_heads=6, num_classes=0
    )
    if pretrained:
        pretrained_url = get_pretrained_url(key)
        verbose = model.load_state_dict(
            torch.hub.load_state_dict_from_url(pretrained_url, progress=progress)
        )
        print(verbose)
    return model

model = vit_small(pretrained=True, progress=False, key="DINO_p16", patch_size=16)
t = model(img)

In [None]:
# Helper Functions 
def patchify(images, n_patches):
    n, c, h, w = images.shape

    assert h == w, "Patchify method is implemented for square images only"

    patches = torch.zeros(n, n_patches ** 2, h * w * c // n_patches ** 2)
    patch_size = h // n_patches

    for idx, image in enumerate(images):
        for i in range(n_patches):
            for j in range(n_patches):
                patch = image[:, i * patch_size: (i + 1) * patch_size, j * patch_size: (j + 1) * patch_size]
                patches[idx, i * n_patches + j] = patch.flatten()
    return patches

def get_positional_embeddings(sequence_length, d):
    result = torch.ones(sequence_length, d)
    for i in range(sequence_length):
        for j in range(d):
            result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
    return result


In [None]:
class Head(nn.Module):
    def __init__(self, d_embed, head_size, block_size=None):
        super().__init__()
        self.key = nn.Linear(d_embed, head_size, bias=False)
        self.query = nn.Linear(d_embed, head_size, bias=False)
        self.value = nn.Linear(d_embed, head_size, bias=False)
        self.block_size = block_size
        self.head_size = head_size
        if block_size != None:
            self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        """ Input: x of shape (B, T, C), i.e. (batch size, sequence length, input channels)
          Output: tensor of shape (B, T, C)

          (1) Computes key and query representations using linear transformations.
          (2) Computes attention scores by multiplying query and key tensors and normalizing by C**-0.5.
              Masks out future information using the lower triangular matrix (you can use tril function).
              Applies softmax to get attention weights. Then, applies dropout for regularization
          (3) Computes the weighted sum of values using attention weights.
        """
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        if self.block_size != None:
            weight = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
            weight = weight.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        else:
            weight = q @ k.transpose(-2,-1) * (self.head_size**-0.5)
        weight = F.softmax(weight, dim=-1) # (B, T, T)
        weight = self.dropout(weight)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = weight @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size, d_embed, block_size=None):
        super().__init__()
        self.heads = nn.ModuleList([Head(d_embed, head_size, block_size=block_size) for _ in range(num_heads)])
        self.proj = nn.Linear(d_embed, d_embed)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        """ Input: x of shape (B, T, C), i.e. (batch size, sequence length, input channels)
          Output: tensor of shape (B, T, C)

          (1) Computes attention for each head in parallel and concatenates their outputs along the last dimension.
          (2) Projects concatenated head outputs back to the original dimension using a linear layer.
          (3) Applies dropout for regularization.
        """
    
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, d_embed):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_embed, 4 * d_embed),
            nn.ReLU(),
            nn.Linear(4 * d_embed, d_embed),
            nn.Dropout(dropout),
        )

    def forward(self, x):

        """ Input: x of shape (B, T, C), i.e. (batch size, sequence length, input channels)
          Output: tensor of shape (B, T, C)
        """

        return self.net(x)
      
class Block(nn.Module):

    def __init__(self, d_embed, n_heads, block_size=None):
        # d_embed: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        assert d_embed % n_heads == 0, f"Can't divide dimension {d_embed} into {n_heads} heads"
        head_size = d_embed // n_heads
        self.sa = MultiHeadAttention(n_heads, head_size, d_embed, block_size=block_size)
        self.ffwd = FeedFoward(d_embed)
        self.ln1 = nn.LayerNorm(d_embed)
        self.ln2 = nn.LayerNorm(d_embed)

    def forward(self, x):

        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x
    
class ImageEncoder(nn.Module):
    def __init__(self, data_shape, n_patches=10, n_blocks=2, d_embed=10, n_heads=2, n_classes=5):
        # Super constructor
        super().__init__()
        
        # Attributes
        self.chw = data_shape # ( C , H , W )
        self.n_patches = n_patches
        self.n_blocks = n_blocks
        self.n_heads = n_heads
        self.d_embed = d_embed
        
        # Input and patches sizes
        assert data_shape[1] % n_patches == 0, "Input shape not entirely divisible by number of patches"
        assert data_shape[2] % n_patches == 0, "Input shape not entirely divisible by number of patches"
        self.patch_size = (data_shape[1] / n_patches, data_shape[2] / n_patches)

        # 1) Linear mapper
        self.input_d = int(data_shape[0] * self.patch_size[0] * self.patch_size[1])
        self.linear_mapper = nn.Linear(self.input_d, d_embed)
        
        # 2) Learnable classification token
        self.classify_token = nn.Parameter(torch.rand(1, 1, d_embed))
        
        # 3) Positional embedding
        # self.register_buffer('positional_embeddings', get_positional_embeddings(n_patches ** 2 + 1, d_embed), persistent=False)
        self.position_embedding_table = nn.Embedding(n_patches ** 2 + 1, d_embed)
        
        # 4) Transformer encoder blocks
        self.blocks = nn.ModuleList([Block(d_embed, n_heads, block_size=None) for _ in range(n_blocks)])
        
        # 5) Classification MLPk
        self.proj = nn.Linear(d_embed, n_classes)
    

    def forward(self, images):
        # Dividing images into patches
        n, c, h, w = images.shape
        patches = patchify(images, self.n_patches).to(device)
        
        # Running linear layer tokenization
        # Map the vector corresponding to each patch to the hidden size dimension
        tokens = self.linear_mapper(patches)
        
        # Adding classification token to the tokens
        tokens = torch.cat((self.classify_token.expand(n, -1, -1), tokens), dim=1)
        
        # Adding positional embedding
        out = tokens + self.position_embedding_table(torch.arange(self.n_patches ** 2 + 1, device=device)) # self.positional_embeddings.repeat(n, 1, 1)
        
        # Transformer Blocks
        for block in self.blocks:
            out = block(out)
            
        # Getting the Classify Token, which is a token that represents the entire input 
        out = out[:, 0]
        logits = self.proj(out)
        
        return logits
    
class TextEncoder(nn.Module):

    def __init__(self, vocab_size, block_size, n_blocks=2, d_embed=10, n_heads=2, n_classes=5):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, d_embed)
        self.position_embedding_table = nn.Embedding(seq_max_length + 1, d_embed)
        self.classify_token = nn.Parameter(torch.rand(1, 1, d_embed))

        self.blocks = nn.Sequential(*[Block(d_embed, n_heads, block_size=block_size) for _ in range(n_blocks)])
        self.ln_f = nn.LayerNorm(d_embed) # final layer norm
        self.lm_head = nn.Linear(d_embed, n_classes)


    def forward(self, idx, targets=None):
        B, T = idx.shape

        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        tok_emb = torch.cat((self.classify_token.expand(B, -1, -1), tok_emb), dim=1)
        pos_emb = self.position_embedding_table(torch.arange(T+1, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)

        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        x = x[:, 0] # get classify token
        logits = self.lm_head(x) # (B,T,vocab_size)

        return logits