# Testing the Mecial CLIP foundation models. 
1. Testing CLIP models for zero-shot classification methods.
2. QuiltNet: https://huggingface.co/wisdomik/QuiltNet-B-32; Use this medical CLIP for zero classification.

## Try the CLIP and Medical CLIP methods.

In [None]:
!pip install clip
!pip install open_clip_torch

In [7]:
import numpy as np
import torch
import clip
from tqdm.notebook import tqdm
from pkg_resources import packaging

print("Torch version:", torch.__version__)

Torch version: 1.12.1+cu116


In [4]:
import open_clip

model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:wisdomik/QuiltNet-B-32')
tokenizer = open_clip.get_tokenizer('hf-hub:wisdomik/QuiltNet-B-32')   # why I need this item?

image = Image.open("/vol/research/wenjieProject/datasets/chaoyang/train/538849-1-IMG006x019-0.JPG")
image = preprocess_val(image).unsqueeze(0)

# text = tokenizer.encode(["A normal colon slide", "A serrated colon slide", "A adenocarcinoma colon slide", "A adenoma colon slide"])
text = open_clip.tokenize(["A normal colon slide", "A serrated colon slide", "A adenocarcinoma colon slide", "A adenoma colon slide"])
# text = [tokenizer.encode(t) for t in ["A normal colon slide", "A serrated colon slide", "A adenocarcinoma colon slide", "A adenoma colon slide"]]
with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)

print("Label probs:", text_probs)  # prints: [[1., 0., 0.]]

Label probs: tensor([[0.2777, 0.4056, 0.1222, 0.1946]])


## Chaoyang Dataset Preparation.

## Following the setting of the GuiltNet

In [39]:
preprocess_val

Compose(
    Resize(size=224, interpolation=bicubic, max_size=None, antialias=None)
    CenterCrop(size=(224, 224))
    <function _convert_to_rgb at 0x7efeb53b1430>
    ToTensor()
    Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
)

In [13]:
import torch.utils.data as data
from PIL import Image
import os
import json
import numpy as np
from torch.utils.data import Dataset
from torchvision import datasets, transforms
import torch.nn as nn
import open_clip
from datasets import load_metric

In [15]:
# model.visual.logit_scale.requires_grad_(False)
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:wisdomik/QuiltNet-B-16')
# model.visual.logit_scale.requires_grad_(False)
model.visual.predict = nn.Linear(in_features=512, out_features=num_classes)
model.visual.predict.weight.data.normal_(mean=0.0, std=0.02)
# model.visual.predict.weight.data.normal_(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
model.visual.predict.bias.data.zero_()

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.visual.predict.parameters())


In [16]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
num_classes = 4  # You mentioned you have 4 classes


class Chaoyang(Dataset):
    def __init__(self, root='', transform=None, is_train=True, input_size=256):
        self.transform = transform
        self.is_train = is_train
        self.input_size = input_size

        if not self.is_train:
            imgs = []
            labels = []
            json_path = os.path.join(root, 'json', 'test_split_2.json')
            with open(json_path, 'r') as f:
                load_list = json.load(f)
                for i in range(len(load_list)):
                    img_path = os.path.join(root, load_list[i]["name"])
                    imgs.append(img_path)
                    labels.append(int(load_list[i]["label"]))
            self.test_data, self.test_labels = np.array(imgs), np.array(labels)
        else:  # is_train = True => Train.
            imgs = []
            labels = []
            json_path = os.path.join(root, 'json', 'train_split_2.json')
            with open(json_path, 'r') as f:
                load_list = json.load(f)
                for i in range(len(load_list)):
                    img_path = os.path.join(root, load_list[i]["name"])
                    imgs.append(img_path)
                    labels.append(int(load_list[i]["label"]))
            self.train_data, self.train_labels = np.array(imgs), np.array(labels)

    
    def __getitem__(self, idx):
        # In the function of building_dataset, the transform has been set following the is_train.
        if self.is_train:
            img, label = self.train_data[idx], self.train_labels[idx]
            img = Image.open(img)
            # img = img.convert('RGB')
            # img = Image.open(img).convert('RGB')
            img = self.transform(img)
            return img, label
        else:
            img, label = self.test_data[idx], self.test_labels[idx]
            # img = Image.open(img).convert('RGB')
            # img = img.convert('RGB')
            img = Image.open(img)
            img = self.transform(img)
            return img, label

    def __len__(self):
        if self.is_train:
            return len(self.train_data)
        else:
            return len(self.test_data)

