# References

Hugging Face - Documentation [online]. Huggingface.co. Available from: https://huggingface.co/docs.

PyTorch Foundation [online]. PyTorch. Available from: https://pytorch.org/.

zhanyuwang, n.d. R2GenGPT: Radiology Report Generation with Frozen LLMs. Available from: https://github.com/wang-zhanyu/R2GenGPT

# <i> Huggingface login </i>

To work with gated repositores, we need to login to huggingface hub

In [None]:
from huggingface_hub import notebook_login
from google.colab import userdata

notebook_login(userdata.get('HF_TOKEN'))



VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

# Importing demanded libraries

In [None]:
!pip install lightning

In [None]:
!pip install -U bitsandbytes

In [None]:
import os
import copy
import math
import json
import re
import yaml
import numpy as np
import functools
import pandas as pd
from PIL import Image
from collections import defaultdict

import torch
import torch.nn as nn
import torch.utils.data as data
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import LambdaLR
import torch.optim as optim

from peft import get_peft_model, LoraConfig, TaskType
from transformers import (
    AutoImageProcessor,
    SwinModel,
    AutoModel,
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoModelForImageClassification,
    Trainer,
    TrainingArguments
)

import lightning.pytorch as pl
from lightning.pytorch import LightningDataModule, seed_everything
from lightning.pytorch import loggers as pl_loggers
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint, EarlyStopping

# Data helper

In [None]:
class FieldParser:
    def __init__(self):
        super().__init__()
        cache_dir = '/content/huggingface'
        self.vit_feature_extractor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50", cache_dir=cache_dir)

    def _parse_image(self, img):
        img = np.resize(img, (224, 224, 3))
        pixel_values = self.vit_feature_extractor(img, return_tensors="pt").pixel_values
        return pixel_values[0]

    def clean_report(self, report):
        report_cleaner = lambda t: t.replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '') \
        .replace('. 2. ', '. ').replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ') \
        .replace(' 2. ', '. ').replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \
        .strip().lower().split('. ')
        sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '').
                                        replace('\\', '').replace("'", '').strip().lower())
        tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []]
        report = ' . '.join(tokens) + ' .'
        # report = ' '.join(report.split()[:max_txt_len])
        return report

    def parse(self, features):
        to_return = {'id': features['id']}
        report = features.get("report", "")
        report = self.clean_report(report)
        to_return['input_text'] = report
        # chest x-ray images
        images = []
        for image_path in features['image_path']:
            with Image.open(os.path.join('/content/drive/MyDrive/iu_xray/images', image_path)) as pil:
                array = np.array(pil, dtype=np.uint8)
                if array.shape[-1] != 3 or len(array.shape) != 3:
                    array = np.array(pil.convert("RGB"), dtype=np.uint8)
                image = self._parse_image(array)
                images.append(image)
        to_return["image"] = images
        return to_return

    def transform_with_parse(self, inputs):
        return self.parse(inputs)

# Data Module

In [None]:
class ParseDataset(data.Dataset):
    def __init__(self, split='train'):
        filename = '/content/drive/MyDrive/iu_xray/annotation.json'
        self.meta = json.load(open(filename, 'r'))
        self.meta = self.meta[split]
        self.parser = FieldParser()

    def __len__(self):
        return len(self.meta)

    def __getitem__(self, index):
        return self.parser.transform_with_parse(self.meta[index])

In [None]:
def create_datasets():
    train_dataset = ParseDataset('train')
    dev_dataset = ParseDataset('val')
    test_dataset = ParseDataset('test')
    return train_dataset, dev_dataset, test_dataset

In [None]:
class DataModule(LightningDataModule):
    def __init__(self):
        super().__init__()

    def setup(self, stage: str):
        train_dataset, dev_dataset, test_dataset = create_datasets()
        self.dataset = {
            "train": train_dataset, "validation": dev_dataset, "test": test_dataset
        }


    def train_dataloader(self):
        loader = DataLoader(self.dataset["train"], batch_size=4, drop_last=True, pin_memory=True,
                        num_workers=4, prefetch_factor=4)
        return loader


    def val_dataloader(self):
        loader = DataLoader(self.dataset["validation"], batch_size=8, drop_last=False, pin_memory=True,
                            num_workers=4, prefetch_factor=4)
        return loader


    def test_dataloader(self):
        loader = DataLoader(self.dataset["test"], batch_size=8, drop_last=False, pin_memory=False,
                        num_workers=4, prefetch_factor=4)
        return loader

# Callbacks

In [None]:
def add_callbacks():
    log_dir = "savedmodels"
    os.makedirs(log_dir, exist_ok=True)
    checkpoint_callback = ModelCheckpoint(
        dirpath=os.path.join(log_dir, "checkpoints"),
        filename="{epoch}-{step}",
        save_top_k=1,
        every_n_train_steps=10,
        save_last=False,
        save_weights_only=False
    )

    lr_monitor_callback = LearningRateMonitor(logging_interval='step')
    tb_logger = pl_loggers.TensorBoardLogger(save_dir=os.path.join(log_dir, "logs"), name="tensorboard")
    csv_logger = CSVLogger(save_dir=os.path.join(log_dir, "logs"), name="csvlog")

    to_returns = {
        "callbacks": [checkpoint_callback, lr_monitor_callback],
        "loggers": [csv_logger, tb_logger]
    }
    return to_returns

# Optimizers

In [None]:
def lr_lambda(current_step, num_warmup_steps, num_training_steps):
    if current_step < num_warmup_steps:
        return float(current_step) / float(max(1, num_warmup_steps))
    return max(
        0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
    )

def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):

    return LambdaLR(optimizer,
                    functools.partial(
                        lr_lambda,
                        num_warmup_steps=num_warmup_steps,
                        num_training_steps=num_training_steps),
                    last_epoch)


def config_optimizer(parameters, init_lr, warmup_steps, max_steps, name='lr'):
    optimizer = optim.AdamW(
        parameters, lr=init_lr, eps=1e-8, correct_bias=False
    )

    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps,
    )
    scheduler = {'scheduler': scheduler, 'name': name, 'interval': 'step', 'frequency': 1}

    return optimizer, scheduler

# BLEU

In [None]:
def precook(s, n=4, out=False):
    words = s.split()
    counts = defaultdict(int)
    for k in range(1,n+1):
        for i in range(len(words)-k+1):
            ngram = tuple(words[i:i+k])
            counts[ngram] += 1
    return (len(words), counts)

def cook_refs(refs, eff=None, n=4):
    reflen = []
    maxcounts = {}
    for ref in refs:
        rl, counts = precook(ref, n)
        reflen.append(rl)
        for (ngram,count) in counts.items():
            maxcounts[ngram] = max(maxcounts.get(ngram,0), count)

    # Calculate effective reference sentence length.
    if eff == "shortest":
        reflen = min(reflen)
    elif eff == "average":
        reflen = float(sum(reflen))/len(reflen)

    return (reflen, maxcounts)

def cook_test(test, crefs, eff=None, n=4):
    reflen, refmaxcounts = crefs[0], crefs[1]

    testlen, counts = precook(test, n, True)

    result = {}

    # Calculate effective reference sentence length.

    if eff == "closest":
        result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1]
    else: ## i.e., "average" or "shortest" or None
        result["reflen"] = reflen

    result["testlen"] = testlen

    result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)]

    result['correct'] = [0]*n
    for (ngram, count) in counts.items():
        result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count)

    return result

