# Testing the Mecial CLIP foundation models. 
1. Testing CLIP models for zero-shot classification methods.
2. QuiltNet: https://huggingface.co/wisdomik/QuiltNet-B-32; Use this medical CLIP for zero classification.

## Try the CLIP and Medical CLIP methods.

In [None]:
!pip install clip
!pip install open_clip_torch

In [7]:
import numpy as np
import torch
import clip
from tqdm.notebook import tqdm
from pkg_resources import packaging

print("Torch version:", torch.__version__)

Torch version: 1.12.1+cu116


In [4]:
import open_clip

model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:wisdomik/QuiltNet-B-32')
tokenizer = open_clip.get_tokenizer('hf-hub:wisdomik/QuiltNet-B-32')   # why I need this item?

image = Image.open("/vol/research/wenjieProject/datasets/chaoyang/train/538849-1-IMG006x019-0.JPG")
image = preprocess_val(image).unsqueeze(0)

# text = tokenizer.encode(["A normal colon slide", "A serrated colon slide", "A adenocarcinoma colon slide", "A adenoma colon slide"])
text = open_clip.tokenize(["A normal colon slide", "A serrated colon slide", "A adenocarcinoma colon slide", "A adenoma colon slide"])
# text = [tokenizer.encode(t) for t in ["A normal colon slide", "A serrated colon slide", "A adenocarcinoma colon slide", "A adenoma colon slide"]]
with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)

print("Label probs:", text_probs)  # prints: [[1., 0., 0.]]

Label probs: tensor([[0.2777, 0.4056, 0.1222, 0.1946]])


## Chaoyang Dataset Preparation.

## Following the setting of the GuiltNet

In [39]:
preprocess_val

Compose(
    Resize(size=224, interpolation=bicubic, max_size=None, antialias=None)
    CenterCrop(size=(224, 224))
    <function _convert_to_rgb at 0x7efeb53b1430>
    ToTensor()
    Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
)

In [13]:
import torch.utils.data as data
from PIL import Image
import os
import json
import numpy as np
from torch.utils.data import Dataset
from torchvision import datasets, transforms
import torch.nn as nn
import open_clip
from datasets import load_metric

In [15]:
# model.visual.logit_scale.requires_grad_(False)
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:wisdomik/QuiltNet-B-16')
# model.visual.logit_scale.requires_grad_(False)
model.visual.predict = nn.Linear(in_features=512, out_features=num_classes)
model.visual.predict.weight.data.normal_(mean=0.0, std=0.02)
# model.visual.predict.weight.data.normal_(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
model.visual.predict.bias.data.zero_()

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.visual.predict.parameters())


In [16]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
num_classes = 4  # You mentioned you have 4 classes


class Chaoyang(Dataset):
    def __init__(self, root='', transform=None, is_train=True, input_size=256):
        self.transform = transform
        self.is_train = is_train
        self.input_size = input_size

        if not self.is_train:
            imgs = []
            labels = []
            json_path = os.path.join(root, 'json', 'test_split_2.json')
            with open(json_path, 'r') as f:
                load_list = json.load(f)
                for i in range(len(load_list)):
                    img_path = os.path.join(root, load_list[i]["name"])
                    imgs.append(img_path)
                    labels.append(int(load_list[i]["label"]))
            self.test_data, self.test_labels = np.array(imgs), np.array(labels)
        else:  # is_train = True => Train.
            imgs = []
            labels = []
            json_path = os.path.join(root, 'json', 'train_split_2.json')
            with open(json_path, 'r') as f:
                load_list = json.load(f)
                for i in range(len(load_list)):
                    img_path = os.path.join(root, load_list[i]["name"])
                    imgs.append(img_path)
                    labels.append(int(load_list[i]["label"]))
            self.train_data, self.train_labels = np.array(imgs), np.array(labels)

    
    def __getitem__(self, idx):
        # In the function of building_dataset, the transform has been set following the is_train.
        if self.is_train:
            img, label = self.train_data[idx], self.train_labels[idx]
            img = Image.open(img)
            # img = img.convert('RGB')
            # img = Image.open(img).convert('RGB')
            img = self.transform(img)
            return img, label
        else:
            img, label = self.test_data[idx], self.test_labels[idx]
            # img = Image.open(img).convert('RGB')
            # img = img.convert('RGB')
            img = Image.open(img)
            img = self.transform(img)
            return img, label

    def __len__(self):
        if self.is_train:
            return len(self.train_data)
        else:
            return len(self.test_data)

