# References



Hugging Face - Documentation [online]. Huggingface.co. Available from: https://huggingface.co/docs.

PyTorch Foundation [online]. PyTorch. Available from: https://pytorch.org/.

zhanyuwang, n.d. R2GenGPT: Radiology Report Generation with Frozen LLMs. Available from: https://github.com/wang-zhanyu/R2GenGPT

# <i> Huggingface login </i>

To work with gated repositores, we need to login to huggingface hub

In [None]:
from huggingface_hub import notebook_login
from google.colab import userdata

notebook_login(userdata.get('HF_TOKEN'))



VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

# Importing demanded libraries

In [None]:
!pip install lightning

In [None]:
!pip install -U bitsandbytes

In [None]:
import os
import copy
import math
import json
import re
import yaml
import numpy as np
import functools
import pandas as pd
from PIL import Image
from collections import defaultdict

import torch
import torch.nn as nn
import torch.utils.data as data
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import LambdaLR
import torch.optim as optim

from peft import get_peft_model, LoraConfig, TaskType
from transformers import (
    AutoImageProcessor,
    SwinModel,
    AutoModel,
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoModelForImageClassification,
    Trainer,
    TrainingArguments
)

import lightning.pytorch as pl
from lightning.pytorch import LightningDataModule, seed_everything
from lightning.pytorch import loggers as pl_loggers
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint, EarlyStopping

# Data helper

In [None]:
class FieldParser:
    def __init__(self):
        super().__init__()
        cache_dir = '/content/huggingface'
        self.vit_feature_extractor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50", cache_dir=cache_dir)

    def _parse_image(self, img):
        img = np.resize(img, (224, 224, 3))
        pixel_values = self.vit_feature_extractor(img, return_tensors="pt").pixel_values
        return pixel_values[0]

    def clean_report(self, report):
        report_cleaner = lambda t: t.replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '') \
        .replace('. 2. ', '. ').replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ') \
        .replace(' 2. ', '. ').replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \
        .strip().lower().split('. ')
        sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '').
                                        replace('\\', '').replace("'", '').strip().lower())
        tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []]
        report = ' . '.join(tokens) + ' .'
        # report = ' '.join(report.split()[:max_txt_len])
        return report

    def parse(self, features):
        to_return = {'id': features['id']}
        report = features.get("report", "")
        report = self.clean_report(report)
        to_return['input_text'] = report
        # chest x-ray images
        images = []
        for image_path in features['image_path']:
            with Image.open(os.path.join('/content/drive/MyDrive/iu_xray/images', image_path)) as pil:
                array = np.array(pil, dtype=np.uint8)
                if array.shape[-1] != 3 or len(array.shape) != 3:
                    array = np.array(pil.convert("RGB"), dtype=np.uint8)
                image = self._parse_image(array)
                images.append(image)
        to_return["image"] = images
        return to_return

    def transform_with_parse(self, inputs):
        return self.parse(inputs)

# Data Module

In [None]:
class ParseDataset(data.Dataset):
    def __init__(self, split='train'):
        filename = '/content/drive/MyDrive/iu_xray/annotation.json'
        self.meta = json.load(open(filename, 'r'))
        self.meta = self.meta[split]
        self.parser = FieldParser()

    def __len__(self):
        return len(self.meta)

    def __getitem__(self, index):
        return self.parser.transform_with_parse(self.meta[index])

In [None]:
def create_datasets():
    train_dataset = ParseDataset('train')
    dev_dataset = ParseDataset('val')
    test_dataset = ParseDataset('test')
    return train_dataset, dev_dataset, test_dataset

In [None]:
class DataModule(LightningDataModule):
    def __init__(self):
        super().__init__()

    def setup(self, stage: str):
        train_dataset, dev_dataset, test_dataset = create_datasets()
        self.dataset = {
            "train": train_dataset, "validation": dev_dataset, "test": test_dataset
        }


    def train_dataloader(self):
        loader = DataLoader(self.dataset["train"], batch_size=4, drop_last=True, pin_memory=True,
                        num_workers=4, prefetch_factor=4)
        return loader


    def val_dataloader(self):
        loader = DataLoader(self.dataset["validation"], batch_size=8, drop_last=False, pin_memory=True,
                            num_workers=4, prefetch_factor=4)
        return loader


    def test_dataloader(self):
        loader = DataLoader(self.dataset["test"], batch_size=8, drop_last=False, pin_memory=False,
                        num_workers=4, prefetch_factor=4)
        return loader

# Callbacks

In [None]:
def add_callbacks():
    log_dir = "savedmodels"
    os.makedirs(log_dir, exist_ok=True)
    checkpoint_callback = ModelCheckpoint(
        dirpath=os.path.join(log_dir, "checkpoints"),
        filename="{epoch}-{step}",
        save_top_k=1,
        every_n_train_steps=10,
        save_last=False,
        save_weights_only=False
    )

    lr_monitor_callback = LearningRateMonitor(logging_interval='step')
    tb_logger = pl_loggers.TensorBoardLogger(save_dir=os.path.join(log_dir, "logs"), name="tensorboard")
    csv_logger = CSVLogger(save_dir=os.path.join(log_dir, "logs"), name="csvlog")

    to_returns = {
        "callbacks": [checkpoint_callback, lr_monitor_callback],
        "loggers": [csv_logger, tb_logger]
    }
    return to_returns

# Optimizers

In [None]:
def lr_lambda(current_step, num_warmup_steps, num_training_steps):
    if current_step < num_warmup_steps:
        return float(current_step) / float(max(1, num_warmup_steps))
    return max(
        0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
    )

def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):

    return LambdaLR(optimizer,
                    functools.partial(
                        lr_lambda,
                        num_warmup_steps=num_warmup_steps,
                        num_training_steps=num_training_steps),
                    last_epoch)


def config_optimizer(parameters, init_lr, warmup_steps, max_steps, name='lr'):
    optimizer = optim.AdamW(
        parameters, lr=init_lr, eps=1e-8, correct_bias=False
    )

    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps,
    )
    scheduler = {'scheduler': scheduler, 'name': name, 'interval': 'step', 'frequency': 1}

    return optimizer, scheduler

# BLEU

In [None]:
def precook(s, n=4, out=False):
    words = s.split()
    counts = defaultdict(int)
    for k in range(1,n+1):
        for i in range(len(words)-k+1):
            ngram = tuple(words[i:i+k])
            counts[ngram] += 1
    return (len(words), counts)

def cook_refs(refs, eff=None, n=4):
    reflen = []
    maxcounts = {}
    for ref in refs:
        rl, counts = precook(ref, n)
        reflen.append(rl)
        for (ngram,count) in counts.items():
            maxcounts[ngram] = max(maxcounts.get(ngram,0), count)

    # Calculate effective reference sentence length.
    if eff == "shortest":
        reflen = min(reflen)
    elif eff == "average":
        reflen = float(sum(reflen))/len(reflen)

    return (reflen, maxcounts)

def cook_test(test, crefs, eff=None, n=4):
    reflen, refmaxcounts = crefs[0], crefs[1]

    testlen, counts = precook(test, n, True)

    result = {}

    # Calculate effective reference sentence length.

    if eff == "closest":
        result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1]
    else: ## i.e., "average" or "shortest" or None
        result["reflen"] = reflen

    result["testlen"] = testlen

    result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)]

    result['correct'] = [0]*n
    for (ngram, count) in counts.items():
        result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count)

    return result