class BleuScorer(object):

    __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen"
    # special_reflen is used in oracle (proportional effective ref len for a node).

    def copy(self):
        new = BleuScorer(n=self.n)
        new.ctest = copy.copy(self.ctest)
        new.crefs = copy.copy(self.crefs)
        new._score = None
        return new

    def __init__(self, test=None, refs=None, n=4, special_reflen=None):
        self.n = n
        self.crefs = []
        self.ctest = []
        self.cook_append(test, refs)
        self.special_reflen = special_reflen

    def cook_append(self, test, refs):
        if refs is not None:
            self.crefs.append(cook_refs(refs))
            if test is not None:
                cooked_test = cook_test(test, self.crefs[-1])
                self.ctest.append(cooked_test) ## N.B.: -1
            else:
                self.ctest.append(None) # lens of crefs and ctest have to match

        self._score = None ## need to recompute

    def ratio(self, option=None):
        self.compute_score(option=option)
        return self._ratio

    def score_ratio(self, option=None):
        return (self.fscore(option=option), self.ratio(option=option))

    def score_ratio_str(self, option=None):
        return "%.4f (%.2f)" % self.score_ratio(option)

    def reflen(self, option=None):
        self.compute_score(option=option)
        return self._reflen

    def testlen(self, option=None):
        self.compute_score(option=option)
        return self._testlen

    def retest(self, new_test):
        if type(new_test) is str:
            new_test = [new_test]
        assert len(new_test) == len(self.crefs), new_test
        self.ctest = []
        for t, rs in zip(new_test, self.crefs):
            self.ctest.append(cook_test(t, rs))
        self._score = None

        return self

    def rescore(self, new_test):
        return self.retest(new_test).compute_score()

    def size(self):
        assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
        return len(self.crefs)

    def __iadd__(self, other):
        if type(other) is tuple:
            ## avoid creating new BleuScorer instances
            self.cook_append(other[0], other[1])
        else:
            assert self.compatible(other), "incompatible BLEUs."
            self.ctest.extend(other.ctest)
            self.crefs.extend(other.crefs)
            self._score = None ## need to recompute

        return self

    def compatible(self, other):
        return isinstance(other, BleuScorer) and self.n == other.n

    def single_reflen(self, option="average"):
        return self._single_reflen(self.crefs[0][0], option)

    def _single_reflen(self, reflens, option=None, testlen=None):

        if option == "shortest":
            reflen = min(reflens)
        elif option == "average":
            reflen = float(sum(reflens))/len(reflens)
        elif option == "closest":
            reflen = min((abs(l-testlen), l) for l in reflens)[1]
        else:
            assert False, "unsupported reflen option %s" % option

        return reflen

    def recompute_score(self, option=None, verbose=0):
        self._score = None
        return self.compute_score(option, verbose)

    def compute_score(self, option=None, verbose=0):
        n = self.n
        small = 1e-9
        tiny = 1e-15 ## so that if guess is 0 still return 0
        bleu_list = [[] for _ in range(n)]

        if self._score is not None:
            return self._score

        if option is None:
            option = "average" if len(self.crefs) == 1 else "closest"

        self._testlen = 0
        self._reflen = 0
        totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n}

        # for each sentence
        for comps in self.ctest:
            testlen = comps['testlen']
            self._testlen += testlen

            if self.special_reflen is None: ## need computation
                reflen = self._single_reflen(comps['reflen'], option, testlen)
            else:
                reflen = self.special_reflen

            self._reflen += reflen

            for key in ['guess','correct']:
                for k in range(n):
                    totalcomps[key][k] += comps[key][k]

            # append per image bleu score
            bleu = 1.
            for k in range(n):
                bleu *= (float(comps['correct'][k]) + tiny) \
                        /(float(comps['guess'][k]) + small)
                bleu_list[k].append(bleu ** (1./(k+1)))
            ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division
            if ratio < 1:
                for k in range(n):
                    bleu_list[k][-1] *= math.exp(1 - 1/ratio)

            if verbose > 1:
                print(comps, reflen)

        totalcomps['reflen'] = self._reflen
        totalcomps['testlen'] = self._testlen

        bleus = []
        bleu = 1.
        for k in range(n):
            bleu *= float(totalcomps['correct'][k] + tiny) \
                    / (totalcomps['guess'][k] + small)
            bleus.append(bleu ** (1./(k+1)))
        ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division
        if ratio < 1:
            for k in range(n):
                bleus[k] *= math.exp(1 - 1/ratio)

        if verbose > 0:
            print(totalcomps)
            print("ratio:", ratio)

        self._score = bleus
        return self._score, bleu_list

In [None]:
class Bleu:
    def __init__(self, n=4):
        # default compute Blue score up to 4
        self._n = n
        self._hypo_for_image = {}
        self.ref_for_image = {}

    def compute_score(self, gts, res, verbose=0):

        assert(gts.keys() == res.keys())
        imgIds = gts.keys()

        bleu_scorer = BleuScorer(n=self._n)
        for id in imgIds:
            hypo = res[id]
            ref = gts[id]

            # Sanity check.
            assert(type(hypo) is list)
            assert(len(hypo) == 1)
            assert(type(ref) is list)
            assert(len(ref) >= 1)

            bleu_scorer += (hypo[0], ref)

        #score, scores = bleu_scorer.compute_score(option='shortest')
        score, scores = bleu_scorer.compute_score(option='closest', verbose=verbose)
        # score, scores = bleu_scorer.compute_score(option='average', verbose=1)

        # return (bleu, bleu_info)
        return score, scores

    def method(self):
        return "Bleu"

# Rouge

In [None]:
def my_lcs(string, sub):
    if(len(string)< len(sub)):
        sub, string = string, sub

    lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)]

    for j in range(1,len(sub)+1):
        for i in range(1,len(string)+1):
            if(string[i-1] == sub[j-1]):
                lengths[i][j] = lengths[i-1][j-1] + 1
            else:
                lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1])

    return lengths[len(string)][len(sub)]

class Rouge():
    '''
    Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set

    '''
    def __init__(self):
        self.beta = 1.2

    def calc_score(self, candidate, refs):
        # assert(len(candidate)==1)
        # assert(len(refs)>0)
        prec = []
        rec = []

        # split into tokens
        token_c = candidate[0].split(" ")

        for reference in refs:
            # split into tokens
            token_r = reference.split(" ")
            # compute the longest common subsequence
            lcs = my_lcs(token_r, token_c)
            prec.append(lcs/float(len(token_c)))
            rec.append(lcs/float(len(token_r)))

        prec_max = max(prec)
        rec_max = max(rec)

        if(prec_max!=0 and rec_max !=0):
            score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max)
        else:
            score = 0.0
        return score

    def compute_score(self, gts, res):
        assert(gts.keys() == res.keys())
        imgIds = gts.keys()

        score = []
        for id in imgIds:
            hypo = res[id]
            ref  = gts[id]

            score.append(self.calc_score(hypo, ref))

            # Sanity check.
            assert(type(hypo) is list)
            assert(len(hypo) == 1)
            assert(type(ref) is list)
            assert(len(ref) > 0)

        average_score = np.mean(np.array(score))
        return average_score, np.array(score)

    def method(self):
        return "Rouge"

# CIDER