In [17]:
from torch.utils.data import DataLoader
from transformers import AdamW

root = '/vol/research/wenjieProject/datasets/chaoyang'
input_size = 384

train_dataset = Chaoyang(root, preprocess_train, is_train=True, input_size=input_size)
test_dataset = Chaoyang(root, preprocess_val, is_train=False, input_size=input_size)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_laoder = DataLoader(test_dataset, batch_size=8, shuffle=False)

In [18]:
model.to(device)
model.train()
from itertools import islice

metric= load_metric("accuracy")
save_acc = False
acc_list = []
num_epochs = 30
for epoch in range(num_epochs):
    train_correct_predictions = 0
    train_total_predictions = 0
    training_accuracy = 0
    # for samples, labels in islice(tqdm(train_loader), 2):
    for samples, labels in tqdm(train_loader):
        # print(batch.shape)
        optimizer.zero_grad()
        samples = samples.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        
        # inputs = feature_extractor(images=samples, return_tensors="pt").pixel_values.to(model.device)
        image_features = model.encode_image(samples)
        # outputs = model(samples, labels=labels)
        # loss = outputs.loss
        logits = model.visual.predict(image_features)

        # print(logits)
        loss = criterion(logits, labels)
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        # lr_scheduler.step()
        optimizer.step()

        _, predicted = torch.max(logits.data, 1)  # gives us the index of the highest value
        train_total_predictions += labels.size(0)  # add total
        train_correct_predictions += (predicted == labels).sum().item()
        training_accuracy = train_correct_predictions / train_total_predictions
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}, Training accuracy: {training_accuracy}")
    ### Add the commands to save the trained model. 
    model.eval()
    # if epoch == 4:
    #     save_acc = True                                                                                                           
    with torch.no_grad():
        total_accuracy = 0
        total_batches = 0
        # for i, (images, labels) in islice(tqdm(enumerate(test_laoder)), 2):
        for i, (images, labels) in tqdm(enumerate(test_laoder)):
            images = images.to(device)
            labels = labels.to(device)

            # inputs = feature_extractor(images=images, return_tensors="pt").pixel_values.to(model.device)
            # outputs = model(images)

            image_features = model.encode_image(images)
            # outputs = model(samples, labels=labels)
            # loss = outputs.loss
            logits = model.visual.predict(image_features)
            # print(outputs.logits)

            predictions = torch.argmax(logits.data, 1)
            
            # print(predictions)
            # metric.add_batch(predictions=predictions, references=labels)
            accuracy = metric.compute(predictions=predictions.tolist(), references=labels.tolist())["accuracy"]
            if save_acc:
                acc_list.append(accuracy)
            # print('the {} th image: {:.2f}'.format(str(i), accuracy))
            total_accuracy += accuracy
            total_batches += 1

        # metric.compute()
        average_accuracy = total_accuracy / total_batches
        print(f"Epoch [{epoch+1}/{num_epochs}]  Accuracy: {average_accuracy}")
       
    model.train()

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


  0%|          | 0/591 [00:00<?, ?it/s]

Epoch 1/30, Loss: 0.3477548062801361, Training accuracy: 0.668994708994709


0it [00:00, ?it/s]

Epoch [1/30]  Accuracy: 0.6414351851851852


  0%|          | 0/591 [00:00<?, ?it/s]