class BleuScorer(object):

    __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen"
    # special_reflen is used in oracle (proportional effective ref len for a node).

    def copy(self):
        new = BleuScorer(n=self.n)
        new.ctest = copy.copy(self.ctest)
        new.crefs = copy.copy(self.crefs)
        new._score = None
        return new

    def __init__(self, test=None, refs=None, n=4, special_reflen=None):
        self.n = n
        self.crefs = []
        self.ctest = []
        self.cook_append(test, refs)
        self.special_reflen = special_reflen

    def cook_append(self, test, refs):
        if refs is not None:
            self.crefs.append(cook_refs(refs))
            if test is not None:
                cooked_test = cook_test(test, self.crefs[-1])
                self.ctest.append(cooked_test) ## N.B.: -1
            else:
                self.ctest.append(None) # lens of crefs and ctest have to match

        self._score = None ## need to recompute

    def ratio(self, option=None):
        self.compute_score(option=option)
        return self._ratio

    def score_ratio(self, option=None):
        return (self.fscore(option=option), self.ratio(option=option))

    def score_ratio_str(self, option=None):
        return "%.4f (%.2f)" % self.score_ratio(option)

    def reflen(self, option=None):
        self.compute_score(option=option)
        return self._reflen

    def testlen(self, option=None):
        self.compute_score(option=option)
        return self._testlen

    def retest(self, new_test):
        if type(new_test) is str:
            new_test = [new_test]
        assert len(new_test) == len(self.crefs), new_test
        self.ctest = []
        for t, rs in zip(new_test, self.crefs):
            self.ctest.append(cook_test(t, rs))
        self._score = None

        return self

    def rescore(self, new_test):
        return self.retest(new_test).compute_score()

    def size(self):
        assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
        return len(self.crefs)

    def __iadd__(self, other):
        if type(other) is tuple:
            ## avoid creating new BleuScorer instances
            self.cook_append(other[0], other[1])
        else:
            assert self.compatible(other), "incompatible BLEUs."
            self.ctest.extend(other.ctest)
            self.crefs.extend(other.crefs)
            self._score = None ## need to recompute

        return self

    def compatible(self, other):
        return isinstance(other, BleuScorer) and self.n == other.n

    def single_reflen(self, option="average"):
        return self._single_reflen(self.crefs[0][0], option)

    def _single_reflen(self, reflens, option=None, testlen=None):

        if option == "shortest":
            reflen = min(reflens)
        elif option == "average":
            reflen = float(sum(reflens))/len(reflens)
        elif option == "closest":
            reflen = min((abs(l-testlen), l) for l in reflens)[1]
        else:
            assert False, "unsupported reflen option %s" % option

        return reflen

    def recompute_score(self, option=None, verbose=0):
        self._score = None
        return self.compute_score(option, verbose)

    def compute_score(self, option=None, verbose=0):
        n = self.n
        small = 1e-9
        tiny = 1e-15 ## so that if guess is 0 still return 0
        bleu_list = [[] for _ in range(n)]

        if self._score is not None:
            return self._score

        if option is None:
            option = "average" if len(self.crefs) == 1 else "closest"

        self._testlen = 0
        self._reflen = 0
        totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n}

        # for each sentence
        for comps in self.ctest:
            testlen = comps['testlen']
            self._testlen += testlen

            if self.special_reflen is None: ## need computation
                reflen = self._single_reflen(comps['reflen'], option, testlen)
            else:
                reflen = self.special_reflen

            self._reflen += reflen

            for key in ['guess','correct']:
                for k in range(n):
                    totalcomps[key][k] += comps[key][k]

            # append per image bleu score
            bleu = 1.
            for k in range(n):
                bleu *= (float(comps['correct'][k]) + tiny) \
                        /(float(comps['guess'][k]) + small)
                bleu_list[k].append(bleu ** (1./(k+1)))
            ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division
            if ratio < 1:
                for k in range(n):
                    bleu_list[k][-1] *= math.exp(1 - 1/ratio)

            if verbose > 1:
                print(comps, reflen)

        totalcomps['reflen'] = self._reflen
        totalcomps['testlen'] = self._testlen

        bleus = []
        bleu = 1.
        for k in range(n):
            bleu *= float(totalcomps['correct'][k] + tiny) \
                    / (totalcomps['guess'][k] + small)
            bleus.append(bleu ** (1./(k+1)))
        ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division
        if ratio < 1:
            for k in range(n):
                bleus[k] *= math.exp(1 - 1/ratio)

        if verbose > 0:
            print(totalcomps)
            print("ratio:", ratio)

        self._score = bleus
        return self._score, bleu_list

In [None]:
class Bleu:
    def __init__(self, n=4):
        # default compute Blue score up to 4
        self._n = n
        self._hypo_for_image = {}
        self.ref_for_image = {}

    def compute_score(self, gts, res, verbose=0):

        assert(gts.keys() == res.keys())
        imgIds = gts.keys()

        bleu_scorer = BleuScorer(n=self._n)
        for id in imgIds:
            hypo = res[id]
            ref = gts[id]

            # Sanity check.
            assert(type(hypo) is list)
            assert(len(hypo) == 1)
            assert(type(ref) is list)
            assert(len(ref) >= 1)

            bleu_scorer += (hypo[0], ref)

        #score, scores = bleu_scorer.compute_score(option='shortest')
        score, scores = bleu_scorer.compute_score(option='closest', verbose=verbose)
        # score, scores = bleu_scorer.compute_score(option='average', verbose=1)

        # return (bleu, bleu_info)
        return score, scores

    def method(self):
        return "Bleu"

# Rouge

In [None]:
def my_lcs(string, sub):
    if(len(string)< len(sub)):
        sub, string = string, sub

    lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)]

    for j in range(1,len(sub)+1):
        for i in range(1,len(string)+1):
            if(string[i-1] == sub[j-1]):
                lengths[i][j] = lengths[i-1][j-1] + 1
            else:
                lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1])

    return lengths[len(string)][len(sub)]

class Rouge():
    '''
    Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set

    '''
    def __init__(self):
        self.beta = 1.2

    def calc_score(self, candidate, refs):
        # assert(len(candidate)==1)
        # assert(len(refs)>0)
        prec = []
        rec = []

        # split into tokens
        token_c = candidate[0].split(" ")

        for reference in refs:
            # split into tokens
            token_r = reference.split(" ")
            # compute the longest common subsequence
            lcs = my_lcs(token_r, token_c)
            prec.append(lcs/float(len(token_c)))
            rec.append(lcs/float(len(token_r)))

        prec_max = max(prec)
        rec_max = max(rec)

        if(prec_max!=0 and rec_max !=0):
            score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max)
        else:
            score = 0.0
        return score

    def compute_score(self, gts, res):
        assert(gts.keys() == res.keys())
        imgIds = gts.keys()

        score = []
        for id in imgIds:
            hypo = res[id]
            ref  = gts[id]

            score.append(self.calc_score(hypo, ref))

            # Sanity check.
            assert(type(hypo) is list)
            assert(len(hypo) == 1)
            assert(type(ref) is list)
            assert(len(ref) > 0)

        average_score = np.mean(np.array(score))
        return average_score, np.array(score)

    def method(self):
        return "Rouge"

# CIDER

In [None]:
class CiderScorer(object):
    """CIDEr scorer.
    """

    def copy(self):
        ''' copy the refs.'''
        new = CiderScorer(n=self.n)
        new.ctest = copy.copy(self.ctest)
        new.crefs = copy.copy(self.crefs)
        return new

    def __init__(self, test=None, refs=None, n=4, sigma=6.0):
        ''' singular instance '''
        self.n = n
        self.sigma = sigma
        self.crefs = []
        self.ctest = []
        self.document_frequency = defaultdict(float)
        self.cook_append(test, refs)
        self.ref_len = None

    def precook(self, s, n=4, out=False):
      words = s.split()
      counts = defaultdict(int)
      for k in range(1,n+1):
          for i in range(len(words)-k+1):
              ngram = tuple(words[i:i+k])
              counts[ngram] += 1
      return counts

    def cook_refs(self, refs, n=4):
        return [self.precook(ref, n) for ref in refs]

    def cook_test(self, test, n=4):
        return self.precook(test, n, True)

    def cook_append(self, test, refs):
        '''called by constructor and __iadd__ to avoid creating new instances.'''

        if refs is not None:
            self.crefs.append(self.cook_refs(refs))
            if test is not None:
                self.ctest.append(self.cook_test(test)) ## N.B.: -1
            else:
                self.ctest.append(None) # lens of crefs and ctest have to match

    def size(self):
        assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
        return len(self.crefs)

    def __iadd__(self, other):
        '''add an instance (e.g., from another sentence).'''

        if type(other) is tuple:
            ## avoid creating new CiderScorer instances
            self.cook_append(other[0], other[1])
        else:
            self.ctest.extend(other.ctest)
            self.crefs.extend(other.crefs)

        return self
    def compute_doc_freq(self):
        for refs in self.crefs:
            # refs, k ref captions of one image
            for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]):
                self.document_frequency[ngram] += 1
            # maxcounts[ngram] = max(maxcounts.get(ngram,0), count)

    def compute_cider(self):
        def counts2vec(cnts):
            vec = [defaultdict(float) for _ in range(self.n)]
            length = 0
            norm = [0.0 for _ in range(self.n)]
            for (ngram, term_freq) in cnts.items():
                # give word count 1 if it doesn't appear in reference corpus
                df = np.log(max(1.0, self.document_frequency[ngram]))
                # ngram index
                n = len(ngram)-1
                # tf (term_freq) * idf (precomputed idf) for n-grams
                vec[n][ngram] = float(term_freq)*(self.ref_len - df)
                # compute norm for the vector.  the norm will be used for computing similarity
                norm[n] += pow(vec[n][ngram], 2)

                if n == 1:
                    length += term_freq
            norm = [np.sqrt(n) for n in norm]
            return vec, norm, length

        def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
            delta = float(length_hyp - length_ref)
            # measure consine similarity
            val = np.array([0.0 for _ in range(self.n)])
            for n in range(self.n):
                # ngram
                for (ngram,count) in vec_hyp[n].items():
                    # vrama91 : added clipping
                    val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram]

                if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
                    val[n] /= (norm_hyp[n]*norm_ref[n])

                assert(not math.isnan(val[n]))
                # vrama91: added a length based gaussian penalty
                val[n] *= np.e**(-(delta**2)/(2*self.sigma**2))
            return val

        # compute log reference length
        self.ref_len = np.log(float(len(self.crefs)))
        if len(self.crefs) == 1:
            self.ref_len = 1
        scores = []
        for test, refs in zip(self.ctest, self.crefs):
            # compute vector for test captions
            vec, norm, length = counts2vec(test)
            # compute vector for ref captions
            score = np.array([0.0 for _ in range(self.n)])
            for ref in refs:
                vec_ref, norm_ref, length_ref = counts2vec(ref)
                score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
            # change by vrama91 - mean of ngram scores, instead of sum
            score_avg = np.mean(score)
            # divide by number of references
            score_avg /= len(refs)
            # multiply score by 10
            score_avg *= 10.0
            # append score of an image to the score list
            scores.append(score_avg)
        return scores

    def compute_score(self, option=None, verbose=0):
        # compute idf
        self.compute_doc_freq()
        # assert to check document frequency
        assert(len(self.ctest) >= max(self.document_frequency.values()))
        # compute cider score
        score = self.compute_cider()
        # debug
        # print score
        return np.mean(np.array(score)), np.array(score)