In [None]:
class CiderScorer(object):
    """CIDEr scorer.
    """

    def copy(self):
        ''' copy the refs.'''
        new = CiderScorer(n=self.n)
        new.ctest = copy.copy(self.ctest)
        new.crefs = copy.copy(self.crefs)
        return new

    def __init__(self, test=None, refs=None, n=4, sigma=6.0):
        ''' singular instance '''
        self.n = n
        self.sigma = sigma
        self.crefs = []
        self.ctest = []
        self.document_frequency = defaultdict(float)
        self.cook_append(test, refs)
        self.ref_len = None

    def precook(self, s, n=4, out=False):
      words = s.split()
      counts = defaultdict(int)
      for k in range(1,n+1):
          for i in range(len(words)-k+1):
              ngram = tuple(words[i:i+k])
              counts[ngram] += 1
      return counts

    def cook_refs(self, refs, n=4):
        return [self.precook(ref, n) for ref in refs]

    def cook_test(self, test, n=4):
        return self.precook(test, n, True)

    def cook_append(self, test, refs):
        '''called by constructor and __iadd__ to avoid creating new instances.'''

        if refs is not None:
            self.crefs.append(self.cook_refs(refs))
            if test is not None:
                self.ctest.append(self.cook_test(test)) ## N.B.: -1
            else:
                self.ctest.append(None) # lens of crefs and ctest have to match

    def size(self):
        assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
        return len(self.crefs)

    def __iadd__(self, other):
        '''add an instance (e.g., from another sentence).'''

        if type(other) is tuple:
            ## avoid creating new CiderScorer instances
            self.cook_append(other[0], other[1])
        else:
            self.ctest.extend(other.ctest)
            self.crefs.extend(other.crefs)

        return self
    def compute_doc_freq(self):
        for refs in self.crefs:
            # refs, k ref captions of one image
            for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]):
                self.document_frequency[ngram] += 1
            # maxcounts[ngram] = max(maxcounts.get(ngram,0), count)

    def compute_cider(self):
        def counts2vec(cnts):
            vec = [defaultdict(float) for _ in range(self.n)]
            length = 0
            norm = [0.0 for _ in range(self.n)]
            for (ngram, term_freq) in cnts.items():
                # give word count 1 if it doesn't appear in reference corpus
                df = np.log(max(1.0, self.document_frequency[ngram]))
                # ngram index
                n = len(ngram)-1
                # tf (term_freq) * idf (precomputed idf) for n-grams
                vec[n][ngram] = float(term_freq)*(self.ref_len - df)
                # compute norm for the vector.  the norm will be used for computing similarity
                norm[n] += pow(vec[n][ngram], 2)

                if n == 1:
                    length += term_freq
            norm = [np.sqrt(n) for n in norm]
            return vec, norm, length

        def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
            delta = float(length_hyp - length_ref)
            # measure consine similarity
            val = np.array([0.0 for _ in range(self.n)])
            for n in range(self.n):
                # ngram
                for (ngram,count) in vec_hyp[n].items():
                    # vrama91 : added clipping
                    val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram]

                if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
                    val[n] /= (norm_hyp[n]*norm_ref[n])

                assert(not math.isnan(val[n]))
                # vrama91: added a length based gaussian penalty
                val[n] *= np.e**(-(delta**2)/(2*self.sigma**2))
            return val

        # compute log reference length
        self.ref_len = np.log(float(len(self.crefs)))
        if len(self.crefs) == 1:
            self.ref_len = 1
        scores = []
        for test, refs in zip(self.ctest, self.crefs):
            # compute vector for test captions
            vec, norm, length = counts2vec(test)
            # compute vector for ref captions
            score = np.array([0.0 for _ in range(self.n)])
            for ref in refs:
                vec_ref, norm_ref, length_ref = counts2vec(ref)
                score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
            # change by vrama91 - mean of ngram scores, instead of sum
            score_avg = np.mean(score)
            # divide by number of references
            score_avg /= len(refs)
            # multiply score by 10
            score_avg *= 10.0
            # append score of an image to the score list
            scores.append(score_avg)
        return scores

    def compute_score(self, option=None, verbose=0):
        # compute idf
        self.compute_doc_freq()
        # assert to check document frequency
        assert(len(self.ctest) >= max(self.document_frequency.values()))
        # compute cider score
        score = self.compute_cider()
        # debug
        # print score
        return np.mean(np.array(score)), np.array(score)

In [None]:
class Cider:
    """
    Main Class to compute the CIDEr metric

    """
    def __init__(self, test=None, refs=None, n=4, sigma=6.0):
        # set cider to sum over 1 to 4-grams
        self._n = n
        # set the standard deviation parameter for gaussian penalty
        self._sigma = sigma

    def compute_score(self, gts, res):

        assert(gts.keys() == res.keys())
        imgIds = gts.keys()

        cider_scorer = CiderScorer(n=self._n, sigma=self._sigma)

        for id in imgIds:
            hypo = res[id]
            ref = gts[id]

            # Sanity check.
            assert(type(hypo) is list)
            assert(len(hypo) == 1)
            assert(type(ref) is list)
            assert(len(ref) > 0)

            cider_scorer += (hypo[0], ref)

        (score, scores) = cider_scorer.compute_score()

        return score, scores

    def method(self):
        return "CIDEr"

# Prompt Engineering