Epoch 2/30, Loss: 0.9428674578666687, Training accuracy: 0.7375661375661375


0it [00:00, ?it/s]

Epoch [2/30]  Accuracy: 0.6879629629629629


  0%|          | 0/591 [00:00<?, ?it/s]

Epoch 3/30, Loss: 1.0729057788848877, Training accuracy: 0.7521693121693122


0it [00:00, ?it/s]

Epoch [3/30]  Accuracy: 0.7289351851851852


  0%|          | 0/591 [00:00<?, ?it/s]

Epoch 4/30, Loss: 0.3970719873905182, Training accuracy: 0.7604232804232804


0it [00:00, ?it/s]

Epoch [4/30]  Accuracy: 0.7094907407407407


  0%|          | 0/591 [00:00<?, ?it/s]

Epoch 5/30, Loss: 0.4364991784095764, Training accuracy: 0.7659259259259259


0it [00:00, ?it/s]

Epoch [5/30]  Accuracy: 0.712962962962963


  0%|          | 0/591 [00:00<?, ?it/s]

Epoch 6/30, Loss: 0.4534083306789398, Training accuracy: 0.7712169312169312


0it [00:00, ?it/s]

Epoch [6/30]  Accuracy: 0.7275462962962963


  0%|          | 0/591 [00:00<?, ?it/s]

Epoch 7/30, Loss: 1.1124347448349, Training accuracy: 0.7807407407407407


0it [00:00, ?it/s]

Epoch [7/30]  Accuracy: 0.674074074074074


  0%|          | 0/591 [00:00<?, ?it/s]

Epoch 8/30, Loss: 0.13685578107833862, Training accuracy: 0.783068783068783


0it [00:00, ?it/s]

Epoch [8/30]  Accuracy: 0.725462962962963


  0%|          | 0/591 [00:00<?, ?it/s]

Epoch 9/30, Loss: 0.40192848443984985, Training accuracy: 0.7826455026455027


0it [00:00, ?it/s]

Epoch [9/30]  Accuracy: 0.712962962962963


  0%|          | 0/591 [00:00<?, ?it/s]

Epoch 10/30, Loss: 1.0750982761383057, Training accuracy: 0.788994708994709


0it [00:00, ?it/s]

Epoch [10/30]  Accuracy: 0.7108796296296296


  0%|          | 0/591 [00:00<?, ?it/s]

Epoch 11/30, Loss: 0.9324085116386414, Training accuracy: 0.7866666666666666


0it [00:00, ?it/s]

Epoch [11/30]  Accuracy: 0.6726851851851852


  0%|          | 0/591 [00:00<?, ?it/s]

Epoch 12/30, Loss: 0.21405000984668732, Training accuracy: 0.7839153439153439


0it [00:00, ?it/s]

Epoch [12/30]  Accuracy: 0.7358796296296297


  0%|          | 0/591 [00:00<?, ?it/s]

Epoch 13/30, Loss: 0.32930833101272583, Training accuracy: 0.7968253968253968


0it [00:00, ?it/s]

Epoch [13/30]  Accuracy: 0.6310185185185185


  0%|          | 0/591 [00:00<?, ?it/s]

Epoch 14/30, Loss: 0.4702834188938141, Training accuracy: 0.7991534391534392


0it [00:00, ?it/s]

Epoch [14/30]  Accuracy: 0.6476851851851851


  0%|          | 0/591 [00:00<?, ?it/s]

Epoch 15/30, Loss: 0.8269513845443726, Training accuracy: 0.7953439153439154


0it [00:00, ?it/s]

Epoch [15/30]  Accuracy: 0.700462962962963


  0%|          | 0/591 [00:00<?, ?it/s]

Epoch 16/30, Loss: 0.4041214883327484, Training accuracy: 0.8002116402116403


0it [00:00, ?it/s]

Epoch [16/30]  Accuracy: 0.712962962962963


  0%|          | 0/591 [00:00<?, ?it/s]

