# This Notebook is for finetuning CheXzero and CLIP based models in general

In [44]:
import sys
import torch
import yaml

from tqdm import tqdm
import torch.optim as optim
import torch.nn as nn


sys.path.append(r'C:\Users\Vishi\VSC Codes\VIsLM_seminar\VLP-Seminar')
sys.path.append(r'C:\Users\Vishi\VSC Codes\VIsLM_seminar\VLP-Seminar\cheXzeroCode')

import cheXzeroCode.clip as clip
from cheXzeroCode.train import load_clip
from cheXzeroCode.train import preprocess_text

from Finetune.datasets.data_module import DataModule
from Finetune.datasets.transforms import DataTransforms
from Finetune.datasets.cls_dataset import RSNAImageClsDataset, ChexPertImageClsDataset


In [45]:
checkpoint_path = r'C:\Users\Vishi\VSC Codes\VIsLM_seminar\VLP-Seminar\data\checkpoints\chexZero checkpoints\best_128_0.0002_original_15000_0.859.pt'

In [46]:

# OpenAI CLIP model and preprocessing
model = load_clip(model_path=checkpoint_path, pretrained = True)

Loaded in pretrained model.


  model.load_state_dict(torch.load(model_path, map_location=device))


In [47]:
default_values = {
    "dataset": "rsna",
    "gpus": 1,
    "config": r"C:\Users\Vishi\VSC Codes\VIsLM_seminar\VLP-Seminar\configs\rsna.yaml",
    "batch_size": 12,
    "num_workers": 16,
    "data_pct": 1.0,
    "max_epochs": 50,
    "ckpt_dir": "data/ckpts",
    "logger_dir": "data/log_output"
}

print(default_values)


{'dataset': 'rsna', 'gpus': 1, 'config': 'C:\\Users\\Vishi\\VSC Codes\\VIsLM_seminar\\VLP-Seminar\\configs\\rsna.yaml', 'batch_size': 12, 'num_workers': 16, 'data_pct': 1.0, 'max_epochs': 50, 'ckpt_dir': 'data/ckpts', 'logger_dir': 'data/log_output'}


In [48]:
if default_values["dataset"] == "rsna":
    num_classes = 2
elif default_values["dataset"] == "chexpert":
    num_classes = 14

In [49]:
def load_config(config_path):
    with open(config_path, "r") as file:
        config = yaml.safe_load(file)
    return config

config = load_config(default_values['config'])


In [50]:
datamodule = DataModule(dataset=RSNAImageClsDataset,
                        config=config, collate_fn=None,
                        transforms=DataTransforms,
                        data_pct=default_values['data_pct'],
                        batch_size=default_values['batch_size'],
                        num_workers=default_values['num_workers']
                        )

In [51]:
train_loader = datamodule.train_dataloader()
val_loader = datamodule.val_dataloader()

Loading RSNA dataset
Dataset size of split train: 18678
Loading RSNA dataset
Dataset size of split valid: 4003


In [52]:

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [53]:
def train_batch(images, texts, model, device, criterion, optimizer):
    images, texts = images.to(device), texts.to(device)
    
    # Forward pass ➡
    logits_per_image, logits_per_text = model(images, texts)
    
    # Create labels
    batch_size = default_values['batch_size']
    labels = torch.arange(batch_size).to(device)
    
    # Compute loss
    loss_img = criterion(logits_per_image, labels)
    loss_txt = criterion(logits_per_text, labels)
    loss = (loss_img + loss_txt)/2 # avg. img and txt loss

    # Backward pass ⬅
    optimizer.zero_grad()
    loss.backward()
    
    # Step with optimizer
    optimizer.step()
        
    return loss


def train_log(loss, example_ct, epoch):
    loss = float(loss)
    print(f"Loss after " + str(example_ct).zfill(5) + f" examples: {loss:.3f}")
    

In [54]:

loader = val_loader

total_batches = len(loader) * default_values['max_epochs']
example_ct = 0  # number of examples seen
batch_ct = 0
highest_val_auc = 0
for epoch in range(default_values['max_epochs']):
    running_loss = 0.0 # running loss over batch
    for data in tqdm(loader):
        # get the images and labels
        image, label = data

        # print(label == 0)
        #for all 0 entries replace with 'no pneumonia'
        txt = ['no pneumonia' if x == 0 else 'pneumonia' for x in label]

        images = image

        # texts = data['txt']
        texts = preprocess_text(txt, model) 
        loss = train_batch(images, texts, model, device, criterion, optimizer)

        # # perform step for a single batch
        # loss = train_batch(images, texts, model, device, criterion, optimizer)
        # example_ct +=  len(images)
        # batch_ct += 1
        # running_loss += loss.item()




  2%|▏         | 8/333 [01:14<50:13,  9.27s/it]  


KeyboardInterrupt: 