In [17]:
from torch.utils.data import DataLoader
from transformers import AdamW

root = '/vol/research/wenjieProject/datasets/chaoyang'
input_size = 384

train_dataset = Chaoyang(root, preprocess_train, is_train=True, input_size=input_size)
test_dataset = Chaoyang(root, preprocess_val, is_train=False, input_size=input_size)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_laoder = DataLoader(test_dataset, batch_size=8, shuffle=False)

In [12]:
model.to(device)
model.train()
from itertools import islice

metric= load_metric("accuracy")
save_acc = False
acc_list = []
num_epochs = 30
for epoch in range(num_epochs):
    train_correct_predictions = 0
    train_total_predictions = 0
    training_accuracy = 0
    # for samples, labels in islice(tqdm(train_loader), 2):
    for samples, labels in tqdm(train_loader):
        # print(batch.shape)
        optimizer.zero_grad()
        samples = samples.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        
        # inputs = feature_extractor(images=samples, return_tensors="pt").pixel_values.to(model.device)
        image_features = model.encode_image(samples)
        # outputs = model(samples, labels=labels)
        # loss = outputs.loss
        logits = model.visual.predict(image_features)

        # print(logits)
        loss = criterion(logits, labels)
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        # lr_scheduler.step()
        optimizer.step()

        _, predicted = torch.max(logits.data, 1)  # gives us the index of the highest value
        train_total_predictions += labels.size(0)  # add total
        train_correct_predictions += (predicted == labels).sum().item()
        training_accuracy = train_correct_predictions / train_total_predictions
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}, Training accuracy: {training_accuracy}")
    ### Add the commands to save the trained model. 
    model.eval()
    # if epoch == 4:
    #     save_acc = True                                                                                                           
    with torch.no_grad():
        total_accuracy = 0
        total_batches = 0
        # for i, (images, labels) in islice(tqdm(enumerate(test_laoder)), 2):
        for i, (images, labels) in tqdm(enumerate(test_laoder)):
            images = images.to(device)
            labels = labels.to(device)

            # inputs = feature_extractor(images=images, return_tensors="pt").pixel_values.to(model.device)
            # outputs = model(images)

            image_features = model.encode_image(images)
            # outputs = model(samples, labels=labels)
            # loss = outputs.loss
            logits = model.visual.predict(image_features)
            # print(outputs.logits)

            predictions = torch.argmax(logits.data, 1)
            
            # print(predictions)
            # metric.add_batch(predictions=predictions, references=labels)
            accuracy = metric.compute(predictions=predictions.tolist(), references=labels.tolist())["accuracy"]
            if save_acc:
                acc_list.append(accuracy)
            # print('the {} th image: {:.2f}'.format(str(i), accuracy))
            total_accuracy += accuracy
            total_batches += 1

        # metric.compute()
        average_accuracy = total_accuracy / total_batches
        print(f"Epoch [{epoch+1}/{num_epochs}]  Accuracy: {average_accuracy}")
       
    model.train()

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


  0%|          | 0/503 [00:00<?, ?it/s]

Epoch 1/20, Loss: 0.3145594000816345, Training accuracy: 0.6913703058940562


0it [00:00, ?it/s]

Epoch [1/20]  Accuracy: 0.7328980099502487


  0%|          | 0/503 [00:00<?, ?it/s]

Epoch 2/20, Loss: 0.22075827419757843, Training accuracy: 0.7197214623228053


0it [00:00, ?it/s]

Epoch [2/20]  Accuracy: 0.745957711442786


  0%|          | 0/503 [00:00<?, ?it/s]

Epoch 3/20, Loss: 0.5656706094741821, Training accuracy: 0.7284257647351405


0it [00:00, ?it/s]

Epoch [3/20]  Accuracy: 0.7468905472636815


  0%|          | 0/503 [00:00<?, ?it/s]

Epoch 4/20, Loss: 0.4472316801548004, Training accuracy: 0.7406117881124098