Epoch 17/30, Loss: 0.24836046993732452, Training accuracy: 0.8031746031746032


0it [00:00, ?it/s]

Epoch [17/30]  Accuracy: 0.737962962962963


  0%|          | 0/591 [00:00<?, ?it/s]

Epoch 18/30, Loss: 0.2045007199048996, Training accuracy: 0.7976719576719576


0it [00:00, ?it/s]

Epoch [18/30]  Accuracy: 0.756712962962963


  0%|          | 0/591 [00:00<?, ?it/s]

Epoch 19/30, Loss: 1.4663653373718262, Training accuracy: 0.7974603174603174


0it [00:00, ?it/s]

Epoch [19/30]  Accuracy: 0.7546296296296297


  0%|          | 0/591 [00:00<?, ?it/s]

Epoch 20/30, Loss: 0.20951437950134277, Training accuracy: 0.8031746031746032


0it [00:00, ?it/s]

Epoch [20/30]  Accuracy: 0.737962962962963


  0%|          | 0/591 [00:00<?, ?it/s]

Epoch 21/30, Loss: 0.28925246000289917, Training accuracy: 0.8044444444444444


0it [00:00, ?it/s]

Epoch [21/30]  Accuracy: 0.7101851851851851


  0%|          | 0/591 [00:00<?, ?it/s]

Epoch 22/30, Loss: 0.10776881873607635, Training accuracy: 0.8035978835978836


0it [00:00, ?it/s]

Epoch [22/30]  Accuracy: 0.6768518518518518


  0%|          | 0/591 [00:00<?, ?it/s]

Epoch 23/30, Loss: 0.7444950342178345, Training accuracy: 0.8023280423280423


0it [00:00, ?it/s]

Epoch [23/30]  Accuracy: 0.7039351851851852


  0%|          | 0/591 [00:00<?, ?it/s]

Epoch 24/30, Loss: 0.8247544169425964, Training accuracy: 0.806984126984127


0it [00:00, ?it/s]

Epoch [24/30]  Accuracy: 0.686574074074074


  0%|          | 0/591 [00:00<?, ?it/s]

Epoch 25/30, Loss: 0.2940449118614197, Training accuracy: 0.8076190476190476


0it [00:00, ?it/s]

Epoch [25/30]  Accuracy: 0.7143518518518519


  0%|          | 0/591 [00:00<?, ?it/s]

Epoch 26/30, Loss: 0.5143049955368042, Training accuracy: 0.8059259259259259


0it [00:00, ?it/s]

Epoch [26/30]  Accuracy: 0.6949074074074074


  0%|          | 0/591 [00:00<?, ?it/s]

Epoch 27/30, Loss: 0.6700918674468994, Training accuracy: 0.8078306878306878


0it [00:00, ?it/s]

Epoch [27/30]  Accuracy: 0.7289351851851852


  0%|          | 0/591 [00:00<?, ?it/s]

Epoch 28/30, Loss: 0.6491243243217468, Training accuracy: 0.8082539682539682


0it [00:00, ?it/s]

Epoch [28/30]  Accuracy: 0.6587962962962963


  0%|          | 0/591 [00:00<?, ?it/s]

Epoch 29/30, Loss: 0.06101410463452339, Training accuracy: 0.8186243386243386


0it [00:00, ?it/s]

Epoch [29/30]  Accuracy: 0.7025462962962963


  0%|          | 0/591 [00:00<?, ?it/s]

Epoch 30/30, Loss: 0.5668994188308716, Training accuracy: 0.8165079365079365


0it [00:00, ?it/s]

Epoch [30/30]  Accuracy: 0.7094907407407407


## This function is copied from the github website.