In [None]:
class Cider:
    """
    Main Class to compute the CIDEr metric

    """
    def __init__(self, test=None, refs=None, n=4, sigma=6.0):
        # set cider to sum over 1 to 4-grams
        self._n = n
        # set the standard deviation parameter for gaussian penalty
        self._sigma = sigma

    def compute_score(self, gts, res):

        assert(gts.keys() == res.keys())
        imgIds = gts.keys()

        cider_scorer = CiderScorer(n=self._n, sigma=self._sigma)

        for id in imgIds:
            hypo = res[id]
            ref = gts[id]

            # Sanity check.
            assert(type(hypo) is list)
            assert(len(hypo) == 1)
            assert(type(ref) is list)
            assert(len(ref) > 0)

            cider_scorer += (hypo[0], ref)

        (score, scores) = cider_scorer.compute_score()

        return score, scores

    def method(self):
        return "CIDEr"

# Model

In [None]:
class R2GenGPT(pl.LightningModule):
    def __init__(self, model_name, vision_model):
        super().__init__()

        cache_dir = '/content/huggingface'

        print(f'Loading vision encoder: {vision_model}')
        self.visual_encoder = AutoModel.from_pretrained(vision_model, cache_dir=cache_dir)
        # self.visual_encoder.train()
        for name,param in self.visual_encoder.named_parameters():
            param.requires_grad = False
        self.visual_encoder.gradient_checkpointing_enable()
        print('Loading trainable vision encoder is Done')

        print('Loading LLAMA')
        # use low resources
        self.llama_model = AutoModelForCausalLM.from_pretrained(model_name,
                                                                torch_dtype=torch.float16,
                                                                device_map="cuda",
                                                                load_in_8bit=True,
                                                                cache_dir=cache_dir,
                                                                trust_remote_code=True)
        self.llama_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, cache_dir=cache_dir)
        # explicitly setting bos_token_id and eos_token_id if not already defined
        if self.llama_tokenizer.bos_token_id is None:
            self.llama_tokenizer.bos_token_id = self.llama_tokenizer.eos_token_id
            self.llama_tokenizer.eos_token_id = 2
        self.llama_tokenizer.pad_token_id = 0
        self.llama_model.generation_config.pad_token_id = self.llama_tokenizer.pad_token_id

        self.embed_tokens = self.llama_model.get_input_embeddings()
        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            r=16,
            lora_alpha=32,
            lora_dropout=0.1,
            # target_modules=[
            #     "embed_int", "embed_out", "attention.query_key_value"
            # ]
            # target_modules=["q_proj", "v_proj"]
            target_modules=["transformer.h.0.attn.c_attn", "transformer.h.0.attn.c_proj"]
        )
        self.llama_model = get_peft_model(self.llama_model, peft_config)
        self.llama_model.print_trainable_parameters()
        self.llama_model.gradient_checkpointing_enable()
        print('Loading LLAMA LoRA Done')

        # lora not used
        # self.embed_tokens = self.llama_model.get_input_embeddings()
        # for name, param in self.llama_model.named_parameters():
        #     param.requires_grad = False
        # print('Loading LLAMA is Done')

        self.llama_proj = nn.Linear(self.visual_encoder.num_features, self.llama_model.config.hidden_size)
        self.layer_norm = nn.LayerNorm(self.llama_model.config.hidden_size)
        self.end_sym = '</s>'
        self.prompt = 'Generate a comprehensive and detailed diagnosis report for this chest xray image.'
        self.val_step_outputs = []
        self.test_step_outputs = []
        self.val_score = 0.0


    def score(self, ref, hypo):
        scorers = [
            (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
            (Rouge(), "ROUGE_L"),
            (Cider(), "CIDEr")
        ]
        final_scores = {}
        for scorer, method in scorers:
            score, scores = scorer.compute_score(ref, hypo)
            if type(score) == list:
                for m, s in zip(method, score):
                    final_scores[m] = s
            else:
                final_scores[method] = score
        return final_scores


    def encode_img(self, images):
        image_embeds = []
        for image in images:
            device = image.device
            image_embed = self.visual_encoder(image)['last_hidden_state'].to(device)
            image_embeds.append(image_embed)

        image_embeds = torch.stack(image_embeds).mean(0)
        inputs_llama = self.llama_proj(image_embeds)
        atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
        return inputs_llama, atts_llama


    def prompt_wrap(self, img_embeds, atts_img):
        prompt=f'Human: <Img><ImageHere></Img> {self.prompt} \nAssistant:'
        batch_size = img_embeds.shape[0]
        p_before, p_after = prompt.split('<ImageHere>')
        p_before_tokens = self.llama_tokenizer(
            p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
        p_after_tokens = self.llama_tokenizer(
            p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
        p_before_embeds = self.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1)
        p_after_embeds = self.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1)
        wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds, p_after_embeds], dim=1)
        wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1])
        return wrapped_img_embeds, wrapped_atts_img


    def forward(self, samples):
        image = samples["image"]
        img_embeds, atts_img = self.encode_img(image)
        img_embeds = self.layer_norm(img_embeds)

        img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img)

        self.llama_tokenizer.padding_side = "right"
        text = [t + self.end_sym for t in samples["input_text"]]

        to_regress_tokens = self.llama_tokenizer(
            text,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=100,
            add_special_tokens=False
        ).to(image[0].device)

        targets = to_regress_tokens.input_ids.masked_fill(
            to_regress_tokens.input_ids == 0, -100
        )

        empty_targets = (
            torch.ones([atts_img.shape[0], atts_img.shape[1]+1],
                       dtype=torch.long).to(image[0].device).fill_(-100)  # plus one for bos
        )
        targets = torch.cat([empty_targets, targets], dim=1)

        batch_size = img_embeds.shape[0]
        bos = torch.ones([batch_size, 1],
                         dtype=to_regress_tokens.input_ids.dtype,
                         device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
        bos_embeds = self.embed_tokens(bos)
        atts_bos = atts_img[:, :1]

        to_regress_embeds = self.embed_tokens(to_regress_tokens.input_ids)
        inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1)
        attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1)

        outputs = self.llama_model(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            return_dict=True,
            labels=targets,
        )
        loss = outputs.loss
        return {"loss": loss}

    def training_step(self, batch, batch_idx):
        result = self(batch)
        self.log_dict(result, prog_bar=True)
        return result

    def save_checkpoint(self, eval_res):
        current_epoch, global_step = self.trainer.current_epoch, self.trainer.global_step
        param_grad_dic = {
            k: v.requires_grad for (k, v) in self.named_parameters() if v.requires_grad
        }
        state_dict = self.state_dict()
        for k in list(state_dict.keys()):
            if k not in param_grad_dic.keys():
                del state_dict[k]
        save_obj = {
            "model": state_dict,
            "epoch": current_epoch,
            "step":global_step
        }
        savedmodel_path = '/content/savedmodels'
        os.makedirs(os.path.join(savedmodel_path, 'checkpoints'), exist_ok=True)
        save_to = os.path.join(
            savedmodel_path, 'checkpoints',
            "checkpoint_epoch{}_step{}_bleu{:3f}.pth".format(current_epoch, global_step, eval_res['Bleu_4']),
        )
        self.print("Saving checkpoint at step {} to {}.".format(global_step, save_to))
        torch.save(save_obj, save_to)

    def validation_step(self, samples, batch_idx):
        self.llama_tokenizer.padding_side = "right"
        to_regress_tokens = self.llama_tokenizer(
            samples['input_text'],
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=100,
            add_special_tokens=False
        )

        image = samples["image"]
        img_embeds, atts_img = self.encode_img(image)
        img_embeds = self.layer_norm(img_embeds)
        img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img)

        batch_size = img_embeds.shape[0]
        bos = torch.ones([batch_size, 1],
                         dtype=atts_img.dtype,
                         device=atts_img.device) * self.llama_tokenizer.bos_token_id
        bos_embeds = self.embed_tokens(bos)
        atts_bos = atts_img[:, :1]

        inputs_embeds = torch.cat([bos_embeds, img_embeds], dim=1)
        attention_mask = torch.cat([atts_bos, atts_img], dim=1)

        outputs = self.llama_model.generate(
            inputs_embeds=inputs_embeds,
            num_beams=1,
            do_sample=False,
            min_new_tokens=80,
            max_new_tokens=120,
            repetition_penalty=2.0,
            length_penalty=2.0,
            temperature=0,
        )
        hypo = [self.decode(i) for i in outputs]
        ref = [self.decode(i) for i in to_regress_tokens['input_ids']]
        self.val_step_outputs.append({"hypo": hypo, "ref": ref, "id": samples["id"]})
        return hypo, ref

    def decode(self, output_token):
        if output_token[0] == 0:  # unknown token <unk> at the beginning. remove it
            output_token = output_token[1:]
        if output_token[0] == 1:  # start token <s> at the beginning. remove it
            output_token = output_token[1:]
        output_text = self.llama_tokenizer.decode(output_token, add_special_tokens=False)
        output_text = output_text.split('</s>')[0].strip()
        output_text = output_text.replace('<unk>', '')
        return output_text

    def on_validation_epoch_end(self):
        ref, hypo, ids = [], [], []
        for i in self.val_step_outputs:
            ref.extend(i['ref'])
            hypo.extend(i['hypo'])
            ids.extend(i['id'])

        ref = {k:[v] for k, v in zip(ids, ref)}
        hypo = {k:[v] for k, v in zip(ids, hypo)}
        eval_res = self.score(ref=ref,hypo=hypo)
        self.log_dict(eval_res, sync_dist=True, logger=True)

        savedmodel_path = '/content/savedmodels'
        result_folder = os.path.join(savedmodel_path, 'result')
        os.makedirs(result_folder, exist_ok=True)
        current_epoch, global_step = self.trainer.current_epoch, self.trainer.global_step
        json.dump(hypo, open(os.path.join(result_folder, f"result_{current_epoch}_{global_step}" + '.json'), 'w'))
        json.dump(ref, open(os.path.join(result_folder, 'refs.json'), 'w'))
        self.print(eval_res)

        val_score = 0
        scorer_types = ['Bleu_4']
        weights = [1]
        for score_type, weight in zip(scorer_types, weights):
            val_score += eval_res[score_type] * weight

        if self.trainer.local_rank == 0:
            if val_score > self.val_score:
                self.save_checkpoint(eval_res)
                self.val_score = val_score
        self.val_step_outputs.clear()


    def test_step(self, samples, batch_idx):
        self.llama_tokenizer.padding_side = "right"
        to_regress_tokens = self.llama_tokenizer(
            samples['input_text'],
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=100,
            add_special_tokens=False
        )

        image = samples["image"]
        img_embeds, atts_img = self.encode_img(image)
        img_embeds = self.layer_norm(img_embeds)
        img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img)

        batch_size = img_embeds.shape[0]
        bos = torch.ones([batch_size, 1],
                         dtype=atts_img.dtype,
                         device=atts_img.device) * self.llama_tokenizer.bos_token_id
        bos_embeds = self.embed_tokens(bos)
        atts_bos = atts_img[:, :1]

        inputs_embeds = torch.cat([bos_embeds, img_embeds], dim=1)
        attention_mask = torch.cat([atts_bos, atts_img], dim=1)

        outputs = self.llama_model.generate(
            inputs_embeds=inputs_embeds,
            num_beams=3,
            do_sample=False,
            min_new_tokens=80,
            max_new_tokens=120,
            repetition_penalty=2.0,
            length_penalty=2.0,
            temperature=0,
        )
        hypo = [self.decode(i) for i in outputs]
        ref = [self.decode(i) for i in to_regress_tokens['input_ids']]
        self.test_step_outputs.append({"hypo": hypo, "ref": ref, "id": samples["id"]})
        return hypo, ref


    def on_test_epoch_end(self):
        ref, hypo, ids = [], [], []
        for i in self.test_step_outputs:
            ref.extend(i['ref'])
            hypo.extend(i['hypo'])
            ids.extend(i['id'])

        ref = {k:[v] for k, v in zip(ids, ref)}
        hypo = {k:[v] for k, v in zip(ids, hypo)}
        eval_res = self.score(ref=ref,hypo=hypo)

        savedmodel_path = '/content/savedmodels'
        result_folder = os.path.join(savedmodel_path, 'result')
        os.makedirs(result_folder, exist_ok=True)
        json.dump(hypo, open(os.path.join(result_folder, f"test_result.json"), 'w'))
        json.dump(ref, open(os.path.join(result_folder, 'test_refs.json'), 'w'))
        self.print(f"Test result : {eval_res}")

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-5)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=3, eta_min=1e-6)
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

    def get_progress_bar_dict(self):
        # don't show the version number
        items = super().get_progress_bar_dict()
        items.pop("v_num", None)
        return items

    def optimizer_zero_grad(self, epoch, batch_idx, optimizer):
        optimizer.zero_grad()