0it [00:00, ?it/s]

Epoch [4/20]  Accuracy: 0.7660136815920398


  0%|          | 0/503 [00:00<?, ?it/s]

Epoch 5/20, Loss: 0.6933461427688599, Training accuracy: 0.7495647848793833


0it [00:00, ?it/s]

Epoch [5/20]  Accuracy: 0.7548196517412935


  0%|          | 0/503 [00:00<?, ?it/s]

Epoch 6/20, Loss: 0.08249319344758987, Training accuracy: 0.7527978114896792


0it [00:00, ?it/s]

Epoch [6/20]  Accuracy: 0.7669465174129353


  0%|          | 0/503 [00:00<?, ?it/s]

Epoch 7/20, Loss: 0.7747183442115784, Training accuracy: 0.7580203929370803


0it [00:00, ?it/s]

Epoch [7/20]  Accuracy: 0.7720771144278606


  0%|          | 0/503 [00:00<?, ?it/s]

Epoch 8/20, Loss: 0.5288785696029663, Training accuracy: 0.7597612534195474


0it [00:00, ?it/s]

Epoch [8/20]  Accuracy: 0.7613495024875622


  0%|          | 0/503 [00:00<?, ?it/s]

Epoch 9/20, Loss: 0.15270616114139557, Training accuracy: 0.7610047251927381


0it [00:00, ?it/s]

Epoch [9/20]  Accuracy: 0.7730099502487562


  0%|          | 0/503 [00:00<?, ?it/s]

Epoch 10/20, Loss: 0.5572280883789062, Training accuracy: 0.7627455856752051


0it [00:00, ?it/s]

Epoch [10/20]  Accuracy: 0.7688121890547264


  0%|          | 0/503 [00:00<?, ?it/s]

Epoch 11/20, Loss: 0.6232172250747681, Training accuracy: 0.7644864461576723


0it [00:00, ?it/s]

Epoch [11/20]  Accuracy: 0.7828047263681591


  0%|          | 0/503 [00:00<?, ?it/s]

Epoch 12/20, Loss: 0.44895705580711365, Training accuracy: 0.7637403630937578


0it [00:00, ?it/s]

Epoch [12/20]  Accuracy: 0.7786069651741293


  0%|          | 0/503 [00:00<?, ?it/s]

Epoch 13/20, Loss: 0.8275815844535828, Training accuracy: 0.7764237751803034


0it [00:00, ?it/s]

Epoch [13/20]  Accuracy: 0.777518656716418


  0%|          | 0/503 [00:00<?, ?it/s]

Epoch 14/20, Loss: 0.7952988147735596, Training accuracy: 0.7771698582442178


0it [00:00, ?it/s]

Epoch [14/20]  Accuracy: 0.7837375621890547


  0%|          | 0/503 [00:00<?, ?it/s]

Epoch 15/20, Loss: 0.5941692590713501, Training accuracy: 0.7791594130813231


0it [00:00, ?it/s]

Epoch [15/20]  Accuracy: 0.7688121890547264


  0%|          | 0/503 [00:00<?, ?it/s]

Epoch 16/20, Loss: 0.20964916050434113, Training accuracy: 0.7769211638895797


0it [00:00, ?it/s]

Epoch [16/20]  Accuracy: 0.7860696517412935


  0%|          | 0/503 [00:00<?, ?it/s]

Epoch 17/20, Loss: 0.4575260281562805, Training accuracy: 0.7843819945287241


0it [00:00, ?it/s]

Epoch [17/20]  Accuracy: 0.7921330845771144


  0%|          | 0/503 [00:00<?, ?it/s]

Epoch 18/20, Loss: 0.6169818639755249, Training accuracy: 0.7811489679184282


0it [00:00, ?it/s]

Epoch [18/20]  Accuracy: 0.7840485074626866


  0%|          | 0/503 [00:00<?, ?it/s]

Epoch 19/20, Loss: 1.1779961585998535, Training accuracy: 0.7816463566277045


0it [00:00, ?it/s]

Epoch [19/20]  Accuracy: 0.7737873134328358


  0%|          | 0/503 [00:00<?, ?it/s]

Epoch 20/20, Loss: 0.3746304512023926, Training accuracy: 0.7901019646854016