In [None]:
def run_zero_shot(model_name, data_dict, pre_trained_path='', context_lenght=77, wrks=0, bsz=8, device='gpu'):
    model, pct, pv = open_clip.create_model_and_transforms(model_name, pretrained=pre_trained_path)
    tokenizer = open_clip.get_tokenizer(model_name)

    model.to(device)
    model.eval()

    res = {}
    for data_name, data in data_dict.items():  
        labels = data.classes
        print(labels)
        
        templates = ["a histopathology slide showing {c}",
            "histopathology image of {c}",
            "pathology tissue showing {c}",
            "presence of {c} tissue on image"]
        
        if pv:
            data.transform = pv

        loader = torch.utils.data.DataLoader(data, batch_size=bsz, 
                                             shuffle=True, num_workers=wrks,
                                             pin_memory=False
                                            )

        pred = []
        true = []

        for images, target in tqdm(loader):
            images, target = images.to(device), target.to(device)

            with torch.no_grad(): 
                image_features = model.encode_image(images, normalize=True)

                zeroshot_weights = []
                for classname in labels:
                    texts = [template.format(c=classname) for template in templates]
                    texts = tokenizer(texts).to(device)  # tokenize
                    class_embeddings = model.encode_text(texts)
                    class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)
                    class_embedding /= class_embedding.norm()
                    zeroshot_weights.append(class_embedding)
                    
                zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
                logits = (100. * image_features @ zeroshot_weights).detach()#.softmax(dim=-1)

                true.append(target.cpu())
                pred.append(logits.float().cpu())

        pred = torch.cat(pred)
        true = torch.cat(true)

        topk = (1,)
        pred_n = pred.topk(max(topk), 1, True, True)[1].t()
        correct = pred_n.eq(true.view(1, -1).expand_as(pred_n))
        n = len(true)
        res[data_name] = [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) / n for k in topk][0]
        print(res)
    return res

In [None]:
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import requests
import torch
import torch.nn as nn
from tqdm import tqdm 
from transformers import get_scheduler
from datasets import load_metric
from tqdm import tqdm
import cv2

torch.cuda.empty_cache()
num_epochs = 20
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)


# num_training_steps = num_epochs * len(train_loader)
# lr_scheduler = get_scheduler(
#     "linear",
#     optimizer=optimizer,
#     num_warmup_steps=0,
#     num_training_steps=num_training_steps
# )

model.to(device)
model.train()

metric= load_metric("accuracy")
save_acc = False
acc_list = []
for epoch in range(num_epochs):
    for samples, labels in tqdm(train_loader):
        # print(batch.shape)
        optimizer.zero_grad()
        samples = samples.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        
        # inputs = feature_extractor(images=samples, return_tensors="pt").pixel_values.to(model.device)
        outputs = model(samples, labels=labels)
        loss = outputs.loss
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        # lr_scheduler.step()
        optimizer.step()
        
        # loss_accum += loss.item()
    print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {loss.item()}")
    ### Add the commands to save the trained model. 
    model.eval()
    if epoch == 4:
        save_acc = True                                                                                                           
    with torch.no_grad():
        total_accuracy = 0
        total_batches = 0
        for i, (images, labels) in tqdm(enumerate(test_laoder)):
            images = images.to(model.device)
            labels = labels.to(model.device)

            # inputs = feature_extractor(images=images, return_tensors="pt").pixel_values.to(model.device)
            outputs = model(images, labels=labels)
            # print(outputs.logits)

            predictions = torch.argmax(outputs.logits, 1)
            # print(predictions)
            # metric.add_batch(predictions=predictions, references=labels)
            accuracy = metric.compute(predictions=predictions.tolist(), references=labels.tolist())["accuracy"]
            if save_acc:
                acc_list.append(accuracy)
            # print('the {} th image: {:.2f}'.format(str(i), accuracy))
            total_accuracy += accuracy
            total_batches += 1

        # metric.compute()
        average_accuracy = total_accuracy / total_batches
        print(f"Epoch [{epoch+1}/{num_epochs}]  Accuracy: {average_accuracy}")
       
    model.train()