In [None]:
class R2GenGPT(pl.LightningModule):
    def __init__(self, model_name, vision_model, prompt_style='instruction'):
        super().__init__()

        # Add prompt_style parameter to control which prompt format to use
        self.prompt_style = prompt_style

        cache_dir = '/content/huggingface'

        print(f'Loading vision encoder: {vision_model}')
        self.visual_encoder = AutoModel.from_pretrained(vision_model, cache_dir=cache_dir)
        # self.visual_encoder.train()
        for name,param in self.visual_encoder.named_parameters():
            param.requires_grad = False
        self.visual_encoder.gradient_checkpointing_enable()
        print('Loading trainable vision encoder is Done')

        print('Loading LLAMA')
        # use low resources
        self.llama_model = AutoModelForCausalLM.from_pretrained(model_name,
                                                                torch_dtype=torch.float16,
                                                                device_map="cuda",
                                                                load_in_8bit=True,
                                                                cache_dir=cache_dir,
                                                                trust_remote_code=True)
        self.llama_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, cache_dir=cache_dir)
        # explicitly setting bos_token_id and eos_token_id if not already defined
        if self.llama_tokenizer.bos_token_id is None:
            self.llama_tokenizer.bos_token_id = self.llama_tokenizer.eos_token_id
            self.llama_tokenizer.eos_token_id = 2
        self.llama_tokenizer.pad_token_id = 0
        self.llama_model.generation_config.pad_token_id = self.llama_tokenizer.pad_token_id

        self.embed_tokens = self.llama_model.get_input_embeddings()
        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            r=16,
            lora_alpha=32,
            lora_dropout=0.1,
            # target_modules=["q_proj", "v_proj"]
            target_modules=["transformer.h.0.attn.c_attn", "transformer.h.0.attn.c_proj"]
        )
        self.llama_model = get_peft_model(self.llama_model, peft_config)
        self.llama_model.print_trainable_parameters()
        self.llama_model.gradient_checkpointing_enable()
        print('Loading LLAMA LoRA Done')

        self.llama_proj = nn.Linear(self.visual_encoder.num_features, self.llama_model.config.hidden_size)
        self.layer_norm = nn.LayerNorm(self.llama_model.config.hidden_size)
        self.end_sym = '</s>'

        # Set the prompt based on style
        self.set_prompt(prompt_style)

        self.val_step_outputs = []
        self.test_step_outputs = []
        self.val_score = 0.0

    def set_prompt(self, style):
        """
        Set different prompt styles for experiments

        Args:
            style (str): One of 'instruction', 'question', 'example', 'detailed', 'minimal'
        """
        if style == 'instruction':
            # Direct instruction style
            self.prompt = 'Generate a comprehensive and detailed diagnosis report for this chest xray image.'

        elif style == 'question':
            # Question-based style
            self.prompt = 'What abnormalities or findings can you observe in this chest xray image? Please provide a complete radiological report.'

        elif style == 'example':
            # Example-based style with few-shot learning pattern
            self.prompt = '''Here's how to analyze a chest xray:
1. Examine for lung opacities, nodules, or masses
2. Check heart size and mediastinal contours
3. Evaluate pleural spaces
4. Assess diaphragm and costophrenic angles
5. Look for skeletal abnormalities

Now, generate a comprehensive diagnosis report for this chest xray image.'''

        elif style == 'detailed':
            # Detailed instruction with specific report structure
            self.prompt = '''Analyze this chest xray and provide a detailed radiological report including:
- Assessment of lung fields (opacifications, nodules, masses)
- Cardiac silhouette evaluation (size, contour)
- Mediastinal structures assessment
- Pleural spaces examination
- Bony thorax evaluation
- Soft tissue analysis
- Final impression with diagnostic considerations'''

        elif style == 'minimal':
            # Minimal prompt
            self.prompt = 'Describe this chest xray.'

        else:
            # Default to instruction style
            self.prompt = 'Generate a comprehensive and detailed diagnosis report for this chest xray image.'

        print(f"Using prompt style: {style}")
        print(f"Prompt: {self.prompt}")

    def score(self, ref, hypo):
        scorers = [
            (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
            (Rouge(), "ROUGE_L"),
            (Cider(), "CIDEr")
        ]
        final_scores = {}
        for scorer, method in scorers:
            score, scores = scorer.compute_score(ref, hypo)
            if type(score) == list:
                for m, s in zip(method, score):
                    final_scores[m] = s
            else:
                final_scores[method] = score
        return final_scores


    def encode_img(self, images):
        image_embeds = []
        for image in images:
            device = image.device
            image_embed = self.visual_encoder(image)['last_hidden_state'].to(device)
            image_embeds.append(image_embed)

        image_embeds = torch.stack(image_embeds).mean(0)
        inputs_llama = self.llama_proj(image_embeds)
        atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
        return inputs_llama, atts_llama


    def prompt_wrap(self, img_embeds, atts_img):
        prompt=f'Human: <Img><ImageHere></Img> {self.prompt} \nAssistant:'
        batch_size = img_embeds.shape[0]
        p_before, p_after = prompt.split('<ImageHere>')
        p_before_tokens = self.llama_tokenizer(
            p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
        p_after_tokens = self.llama_tokenizer(
            p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
        p_before_embeds = self.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1)
        p_after_embeds = self.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1)
        wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds, p_after_embeds], dim=1)
        wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1])
        return wrapped_img_embeds, wrapped_atts_img


    def forward(self, samples):
        image = samples["image"]
        img_embeds, atts_img = self.encode_img(image)
        img_embeds = self.layer_norm(img_embeds)

        img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img)

        self.llama_tokenizer.padding_side = "right"
        text = [t + self.end_sym for t in samples["input_text"]]

        to_regress_tokens = self.llama_tokenizer(
            text,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=100,
            add_special_tokens=False
        ).to(image[0].device)

        targets = to_regress_tokens.input_ids.masked_fill(
            to_regress_tokens.input_ids == 0, -100
        )

        empty_targets = (
            torch.ones([atts_img.shape[0], atts_img.shape[1]+1],
                       dtype=torch.long).to(image[0].device).fill_(-100)  # plus one for bos
        )
        targets = torch.cat([empty_targets, targets], dim=1)

        batch_size = img_embeds.shape[0]
        bos = torch.ones([batch_size, 1],
                         dtype=to_regress_tokens.input_ids.dtype,
                         device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
        bos_embeds = self.embed_tokens(bos)
        atts_bos = atts_img[:, :1]

        to_regress_embeds = self.embed_tokens(to_regress_tokens.input_ids)
        inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1)
        attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1)

        outputs = self.llama_model(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            return_dict=True,
            labels=targets,
        )
        loss = outputs.loss
        return {"loss": loss}

    def training_step(self, batch, batch_idx):
        result = self(batch)
        self.log_dict(result, prog_bar=True)
        return result

    def save_checkpoint(self, eval_res):
        current_epoch, global_step = self.trainer.current_epoch, self.trainer.global_step
        param_grad_dic = {
            k: v.requires_grad for (k, v) in self.named_parameters() if v.requires_grad
        }
        state_dict = self.state_dict()
        for k in list(state_dict.keys()):
            if k not in param_grad_dic.keys():
                del state_dict[k]
        save_obj = {
            "model": state_dict,
            "epoch": current_epoch,
            "step": global_step,
            "prompt_style": self.prompt_style  # Save prompt style for reference
        }
        savedmodel_path = f'/content/savedmodels/{self.prompt_style}'  # Add prompt style to path
        os.makedirs(os.path.join(savedmodel_path, 'checkpoints'), exist_ok=True)
        save_to = os.path.join(
            savedmodel_path, 'checkpoints',
            f"checkpoint_{self.prompt_style}_epoch{current_epoch}_step{global_step}_bleu{eval_res['Bleu_4']:.3f}.pth"
        )
        self.print(f"Saving checkpoint at step {global_step} to {save_to}.")
        torch.save(save_obj, save_to)

    def validation_step(self, samples, batch_idx):
        self.llama_tokenizer.padding_side = "right"
        to_regress_tokens = self.llama_tokenizer(
            samples['input_text'],
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=100,
            add_special_tokens=False
        )

        image = samples["image"]
        img_embeds, atts_img = self.encode_img(image)
        img_embeds = self.layer_norm(img_embeds)
        img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img)

        batch_size = img_embeds.shape[0]
        bos = torch.ones([batch_size, 1],
                         dtype=atts_img.dtype,
                         device=atts_img.device) * self.llama_tokenizer.bos_token_id
        bos_embeds = self.embed_tokens(bos)
        atts_bos = atts_img[:, :1]

        inputs_embeds = torch.cat([bos_embeds, img_embeds], dim=1)
        attention_mask = torch.cat([atts_bos, atts_img], dim=1)

        outputs = self.llama_model.generate(
            inputs_embeds=inputs_embeds,
            num_beams=1,
            do_sample=False,
            min_new_tokens=80,
            max_new_tokens=120,
            repetition_penalty=2.0,
            length_penalty=2.0,
            temperature=0,
        )
        hypo = [self.decode(i) for i in outputs]
        ref = [self.decode(i) for i in to_regress_tokens['input_ids']]
        self.val_step_outputs.append({"hypo": hypo, "ref": ref, "id": samples["id"]})
        return hypo, ref

    def decode(self, output_token):
        if output_token[0] == 0:  # unknown token <unk> at the beginning. remove it
            output_token = output_token[1:]
        if output_token[0] == 1:  # start token <s> at the beginning. remove it
            output_token = output_token[1:]
        output_text = self.llama_tokenizer.decode(output_token, add_special_tokens=False)
        output_text = output_text.split('</s>')[0].strip()
        output_text = output_text.replace('<unk>', '')
        return output_text

    def on_validation_epoch_end(self):
        ref, hypo, ids = [], [], []
        for i in self.val_step_outputs:
            ref.extend(i['ref'])
            hypo.extend(i['hypo'])
            ids.extend(i['id'])

        ref = {k:[v] for k, v in zip(ids, ref)}
        hypo = {k:[v] for k, v in zip(ids, hypo)}
        eval_res = self.score(ref=ref,hypo=hypo)

        # Add prompt style to metrics for comparison
        self.log(f"prompt_style/{self.prompt_style}", 1, logger=True)
        self.log_dict(eval_res, sync_dist=True, logger=True)

        savedmodel_path = f'/content/savedmodels/{self.prompt_style}'
        result_folder = os.path.join(savedmodel_path, 'result')
        os.makedirs(result_folder, exist_ok=True)
        current_epoch, global_step = self.trainer.current_epoch, self.trainer.global_step

        # Save results with prompt style in filename
        json.dump(hypo, open(os.path.join(result_folder, f"result_{self.prompt_style}_{current_epoch}_{global_step}" + '.json'), 'w'))
        json.dump(ref, open(os.path.join(result_folder, 'refs.json'), 'w'))
        self.print(f"Prompt style: {self.prompt_style} - Eval results: {eval_res}")

        val_score = 0
        scorer_types = ['Bleu_4']
        weights = [1]
        for score_type, weight in zip(scorer_types, weights):
            val_score += eval_res[score_type] * weight

        if self.trainer.local_rank == 0:
            if val_score > self.val_score:
                self.save_checkpoint(eval_res)
                self.val_score = val_score
        self.val_step_outputs.clear()


    def test_step(self, samples, batch_idx):
        self.llama_tokenizer.padding_side = "right"
        to_regress_tokens = self.llama_tokenizer(
            samples['input_text'],
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=100,
            add_special_tokens=False
        )

        image = samples["image"]
        img_embeds, atts_img = self.encode_img(image)
        img_embeds = self.layer_norm(img_embeds)
        img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img)

        batch_size = img_embeds.shape[0]
        bos = torch.ones([batch_size, 1],
                         dtype=atts_img.dtype,
                         device=atts_img.device) * self.llama_tokenizer.bos_token_id
        bos_embeds = self.embed_tokens(bos)
        atts_bos = atts_img[:, :1]

        inputs_embeds = torch.cat([bos_embeds, img_embeds], dim=1)
        attention_mask = torch.cat([atts_bos, atts_img], dim=1)

        outputs = self.llama_model.generate(
            inputs_embeds=inputs_embeds,
            num_beams=3,
            do_sample=False,
            min_new_tokens=80,
            max_new_tokens=120,
            repetition_penalty=2.0,
            length_penalty=2.0,
            temperature=0,
        )
        hypo = [self.decode(i) for i in outputs]
        ref = [self.decode(i) for i in to_regress_tokens['input_ids']]
        self.test_step_outputs.append({"hypo": hypo, "ref": ref, "id": samples["id"]})
        return {"hypo": hypo, "ref": ref}


    def on_test_epoch_end(self):
        ref, hypo, ids = [], [], []
        for i in self.test_step_outputs:
            ref.extend(i['ref'])
            hypo.extend(i['hypo'])
            ids.extend(i['id'])

        ref = {k:[v] for k, v in zip(ids, ref)}
        hypo = {k:[v] for k, v in zip(ids, hypo)}
        eval_res = self.score(ref=ref,hypo=hypo)

        # Log metrics for the test set
        for key, value in eval_res.items():
            self.log(f"test_{key}", value, logger=True)

        savedmodel_path = f'/content/savedmodels/{self.prompt_style}'
        result_folder = os.path.join(savedmodel_path, 'result')
        os.makedirs(result_folder, exist_ok=True)
        json.dump(hypo, open(os.path.join(result_folder, f"test_result_{self.prompt_style}.json"), 'w'))
        json.dump(ref, open(os.path.join(result_folder, f"test_refs_{self.prompt_style}.json"), 'w'))
        self.print(f"Test result for {self.prompt_style} prompt: {eval_res}")

        # Clear the outputs list
        self.test_step_outputs.clear()

        # Return the evaluation results so they can be collected by the trainer
        return eval_res

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-5)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=3, eta_min=1e-6)
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

    def get_progress_bar_dict(self):
        # don't show the version number
        items = super().get_progress_bar_dict()
        items.pop("v_num", None)
        return items

    def optimizer_zero_grad(self, epoch, batch_idx, optimizer):
        optimizer.zero_grad()