0it [00:00, ?it/s]

Epoch [20/20]  Accuracy: 0.7935323383084577


## This function is copied from the github website.

In [None]:
def run_zero_shot(model_name, data_dict, pre_trained_path='', context_lenght=77, wrks=0, bsz=8, device='gpu'):
    model, pct, pv = open_clip.create_model_and_transforms(model_name, pretrained=pre_trained_path)
    tokenizer = open_clip.get_tokenizer(model_name)

    model.to(device)
    model.eval()

    res = {}
    for data_name, data in data_dict.items():  
        labels = data.classes
        print(labels)
        
        templates = ["a histopathology slide showing {c}",
            "histopathology image of {c}",
            "pathology tissue showing {c}",
            "presence of {c} tissue on image"]
        
        if pv:
            data.transform = pv

        loader = torch.utils.data.DataLoader(data, batch_size=bsz, 
                                             shuffle=True, num_workers=wrks,
                                             pin_memory=False
                                            )

        pred = []
        true = []

        for images, target in tqdm(loader):
            images, target = images.to(device), target.to(device)

            with torch.no_grad(): 
                image_features = model.encode_image(images, normalize=True)

                zeroshot_weights = []
                for classname in labels:
                    texts = [template.format(c=classname) for template in templates]
                    texts = tokenizer(texts).to(device)  # tokenize
                    class_embeddings = model.encode_text(texts)
                    class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)
                    class_embedding /= class_embedding.norm()
                    zeroshot_weights.append(class_embedding)
                    
                zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
                logits = (100. * image_features @ zeroshot_weights).detach()#.softmax(dim=-1)

                true.append(target.cpu())
                pred.append(logits.float().cpu())

        pred = torch.cat(pred)
        true = torch.cat(true)

        topk = (1,)
        pred_n = pred.topk(max(topk), 1, True, True)[1].t()
        correct = pred_n.eq(true.view(1, -1).expand_as(pred_n))
        n = len(true)
        res[data_name] = [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) / n for k in topk][0]
        print(res)
    return res

In [None]:
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import requests
import torch
import torch.nn as nn
from tqdm import tqdm 
from transformers import get_scheduler
from datasets import load_metric
from tqdm import tqdm
import cv2

torch.cuda.empty_cache()
num_epochs = 20
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)


# num_training_steps = num_epochs * len(train_loader)
# lr_scheduler = get_scheduler(
#     "linear",
#     optimizer=optimizer,
#     num_warmup_steps=0,
#     num_training_steps=num_training_steps
# )

model.to(device)
model.train()

metric= load_metric("accuracy")
save_acc = False
acc_list = []
for epoch in range(num_epochs):
    for samples, labels in tqdm(train_loader):
        # print(batch.shape)
        optimizer.zero_grad()
        samples = samples.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        
        # inputs = feature_extractor(images=samples, return_tensors="pt").pixel_values.to(model.device)
        outputs = model(samples, labels=labels)
        loss = outputs.loss
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        # lr_scheduler.step()
        optimizer.step()
        
        # loss_accum += loss.item()
    print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {loss.item()}")
    ### Add the commands to save the trained model. 
    model.eval()
    if epoch == 4:
        save_acc = True                                                                                                           
    with torch.no_grad():
        total_accuracy = 0
        total_batches = 0
        for i, (images, labels) in tqdm(enumerate(test_laoder)):
            images = images.to(model.device)
            labels = labels.to(model.device)

            # inputs = feature_extractor(images=images, return_tensors="pt").pixel_values.to(model.device)
            outputs = model(images, labels=labels)
            # print(outputs.logits)

            predictions = torch.argmax(outputs.logits, 1)
            # print(predictions)
            # metric.add_batch(predictions=predictions, references=labels)
            accuracy = metric.compute(predictions=predictions.tolist(), references=labels.tolist())["accuracy"]
            if save_acc:
                acc_list.append(accuracy)
            # print('the {} th image: {:.2f}'.format(str(i), accuracy))
            total_accuracy += accuracy
            total_batches += 1

        # metric.compute()
        average_accuracy = total_accuracy / total_batches
        print(f"Epoch [{epoch+1}/{num_epochs}]  Accuracy: {average_accuracy}")
       
    model.train()