# Training

Steps to overcome GPU limitations



1.   Reducing batch size
2.   using model gradient checkpointing
3.   setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
4.   using load_in_8bit in model (update bitsandbytes)
5.   clearing GPU before using



In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [None]:
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

In [None]:
model_names = {
    'llama': [
        # LLama
        "meta-llama/Llama-3.2-3B-Instruct",
        "meta-llama/Llama-3.2-11B-Vision-Instruct",
        "meta-llama/Llama-3.2-3B",
        "meta-llama/Llama-3.2-11B-Vision",
        "meta-llama/Meta-Llama-3-8B"
    ],
    'deepseek': [
        # Deepseek
        "deepseek-ai/deepseek-llm-7b-chat",
        "nvidia/DeepSeek-R1-FP4",
        "deepseek-ai/DeepSeek-V2"
    ],
    'gpt': [
        # GPT
        "EleutherAI/gpt-neo-1.3B",
        "EleutherAI/gpt-neox-20b"
    ],
    'google': [
        # Gemini
        "google/gemma-3-1b-it"
    ]
}
vision_models = {
    "microsoft": [
        "microsoft/swin-tiny-patch4-window7-224",
        "microsoft/swin-base-patch4-window12-384"
    ],
    "facebook": [
        "facebook/maskformer-swin-base-coco",
        "facebook/detr-resnet-50"
    ]
}
small_models = ["Qwen/Qwen2-1.5B-Instruct", "meta-llama/Llama-3.2-1B",
                "cerebras/Cerebras-GPT-1.3B", "microsoft/Phi-3.5-vision-instruct", "stabilityai/stablelm-zephyr-3b",
                "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mtgv/MobileLLaMA-2.7B-Base", "MBZUAI/LaMini-GPT-1.5B",
                "google/gemma-3-4b-it", "openbmb/MiniCPM3-4B", "apple/OpenELM-3B", "ThisIsATest/dclm-1b-raw",
                "tensoropera/Fox-1-1.6B"]

In [None]:
!pip install flash-attn --no-build-isolation

In [None]:
dm = DataModule()
callbacks = add_callbacks()


trainer = pl.Trainer(
    devices=1,
    num_nodes=1,
    strategy='auto',
    accelerator='auto',
    precision='bf16-mixed',
    enable_checkpointing=True,
    # limit_train_batches=1.0,
    # limit_val_batches=1.0,
    val_check_interval=1.0,
    max_epochs=15,
    num_sanity_val_steps=2,
    accumulate_grad_batches=8,
    gradient_clip_val=1.0,
    callbacks=callbacks["callbacks"],
    logger=callbacks["loggers"]
)


os.makedirs('savedmodels', exist_ok=True)
seed_everything(42, workers=True)


model = R2GenGPT("cerebras/Cerebras-GPT-1.3B", vision_models['microsoft'][0])
trainer.fit(model, datamodule=dm)

INFO: Using bfloat16 Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using bfloat16 Automatic Mixed Precision (AMP)
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: `Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..
INFO:lightning.pytorch.utilities.rank_zero:`Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..
INFO: Seed set to 42
INFO:lightning.fabric.utilities.seed:Seed set to 42


Loading vision encoder: microsoft/swin-tiny-patch4-window7-224
Loading trainable vision encoder is Done
Loading LLAMA


The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
INFO: You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:lightning.pytorch.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