# Prompt Styles

In [None]:
def main(args):
    # Define prompt styles to experiment with
    prompt_styles = ['instruction', 'question', 'example']

    # Create a results summary file
    os.makedirs('/content/prompt_engineering_results', exist_ok=True)
    summary_file = open('/content/prompt_engineering_results/summary.txt', 'w')
    summary_file.write("Prompt Engineering Results Summary\n")
    summary_file.write("================================\n\n")

    # Dictionary to store results for each prompt style
    results = {}

    # Initialize data module once
    data_module = DataModule()

    # Train and evaluate a model for each prompt style
    for prompt_style in prompt_styles:
        print(f"\n{'='*80}")
        print(f"Running experiment with prompt style: {prompt_style}")
        print(f"{'='*80}\n")

        # Configure logger and checkpoint callback for this prompt style
        logger = TensorBoardLogger(
            save_dir=f'/content/logs',
            name=f'prompt_style_{prompt_style}'
        )

        checkpoint_callback = ModelCheckpoint(
            dirpath=f'/content/savedmodels/{prompt_style}/checkpoints',
            filename=f'{prompt_style}_' + '{epoch:02d}-{val_Bleu_4:.4f}',
            monitor='Bleu_4',
            mode='max',
            save_top_k=1,
            verbose=True
        )

        # Initialize model with the current prompt style
        model = R2GenGPT(
            model_name=args['model_name'],
            vision_model=args['vision_model'],
            prompt_style=prompt_style
        )

        # Configure trainer
        trainer = pl.Trainer(
            max_epochs=args['max_epochs'],
            accelerator='gpu',
            devices=args['devices'],
            precision=16,
            callbacks=[checkpoint_callback],
            logger=logger,
            log_every_n_steps=10,
            gradient_clip_val=1.0
        )

        # Train the model
        trainer.fit(model, data_module)

        # Load the best checkpoint before testing
        if checkpoint_callback.best_model_path:
            print(f"Loading best model from {checkpoint_callback.best_model_path}")
            best_model = R2GenGPT.load_from_checkpoint(
                checkpoint_callback.best_model_path,
                model_name=args['model_name'],
                vision_model=args['vision_model'],
                prompt_style=prompt_style
            )
        else:
            best_model = model
            print("No checkpoint found, using current model for testing")

        # Test the model
        test_results = trainer.test(model, data_module)

        # Check if we have valid test results
        if not test_results or len(test_results) == 0:
            print(f"Warning: No test results returned for prompt style {prompt_style}")
            test_metrics = {'Bleu_4': 0.0, 'ROUGE_L': 0.0, 'CIDEr': 0.0}
        else:
            test_metrics = test_results[0]
            print(f"Test results for {prompt_style}: {test_metrics}")

        # Store results
        results[prompt_style] = {
            'test': test_metrics,
            'best_val_score': checkpoint_callback.best_model_score.item() if checkpoint_callback.best_model_score else 0
        }

        # Write results to summary file
        summary_file.write(f"Prompt Style: {prompt_style}\n")
        summary_file.write(f"Prompt: {model.prompt}\n")
        summary_file.write(f"Best Validation Bleu_4: {results[prompt_style]['best_val_score']:.4f}\n")
        summary_file.write(f"Test Results: {results[prompt_style]['test']}\n\n")
        summary_file.flush()

        # Clear memory between runs
        del model, trainer
        torch.cuda.empty_cache()

    # Final comparison
    summary_file.write("Comparative Results\n")
    summary_file.write("==================\n\n")
    summary_file.write("Prompt Style | Best Val Bleu_4 | Test Bleu_4 | Test ROUGE_L | Test CIDEr\n")
    summary_file.write("------------ | -------------- | ----------- | ------------ | ----------\n")

    for style in prompt_styles:
        val_score = results[style]['best_val_score']
        test_bleu = results[style]['test'].get('test_Bleu_4', results[style]['test'].get('Bleu_4', 0.0))
        test_rouge = results[style]['test'].get('test_ROUGE_L', results[style]['test'].get('ROUGE_L', 0.0))
        test_cider = results[style]['test'].get('test_CIDEr', results[style]['test'].get('CIDEr', 0.0))

        summary_file.write(f"{style.ljust(12)} | {val_score:.4f} | {test_bleu:.4f} | {test_rouge:.4f} | {test_cider:.4f}\n")

    # Find best performing prompt style
    best_style = max(results.keys(), key=lambda x: results[x]['test'].get('test_Bleu_4', results[x]['test'].get('Bleu_4', 0.0)))

    summary_file.write(f"\nBest performing prompt style: {best_style}\n")
    summary_file.write(f"Best prompt: {prompt_styles[prompt_styles.index(best_style)]}\n")

    summary_file.close()
    print(f"Prompt engineering experiments completed. Results saved to /content/prompt_engineering_results/summary.txt")