trainable params: 196,608 || all params: 1,315,919,872 || trainable%: 0.0149
Loading LLAMA LoRA Done


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name           | Type                 | Params | Mode 
----------------------------------------------------------------
0 | visual_encoder | SwinModel            | 27.5 M | eval 
1 | llama_model    | PeftModelForCausalLM | 1.3 B  | train
2 | embed_tokens   | Embedding            | 102 M  | eval 
3 | llama_proj     | Linear               | 1.6 M  | train
4 | layer_norm     | LayerNorm            | 4.1 K  | train
----------------------------------------------------------------
1.8 M     Trainable params
1.3 B     Non-trainable par

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

{'Bleu_1': 0.07163886162897778, 'Bleu_2': 2.6725378924406374e-10, 'Bleu_3': 4.1671921866839303e-13, 'Bleu_4': 1.6522579116923744e-14, 'ROUGE_L': np.float64(0.06234460669707777), 'CIDEr': np.float64(0.009191026076713342)}
Saving checkpoint at step 0 to /content/savedmodels/checkpoints/checkpoint_epoch0_step0_bleu0.000000.pth.


Training: |          | 0/? [00:00<?, ?it/s]

`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.


Validation: |          | 0/? [00:00<?, ?it/s]

{'Bleu_1': 0.08769931662869535, 'Bleu_2': 0.02157628166134561, 'Bleu_3': 0.006233209957181524, 'Bleu_4': 3.683013792125495e-07, 'ROUGE_L': np.float64(0.05299153056951855), 'CIDEr': np.float64(0.0036625670072388365)}
Saving checkpoint at step 65 to /content/savedmodels/checkpoints/checkpoint_epoch0_step65_bleu0.000000.pth.


Validation: |          | 0/? [00:00<?, ?it/s]

{'Bleu_1': 0.21875608032589092, 'Bleu_2': 0.08915185368757404, 'Bleu_3': 0.04134115237010449, 'Bleu_4': 0.02168431185888477, 'ROUGE_L': np.float64(0.14573184020020324), 'CIDEr': np.float64(0.014784668942041787)}
Saving checkpoint at step 130 to /content/savedmodels/checkpoints/checkpoint_epoch1_step130_bleu0.021684.pth.


Validation: |          | 0/? [00:00<?, ?it/s]

{'Bleu_1': 0.23774484612674007, 'Bleu_2': 0.09703324440589961, 'Bleu_3': 0.043239587904740806, 'Bleu_4': 0.022231392399097865, 'ROUGE_L': np.float64(0.16440558824592802), 'CIDEr': np.float64(0.018875243641001828)}
Saving checkpoint at step 195 to /content/savedmodels/checkpoints/checkpoint_epoch2_step195_bleu0.022231.pth.


Validation: |          | 0/? [00:00<?, ?it/s]

{'Bleu_1': 0.22517880489116163, 'Bleu_2': 0.09469291770832573, 'Bleu_3': 0.04648159337610684, 'Bleu_4': 0.024713430196469346, 'ROUGE_L': np.float64(0.16701066646417448), 'CIDEr': np.float64(0.02073413444515822)}
Saving checkpoint at step 260 to /content/savedmodels/checkpoints/checkpoint_epoch3_step260_bleu0.024713.pth.


Validation: |          | 0/? [00:00<?, ?it/s]

{'Bleu_1': 0.20550896808757246, 'Bleu_2': 0.09857847073766463, 'Bleu_3': 0.054684126181752564, 'Bleu_4': 0.0307470310055665, 'ROUGE_L': np.float64(0.1710867676375799), 'CIDEr': np.float64(0.01899149576293443)}
Saving checkpoint at step 325 to /content/savedmodels/checkpoints/checkpoint_epoch4_step325_bleu0.030747.pth.


Validation: |          | 0/? [00:00<?, ?it/s]

{'Bleu_1': 0.20801546661092027, 'Bleu_2': 0.10246477067632238, 'Bleu_3': 0.058242691831938184, 'Bleu_4': 0.03470496106224875, 'ROUGE_L': np.float64(0.17304014505353713), 'CIDEr': np.float64(0.015709183075994903)}
Saving checkpoint at step 390 to /content/savedmodels/checkpoints/checkpoint_epoch5_step390_bleu0.034705.pth.


Validation: |          | 0/? [00:00<?, ?it/s]

{'Bleu_1': 0.21066208082544075, 'Bleu_2': 0.10798509932148882, 'Bleu_3': 0.0629354275972706, 'Bleu_4': 0.03769871090016194, 'ROUGE_L': np.float64(0.1761389066455386), 'CIDEr': np.float64(0.0219274817879064)}
Saving checkpoint at step 455 to /content/savedmodels/checkpoints/checkpoint_epoch6_step455_bleu0.037699.pth.


Validation: |          | 0/? [00:00<?, ?it/s]

{'Bleu_1': 0.21522164767306515, 'Bleu_2': 0.10973911262758664, 'Bleu_3': 0.0635038546881491, 'Bleu_4': 0.03781773902545432, 'ROUGE_L': np.float64(0.17886514803799267), 'CIDEr': np.float64(0.019340173599680117)}
Saving checkpoint at step 520 to /content/savedmodels/checkpoints/checkpoint_epoch7_step520_bleu0.037818.pth.


Validation: |          | 0/? [00:00<?, ?it/s]

{'Bleu_1': 0.22214156079853656, 'Bleu_2': 0.1137987974298749, 'Bleu_3': 0.06719807259794856, 'Bleu_4': 0.041509641888606556, 'ROUGE_L': np.float64(0.18533972486556846), 'CIDEr': np.float64(0.021093983426993184)}
Saving checkpoint at step 585 to /content/savedmodels/checkpoints/checkpoint_epoch8_step585_bleu0.041510.pth.


Validation: |          | 0/? [00:00<?, ?it/s]

{'Bleu_1': 0.21833224032688198, 'Bleu_2': 0.11270166654945159, 'Bleu_3': 0.06608660092804028, 'Bleu_4': 0.0401807243435374, 'ROUGE_L': np.float64(0.18457975307847727), 'CIDEr': np.float64(0.020767382676066527)}


Validation: |          | 0/? [00:00<?, ?it/s]

{'Bleu_1': 0.21858328393597024, 'Bleu_2': 0.11097929178696044, 'Bleu_3': 0.06567577290452517, 'Bleu_4': 0.04156784718951624, 'ROUGE_L': np.float64(0.18335032116428968), 'CIDEr': np.float64(0.01551311244779132)}
Saving checkpoint at step 715 to /content/savedmodels/checkpoints/checkpoint_epoch10_step715_bleu0.041568.pth.


Validation: |          | 0/? [00:00<?, ?it/s]

{'Bleu_1': 0.21901216182674516, 'Bleu_2': 0.11269443277610454, 'Bleu_3': 0.06771316910700731, 'Bleu_4': 0.042637934955139935, 'ROUGE_L': np.float64(0.18753294071070647), 'CIDEr': np.float64(0.018326146187073903)}
Saving checkpoint at step 780 to /content/savedmodels/checkpoints/checkpoint_epoch11_step780_bleu0.042638.pth.


Validation: |          | 0/? [00:00<?, ?it/s]

{'Bleu_1': 0.2119798114465188, 'Bleu_2': 0.10916304366222256, 'Bleu_3': 0.06516717124674198, 'Bleu_4': 0.041520299600718756, 'ROUGE_L': np.float64(0.18688293185066235), 'CIDEr': np.float64(0.016206677941951974)}


Validation: |          | 0/? [00:00<?, ?it/s]

{'Bleu_1': 0.21644793562600745, 'Bleu_2': 0.11149147866924645, 'Bleu_3': 0.0658973479554643, 'Bleu_4': 0.04207835785245351, 'ROUGE_L': np.float64(0.19020708110516854), 'CIDEr': np.float64(0.016174653789098915)}


Validation: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=15` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=15` reached.


{'Bleu_1': 0.21437317509653317, 'Bleu_2': 0.10972896512866859, 'Bleu_3': 0.06582636658419258, 'Bleu_4': 0.04237523424812528, 'ROUGE_L': np.float64(0.18996892362832493), 'CIDEr': np.float64(0.01926251985296887)}


# Validation

In [None]:
temp_model = AutoModelForCausalLM.from_pretrained('cerebras/Cerebras-GPT-1.3B')
print(temp_model)

In [None]:
trainer.validate(model, datamodule=dm)

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |          | 0/? [00:00<?, ?it/s]

{'Bleu_1': 0.0002901461936162852, 'Bleu_2': 5.4360320036983315e-12, 'Bleu_3': 1.470265040864592e-14, 'Bleu_4': 7.757338710926842e-16, 'ROUGE_L': np.float64(0.0003738576819515393), 'CIDEr': np.float64(1.8261508622314814e-06)}


[{'Bleu_1': 0.0002901461848523468,
  'Bleu_2': 5.436031833006005e-12,
  'Bleu_3': 1.4702650514707323e-14,
  'Bleu_4': 7.757338959796104e-16,
  'ROUGE_L': 0.0003738576779142022,
  'CIDEr': 1.8261508785144542e-06}]

# Test

In [None]:
trainer.test(model, datamodule=dm)

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

Test result : {'Bleu_1': 0.00030478726532532643, 'Bleu_2': 4.060231264320486e-12, 'Bleu_3': 9.842314031824098e-15, 'Bleu_4': 4.933600046780447e-16, 'ROUGE_L': np.float64(0.00035118754602809376), 'CIDEr': np.float64(7.314112983465712e-06)}


[{}]

# Clinical Efficacy

In [None]:
class R2GenGPT(pl.LightningModule):
    def __init__(self, model_name, vision_model):
        super().__init__()

        cache_dir = '/content/huggingface'

        print(f'Loading vision encoder: {vision_model}')
        self.visual_encoder = AutoModel.from_pretrained(vision_model, cache_dir=cache_dir)
        self.visual_encoder.train()
        self.visual_encoder.gradient_checkpointing_enable()
        print('Loading trainable vision encoder is Done')

        print('Loading LLAMA')
        # use low resources
        self.llama_model = AutoModelForCausalLM.from_pretrained(model_name,
                                                                torch_dtype=torch.float16,
                                                                device_map="cuda",
                                                                load_in_8bit=True,
                                                                cache_dir=cache_dir,
                                                                trust_remote_code=True)
        self.llama_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, cache_dir=cache_dir)
        # explicitly setting bos_token_id and eos_token_id if not already defined
        if self.llama_tokenizer.bos_token_id is None:
            self.llama_tokenizer.bos_token_id = self.llama_tokenizer.eos_token_id
            self.llama_tokenizer.eos_token_id = 2
        self.llama_tokenizer.pad_token_id = 0
        self.llama_model.generation_config.pad_token_id = self.llama_tokenizer.pad_token_id

        self.embed_tokens = self.llama_model.get_input_embeddings()
        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            r=16,
            lora_alpha=32,
            lora_dropout=0.1,
            target_modules=["transformer.h.0.attn.c_attn", "transformer.h.0.attn.c_proj"]
        )
        self.llama_model = get_peft_model(self.llama_model, peft_config)
        self.llama_model.print_trainable_parameters()
        self.llama_model.gradient_checkpointing_enable()
        print('Loading LLAMA LoRA Done')

        self.llama_proj = nn.Linear(self.visual_encoder.num_features, self.llama_model.config.hidden_size)
        self.layer_norm = nn.LayerNorm(self.llama_model.config.hidden_size)
        self.end_sym = '</s>'
        self.prompt = 'Generate a comprehensive and detailed diagnosis report for this chest xray image.'
        self.val_step_outputs = []
        self.test_step_outputs = []
        self.val_score = 0.0


    def score(self, ref, hypo):
        scorers = [
            (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
            (Rouge(), "ROUGE_L"),
            (Cider(), "CIDEr")
        ]
        final_scores = {}
        for scorer, method in scorers:
            score, scores = scorer.compute_score(ref, hypo)
            if type(score) == list:
                for m, s in zip(method, score):
                    final_scores[m] = s
            else:
                final_scores[method] = score
        return final_scores


    def encode_img(self, images):
        image_embeds = []
        for image in images:
            device = image.device
            image_embed = self.visual_encoder(image)['last_hidden_state'].to(device)
            image_embeds.append(image_embed)

        image_embeds = torch.stack(image_embeds).mean(0)
        inputs_llama = self.llama_proj(image_embeds)
        atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
        return inputs_llama, atts_llama


    def prompt_wrap(self, img_embeds, atts_img):
        prompt=f'Human: <Img><ImageHere></Img> {self.prompt} \nAssistant:'
        batch_size = img_embeds.shape[0]
        p_before, p_after = prompt.split('<ImageHere>')
        p_before_tokens = self.llama_tokenizer(
            p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
        p_after_tokens = self.llama_tokenizer(
            p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
        p_before_embeds = self.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1)
        p_after_embeds = self.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1)
        wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds, p_after_embeds], dim=1)
        wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1])
        return wrapped_img_embeds, wrapped_atts_img


    def forward(self, samples):
        image = samples["image"]
        img_embeds, atts_img = self.encode_img(image)
        img_embeds = self.layer_norm(img_embeds)

        img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img)

        self.llama_tokenizer.padding_side = "right"
        text = [t + self.end_sym for t in samples["input_text"]]

        to_regress_tokens = self.llama_tokenizer(
            text,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=100,
            add_special_tokens=False
        ).to(image[0].device)

        targets = to_regress_tokens.input_ids.masked_fill(
            to_regress_tokens.input_ids == 0, -100
        )

        empty_targets = (
            torch.ones([atts_img.shape[0], atts_img.shape[1]+1],
                       dtype=torch.long).to(image[0].device).fill_(-100)  # plus one for bos
        )
        targets = torch.cat([empty_targets, targets], dim=1)

        batch_size = img_embeds.shape[0]
        bos = torch.ones([batch_size, 1],
                         dtype=to_regress_tokens.input_ids.dtype,
                         device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
        bos_embeds = self.embed_tokens(bos)
        atts_bos = atts_img[:, :1]

        to_regress_embeds = self.embed_tokens(to_regress_tokens.input_ids)
        inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1)
        attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1)

        outputs = self.llama_model(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            return_dict=True,
            labels=targets,
        )
        loss = outputs.loss
        return {"loss": loss}

    def training_step(self, batch, batch_idx):
        result = self(batch)
        self.log_dict(result, prog_bar=True)
        return result

    def save_checkpoint(self, eval_res):
        current_epoch, global_step = self.trainer.current_epoch, self.trainer.global_step
        param_grad_dic = {
            k: v.requires_grad for (k, v) in self.named_parameters() if v.requires_grad
        }
        state_dict = self.state_dict()
        for k in list(state_dict.keys()):
            if k not in param_grad_dic.keys():
                del state_dict[k]
        save_obj = {
            "model": state_dict,
            "epoch": current_epoch,
            "step":global_step
        }
        savedmodel_path = '/content/savedmodels'
        os.makedirs(os.path.join(savedmodel_path, 'checkpoints'), exist_ok=True)
        save_to = os.path.join(
            savedmodel_path, 'checkpoints',
            "checkpoint_epoch{}_step{}_bleu{:3f}.pth".format(current_epoch, global_step, eval_res['Bleu_4']),
        )
        self.print("Saving checkpoint at step {} to {}.".format(global_step, save_to))
        torch.save(save_obj, save_to)

    def validation_step(self, samples, batch_idx):
        self.llama_tokenizer.padding_side = "right"
        to_regress_tokens = self.llama_tokenizer(
            samples['input_text'],
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=100,
            add_special_tokens=False
        )

        image = samples["image"]
        img_embeds, atts_img = self.encode_img(image)
        img_embeds = self.layer_norm(img_embeds)
        img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img)

        batch_size = img_embeds.shape[0]
        bos = torch.ones([batch_size, 1],
                         dtype=atts_img.dtype,
                         device=atts_img.device) * self.llama_tokenizer.bos_token_id
        bos_embeds = self.embed_tokens(bos)
        atts_bos = atts_img[:, :1]

        inputs_embeds = torch.cat([bos_embeds, img_embeds], dim=1)
        attention_mask = torch.cat([atts_bos, atts_img], dim=1)

        outputs = self.llama_model.generate(
            inputs_embeds=inputs_embeds,
            num_beams=1,
            do_sample=False,
            min_new_tokens=80,
            max_new_tokens=120,
            repetition_penalty=2.0,
            length_penalty=2.0,
            temperature=0,
        )
        hypo = [self.decode(i) for i in outputs]
        ref = [self.decode(i) for i in to_regress_tokens['input_ids']]
        self.val_step_outputs.append({"hypo": hypo, "ref": ref, "id": samples["id"]})
        return hypo, ref

    def decode(self, output_token):
        if output_token[0] == 0:  # unknown token <unk> at the beginning. remove it
            output_token = output_token[1:]
        if output_token[0] == 1:  # start token <s> at the beginning. remove it
            output_token = output_token[1:]
        output_text = self.llama_tokenizer.decode(output_token, add_special_tokens=False)
        output_text = output_text.split('</s>')[0].strip()
        output_text = output_text.replace('<unk>', '')
        return output_text

    def extract_clinical_entities(self, text):
        """
        Extract important clinical findings from radiology reports.
        Returns a dictionary with entities as keys and True/False as values.
        """
        # Define important clinical findings/entities for IU X-ray dataset
        entities = [
            "pneumonia", "cardiomegaly", "effusion", "pneumothorax", "edema",
            "atelectasis", "consolidation", "mass", "nodule", "infiltration",
            "fracture", "pleural thickening", "emphysema", "fibrosis", "opacity"
        ]

        # Handle negations (common in medical reports)
        negation_terms = ["no ", "not ", "without ", "free of ", "absence of ", "negative for "]

        findings = {}
        text = text.lower()

        for entity in entities:
            # Check if entity exists in text
            if entity in text:
                # Check if entity is negated
                negated = False
                for negation in negation_terms:
                    # Look for negation terms within reasonable distance before the entity
                    search_window = 20  # characters
                    entity_pos = text.find(entity)
                    if entity_pos >= 0:
                        start_pos = max(0, entity_pos - search_window)
                        if negation in text[start_pos:entity_pos + len(entity)]:
                            negated = True
                            break
                findings[entity] = not negated

        return findings

    def calculate_clinical_metrics(self, reference_texts, generated_texts):
        """
        Calculate precision, recall, and F1 score for clinical entities.
        """
        total_true_positives = 0
        total_false_positives = 0
        total_false_negatives = 0

        entity_metrics = {}  # For per-entity metrics

        for ref_text, gen_text in zip(reference_texts, generated_texts):
            ref_findings = self.extract_clinical_entities(ref_text)
            gen_findings = self.extract_clinical_entities(gen_text)

            # Calculate overall metrics
            for entity in set(list(ref_findings.keys()) + list(gen_findings.keys())):
                ref_positive = ref_findings.get(entity, False)
                gen_positive = gen_findings.get(entity, False)

                # Initialize entity metrics if not exists
                if entity not in entity_metrics:
                    entity_metrics[entity] = {"tp": 0, "fp": 0, "fn": 0}

                # Both positive = true positive
                if ref_positive and gen_positive:
                    total_true_positives += 1
                    entity_metrics[entity]["tp"] += 1
                # Only in generated = false positive
                elif not ref_positive and gen_positive:
                    total_false_positives += 1
                    entity_metrics[entity]["fp"] += 1
                # Only in reference = false negative
                elif ref_positive and not gen_positive:
                    total_false_negatives += 1
                    entity_metrics[entity]["fn"] += 1

        # Calculate overall metrics
        precision = total_true_positives / (total_true_positives + total_false_positives) if (total_true_positives + total_false_positives) > 0 else 0
        recall = total_true_positives / (total_true_positives + total_false_negatives) if (total_true_positives + total_false_negatives) > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

        # Calculate per-entity metrics
        per_entity_results = {}
        for entity, counts in entity_metrics.items():
            if counts["tp"] + counts["fp"] > 0:
                entity_precision = counts["tp"] / (counts["tp"] + counts["fp"])
            else:
                entity_precision = 0

            if counts["tp"] + counts["fn"] > 0:
                entity_recall = counts["tp"] / (counts["tp"] + counts["fn"])
            else:
                entity_recall = 0

            if entity_precision + entity_recall > 0:
                entity_f1 = 2 * entity_precision * entity_recall / (entity_precision + entity_recall)
            else:
                entity_f1 = 0

            per_entity_results[entity] = {
                "precision": entity_precision,
                "recall": entity_recall,
                "f1": entity_f1,
                "support": counts["tp"] + counts["fn"]  # How many times this entity appears in references
            }

        return {
            "overall": {
                "precision": precision,
                "recall": recall,
                "f1": f1
            },
            "per_entity": per_entity_results
        }

    def on_validation_epoch_end(self):
        ref, hypo, ids = [], [], []
        for i in self.val_step_outputs:
            ref.extend(i['ref'])
            hypo.extend(i['hypo'])
            ids.extend(i['id'])

        # Regular text generation metrics
        ref_dict = {k:[v] for k, v in zip(ids, ref)}
        hypo_dict = {k:[v] for k, v in zip(ids, hypo)}
        eval_res = self.score(ref=ref_dict, hypo=hypo_dict)

        # Calculate clinical metrics
        clinical_metrics = self.calculate_clinical_metrics(ref, hypo)

        # Add clinical metrics to evaluation results
        eval_res["clinical_precision"] = clinical_metrics["overall"]["precision"]
        eval_res["clinical_recall"] = clinical_metrics["overall"]["recall"]
        eval_res["clinical_f1"] = clinical_metrics["overall"]["f1"]

        self.log_dict(eval_res, sync_dist=True, logger=True)

        # Print detailed metrics
        self.print(f"Text Generation Metrics: BLEU-4: {eval_res['Bleu_4']:.4f}, ROUGE-L: {eval_res['ROUGE_L']:.4f}, CIDEr: {eval_res['CIDEr']:.4f}")
        self.print(f"Clinical Metrics - Precision: {clinical_metrics['overall']['precision']:.4f}, Recall: {clinical_metrics['overall']['recall']:.4f}, F1: {clinical_metrics['overall']['f1']:.4f}")

        # Save detailed per-entity metrics
        savedmodel_path = '/content/savedmodels'
        result_folder = os.path.join(savedmodel_path, 'result')
        os.makedirs(result_folder, exist_ok=True)
        current_epoch, global_step = self.trainer.current_epoch, self.trainer.global_step

        # Save clinical metrics
        json.dump(clinical_metrics, open(os.path.join(result_folder, f"clinical_metrics_{current_epoch}_{global_step}.json"), 'w'))

        # Original code for saving regular metrics
        json.dump(hypo_dict, open(os.path.join(result_folder, f"result_{current_epoch}_{global_step}" + '.json'), 'w'))
        json.dump(ref_dict, open(os.path.join(result_folder, 'refs.json'), 'w'))

        # Update validation score to include clinical metrics
        val_score = 0
        scorer_types = ['Bleu_4', 'clinical_f1']  # Consider both language and clinical metrics
        weights = [0.6, 0.4]  # Adjust weights based on your priorities
        for score_type, weight in zip(scorer_types, weights):
            val_score += eval_res[score_type] * weight

        if self.trainer.local_rank == 0:
            if val_score > self.val_score:
                self.save_checkpoint(eval_res)
                self.val_score = val_score
        self.val_step_outputs.clear()


    def test_step(self, samples, batch_idx):
        self.llama_tokenizer.padding_side = "right"
        to_regress_tokens = self.llama_tokenizer(
            samples['input_text'],
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=100,
            add_special_tokens=False
        )

        image = samples["image"]
        img_embeds, atts_img = self.encode_img(image)
        img_embeds = self.layer_norm(img_embeds)
        img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img)

        batch_size = img_embeds.shape[0]
        bos = torch.ones([batch_size, 1],
                         dtype=atts_img.dtype,
                         device=atts_img.device) * self.llama_tokenizer.bos_token_id
        bos_embeds = self.embed_tokens(bos)
        atts_bos = atts_img[:, :1]

        inputs_embeds = torch.cat([bos_embeds, img_embeds], dim=1)
        attention_mask = torch.cat([atts_bos, atts_img], dim=1)

        outputs = self.llama_model.generate(
            inputs_embeds=inputs_embeds,
            num_beams=3,
            do_sample=False,
            min_new_tokens=80,
            max_new_tokens=120,
            repetition_penalty=2.0,
            length_penalty=2.0,
            temperature=0,
        )
        hypo = [self.decode(i) for i in outputs]
        ref = [self.decode(i) for i in to_regress_tokens['input_ids']]
        self.test_step_outputs.append({"hypo": hypo, "ref": ref, "id": samples["id"]})
        return hypo, ref


    def on_test_epoch_end(self):
        ref, hypo, ids = [], [], []
        for i in self.test_step_outputs:
            ref.extend(i['ref'])
            hypo.extend(i['hypo'])
            ids.extend(i['id'])

        ref_dict = {k:[v] for k, v in zip(ids, ref)}
        hypo_dict = {k:[v] for k, v in zip(ids, hypo)}
        eval_res = self.score(ref=ref_dict, hypo=hypo_dict)

        # Calculate clinical metrics
        clinical_metrics = self.calculate_clinical_metrics(ref, hypo)

        # Add clinical metrics to evaluation results
        eval_res["clinical_precision"] = clinical_metrics["overall"]["precision"]
        eval_res["clinical_recall"] = clinical_metrics["overall"]["recall"]
        eval_res["clinical_f1"] = clinical_metrics["overall"]["f1"]

        savedmodel_path = '/content/savedmodels'
        result_folder = os.path.join(savedmodel_path, 'result')
        os.makedirs(result_folder, exist_ok=True)

        json.dump(hypo_dict, open(os.path.join(result_folder, f"test_result.json"), 'w'))
        json.dump(ref_dict, open(os.path.join(result_folder, 'test_refs.json'), 'w'))
        json.dump(clinical_metrics, open(os.path.join(result_folder, f"test_clinical_metrics.json"), 'w'))

        self.print(f"Test result (Text): {eval_res}")
        self.print(f"Test result (Clinical) - Precision: {clinical_metrics['overall']['precision']:.4f}, Recall: {clinical_metrics['overall']['recall']:.4f}, F1: {clinical_metrics['overall']['f1']:.4f}")

        # Create a comparison table with previous methods if available
        self.print("\nComparison with previous methods:")
        self.print("-" * 80)
        self.print("{:<20} {:<10} {:<10} {:<10} {:<10} {:<10}".format("Method", "BLEU-4", "ROUGE-L", "CIDEr", "Clinical-F1", "Avg"))
        self.print("-" * 80)

        # Add previous methods' results (you'll need to update these values)
        previous_methods = [
            # Model name, BLEU-4, ROUGE-L, CIDEr, Clinical-F1, Average
            ["R2Gen", 0.103, 0.277, 0.126, 0.345, "-"],
            ["CMN", 0.106, 0.294, 0.154, 0.320, "-"],
            ["PPKED", 0.113, 0.306, 0.185, 0.371, "-"],
            # Add more methods as needed
        ]

        for method in previous_methods:
            avg = sum(method[1:5])/4
            self.print("{:<20} {:<10.4f} {:<10.4f} {:<10.4f} {:<10.4f} {:<10.4f}".format(
                method[0], method[1], method[2], method[3], method[4], avg
            ))

        # Add current method
        current_avg = (eval_res['Bleu_4'] + eval_res['ROUGE_L'] + eval_res['CIDEr'] + clinical_metrics['overall']['f1'])/4
        self.print("{:<20} {:<10.4f} {:<10.4f} {:<10.4f} {:<10.4f} {:<10.4f}".format(
            "R2GenGPT", eval_res['Bleu_4'], eval_res['ROUGE_L'], eval_res['CIDEr'],
            clinical_metrics['overall']['f1'], current_avg
        ))
        self.print("-" * 80)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-5)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=3, eta_min=1e-6)
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

    def get_progress_bar_dict(self):
        # don't show the version number
        items = super().get_progress_bar_dict()
        items.pop("v_num", None)
        return items

    def optimizer_zero_grad(self, epoch, batch_idx, optimizer):
        optimizer.zero_grad()

In [None]:
dm = DataModule()
callbacks = add_callbacks()


trainer = pl.Trainer(
    devices=1,
    num_nodes=1,
    strategy='auto',
    accelerator='auto',
    precision='bf16-mixed',
    enable_checkpointing=True,
    # limit_train_batches=1.0,
    # limit_val_batches=1.0,
    val_check_interval=1.0,
    max_epochs=15,
    num_sanity_val_steps=2,
    accumulate_grad_batches=8,
    gradient_clip_val=1.0,
    callbacks=callbacks["callbacks"],
    logger=callbacks["loggers"]
)


os.makedirs('savedmodels', exist_ok=True)
seed_everything(42, workers=True)


model = R2GenGPT("cerebras/Cerebras-GPT-1.3B", vision_models['microsoft'][0])
trainer.fit(model, datamodule=dm)

INFO: Using bfloat16 Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using bfloat16 Automatic Mixed Precision (AMP)
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: `Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..
INFO:lightning.pytorch.utilities.rank_zero:`Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..
INFO: Seed set to 42
INFO:lightning.fabric.utilities.seed:Seed set to 42


Loading vision encoder: microsoft/swin-tiny-patch4-window7-224


config.json:   0%|          | 0.00/71.8k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/113M [00:00<?, ?B/s]

Loading trainable vision encoder is Done
Loading LLAMA


config.json:   0%|          | 0.00/360 [00:00<?, ?B/s]

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


pytorch_model.bin:   0%|          | 0.00/5.36G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/5.36G [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

INFO: You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:lightning.pytorch.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


trainable params: 196,608 || all params: 1,315,919,872 || trainable%: 0.0149
Loading LLAMA LoRA Done


preprocessor_config.json:   0%|          | 0.00/290 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name           | Type                 | Params | Mode 
----------------------------------------------------------------
0 | visual_encoder | SwinModel            | 27.5 M | train
1 | llama_model    | PeftModelForCausalLM | 1.3 B  | train
2 | embed_tokens   | Embedding            | 102 M  | eval 
3 | llama_proj     | Linear               | 1.6 M  | train
4 | layer_norm     | LayerNorm            | 4.1 K  | train
----------------------------------------------------------------
29.3 M    Trainable params
1.3 B     Non-trainable par

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Text Generation Metrics: BLEU-4: 0.0000, ROUGE-L: 0.0623, CIDEr: 0.0092
Clinical Metrics - Precision: 0.0000, Recall: 0.0000, F1: 0.0000
Saving checkpoint at step 0 to /content/savedmodels/checkpoints/checkpoint_epoch0_step0_bleu0.000000.pth.


Training: |          | 0/? [00:00<?, ?it/s]

`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.


Validation: |          | 0/? [00:00<?, ?it/s]

Text Generation Metrics: BLEU-4: 0.0133, ROUGE-L: 0.1622, CIDEr: 0.0266
Clinical Metrics - Precision: 0.2901, Recall: 0.2520, F1: 0.2697
Saving checkpoint at step 65 to /content/savedmodels/checkpoints/checkpoint_epoch0_step65_bleu0.013331.pth.


Validation: |          | 0/? [00:00<?, ?it/s]

Text Generation Metrics: BLEU-4: 0.0762, ROUGE-L: 0.2370, CIDEr: 0.0733
Clinical Metrics - Precision: 0.2647, Recall: 0.4102, F1: 0.3218
Saving checkpoint at step 130 to /content/savedmodels/checkpoints/checkpoint_epoch1_step130_bleu0.076208.pth.


Validation: |          | 0/? [00:00<?, ?it/s]

Text Generation Metrics: BLEU-4: 0.0557, ROUGE-L: 0.2094, CIDEr: 0.0463
Clinical Metrics - Precision: 0.2395, Recall: 0.4718, F1: 0.3177


Validation: |          | 0/? [00:00<?, ?it/s]

Text Generation Metrics: BLEU-4: 0.0583, ROUGE-L: 0.2229, CIDEr: 0.0558
Clinical Metrics - Precision: 0.2771, Recall: 0.5147, F1: 0.3602
Saving checkpoint at step 260 to /content/savedmodels/checkpoints/checkpoint_epoch3_step260_bleu0.058325.pth.