# Training

In [None]:
args = {
    "model_name": "cerebras/Cerebras-GPT-1.3B",
    "vision_model": "microsoft/swin-tiny-patch4-window7-224",
    "batch_size": 4,
    "num_workers": 4,
    "max_epochs": 5,
    "devices": 1,
    "data_dir": "/content/data"
}

main(args)


Running experiment with prompt style: instruction

Loading vision encoder: microsoft/swin-tiny-patch4-window7-224
Loading trainable vision encoder is Done
Loading LLAMA


The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


trainable params: 196,608 || all params: 1,315,919,872 || trainable%: 0.0149
Loading LLAMA LoRA Done
Using prompt style: instruction
Prompt: Generate a comprehensive and detailed diagnosis report for this chest xray image.


INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name           | Type                 | Params | Mode 
----------------------------------------------------------------
0 | visual_encoder | SwinModel            | 27.5 M | eval 
1 | llama_model    | PeftModelForCausalLM | 1.3 B  | train
2 | embed_tokens   | Embedding            | 102 M  | eval 
3 | llama_proj     | Linear               | 1.6 M  | train
4 | layer_norm     | LayerNorm            | 4.1 K  | train
----------------------------------------------------------------
1.8 M     Trainable params
1.3 B     Non-trainable params
1.3 B     Total params
5,380.073 Total estimated model params size (MB)
24        Modules in train mode
575       Modules in eval mode
INFO:lightning.pytorch.callbacks.model_summary:
  | Name           | Type                 | Params | Mode 
----------------------------------------------------------------
0 | visual_en

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Prompt style: instruction - Eval results: {'Bleu_1': 0.023411371237419046, 'Bleu_2': 2.0056351556443408e-10, 'Bleu_3': 4.14218169661345e-13, 'Bleu_4': 1.895968801149949e-14, 'ROUGE_L': np.float64(0.01307339139635925), 'CIDEr': np.float64(5.162277159450624e-05)}
Saving checkpoint at step 0 to /content/savedmodels/instruction/checkpoints/checkpoint_instruction_epoch0_step0_bleu0.000.pth.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 0, global step 517: 'Bleu_4' reached 0.06034 (best 0.06034), saving model to '/content/savedmodels/instruction/checkpoints/instruction_epoch=00-val_Bleu_4=0.0000-v3.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 0, global step 517: 'Bleu_4' reached 0.06034 (best 0.06034), saving model to '/content/savedmodels/instruction/checkpoints/instruction_epoch=00-val_Bleu_4=0.0000-v3.ckpt' as top 1


Prompt style: instruction - Eval results: {'Bleu_1': 0.29809782608692953, 'Bleu_2': 0.1585479034237781, 'Bleu_3': 0.0960237260081301, 'Bleu_4': 0.06034022289729853, 'ROUGE_L': np.float64(0.2158701012374544), 'CIDEr': np.float64(0.040395933273930054)}
Saving checkpoint at step 517 to /content/savedmodels/instruction/checkpoints/checkpoint_instruction_epoch0_step517_bleu0.060.pth.


Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 1, global step 1034: 'Bleu_4' reached 0.07095 (best 0.07095), saving model to '/content/savedmodels/instruction/checkpoints/instruction_epoch=01-val_Bleu_4=0.0000-v1.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 1, global step 1034: 'Bleu_4' reached 0.07095 (best 0.07095), saving model to '/content/savedmodels/instruction/checkpoints/instruction_epoch=01-val_Bleu_4=0.0000-v1.ckpt' as top 1


Prompt style: instruction - Eval results: {'Bleu_1': 0.3360312587217103, 'Bleu_2': 0.18116831088604837, 'Bleu_3': 0.10956314136616688, 'Bleu_4': 0.07095067268068825, 'ROUGE_L': np.float64(0.23396270893907792), 'CIDEr': np.float64(0.07060717834125374)}
Saving checkpoint at step 1034 to /content/savedmodels/instruction/checkpoints/checkpoint_instruction_epoch1_step1034_bleu0.071.pth.


Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 2, global step 1551: 'Bleu_4' reached 0.07489 (best 0.07489), saving model to '/content/savedmodels/instruction/checkpoints/instruction_epoch=02-val_Bleu_4=0.0000-v1.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 2, global step 1551: 'Bleu_4' reached 0.07489 (best 0.07489), saving model to '/content/savedmodels/instruction/checkpoints/instruction_epoch=02-val_Bleu_4=0.0000-v1.ckpt' as top 1


Prompt style: instruction - Eval results: {'Bleu_1': 0.33736108661136555, 'Bleu_2': 0.1866276758639535, 'Bleu_3': 0.11538814729971218, 'Bleu_4': 0.07489369727649552, 'ROUGE_L': np.float64(0.2398458787360321), 'CIDEr': np.float64(0.07447690096551514)}
Saving checkpoint at step 1551 to /content/savedmodels/instruction/checkpoints/checkpoint_instruction_epoch2_step1551_bleu0.075.pth.


Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 3, global step 2068: 'Bleu_4' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 3, global step 2068: 'Bleu_4' was not in top 1