Validation: |          | 0/? [00:00<?, ?it/s]

Text Generation Metrics: BLEU-4: 0.0469, ROUGE-L: 0.1971, CIDEr: 0.0270
Clinical Metrics - Precision: 0.2368, Recall: 0.6139, F1: 0.3418


Validation: |          | 0/? [00:00<?, ?it/s]

Text Generation Metrics: BLEU-4: 0.0417, ROUGE-L: 0.1881, CIDEr: 0.0219
Clinical Metrics - Precision: 0.1979, Recall: 0.5684, F1: 0.2936


Validation: |          | 0/? [00:00<?, ?it/s]

Text Generation Metrics: BLEU-4: 0.0470, ROUGE-L: 0.1995, CIDEr: 0.0189
Clinical Metrics - Precision: 0.2181, Recall: 0.5952, F1: 0.3192


Validation: |          | 0/? [00:00<?, ?it/s]

Text Generation Metrics: BLEU-4: 0.0432, ROUGE-L: 0.1853, CIDEr: 0.0129
Clinical Metrics - Precision: 0.1968, Recall: 0.6971, F1: 0.3070


Validation: |          | 0/? [00:00<?, ?it/s]

Text Generation Metrics: BLEU-4: 0.0444, ROUGE-L: 0.1917, CIDEr: 0.0239
Clinical Metrics - Precision: 0.2109, Recall: 0.6944, F1: 0.3235