Prompt style: instruction - Eval results: {'Bleu_1': 0.33103795767548444, 'Bleu_2': 0.17884706268733103, 'Bleu_3': 0.10898800951383718, 'Bleu_4': 0.07022329904352385, 'ROUGE_L': np.float64(0.23851785899154146), 'CIDEr': np.float64(0.07161455086676079)}


Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 4, global step 2585: 'Bleu_4' reached 0.08057 (best 0.08057), saving model to '/content/savedmodels/instruction/checkpoints/instruction_epoch=04-val_Bleu_4=0.0000.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 4, global step 2585: 'Bleu_4' reached 0.08057 (best 0.08057), saving model to '/content/savedmodels/instruction/checkpoints/instruction_epoch=04-val_Bleu_4=0.0000.ckpt' as top 1


Prompt style: instruction - Eval results: {'Bleu_1': 0.34203428173667955, 'Bleu_2': 0.19265137739340737, 'Bleu_3': 0.12120667250806899, 'Bleu_4': 0.08057228099736626, 'ROUGE_L': np.float64(0.24464444911610903), 'CIDEr': np.float64(0.08350405603245456)}
Saving checkpoint at step 2585 to /content/savedmodels/instruction/checkpoints/checkpoint_instruction_epoch4_step2585_bleu0.081.pth.


INFO: `Trainer.fit` stopped: `max_epochs=5` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.


Loading best model from /content/savedmodels/instruction/checkpoints/instruction_epoch=04-val_Bleu_4=0.0000.ckpt
Loading vision encoder: microsoft/swin-tiny-patch4-window7-224
Loading trainable vision encoder is Done
Loading LLAMA


The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


trainable params: 196,608 || all params: 1,315,919,872 || trainable%: 0.0149
Loading LLAMA LoRA Done
Using prompt style: instruction
Prompt: Generate a comprehensive and detailed diagnosis report for this chest xray image.


INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

Test result for instruction prompt: {'Bleu_1': 4.712717577319185e-32, 'Bleu_2': 3.619907057996158e-26, 'Bleu_3': 3.315172175416523e-24, 'Bleu_4': 3.1725642272616665e-23, 'ROUGE_L': np.float64(0.0), 'CIDEr': np.float64(0.0)}


Test results for instruction: {'test_Bleu_1': 4.7127176012990256e-32, 'test_Bleu_2': 3.619907049248295e-26, 'test_Bleu_3': 3.315172176973198e-24, 'test_Bleu_4': 3.172564111337878e-23, 'test_ROUGE_L': 0.0, 'test_CIDEr': 0.0}

Running experiment with prompt style: question

Loading vision encoder: microsoft/swin-tiny-patch4-window7-224
Loading trainable vision encoder is Done
Loading LLAMA


The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


trainable params: 196,608 || all params: 1,315,919,872 || trainable%: 0.0149
Loading LLAMA LoRA Done
Using prompt style: question
Prompt: What abnormalities or findings can you observe in this chest xray image? Please provide a complete radiological report.


/usr/local/lib/python3.11/dist-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /content/savedmodels/question/checkpoints exists and is not empty.
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name           | Type                 | Params | Mode 
----------------------------------------------------------------
0 | visual_encoder | SwinModel            | 27.5 M | eval 
1 | llama_model    | PeftModelForCausalLM | 1.3 B  | train
2 | embed_tokens   | Embedding            | 102 M  | eval 
3 | llama_proj     | Linear               | 1.6 M  | train
4 | layer_norm     | LayerNorm            | 4.1 K  | train
----------------------------------------------------------------
1.8 M     Trainable params
1.3 B     Non-trainable params
1.3 B     Total params
5,380.073 Total estimated model params size (MB)
24        Modules in train mode
575       Modules in eval mode
INFO

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Prompt style: question - Eval results: {'Bleu_1': 0.08163265306113193, 'Bleu_2': 0.016816423539901128, 'Bleu_3': 6.929194990896017e-08, 'Bleu_4': 1.41325514244261e-10, 'ROUGE_L': np.float64(0.06125415483963764), 'CIDEr': np.float64(0.015554587672263188)}
Saving checkpoint at step 0 to /content/savedmodels/question/checkpoints/checkpoint_question_epoch0_step0_bleu0.000.pth.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 0, global step 517: 'Bleu_4' reached 0.06240 (best 0.06240), saving model to '/content/savedmodels/question/checkpoints/question_epoch=00-val_Bleu_4=0.0000.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 0, global step 517: 'Bleu_4' reached 0.06240 (best 0.06240), saving model to '/content/savedmodels/question/checkpoints/question_epoch=00-val_Bleu_4=0.0000.ckpt' as top 1


Prompt style: question - Eval results: {'Bleu_1': 0.2943845231980573, 'Bleu_2': 0.15852765290425416, 'Bleu_3': 0.09669792087227617, 'Bleu_4': 0.062400330874393325, 'ROUGE_L': np.float64(0.2211873080480498), 'CIDEr': np.float64(0.05074366569966207)}
Saving checkpoint at step 517 to /content/savedmodels/question/checkpoints/checkpoint_question_epoch0_step517_bleu0.062.pth.


Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 1, global step 1034: 'Bleu_4' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 1, global step 1034: 'Bleu_4' was not in top 1


Prompt style: question - Eval results: {'Bleu_1': 0.2838886353263134, 'Bleu_2': 0.14751648481287913, 'Bleu_3': 0.08806399890643885, 'Bleu_4': 0.05555752945804105, 'ROUGE_L': np.float64(0.21759926093187423), 'CIDEr': np.float64(0.05159731107768328)}


Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 2, global step 1551: 'Bleu_4' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 2, global step 1551: 'Bleu_4' was not in top 1


Prompt style: question - Eval results: {'Bleu_1': 0.2832165252674081, 'Bleu_2': 0.14803949337928504, 'Bleu_3': 0.08806641986695954, 'Bleu_4': 0.056275265098527524, 'ROUGE_L': np.float64(0.21743482086391538), 'CIDEr': np.float64(0.049989254673936896)}


Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 3, global step 2068: 'Bleu_4' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 3, global step 2068: 'Bleu_4' was not in top 1


Prompt style: question - Eval results: {'Bleu_1': 0.28765580417477943, 'Bleu_2': 0.15062010534212422, 'Bleu_3': 0.09058786204296385, 'Bleu_4': 0.058808787354006756, 'ROUGE_L': np.float64(0.2181165985315501), 'CIDEr': np.float64(0.04694540073513178)}


Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 4, global step 2585: 'Bleu_4' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 4, global step 2585: 'Bleu_4' was not in top 1
INFO: `Trainer.fit` stopped: `max_epochs=5` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.


Prompt style: question - Eval results: {'Bleu_1': 0.3061975209915789, 'Bleu_2': 0.15899777772433882, 'Bleu_3': 0.09242042661853213, 'Bleu_4': 0.057778124631584084, 'ROUGE_L': np.float64(0.22778706493998863), 'CIDEr': np.float64(0.05642261180725288)}
Loading best model from /content/savedmodels/question/checkpoints/question_epoch=00-val_Bleu_4=0.0000.ckpt
Loading vision encoder: microsoft/swin-tiny-patch4-window7-224
Loading trainable vision encoder is Done
Loading LLAMA


The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


trainable params: 196,608 || all params: 1,315,919,872 || trainable%: 0.0149
Loading LLAMA LoRA Done
Using prompt style: question
Prompt: What abnormalities or findings can you observe in this chest xray image? Please provide a complete radiological report.


INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

Test result for question prompt: {'Bleu_1': 4.712717577319185e-32, 'Bleu_2': 3.619907057996158e-26, 'Bleu_3': 3.315172175416523e-24, 'Bleu_4': 3.1725642272616665e-23, 'ROUGE_L': np.float64(0.0), 'CIDEr': np.float64(0.0)}


Test results for question: {'test_Bleu_1': 4.7127176012990256e-32, 'test_Bleu_2': 3.619907049248295e-26, 'test_Bleu_3': 3.315172176973198e-24, 'test_Bleu_4': 3.172564111337878e-23, 'test_ROUGE_L': 0.0, 'test_CIDEr': 0.0}

Running experiment with prompt style: example

Loading vision encoder: microsoft/swin-tiny-patch4-window7-224
Loading trainable vision encoder is Done
Loading LLAMA


The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


trainable params: 196,608 || all params: 1,315,919,872 || trainable%: 0.0149
Loading LLAMA LoRA Done
Using prompt style: example
Prompt: Here's how to analyze a chest xray:
1. Examine for lung opacities, nodules, or masses
2. Check heart size and mediastinal contours
3. Evaluate pleural spaces
4. Assess diaphragm and costophrenic angles
5. Look for skeletal abnormalities

Now, generate a comprehensive diagnosis report for this chest xray image.


/usr/local/lib/python3.11/dist-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /content/savedmodels/example/checkpoints exists and is not empty.
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name           | Type                 | Params | Mode 
----------------------------------------------------------------
0 | visual_encoder | SwinModel            | 27.5 M | eval 
1 | llama_model    | PeftModelForCausalLM | 1.3 B  | train
2 | embed_tokens   | Embedding            | 102 M  | eval 
3 | llama_proj     | Linear               | 1.6 M  | train
4 | layer_norm     | LayerNorm            | 4.1 K  | train
----------------------------------------------------------------
1.8 M     Trainable params
1.3 B     Non-trainable params
1.3 B     Total params
5,380.073 Total estimated model params size (MB)
24        Modules in train mode
575       Modules in eval mode
INFO:

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Prompt style: example - Eval results: {'Bleu_1': 0.009333888811392101, 'Bleu_2': 9.670352497857946e-11, 'Bleu_3': 2.1624060131547566e-13, 'Bleu_4': 1.0438649473223255e-14, 'ROUGE_L': np.float64(0.022685507024184454), 'CIDEr': np.float64(0.00032217782276443903)}
Saving checkpoint at step 0 to /content/savedmodels/example/checkpoints/checkpoint_example_epoch0_step0_bleu0.000.pth.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 0, global step 517: 'Bleu_4' reached 0.04806 (best 0.04806), saving model to '/content/savedmodels/example/checkpoints/example_epoch=00-val_Bleu_4=0.0000.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 0, global step 517: 'Bleu_4' reached 0.04806 (best 0.04806), saving model to '/content/savedmodels/example/checkpoints/example_epoch=00-val_Bleu_4=0.0000.ckpt' as top 1


Prompt style: example - Eval results: {'Bleu_1': 0.2597297694228037, 'Bleu_2': 0.13421842428893568, 'Bleu_3': 0.07730125680173716, 'Bleu_4': 0.04805760782045615, 'ROUGE_L': np.float64(0.2058875289878237), 'CIDEr': np.float64(0.04261161241879104)}
Saving checkpoint at step 517 to /content/savedmodels/example/checkpoints/checkpoint_example_epoch0_step517_bleu0.048.pth.


Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 1, global step 1034: 'Bleu_4' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 1, global step 1034: 'Bleu_4' was not in top 1


Prompt style: example - Eval results: {'Bleu_1': 0.2166259666259556, 'Bleu_2': 0.11003129154130434, 'Bleu_3': 0.06606296805768401, 'Bleu_4': 0.04197125011757748, 'ROUGE_L': np.float64(0.18186844932592292), 'CIDEr': np.float64(0.023020847851664594)}


Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 2, global step 1551: 'Bleu_4' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 2, global step 1551: 'Bleu_4' was not in top 1


Prompt style: example - Eval results: {'Bleu_1': 0.2306251338043122, 'Bleu_2': 0.11740454679325346, 'Bleu_3': 0.06972915809174128, 'Bleu_4': 0.044070351507260616, 'ROUGE_L': np.float64(0.18656452246783845), 'CIDEr': np.float64(0.027502656031379497)}


Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 3, global step 2068: 'Bleu_4' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 3, global step 2068: 'Bleu_4' was not in top 1


Prompt style: example - Eval results: {'Bleu_1': 0.22251069028709924, 'Bleu_2': 0.1113956221772691, 'Bleu_3': 0.06532505956350836, 'Bleu_4': 0.04099333059837074, 'ROUGE_L': np.float64(0.1810531211270015), 'CIDEr': np.float64(0.03174554781839384)}


Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 4, global step 2585: 'Bleu_4' was not in top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 4, global step 2585: 'Bleu_4' was not in top 1
INFO: `Trainer.fit` stopped: `max_epochs=5` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.


Prompt style: example - Eval results: {'Bleu_1': 0.22863717184393742, 'Bleu_2': 0.11362163248071061, 'Bleu_3': 0.06638938418331018, 'Bleu_4': 0.04205931017047892, 'ROUGE_L': np.float64(0.18306094789558247), 'CIDEr': np.float64(0.021522010134749376)}
Loading best model from /content/savedmodels/example/checkpoints/example_epoch=00-val_Bleu_4=0.0000.ckpt
Loading vision encoder: microsoft/swin-tiny-patch4-window7-224
Loading trainable vision encoder is Done
Loading LLAMA


The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


trainable params: 196,608 || all params: 1,315,919,872 || trainable%: 0.0149
Loading LLAMA LoRA Done
Using prompt style: example
Prompt: Here's how to analyze a chest xray:
1. Examine for lung opacities, nodules, or masses
2. Check heart size and mediastinal contours
3. Evaluate pleural spaces
4. Assess diaphragm and costophrenic angles
5. Look for skeletal abnormalities

Now, generate a comprehensive diagnosis report for this chest xray image.


INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

Test result for example prompt: {'Bleu_1': 4.712717577319185e-32, 'Bleu_2': 3.619907057996158e-26, 'Bleu_3': 3.315172175416523e-24, 'Bleu_4': 3.1725642272616665e-23, 'ROUGE_L': np.float64(0.0), 'CIDEr': np.float64(0.0)}


Test results for example: {'test_Bleu_1': 4.7127176012990256e-32, 'test_Bleu_2': 3.619907049248295e-26, 'test_Bleu_3': 3.315172176973198e-24, 'test_Bleu_4': 3.172564111337878e-23, 'test_ROUGE_L': 0.0, 'test_CIDEr': 0.0}
Prompt engineering experiments completed. Results saved to /content/prompt_engineering_results/summary.txt


In [None]:
!tensorboard --logdir=/content/logs/prompt_style_example/version_0/events.out.tfevents

2025-05-05 22:32:05.063707: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746484325.084621   69479 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746484325.091115   69479 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.18.0 at http://localhost:6006/ (Press CTRL+C to quit)