Validation: |          | 0/? [00:00<?, ?it/s]

Text Generation Metrics: BLEU-4: 0.0433, ROUGE-L: 0.1891, CIDEr: 0.0220
Clinical Metrics - Precision: 0.2132, Recall: 0.6568, F1: 0.3219


Validation: |          | 0/? [00:00<?, ?it/s]

Text Generation Metrics: BLEU-4: 0.0481, ROUGE-L: 0.1953, CIDEr: 0.0212
Clinical Metrics - Precision: 0.1836, Recall: 0.6434, F1: 0.2857


Validation: |          | 0/? [00:00<?, ?it/s]

Text Generation Metrics: BLEU-4: 0.0430, ROUGE-L: 0.1822, CIDEr: 0.0112
Clinical Metrics - Precision: 0.1947, Recall: 0.7158, F1: 0.3062


Validation: |          | 0/? [00:00<?, ?it/s]

Text Generation Metrics: BLEU-4: 0.0425, ROUGE-L: 0.1852, CIDEr: 0.0106
Clinical Metrics - Precision: 0.2041, Recall: 0.7131, F1: 0.3174


Validation: |          | 0/? [00:00<?, ?it/s]

Text Generation Metrics: BLEU-4: 0.0376, ROUGE-L: 0.1789, CIDEr: 0.0098
Clinical Metrics - Precision: 0.2057, Recall: 0.6997, F1: 0.3179


Validation: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=15` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=15` reached.


Text Generation Metrics: BLEU-4: 0.0401, ROUGE-L: 0.1834, CIDEr: 0.0151
Clinical Metrics - Precision: 0.2095, Recall: 0.6139, F1: 0.3124
