# References

Hugging Face - Documentation [online]. Huggingface.co. Available from: https://huggingface.co/docs.

PyTorch Foundation [online]. PyTorch. Available from: https://pytorch.org/.

zhanyuwang, n.d. R2GenGPT: Radiology Report Generation with Frozen LLMs. Available from: https://github.com/wang-zhanyu/R2GenGPT

# <i> Huggingface login </i>

To work with gated repositores, we need to login to huggingface hub

In [None]:
from huggingface_hub import notebook_login
from google.colab import userdata

notebook_login(userdata.get('HF_TOKEN'))



VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

# Importing demanded libraries

In [None]:
!pip install lightning

In [None]:
!pip install -U bitsandbytes

In [None]:
import os
import copy
import math
import json
import re
import yaml
import numpy as np
import functools
import pandas as pd
from PIL import Image
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import LambdaLR
import torch.optim as optim

from peft import get_peft_model, LoraConfig, TaskType
from transformers import (
    AutoImageProcessor,
    SwinModel,
    AutoModel,
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoModelForImageClassification,
    Trainer,
    TrainingArguments
)

import lightning.pytorch as pl
from lightning.pytorch import LightningDataModule, seed_everything
from lightning.pytorch import loggers as pl_loggers
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint, EarlyStopping

# Data helper

In [None]:
class FieldParser:
    def __init__(self):
        super().__init__()
        cache_dir = '/content/huggingface'
        self.vit_feature_extractor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50", cache_dir=cache_dir)

    def _parse_image(self, img):
        img = np.resize(img, (224, 224, 3))
        pixel_values = self.vit_feature_extractor(img, return_tensors="pt").pixel_values
        return pixel_values[0]

    def clean_report(self, report):
        report_cleaner = lambda t: t.replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '') \
        .replace('. 2. ', '. ').replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ') \
        .replace(' 2. ', '. ').replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \
        .strip().lower().split('. ')
        sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '').
                                        replace('\\', '').replace("'", '').strip().lower())
        tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []]
        report = ' . '.join(tokens) + ' .'
        # report = ' '.join(report.split()[:max_txt_len])
        return report

    def parse(self, features):
        to_return = {'id': features['id']}
        report = features.get("report", "")
        report = self.clean_report(report)
        to_return['input_text'] = report
        # chest x-ray images
        images = []
        for image_path in features['image_path']:
            with Image.open(os.path.join('/content/drive/MyDrive/iu_xray/images', image_path)) as pil:
                array = np.array(pil, dtype=np.uint8)
                if array.shape[-1] != 3 or len(array.shape) != 3:
                    array = np.array(pil.convert("RGB"), dtype=np.uint8)
                image = self._parse_image(array)
                images.append(image)
        to_return["image"] = images
        return to_return

    def transform_with_parse(self, inputs):
        return self.parse(inputs)

# Data Module

In [None]:
class ParseDataset(data.Dataset):
    def __init__(self, split='train'):
        filename = '/content/drive/MyDrive/iu_xray/annotation.json'
        self.meta = json.load(open(filename, 'r'))
        self.meta = self.meta[split]
        self.parser = FieldParser()

    def __len__(self):
        return len(self.meta)

    def __getitem__(self, index):
        return self.parser.transform_with_parse(self.meta[index])

In [None]:
def create_datasets():
    train_dataset = ParseDataset('train')
    dev_dataset = ParseDataset('val')
    test_dataset = ParseDataset('test')
    return train_dataset, dev_dataset, test_dataset

In [None]:
class DataModule(LightningDataModule):
    def __init__(self):
        super().__init__()

    def setup(self, stage: str):
        train_dataset, dev_dataset, test_dataset = create_datasets()
        self.dataset = {
            "train": train_dataset, "validation": dev_dataset, "test": test_dataset
        }


    def train_dataloader(self):
        loader = DataLoader(self.dataset["train"], batch_size=4, drop_last=True, pin_memory=True,
                        num_workers=4, prefetch_factor=4)
        return loader


    def val_dataloader(self):
        loader = DataLoader(self.dataset["validation"], batch_size=8, drop_last=False, pin_memory=True,
                            num_workers=4, prefetch_factor=4)
        return loader


    def test_dataloader(self):
        loader = DataLoader(self.dataset["test"], batch_size=8, drop_last=False, pin_memory=False,
                        num_workers=4, prefetch_factor=4)
        return loader

# Callbacks

In [None]:
def add_callbacks():
    log_dir = "savedmodels"
    os.makedirs(log_dir, exist_ok=True)
    checkpoint_callback = ModelCheckpoint(
        dirpath=os.path.join(log_dir, "checkpoints"),
        filename="{epoch}-{step}",
        save_top_k=1,
        every_n_train_steps=10,
        save_last=False,
        save_weights_only=False
    )

    lr_monitor_callback = LearningRateMonitor(logging_interval='step')
    tb_logger = pl_loggers.TensorBoardLogger(save_dir=os.path.join(log_dir, "logs"), name="tensorboard")
    csv_logger = CSVLogger(save_dir=os.path.join(log_dir, "logs"), name="csvlog")

    to_returns = {
        "callbacks": [checkpoint_callback, lr_monitor_callback],
        "loggers": [csv_logger, tb_logger]
    }
    return to_returns

# Optimizers

In [None]:
def lr_lambda(current_step, num_warmup_steps, num_training_steps):
    if current_step < num_warmup_steps:
        return float(current_step) / float(max(1, num_warmup_steps))
    return max(
        0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
    )

def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):

    return LambdaLR(optimizer,
                    functools.partial(
                        lr_lambda,
                        num_warmup_steps=num_warmup_steps,
                        num_training_steps=num_training_steps),
                    last_epoch)


def config_optimizer(parameters, init_lr, warmup_steps, max_steps, name='lr'):
    optimizer = optim.AdamW(
        parameters, lr=init_lr, eps=1e-8, correct_bias=False
    )

    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps,
    )
    scheduler = {'scheduler': scheduler, 'name': name, 'interval': 'step', 'frequency': 1}

    return optimizer, scheduler

# BLEU

In [None]:
def precook(s, n=4, out=False):
    words = s.split()
    counts = defaultdict(int)
    for k in range(1,n+1):
        for i in range(len(words)-k+1):
            ngram = tuple(words[i:i+k])
            counts[ngram] += 1
    return (len(words), counts)

def cook_refs(refs, eff=None, n=4):
    reflen = []
    maxcounts = {}
    for ref in refs:
        rl, counts = precook(ref, n)
        reflen.append(rl)
        for (ngram,count) in counts.items():
            maxcounts[ngram] = max(maxcounts.get(ngram,0), count)

    # Calculate effective reference sentence length.
    if eff == "shortest":
        reflen = min(reflen)
    elif eff == "average":
        reflen = float(sum(reflen))/len(reflen)

    return (reflen, maxcounts)

def cook_test(test, crefs, eff=None, n=4):
    reflen, refmaxcounts = crefs[0], crefs[1]

    testlen, counts = precook(test, n, True)

    result = {}

    # Calculate effective reference sentence length.

    if eff == "closest":
        result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1]
    else: ## i.e., "average" or "shortest" or None
        result["reflen"] = reflen

    result["testlen"] = testlen

    result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)]

    result['correct'] = [0]*n
    for (ngram, count) in counts.items():
        result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count)

    return result

class BleuScorer(object):

    __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen"
    # special_reflen is used in oracle (proportional effective ref len for a node).

    def copy(self):
        new = BleuScorer(n=self.n)
        new.ctest = copy.copy(self.ctest)
        new.crefs = copy.copy(self.crefs)
        new._score = None
        return new

    def __init__(self, test=None, refs=None, n=4, special_reflen=None):
        self.n = n
        self.crefs = []
        self.ctest = []
        self.cook_append(test, refs)
        self.special_reflen = special_reflen

    def cook_append(self, test, refs):
        if refs is not None:
            self.crefs.append(cook_refs(refs))
            if test is not None:
                cooked_test = cook_test(test, self.crefs[-1])
                self.ctest.append(cooked_test) ## N.B.: -1
            else:
                self.ctest.append(None) # lens of crefs and ctest have to match

        self._score = None ## need to recompute

    def ratio(self, option=None):
        self.compute_score(option=option)
        return self._ratio

    def score_ratio(self, option=None):
        return (self.fscore(option=option), self.ratio(option=option))

    def score_ratio_str(self, option=None):
        return "%.4f (%.2f)" % self.score_ratio(option)

    def reflen(self, option=None):
        self.compute_score(option=option)
        return self._reflen

    def testlen(self, option=None):
        self.compute_score(option=option)
        return self._testlen

    def retest(self, new_test):
        if type(new_test) is str:
            new_test = [new_test]
        assert len(new_test) == len(self.crefs), new_test
        self.ctest = []
        for t, rs in zip(new_test, self.crefs):
            self.ctest.append(cook_test(t, rs))
        self._score = None

        return self

    def rescore(self, new_test):
        return self.retest(new_test).compute_score()

    def size(self):
        assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
        return len(self.crefs)

    def __iadd__(self, other):
        if type(other) is tuple:
            ## avoid creating new BleuScorer instances
            self.cook_append(other[0], other[1])
        else:
            assert self.compatible(other), "incompatible BLEUs."
            self.ctest.extend(other.ctest)
            self.crefs.extend(other.crefs)
            self._score = None ## need to recompute

        return self

    def compatible(self, other):
        return isinstance(other, BleuScorer) and self.n == other.n

    def single_reflen(self, option="average"):
        return self._single_reflen(self.crefs[0][0], option)

    def _single_reflen(self, reflens, option=None, testlen=None):

        if option == "shortest":
            reflen = min(reflens)
        elif option == "average":
            reflen = float(sum(reflens))/len(reflens)
        elif option == "closest":
            reflen = min((abs(l-testlen), l) for l in reflens)[1]
        else:
            assert False, "unsupported reflen option %s" % option

        return reflen

    def recompute_score(self, option=None, verbose=0):
        self._score = None
        return self.compute_score(option, verbose)

    def compute_score(self, option=None, verbose=0):
        n = self.n
        small = 1e-9
        tiny = 1e-15 ## so that if guess is 0 still return 0
        bleu_list = [[] for _ in range(n)]

        if self._score is not None:
            return self._score

        if option is None:
            option = "average" if len(self.crefs) == 1 else "closest"

        self._testlen = 0
        self._reflen = 0
        totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n}

        # for each sentence
        for comps in self.ctest:
            testlen = comps['testlen']
            self._testlen += testlen

            if self.special_reflen is None: ## need computation
                reflen = self._single_reflen(comps['reflen'], option, testlen)
            else:
                reflen = self.special_reflen

            self._reflen += reflen

            for key in ['guess','correct']:
                for k in range(n):
                    totalcomps[key][k] += comps[key][k]

            # append per image bleu score
            bleu = 1.
            for k in range(n):
                bleu *= (float(comps['correct'][k]) + tiny) \
                        /(float(comps['guess'][k]) + small)
                bleu_list[k].append(bleu ** (1./(k+1)))
            ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division
            if ratio < 1:
                for k in range(n):
                    bleu_list[k][-1] *= math.exp(1 - 1/ratio)

            if verbose > 1:
                print(comps, reflen)

        totalcomps['reflen'] = self._reflen
        totalcomps['testlen'] = self._testlen

        bleus = []
        bleu = 1.
        for k in range(n):
            bleu *= float(totalcomps['correct'][k] + tiny) \
                    / (totalcomps['guess'][k] + small)
            bleus.append(bleu ** (1./(k+1)))
        ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division
        if ratio < 1:
            for k in range(n):
                bleus[k] *= math.exp(1 - 1/ratio)

        if verbose > 0:
            print(totalcomps)
            print("ratio:", ratio)

        self._score = bleus
        return self._score, bleu_list

In [None]:
class Bleu:
    def __init__(self, n=4):
        # default compute Blue score up to 4
        self._n = n
        self._hypo_for_image = {}
        self.ref_for_image = {}

    def compute_score(self, gts, res, verbose=0):

        assert(gts.keys() == res.keys())
        imgIds = gts.keys()

        bleu_scorer = BleuScorer(n=self._n)
        for id in imgIds:
            hypo = res[id]
            ref = gts[id]

            # Sanity check.
            assert(type(hypo) is list)
            assert(len(hypo) == 1)
            assert(type(ref) is list)
            assert(len(ref) >= 1)

            bleu_scorer += (hypo[0], ref)

        #score, scores = bleu_scorer.compute_score(option='shortest')
        score, scores = bleu_scorer.compute_score(option='closest', verbose=verbose)
        # score, scores = bleu_scorer.compute_score(option='average', verbose=1)

        # return (bleu, bleu_info)
        return score, scores

    def method(self):
        return "Bleu"

# Rouge

In [None]:
def my_lcs(string, sub):
    if(len(string)< len(sub)):
        sub, string = string, sub

    lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)]

    for j in range(1,len(sub)+1):
        for i in range(1,len(string)+1):
            if(string[i-1] == sub[j-1]):
                lengths[i][j] = lengths[i-1][j-1] + 1
            else:
                lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1])

    return lengths[len(string)][len(sub)]

class Rouge():
    '''
    Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set

    '''
    def __init__(self):
        self.beta = 1.2

    def calc_score(self, candidate, refs):
        # assert(len(candidate)==1)
        # assert(len(refs)>0)
        prec = []
        rec = []

        # split into tokens
        token_c = candidate[0].split(" ")

        for reference in refs:
            # split into tokens
            token_r = reference.split(" ")
            # compute the longest common subsequence
            lcs = my_lcs(token_r, token_c)
            prec.append(lcs/float(len(token_c)))
            rec.append(lcs/float(len(token_r)))

        prec_max = max(prec)
        rec_max = max(rec)

        if(prec_max!=0 and rec_max !=0):
            score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max)
        else:
            score = 0.0
        return score

    def compute_score(self, gts, res):
        assert(gts.keys() == res.keys())
        imgIds = gts.keys()

        score = []
        for id in imgIds:
            hypo = res[id]
            ref  = gts[id]

            score.append(self.calc_score(hypo, ref))

            # Sanity check.
            assert(type(hypo) is list)
            assert(len(hypo) == 1)
            assert(type(ref) is list)
            assert(len(ref) > 0)

        average_score = np.mean(np.array(score))
        return average_score, np.array(score)

    def method(self):
        return "Rouge"

# CIDER

In [None]:
class CiderScorer(object):
    """CIDEr scorer.
    """

    def copy(self):
        ''' copy the refs.'''
        new = CiderScorer(n=self.n)
        new.ctest = copy.copy(self.ctest)
        new.crefs = copy.copy(self.crefs)
        return new

    def __init__(self, test=None, refs=None, n=4, sigma=6.0):
        ''' singular instance '''
        self.n = n
        self.sigma = sigma
        self.crefs = []
        self.ctest = []
        self.document_frequency = defaultdict(float)
        self.cook_append(test, refs)
        self.ref_len = None

    def precook(self, s, n=4, out=False):
      words = s.split()
      counts = defaultdict(int)
      for k in range(1,n+1):
          for i in range(len(words)-k+1):
              ngram = tuple(words[i:i+k])
              counts[ngram] += 1
      return counts

    def cook_refs(self, refs, n=4):
        return [self.precook(ref, n) for ref in refs]

    def cook_test(self, test, n=4):
        return self.precook(test, n, True)

    def cook_append(self, test, refs):
        '''called by constructor and __iadd__ to avoid creating new instances.'''

        if refs is not None:
            self.crefs.append(self.cook_refs(refs))
            if test is not None:
                self.ctest.append(self.cook_test(test)) ## N.B.: -1
            else:
                self.ctest.append(None) # lens of crefs and ctest have to match

    def size(self):
        assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
        return len(self.crefs)

    def __iadd__(self, other):
        '''add an instance (e.g., from another sentence).'''

        if type(other) is tuple:
            ## avoid creating new CiderScorer instances
            self.cook_append(other[0], other[1])
        else:
            self.ctest.extend(other.ctest)
            self.crefs.extend(other.crefs)

        return self
    def compute_doc_freq(self):
        for refs in self.crefs:
            # refs, k ref captions of one image
            for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]):
                self.document_frequency[ngram] += 1
            # maxcounts[ngram] = max(maxcounts.get(ngram,0), count)

    def compute_cider(self):
        def counts2vec(cnts):
            vec = [defaultdict(float) for _ in range(self.n)]
            length = 0
            norm = [0.0 for _ in range(self.n)]
            for (ngram, term_freq) in cnts.items():
                # give word count 1 if it doesn't appear in reference corpus
                df = np.log(max(1.0, self.document_frequency[ngram]))
                # ngram index
                n = len(ngram)-1
                # tf (term_freq) * idf (precomputed idf) for n-grams
                vec[n][ngram] = float(term_freq)*(self.ref_len - df)
                # compute norm for the vector.  the norm will be used for computing similarity
                norm[n] += pow(vec[n][ngram], 2)

                if n == 1:
                    length += term_freq
            norm = [np.sqrt(n) for n in norm]
            return vec, norm, length

        def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
            delta = float(length_hyp - length_ref)
            # measure consine similarity
            val = np.array([0.0 for _ in range(self.n)])
            for n in range(self.n):
                # ngram
                for (ngram,count) in vec_hyp[n].items():
                    # vrama91 : added clipping
                    val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram]

                if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
                    val[n] /= (norm_hyp[n]*norm_ref[n])

                assert(not math.isnan(val[n]))
                # vrama91: added a length based gaussian penalty
                val[n] *= np.e**(-(delta**2)/(2*self.sigma**2))
            return val

        # compute log reference length
        self.ref_len = np.log(float(len(self.crefs)))
        if len(self.crefs) == 1:
            self.ref_len = 1
        scores = []
        for test, refs in zip(self.ctest, self.crefs):
            # compute vector for test captions
            vec, norm, length = counts2vec(test)
            # compute vector for ref captions
            score = np.array([0.0 for _ in range(self.n)])
            for ref in refs:
                vec_ref, norm_ref, length_ref = counts2vec(ref)
                score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
            # change by vrama91 - mean of ngram scores, instead of sum
            score_avg = np.mean(score)
            # divide by number of references
            score_avg /= len(refs)
            # multiply score by 10
            score_avg *= 10.0
            # append score of an image to the score list
            scores.append(score_avg)
        return scores

    def compute_score(self, option=None, verbose=0):
        # compute idf
        self.compute_doc_freq()
        # assert to check document frequency
        assert(len(self.ctest) >= max(self.document_frequency.values()))
        # compute cider score
        score = self.compute_cider()
        # debug
        # print score
        return np.mean(np.array(score)), np.array(score)

In [None]:
class Cider:
    """
    Main Class to compute the CIDEr metric

    """
    def __init__(self, test=None, refs=None, n=4, sigma=6.0):
        # set cider to sum over 1 to 4-grams
        self._n = n
        # set the standard deviation parameter for gaussian penalty
        self._sigma = sigma

    def compute_score(self, gts, res):

        assert(gts.keys() == res.keys())
        imgIds = gts.keys()

        cider_scorer = CiderScorer(n=self._n, sigma=self._sigma)

        for id in imgIds:
            hypo = res[id]
            ref = gts[id]

            # Sanity check.
            assert(type(hypo) is list)
            assert(len(hypo) == 1)
            assert(type(ref) is list)
            assert(len(ref) > 0)

            cider_scorer += (hypo[0], ref)

        (score, scores) = cider_scorer.compute_score()

        return score, scores

    def method(self):
        return "CIDEr"

# Ablation Study

In [None]:
class AblationR2GenGPT(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.cache_dir = config['cache_dir']
        self.model_name = config['model_name']

        # Debug mode for troubleshooting
        self.debug_mode = config.get('debug_mode', False)
        if self.debug_mode:
            print("Debug mode enabled - will print extra diagnostic information")

        # Set up components based on config
        self._setup_vision_encoder()
        self._setup_llm()
        self._setup_projection()

        # Store params
        self.end_sym = config.get('end_sym', '</s>')
        self.prompt = config.get('prompt', 'Generate a comprehensive and detailed diagnosis report for this chest xray image.')
        self.val_step_outputs = []
        self.test_step_outputs = []
        self.val_score = 0.0

        # Save config for reference
        self.save_hyperparameters()

    def _setup_vision_encoder(self):
        """Set up vision encoder based on config"""
        vision_config = self.config['vision_encoder']
        vision_model_name = vision_config['model_name']
        freeze_encoder = vision_config.get('freeze', True)

        print(f'Loading vision encoder: {vision_model_name}')

        # For Swin Transformer specifically
        if "swin" in vision_model_name.lower():
            from transformers import SwinModel

            # Load the model - don't need separate processor as dataset handles it
            self.visual_encoder = SwinModel.from_pretrained(vision_model_name, cache_dir=self.cache_dir)

            # Store the feature dimension for projection layer setup
            if hasattr(self.visual_encoder.config, 'hidden_size'):
                self.visual_encoder.num_features = self.visual_encoder.config.hidden_size
            else:
                # Swin typically has these dimensions depending on variant
                swin_dims = {
                    'tiny': 768,
                    'small': 768,
                    'base': 1024,
                    'large': 1536
                }
                # Try to determine variant from model name
                variant = next((v for v in swin_dims if v in vision_model_name.lower()), 'base')
                self.visual_encoder.num_features = swin_dims[variant]
                print(f"Detected Swin variant: {variant}, using dimension: {self.visual_encoder.num_features}")
        else:
            # Standard loading for other models
            self.visual_encoder = AutoModel.from_pretrained(vision_model_name, cache_dir=self.cache_dir)

        # Freeze if needed
        if freeze_encoder:
            for name, param in self.visual_encoder.named_parameters():
                param.requires_grad = False

        # Enable gradient checkpointing for memory efficiency
        if vision_config.get('gradient_checkpointing', True) and hasattr(self.visual_encoder, 'gradient_checkpointing_enable'):
            self.visual_encoder.gradient_checkpointing_enable()

        print('Loading vision encoder is Done')
        print(f'Visual encoder output dimension: {self.visual_encoder.num_features}')

    def _setup_llm(self):
        """Set up language model based on config"""
        llm_config = self.config['llm']
        lora_config = llm_config.get('lora', None)

        print('Loading LLAMA')
        # Load the model
        self.llama_model = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            torch_dtype=torch.float16,
            device_map="cuda",
            load_in_8bit=True,
            cache_dir=self.cache_dir,
            trust_remote_code=True
        )

        # Load tokenizer
        tokenizer_name = llm_config.get('tokenizer', self.model_name)
        self.llama_tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_name,
            use_fast=True,
            cache_dir=self.cache_dir
        )

        # Setup tokenizer special tokens
        if self.llama_tokenizer.bos_token_id is None:
            self.llama_tokenizer.bos_token_id = self.llama_tokenizer.eos_token_id
        if llm_config.get('custom_eos_token_id') is not None:
            self.llama_tokenizer.eos_token_id = llm_config['custom_eos_token_id']
        self.llama_tokenizer.pad_token_id = 0
        self.llama_model.generation_config.pad_token_id = self.llama_tokenizer.pad_token_id

        # Setup embedding layer reference
        self.embed_tokens = self.llama_model.get_input_embeddings()

        # Apply LoRA if configured
        if lora_config:
            peft_config = LoraConfig(
                task_type=TaskType.CAUSAL_LM,
                r=lora_config.get('r', 16),
                lora_alpha=lora_config.get('alpha', 32),
                lora_dropout=lora_config.get('dropout', 0.1),
                target_modules=lora_config.get('target_modules', ["transformer.h.0.attn.c_attn", "transformer.h.0.attn.c_proj"])
            )
            self.llama_model = get_peft_model(self.llama_model, peft_config)
            self.llama_model.print_trainable_parameters()

        # Enable gradient checkpointing for memory efficiency
        if llm_config.get('gradient_checkpointing', True):
            self.llama_model.gradient_checkpointing_enable()

        print('Loading LLM Done')

    def _setup_projection(self):
        """Set up projection layers based on config"""
        proj_config = self.config['projection']
        proj_type = proj_config.get('type', 'linear_layernorm')

        visual_dim = self.visual_encoder.num_features
        llm_dim = self.llama_model.config.hidden_size

        if proj_type == 'linear_layernorm':
            self.llama_proj = nn.Linear(visual_dim, llm_dim)
            self.layer_norm = nn.LayerNorm(llm_dim)
        elif proj_type == 'mlp_layernorm':
            hidden_dim = proj_config.get('hidden_dim', llm_dim * 2)
            self.llama_proj = nn.Sequential(
                nn.Linear(visual_dim, hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, llm_dim)
            )
            self.layer_norm = nn.LayerNorm(llm_dim)
        elif proj_type == 'linear':
            self.llama_proj = nn.Linear(visual_dim, llm_dim)
            self.layer_norm = nn.Identity()
        elif proj_type == 'linear_batchnorm':
            self.llama_proj = nn.Linear(visual_dim, llm_dim)
            self.layer_norm = nn.BatchNorm1d(llm_dim)
        elif proj_type == 'identity':
            if visual_dim != llm_dim:
                raise ValueError(f"Cannot use identity projection when dimensions don't match: {visual_dim} vs {llm_dim}")
            self.llama_proj = nn.Identity()
            self.layer_norm = nn.Identity()
        else:
            raise ValueError(f"Unknown projection type: {proj_type}")

    def encode_img(self, images):
        """Encode images into feature representations

        The images are already preprocessed by the dataset class (FieldParser)
        and are in the format expected by the model.
        """
        image_embeds = []
        for image in images:
            device = image.device

            # Add batch dimension if not present
            if image.dim() == 3:
                image = image.unsqueeze(0)

            try:
                # Try to run the model with the image
                with torch.no_grad():  # Use no_grad for inference
                    outputs = self.visual_encoder(image)

                # Handle different return types from different vision models
                if isinstance(outputs, dict):
                    if 'last_hidden_state' in outputs:
                        # Standard HuggingFace transformer output format
                        image_embed = outputs['last_hidden_state'].to(device)
                    elif 'pooler_output' in outputs:
                        # Some models return pooler output
                        # Reshape to match expected dimensions [seq_len, hidden_size]
                        pooler = outputs['pooler_output'].to(device)
                        image_embed = pooler.unsqueeze(1)  # [B, 1, hidden_size]
                    else:
                        # Use the first value in the dict
                        key = list(outputs.keys())[0]
                        image_embed = outputs[key].to(device)
                elif hasattr(outputs, 'last_hidden_state'):
                    # BaseModelOutput type from transformers
                    image_embed = outputs.last_hidden_state.to(device)
                elif isinstance(outputs, tuple):
                    # Some models return tuples
                    # For Swin, it's typically (last_hidden_state, pooled_output)
                    image_embed = outputs[0].to(device)
                elif isinstance(outputs, torch.Tensor):
                    # Direct tensor output
                    image_embed = outputs.to(device)
                else:
                    # As a fallback, log the output type for debugging
                    print(f"Unexpected output type from vision encoder: {type(outputs)}")
                    if hasattr(outputs, '_fields'):  # namedtuple
                        print(f"Fields: {outputs._fields}")
                    raise ValueError(f"Unsupported output format: {type(outputs)}")

            except Exception as e:
                # Add more debugging info
                print(f"Error processing image with shape {image.shape}")
                print(f"Error details: {str(e)}")
                raise

            image_embeds.append(image_embed)

        # Stack image embeddings
        if len(image_embeds) > 1:
            # If we have multiple images, we need to handle them appropriately
            # Check if we're dealing with sequences or flat embeddings
            if len(image_embeds[0].shape) > 2:  # sequence data
                # For sequence data, mean pooling across images
                image_embeds = torch.stack(image_embeds).mean(0)
            else:  # flat embeddings
                image_embeds = torch.cat(image_embeds, dim=0)
        else:
            # Single image
            image_embeds = image_embeds[0]

        # Project to LLM dimension
        inputs_llama = self.llama_proj(image_embeds)

        # Apply normalization based on chosen method
        if isinstance(self.layer_norm, nn.BatchNorm1d):
            # BatchNorm1d expects [B, C, *] format
            inputs_llama = inputs_llama.transpose(1, 2)
            inputs_llama = self.layer_norm(inputs_llama)
            inputs_llama = inputs_llama.transpose(1, 2)
        else:
            inputs_llama = self.layer_norm(inputs_llama)

        atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
        return inputs_llama, atts_llama


    def prompt_wrap(self, img_embeds, atts_img):
        """Wrap image embeddings with prompt based on configuration"""
        feature_pos = self.config.get('feature_position', 'before_prompt')
        prompt = self.config.get('prompt_template', f'Human: <Img><ImageHere></Img> {self.prompt} \nAssistant:')

        batch_size = img_embeds.shape[0]

        if feature_pos == 'before_prompt':
            p_before, p_after = prompt.split('<ImageHere>')
            p_before_tokens = self.llama_tokenizer(
                p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
            p_after_tokens = self.llama_tokenizer(
                p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
            p_before_embeds = self.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1)
            p_after_embeds = self.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1)
            wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds, p_after_embeds], dim=1)

        elif feature_pos == 'after_prompt':
            prompt_tokens = self.llama_tokenizer(
                prompt.replace('<ImageHere>', ''),
                return_tensors="pt",
                add_special_tokens=False
            ).to(img_embeds.device)
            prompt_embeds = self.embed_tokens(prompt_tokens.input_ids).expand(batch_size, -1, -1)
            wrapped_img_embeds = torch.cat([prompt_embeds, img_embeds], dim=1)

        elif feature_pos == 'beginning_and_end':
            p_before, p_after = prompt.split('<ImageHere>')
            p_before_tokens = self.llama_tokenizer(
                p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
            p_after_tokens = self.llama_tokenizer(
                p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
            p_before_embeds = self.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1)
            p_after_embeds = self.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1)
            # Use image embeddings at both beginning and end
            wrapped_img_embeds = torch.cat([img_embeds, p_before_embeds, img_embeds, p_after_embeds], dim=1)

        else:
            raise ValueError(f"Unknown feature position: {feature_pos}")

        wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1])
        return wrapped_img_embeds, wrapped_atts_img


    def debug_swin_output(self, image):
        """
        Debug function to understand the Swin Transformer output structure
        """
        print(f"\n=== DEBUG SWIN OUTPUT ===")
        print(f"Input image shape: {image.shape}")

        # Try different input configurations
        try:
            with torch.no_grad():
                # Try with batch dimension
                if image.dim() == 3:
                    image_with_batch = image.unsqueeze(0)
                    outputs = self.visual_encoder(image_with_batch)
                else:
                    outputs = self.visual_encoder(image)

            print(f"Output type: {type(outputs)}")

            # If it's a tuple, print details about each element
            if isinstance(outputs, tuple):
                print(f"Tuple length: {len(outputs)}")
                for i, item in enumerate(outputs):
                    if isinstance(item, torch.Tensor):
                        print(f"  Item {i}: type={type(item)}, shape={item.shape}")
                    else:
                        print(f"  Item {i}: type={type(item)}")

            # If it's a dictionary, print the keys and shapes
            elif isinstance(outputs, dict):
                print(f"Dictionary keys: {list(outputs.keys())}")
                for k, v in outputs.items():
                    if isinstance(v, torch.Tensor):
                        print(f"  Key '{k}': shape={v.shape}")
                    else:
                        print(f"  Key '{k}': type={type(v)}")

            # If it has attributes like a custom class
            elif hasattr(outputs, 'last_hidden_state'):
                print("BaseModelOutput with last_hidden_state")
                print(f"  last_hidden_state shape: {outputs.last_hidden_state.shape}")
                if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
                    print(f"  pooler_output shape: {outputs.pooler_output.shape}")

            # If it's a tensor
            elif isinstance(outputs, torch.Tensor):
                print(f"Tensor shape: {outputs.shape}")

            # Check if it has _fields attribute (namedtuple)
            if hasattr(outputs, '_fields'):
                print(f"Named tuple fields: {outputs._fields}")
                for field in outputs._fields:
                    value = getattr(outputs, field)
                    if isinstance(value, torch.Tensor):
                        print(f"  Field '{field}': shape={value.shape}")
                    else:
                        print(f"  Field '{field}': type={type(value)}")

            print("=== END DEBUG ===\n")
            return outputs

        except Exception as e:
            print(f"Error in debug_swin_output: {str(e)}")
            import traceback
            traceback.print_exc()
            return None


    def forward(self, samples):
        image = samples["image"]

        # Add debug for the first image in the batch to understand structure
        if hasattr(self, 'debug_mode') and self.debug_mode:
            self.debug_swin_output(image[0] if isinstance(image, list) else image[0:1])

        try:
            img_embeds, atts_img = self.encode_img(image)
            img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img)
        except Exception as e:
            print(f"Error in image encoding: {str(e)}")
            # Log image shape for debugging
            if isinstance(image, list):
                print(f"Image list length: {len(image)}")
                for i, img in enumerate(image):
                    print(f"Processing image {i} with shape {img.shape}")
            else:
                print(f"Image tensor shape: {image.shape}")
            raise

        self.llama_tokenizer.padding_side = "right"
        text = [t + self.end_sym for t in samples["input_text"]]

        to_regress_tokens = self.llama_tokenizer(
            text,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=self.config.get('max_text_length', 100),
            add_special_tokens=False
        ).to(image[0].device if isinstance(image, list) else image.device)

        targets = to_regress_tokens.input_ids.masked_fill(
            to_regress_tokens.input_ids == 0, -100
        )

        empty_targets = (
            torch.ones([atts_img.shape[0], atts_img.shape[1]+1],
                      dtype=torch.long).to(image[0].device if isinstance(image, list) else image.device).fill_(-100)  # plus one for bos
        )
        targets = torch.cat([empty_targets, targets], dim=1)

        batch_size = img_embeds.shape[0]
        bos = torch.ones([batch_size, 1],
                        dtype=to_regress_tokens.input_ids.dtype,
                        device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
        bos_embeds = self.embed_tokens(bos)
        atts_bos = atts_img[:, :1]

        to_regress_embeds = self.embed_tokens(to_regress_tokens.input_ids)
        inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1)
        attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1)

        outputs = self.llama_model(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            return_dict=True,
            labels=targets,
        )
        loss = outputs.loss
        return {"loss": loss}

    def training_step(self, batch, batch_idx):
        result = self(batch)
        self.log_dict(result, prog_bar=True)
        return result

    def validation_step(self, samples, batch_idx):
        self.llama_tokenizer.padding_side = "right"
        to_regress_tokens = self.llama_tokenizer(
            samples['input_text'],
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=self.config.get('max_text_length', 100),
            add_special_tokens=False
        )

        image = samples["image"]

        # Enable debug mode for the first batch only
        if batch_idx == 0 and hasattr(self, 'debug_mode') and self.debug_mode:
            self.debug_swin_output(image[0] if isinstance(image, list) else image[0:1])

        try:
            img_embeds, atts_img = self.encode_img(image)
            img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img)
        except Exception as e:
            print(f"Validation error in batch {batch_idx}: {str(e)}")
            if isinstance(image, list):
                print(f"Image list length: {len(image)}")
                for i, img in enumerate(image):
                    print(f"Validation image {i} with shape {img.shape}")
            else:
                print(f"Validation image tensor shape: {image.shape}")
            raise

        batch_size = img_embeds.shape[0]
        bos = torch.ones([batch_size, 1],
                        dtype=atts_img.dtype,
                        device=atts_img.device) * self.llama_tokenizer.bos_token_id
        bos_embeds = self.embed_tokens(bos)
        atts_bos = atts_img[:, :1]

        inputs_embeds = torch.cat([bos_embeds, img_embeds], dim=1)
        attention_mask = torch.cat([atts_bos, atts_img], dim=1)

        # Configure generation based on settings
        decoding_config = self.config.get('decoding', {})
        num_beams = decoding_config.get('num_beams', 1)
        do_sample = decoding_config.get('do_sample', False)
        temperature = decoding_config.get('temperature', 0)
        min_new_tokens = decoding_config.get('min_new_tokens', 80)
        max_new_tokens = decoding_config.get('max_new_tokens', 120)

        outputs = self.llama_model.generate(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            num_beams=num_beams,
            do_sample=do_sample,
            min_new_tokens=min_new_tokens,
            max_new_tokens=max_new_tokens,
            repetition_penalty=decoding_config.get('repetition_penalty', 2.0),
            length_penalty=decoding_config.get('length_penalty', 2.0),
            temperature=temperature,
        )
        hypo = [self.decode(i) for i in outputs]
        ref = [self.decode(i) for i in to_regress_tokens['input_ids']]
        self.val_step_outputs.append({"hypo": hypo, "ref": ref, "id": samples["id"]})
        return hypo, ref

    def decode(self, output_token):
        if output_token[0] == 0:  # unknown token <unk> at the beginning. remove it
            output_token = output_token[1:]
        if output_token[0] == 1:  # start token <s> at the beginning. remove it
            output_token = output_token[1:]
        output_text = self.llama_tokenizer.decode(output_token, add_special_tokens=False)
        output_text = output_text.split('</s>')[0].strip()
        output_text = output_text.replace('<unk>', '')
        return output_text

    def on_validation_epoch_end(self):
        ref, hypo, ids = [], [], []
        for i in self.val_step_outputs:
            ref.extend(i['ref'])
            hypo.extend(i['hypo'])
            ids.extend(i['id'])

        ref = {k:[v] for k, v in zip(ids, ref)}
        hypo = {k:[v] for k, v in zip(ids, hypo)}
        eval_res = self.score(ref=ref, hypo=hypo)
        self.log_dict(eval_res, sync_dist=True, logger=True)

        # Save results
        result_folder = os.path.join(self.config['output_dir'], 'result')
        os.makedirs(result_folder, exist_ok=True)
        current_epoch, global_step = self.trainer.current_epoch, self.trainer.global_step
        json.dump(hypo, open(os.path.join(result_folder, f"result_{current_epoch}_{global_step}" + '.json'), 'w'))
        json.dump(ref, open(os.path.join(result_folder, 'refs.json'), 'w'))
        self.print(eval_res)

        # Track best model
        val_score = 0
        # Can customize which metrics to track
        scorer_types = self.config.get('val_metrics', ['Bleu_4'])
        weights = self.config.get('val_weights', [1])
        for score_type, weight in zip(scorer_types, weights):
            val_score += eval_res[score_type] * weight

        if self.trainer.local_rank == 0:
            if val_score > self.val_score:
                self.save_checkpoint(eval_res)
                self.val_score = val_score
        self.val_step_outputs.clear()

    def test_step(self, samples, batch_idx):
        self.llama_tokenizer.padding_side = "right"
        to_regress_tokens = self.llama_tokenizer(
            samples['input_text'],
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=self.config.get('max_text_length', 100),
            add_special_tokens=False
        )

        image = samples["image"]

        try:
            img_embeds, atts_img = self.encode_img(image)
            img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img)
        except Exception as e:
            print(f"Test error in batch {batch_idx}: {str(e)}")
            if isinstance(image, list):
                print(f"Image list length: {len(image)}")
                for i, img in enumerate(image):
                    print(f"Test image {i} with shape {img.shape}")
            else:
                print(f"Test image tensor shape: {image.shape}")
            raise

        batch_size = img_embeds.shape[0]
        bos = torch.ones([batch_size, 1],
                        dtype=atts_img.dtype,
                        device=atts_img.device) * self.llama_tokenizer.bos_token_id
        bos_embeds = self.embed_tokens(bos)
        atts_bos = atts_img[:, :1]

        inputs_embeds = torch.cat([bos_embeds, img_embeds], dim=1)
        attention_mask = torch.cat([atts_bos, atts_img], dim=1)

        # Configure generation based on settings
        decoding_config = self.config.get('decoding', {})
        num_beams = decoding_config.get('num_beams', 3)  # Default to 3 for test
        do_sample = decoding_config.get('do_sample', False)
        temperature = decoding_config.get('temperature', 0)
        min_new_tokens = decoding_config.get('min_new_tokens', 80)
        max_new_tokens = decoding_config.get('max_new_tokens', 120)

        outputs = self.llama_model.generate(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            num_beams=num_beams,
            do_sample=do_sample,
            min_new_tokens=min_new_tokens,
            max_new_tokens=max_new_tokens,
            repetition_penalty=decoding_config.get('repetition_penalty', 2.0),
            length_penalty=decoding_config.get('length_penalty', 2.0),
            temperature=temperature,
        )
        hypo = [self.decode(i) for i in outputs]
        ref = [self.decode(i) for i in to_regress_tokens['input_ids']]
        self.test_step_outputs.append({"hypo": hypo, "ref": ref, "id": samples["id"]})
        return hypo, ref

    def on_test_epoch_end(self):
        ref, hypo, ids = [], [], []
        for i in self.test_step_outputs:
            ref.extend(i['ref'])
            hypo.extend(i['hypo'])
            ids.extend(i['id'])

        ref = {k:[v] for k, v in zip(ids, ref)}
        hypo = {k:[v] for k, v in zip(ids, hypo)}
        eval_res = self.score(ref=ref, hypo=hypo)

        result_folder = os.path.join(self.config['output_dir'], 'result')
        os.makedirs(result_folder, exist_ok=True)
        json.dump(hypo, open(os.path.join(result_folder, f"test_result.json"), 'w'))
        json.dump(ref, open(os.path.join(result_folder, 'test_refs.json'), 'w'))

        # Save test results by config name
        config_name = self.config.get('name', 'default')
        json.dump(eval_res, open(os.path.join(result_folder, f"test_metrics_{config_name}.json"), 'w'))
        self.print(f"Test result for {config_name}: {eval_res}")

    def save_checkpoint(self, eval_res):
        current_epoch, global_step = self.trainer.current_epoch, self.trainer.global_step
        param_grad_dic = {
            k: v.requires_grad for (k, v) in self.named_parameters() if v.requires_grad
        }
        state_dict = self.state_dict()
        for k in list(state_dict.keys()):
            if k not in param_grad_dic.keys():
                del state_dict[k]
        save_obj = {
            "model": state_dict,
            "epoch": current_epoch,
            "step": global_step,
            "config": self.config
        }
        config_name = self.config.get('name', 'default')
        output_dir = self.config['output_dir']
        os.makedirs(os.path.join(output_dir, 'checkpoints'), exist_ok=True)
        save_to = os.path.join(
            output_dir, 'checkpoints',
            f"checkpoint_{config_name}_epoch{current_epoch}_step{global_step}_bleu{eval_res['Bleu_4']:.3f}.pth"
        )
        self.print(f"Saving checkpoint at step {global_step} to {save_to}.")
        torch.save(save_obj, save_to)

    def score(self, ref, hypo):
        scorers = [
            (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
            (Rouge(), "ROUGE_L"),
            (Cider(), "CIDEr")
        ]
        final_scores = {}
        for scorer, method in scorers:
            score, scores = scorer.compute_score(ref, hypo)
            if type(score) == list:
                for m, s in zip(method, score):
                    final_scores[m] = s
            else:
                final_scores[method] = score
        return final_scores

    def configure_optimizers(self):
        optimizer_config = self.config.get('optimizer', {})
        optimizer_type = optimizer_config.get('type', 'adamw')
        lr = optimizer_config.get('lr', 1e-5)
        weight_decay = optimizer_config.get('weight_decay', 0.0)

        if optimizer_type.lower() == 'adamw':
            optimizer = torch.optim.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay)
        elif optimizer_type.lower() == 'adam':
            optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay)
        else:
            raise ValueError(f"Unknown optimizer type: {optimizer_type}")

        scheduler_config = self.config.get('scheduler', {})
        scheduler_type = scheduler_config.get('type', 'cosine')

        if scheduler_type == 'cosine':
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer=optimizer,
                T_max=scheduler_config.get('T_max', 3),
                eta_min=scheduler_config.get('eta_min', 1e-6)
            )
        elif scheduler_type == 'linear':
            scheduler = torch.optim.lr_scheduler.LinearLR(
                optimizer=optimizer,
                start_factor=scheduler_config.get('start_factor', 1.0),
                end_factor=scheduler_config.get('end_factor', 0.1),
                total_iters=scheduler_config.get('total_iters', 3)
            )
        elif scheduler_type == 'none':
            return {"optimizer": optimizer}
        else:
            raise ValueError(f"Unknown scheduler type: {scheduler_type}")

        return {"optimizer": optimizer, "lr_scheduler": scheduler}

    def get_progress_bar_dict(self):
        # don't show the version number
        items = super().get_progress_bar_dict()
        items.pop("v_num", None)
        return items

    def optimizer_zero_grad(self, epoch, batch_idx, optimizer):
        optimizer.zero_grad()

# Variants

In [None]:
# Base configuration that all variants will extend
base_config = {
    "name": "baseline",
    "model_name": "cerebras/Cerebras-GPT-1.3B",
    "cache_dir": "/content/huggingface",
    "output_dir": "/content/ablation_study_results",

    "vision_encoder": {
        "model_name": "microsoft/swin-base-patch4-window7-224",
        "freeze": True,
        "gradient_checkpointing": False
    },

    "llm": {
        "gradient_checkpointing": True,
        "tokenizer": "cerebras/Cerebras-GPT-1.3B",
        "lora": {
            "r": 16,
            "alpha": 32,
            "dropout": 0.1,
            # "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj"]
        }
    },

    "projection": {
        "type": "linear_layernorm"
    },

    "feature_position": "before_prompt",

    "prompt_template": "Human: <Img><ImageHere></Img> Generate a comprehensive and detailed diagnosis report for this chest xray image. \nAssistant:",

    "max_text_length": 100,

    "decoding": {
        "num_beams": 3,
        "do_sample": False,
        "min_new_tokens": 80,
        "max_new_tokens": 120,
        "repetition_penalty": 2.0,
        "length_penalty": 2.0,
        "temperature": 0
    },

    "optimizer": {
        "type": "adamw",
        "lr": 1e-5,
        "weight_decay": 0.0
    },

    "scheduler": {
        "type": "cosine",
        "T_max": 3,
        "eta_min": 1e-6
    },

    "val_metrics": ["Bleu_4"],
    "val_weights": [1.0]
}

# Helper function to create a variant config
def create_variant(base, name, updates):
    variant = copy.deepcopy(base)
    variant["name"] = name

    # Apply updates, handling nested dictionaries
    for key, value in updates.items():
        if isinstance(value, dict) and key in variant and isinstance(variant[key], dict):
            variant[key].update(value)
        else:
            variant[key] = value

    return variant

# Create variant configs
# 1. Vision Encoder Variants
vision_variants = [
    ("vision_resnet50", {"vision_encoder": {"model_name": "microsoft/resnet-50"}}),
    ("vision_densenet121", {"vision_encoder": {"model_name": "microsoft/densenet-121"}}),
    ("vision_vit", {"vision_encoder": {"model_name": "google/vit-base-patch16-224"}}),
    ("vision_convnext", {"vision_encoder": {"model_name": "facebook/convnext-tiny-224"}}),
    ("vision_mask", {"vision_encoder": {"model_name": "facebook/maskformer-swin-base-ade"}}),
    ("vision_swin", {"vision_encoder": {"model_name": "masafresh/swin-transformer"}})
]

# 2. Visual Feature Projection Variants
projection_variants = [
    ("proj_mlp_layernorm", {"projection": {"type": "mlp_layernorm", "hidden_dim": 2048}}),
    ("proj_linear", {"projection": {"type": "linear"}}),
    ("proj_linear_batchnorm", {"projection": {"type": "linear_batchnorm"}}),
    # Note: identity projection would require matching dimensions, so this may not work without additional changes
]

# 3. Feature Position Variants
position_variants = [
    ("pos_after_prompt", {"feature_position": "after_prompt"}),
    ("pos_beginning_and_end", {"feature_position": "beginning_and_end"}),
    # Cross-attention would require architectural changes beyond config
]

# 4. Prompt Engineering Variants
prompt_variants = [
    ("prompt_detailed", {"prompt_template": "Human: <Img><ImageHere></Img> Please analyze this chest X-ray image in detail. Describe any abnormalities, potential diagnoses, and recommendations. Consider lung fields, heart size, pleural spaces, mediastinum, bones, and soft tissues in your analysis. \nAssistant:"}),
    ("prompt_concise", {"prompt_template": "Human: <Img><ImageHere></Img> What do you see in this chest X-ray? \nAssistant:"}),
    ("prompt_role", {"prompt_template": "Human: <Img><ImageHere></Img> You are an expert radiologist. Write a professional report for this chest X-ray. \nAssistant:"}),
]

# 5. Decoding Strategy Variants
decoding_variants = [
    ("decode_greedy", {"decoding": {"num_beams": 1, "do_sample": False}}),
    ("decode_beam5", {"decoding": {"num_beams": 5, "do_sample": False}}),
    ("decode_sample_t07", {"decoding": {"num_beams": 1, "do_sample": True, "temperature": 0.7}}),
    ("decode_sample_t1", {"decoding": {"num_beams": 1, "do_sample": True, "temperature": 1.0}}),
]

# Gather all variant configs
all_variants = [base_config]  # Start with baseline

# Extend with all variant types
for name, updates in (
    vision_variants +
    projection_variants +
    position_variants +
    prompt_variants +
    decoding_variants
):
    all_variants.append(create_variant(base_config, name, updates))

# Create config directory
os.makedirs("/content/configs", exist_ok=True)

# Save all configs
for config in all_variants:
    config_path = f"/content/configs/{config['name']}.yaml"
    with open(config_path, "w") as f:
        yaml.dump(config, f)
    print(f"Saved config: {config_path}")

Saved config: /content/configs/baseline.yaml
Saved config: /content/configs/vision_resnet50.yaml
Saved config: /content/configs/vision_densenet121.yaml
Saved config: /content/configs/vision_vit.yaml
Saved config: /content/configs/vision_convnext.yaml
Saved config: /content/configs/vision_mask.yaml
Saved config: /content/configs/vision_swin.yaml
Saved config: /content/configs/proj_mlp_layernorm.yaml
Saved config: /content/configs/proj_linear.yaml
Saved config: /content/configs/proj_linear_batchnorm.yaml
Saved config: /content/configs/pos_after_prompt.yaml
Saved config: /content/configs/pos_beginning_and_end.yaml
Saved config: /content/configs/prompt_detailed.yaml
Saved config: /content/configs/prompt_concise.yaml
Saved config: /content/configs/prompt_role.yaml
Saved config: /content/configs/decode_greedy.yaml
Saved config: /content/configs/decode_beam5.yaml
Saved config: /content/configs/decode_sample_t07.yaml
Saved config: /content/configs/decode_sample_t1.yaml


# Experiments

In [None]:
def run_experiment(config_path):
    # Load configuration
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)

    print(f"Running experiment: {config['name']}")

    # Ensure output directories exist
    os.makedirs(config['output_dir'], exist_ok=True)
    os.makedirs(config['cache_dir'], exist_ok=True)
    experiment_dir = os.path.join(config['output_dir'], config['name'])
    os.makedirs(experiment_dir, exist_ok=True)

    # Copy config to experiment dir
    with open(os.path.join(experiment_dir, 'config.yaml'), 'w') as f:
        yaml.dump(config, f)

    # Create datasets
    train_dataset = ParseDataset(split='train')
    val_dataset = ParseDataset(split='val')
    test_dataset = ParseDataset(split='test')

    print(f"Dataset sizes - Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

    # Create data loaders - adjust batch size based on your GPU memory
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=2)

    # Update config with experiment-specific output dir
    config['output_dir'] = experiment_dir

    # Initialize model
    model = AblationR2GenGPT(config)

    # Set up callbacks
    checkpoint_callback = ModelCheckpoint(
        dirpath=os.path.join(experiment_dir, 'checkpoints'),
        filename='{epoch}-{val_Bleu_4:.2f}',
        save_top_k=3,
        verbose=True,
        monitor='Bleu_4',
        mode='max'
    )

    early_stop_callback = EarlyStopping(
        monitor='Bleu_4',
        min_delta=0.00,
        patience=2,
        verbose=True,
        mode='max'
    )

    # Set up logger
    logger = TensorBoardLogger(save_dir=experiment_dir, name="logs")

    # Define trainer with appropriate resources
    trainer = pl.Trainer(
        max_epochs=5,
        callbacks=[checkpoint_callback, early_stop_callback],
        logger=logger,
        accelerator='auto',  # Use GPU if available
        devices=1,
        precision=16,  # Use mixed precision for better memory efficiency
        gradient_clip_val=1.0,  # Clip gradients to avoid exploding gradients
        log_every_n_steps=10,
    )

    # Train the model
    trainer.fit(model, train_loader, val_loader)

    # Test the model
    test_results = trainer.test(model, test_loader)

    # Save test results
    with open(os.path.join(experiment_dir, 'test_results.json'), 'w') as f:
        json.dump(test_results, f, indent=2)

    print(f"Experiment {config['name']} completed!")

    # Clean up to free memory
    del model, trainer
    torch.cuda.empty_cache()

    return test_results

In [None]:
# Main execution
def main():

    # Choose which variants to run - comment out those you don't want to run
    variants_to_run = [
        # Vision encoder variants
        "vision_swin"
        "vision_convnext",
        "vision_mask",
        # Projection variants
        "proj_mlp_layernorm",
    ]

    # Create a results summary dictionary
    results_summary = {}

    # Run selected experiments
    for variant in variants_to_run:
        config_path = f"/content/configs/{variant}.yaml"
        if not os.path.exists(config_path):
            print(f"Config not found for {variant}, skipping...")
            continue

        try:
              results = run_experiment(config_path)
              results_summary[variant] = results
        except Exception as e:
            print(f"Error running experiment {variant}: {e}")
            # Continue with next experiment instead of stopping everything

    # Save overall summary
    with open("/content/ablation_study_results/summary.json", "w") as f:
        json.dump(results_summary, f, indent=2)

    print("All experiments completed!")

    # Create a simple comparative table of results
    results_table = {"Variant": [], "Bleu_4": [], "ROUGE_L": [], "CIDEr": []}

    for variant, result_list in results_summary.items():
        if result_list and len(result_list) > 0:
            result = result_list[0]  # Take first test result
            results_table["Variant"].append(variant)
            # Extract metrics, defaulting to 0 if not present
            results_table["Bleu_4"].append(result.get("Bleu_4", 0))
            results_table["ROUGE_L"].append(result.get("ROUGE_L", 0))
            results_table["CIDEr"].append(result.get("CIDEr", 0))

    # Convert to DataFrame for pretty display
    results_df = pd.DataFrame(results_table)
    print("Ablation Study Results:")
    print(results_df)

    # Save as CSV
    results_df.to_csv("/content/ablation_study_results/comparative_results.csv", index=False)

# Training

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
main()

Running experiment: vision_convnext
Dataset sizes - Train: 2069, Val: 296, Test: 590
Loading vision encoder: facebook/convnext-tiny-224
Error running experiment vision_convnext: ConvNextModel does not support gradient checkpointing.
Running experiment: vision_mask
Dataset sizes - Train: 2069, Val: 296, Test: 590
Loading vision encoder: facebook/maskformer-swin-base-ade
Error running experiment vision_mask: MaskFormerModel does not support gradient checkpointing.
Running experiment: proj_mlp_layernorm
Dataset sizes - Train: 2069, Val: 296, Test: 590
Loading vision encoder: microsoft/swin-base-patch4-window7-224
Loading vision encoder is Done
Loading LLAMA


config.json:   0%|          | 0.00/360 [00:00<?, ?B/s]

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


pytorch_model.bin:   0%|          | 0.00/5.36G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/5.36G [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


trainable params: 196,608 || all params: 1,315,919,872 || trainable%: 0.0149
Loading LLM Done


INFO: 
  | Name           | Type                 | Params | Mode 
----------------------------------------------------------------
0 | visual_encoder | SwinModel            | 86.7 M | eval 
1 | llama_model    | PeftModelForCausalLM | 1.3 B  | train
2 | embed_tokens   | Embedding            | 102 M  | eval 
3 | llama_proj     | Sequential           | 6.3 M  | train
4 | layer_norm     | LayerNorm            | 4.1 K  | train
----------------------------------------------------------------
6.5 M     Trainable params
1.4 B     Non-trainable params
1.4 B     Total params
5,635.851 Total estimated model params size (MB)
27        Modules in train mode
803       Modules in eval mode
INFO:lightning.pytorch.callbacks.model_summary:
  | Name           | Type                 | Params | Mode 
----------------------------------------------------------------
0 | visual_encoder | SwinModel            | 86.7 M | eval 
1 | llama_model    | PeftModelForCausalLM | 1.3 B  | train
2 | embed_tokens   | Embed

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

{'Bleu_1': 0.04302744409664883, 'Bleu_2': 3.6087901921599845e-10, 'Bleu_3': 7.472216965854099e-13, 'Bleu_4': 3.4449878727291925e-14, 'ROUGE_L': np.float64(0.055041872236622276), 'CIDEr': np.float64(0.0038096185076584117)}
Saving checkpoint at step 0 to /content/ablation_study_results/proj_mlp_layernorm/checkpoints/checkpoint_proj_mlp_layernorm_epoch0_step0_bleu0.000.pth.


Training: |          | 0/? [00:00<?, ?it/s]

`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.


Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Metric Bleu_4 improved. New best score: 0.085
INFO:lightning.pytorch.callbacks.early_stopping:Metric Bleu_4 improved. New best score: 0.085
INFO: Epoch 0, global step 518: 'Bleu_4' reached 0.08461 (best 0.08461), saving model to '/content/ablation_study_results/proj_mlp_layernorm/checkpoints/epoch=0-val_Bleu_4=0.00.ckpt' as top 3
INFO:lightning.pytorch.utilities.rank_zero:Epoch 0, global step 518: 'Bleu_4' reached 0.08461 (best 0.08461), saving model to '/content/ablation_study_results/proj_mlp_layernorm/checkpoints/epoch=0-val_Bleu_4=0.00.ckpt' as top 3


{'Bleu_1': 0.2397422483767492, 'Bleu_2': 0.15871505141678824, 'Bleu_3': 0.11433535070607453, 'Bleu_4': 0.08460932736676735, 'ROUGE_L': np.float64(0.3173623343606279), 'CIDEr': np.float64(0.20526751527738551)}
Saving checkpoint at step 518 to /content/ablation_study_results/proj_mlp_layernorm/checkpoints/checkpoint_proj_mlp_layernorm_epoch0_step518_bleu0.085.pth.


Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Metric Bleu_4 improved by 0.043 >= min_delta = 0.0. New best score: 0.128
INFO:lightning.pytorch.callbacks.early_stopping:Metric Bleu_4 improved by 0.043 >= min_delta = 0.0. New best score: 0.128
INFO: Epoch 1, global step 1036: 'Bleu_4' reached 0.12789 (best 0.12789), saving model to '/content/ablation_study_results/proj_mlp_layernorm/checkpoints/epoch=1-val_Bleu_4=0.00.ckpt' as top 3
INFO:lightning.pytorch.utilities.rank_zero:Epoch 1, global step 1036: 'Bleu_4' reached 0.12789 (best 0.12789), saving model to '/content/ablation_study_results/proj_mlp_layernorm/checkpoints/epoch=1-val_Bleu_4=0.00.ckpt' as top 3


{'Bleu_1': 0.3307111337020077, 'Bleu_2': 0.22673770712706337, 'Bleu_3': 0.16857433224724777, 'Bleu_4': 0.1278884211297102, 'ROUGE_L': np.float64(0.3441842045983181), 'CIDEr': np.float64(0.30433303733364864)}
Saving checkpoint at step 1036 to /content/ablation_study_results/proj_mlp_layernorm/checkpoints/checkpoint_proj_mlp_layernorm_epoch1_step1036_bleu0.128.pth.


Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 2, global step 1554: 'Bleu_4' reached 0.11345 (best 0.12789), saving model to '/content/ablation_study_results/proj_mlp_layernorm/checkpoints/epoch=2-val_Bleu_4=0.00.ckpt' as top 3
INFO:lightning.pytorch.utilities.rank_zero:Epoch 2, global step 1554: 'Bleu_4' reached 0.11345 (best 0.12789), saving model to '/content/ablation_study_results/proj_mlp_layernorm/checkpoints/epoch=2-val_Bleu_4=0.00.ckpt' as top 3


{'Bleu_1': 0.30668025187187614, 'Bleu_2': 0.20666614897561567, 'Bleu_3': 0.15197540261196332, 'Bleu_4': 0.11344508285012152, 'ROUGE_L': np.float64(0.33330003947757003), 'CIDEr': np.float64(0.24120630707731924)}


Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Monitored metric Bleu_4 did not improve in the last 2 records. Best score: 0.128. Signaling Trainer to stop.
INFO:lightning.pytorch.callbacks.early_stopping:Monitored metric Bleu_4 did not improve in the last 2 records. Best score: 0.128. Signaling Trainer to stop.
INFO: Epoch 3, global step 2072: 'Bleu_4' reached 0.11076 (best 0.12789), saving model to '/content/ablation_study_results/proj_mlp_layernorm/checkpoints/epoch=3-val_Bleu_4=0.00.ckpt' as top 3
INFO:lightning.pytorch.utilities.rank_zero:Epoch 3, global step 2072: 'Bleu_4' reached 0.11076 (best 0.12789), saving model to '/content/ablation_study_results/proj_mlp_layernorm/checkpoints/epoch=3-val_Bleu_4=0.00.ckpt' as top 3


{'Bleu_1': 0.30535718752709784, 'Bleu_2': 0.2047824513117109, 'Bleu_3': 0.14911442132208963, 'Bleu_4': 0.1107632030335465, 'ROUGE_L': np.float64(0.33515103065177876), 'CIDEr': np.float64(0.23613885363789755)}


INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

Test result for proj_mlp_layernorm: {'Bleu_1': 4.712717577319185e-32, 'Bleu_2': 3.619907057996158e-26, 'Bleu_3': 3.315172175416523e-24, 'Bleu_4': 3.1725642272616665e-23, 'ROUGE_L': np.float64(0.0), 'CIDEr': np.float64(0.0)}
Experiment proj_mlp_layernorm completed!
All experiments completed!
Ablation Study Results:
              Variant  Bleu_4  ROUGE_L  CIDEr
0  proj_mlp_layernorm       0        0      0


In [None]:
main()

Running experiment: vision_swin
Dataset sizes - Train: 2069, Val: 296, Test: 590
Loading vision encoder: masafresh/swin-transformer
Loading vision encoder is Done
Visual encoder output dimension: 768
Loading LLAMA


The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name           | Type                 | Params | Mode 
----------------------------------------------------------------
0 | visual_encoder | SwinMode

trainable params: 196,608 || all params: 1,315,919,872 || trainable%: 0.0149
Loading LLM Done


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

{'Bleu_1': 0.02282642267711173, 'Bleu_2': 2.773377924829479e-10, 'Bleu_3': 6.448230297738249e-13, 'Bleu_4': 3.136651539350067e-14, 'ROUGE_L': np.float64(0.015316877457052068), 'CIDEr': np.float64(0.001959233969146106)}
Saving checkpoint at step 0 to /content/ablation_study_results/vision_swin/checkpoints/checkpoint_vision_swin_epoch0_step0_bleu0.000.pth.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Metric Bleu_4 improved. New best score: 0.081
INFO:lightning.pytorch.callbacks.early_stopping:Metric Bleu_4 improved. New best score: 0.081
INFO: Epoch 0, global step 518: 'Bleu_4' reached 0.08087 (best 0.08087), saving model to '/content/ablation_study_results/vision_swin/checkpoints/epoch=0-val_Bleu_4=0.00.ckpt' as top 3
INFO:lightning.pytorch.utilities.rank_zero:Epoch 0, global step 518: 'Bleu_4' reached 0.08087 (best 0.08087), saving model to '/content/ablation_study_results/vision_swin/checkpoints/epoch=0-val_Bleu_4=0.00.ckpt' as top 3


{'Bleu_1': 0.286564047290723, 'Bleu_2': 0.17875446350036026, 'Bleu_3': 0.11993455598033266, 'Bleu_4': 0.08087232202108591, 'ROUGE_L': np.float64(0.2930690785301276), 'CIDEr': np.float64(0.15285022684213345)}
Saving checkpoint at step 518 to /content/ablation_study_results/vision_swin/checkpoints/checkpoint_vision_swin_epoch0_step518_bleu0.081.pth.


Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Metric Bleu_4 improved by 0.020 >= min_delta = 0.0. New best score: 0.101
INFO:lightning.pytorch.callbacks.early_stopping:Metric Bleu_4 improved by 0.020 >= min_delta = 0.0. New best score: 0.101
INFO: Epoch 1, global step 1036: 'Bleu_4' reached 0.10068 (best 0.10068), saving model to '/content/ablation_study_results/vision_swin/checkpoints/epoch=1-val_Bleu_4=0.00.ckpt' as top 3
INFO:lightning.pytorch.utilities.rank_zero:Epoch 1, global step 1036: 'Bleu_4' reached 0.10068 (best 0.10068), saving model to '/content/ablation_study_results/vision_swin/checkpoints/epoch=1-val_Bleu_4=0.00.ckpt' as top 3


{'Bleu_1': 0.3123820573652303, 'Bleu_2': 0.20070161983468227, 'Bleu_3': 0.13982654867518715, 'Bleu_4': 0.10067625306059928, 'ROUGE_L': np.float64(0.316114407041808), 'CIDEr': np.float64(0.18344967702330237)}
Saving checkpoint at step 1036 to /content/ablation_study_results/vision_swin/checkpoints/checkpoint_vision_swin_epoch1_step1036_bleu0.101.pth.


Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 2, global step 1554: 'Bleu_4' reached 0.09317 (best 0.10068), saving model to '/content/ablation_study_results/vision_swin/checkpoints/epoch=2-val_Bleu_4=0.00.ckpt' as top 3
INFO:lightning.pytorch.utilities.rank_zero:Epoch 2, global step 1554: 'Bleu_4' reached 0.09317 (best 0.10068), saving model to '/content/ablation_study_results/vision_swin/checkpoints/epoch=2-val_Bleu_4=0.00.ckpt' as top 3


{'Bleu_1': 0.28070417994344327, 'Bleu_2': 0.18678102894479554, 'Bleu_3': 0.13071193371833653, 'Bleu_4': 0.09316758493111504, 'ROUGE_L': np.float64(0.31470247153082936), 'CIDEr': np.float64(0.20142501554441028)}


Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Metric Bleu_4 improved by 0.002 >= min_delta = 0.0. New best score: 0.102
INFO:lightning.pytorch.callbacks.early_stopping:Metric Bleu_4 improved by 0.002 >= min_delta = 0.0. New best score: 0.102
INFO: Epoch 3, global step 2072: 'Bleu_4' reached 0.10245 (best 0.10245), saving model to '/content/ablation_study_results/vision_swin/checkpoints/epoch=3-val_Bleu_4=0.00.ckpt' as top 3
INFO:lightning.pytorch.utilities.rank_zero:Epoch 3, global step 2072: 'Bleu_4' reached 0.10245 (best 0.10245), saving model to '/content/ablation_study_results/vision_swin/checkpoints/epoch=3-val_Bleu_4=0.00.ckpt' as top 3


{'Bleu_1': 0.30588366440149106, 'Bleu_2': 0.20341940804096736, 'Bleu_3': 0.1428093138890827, 'Bleu_4': 0.10245381623530858, 'ROUGE_L': np.float64(0.3092008713432334), 'CIDEr': np.float64(0.21164292978630844)}
Saving checkpoint at step 2072 to /content/ablation_study_results/vision_swin/checkpoints/checkpoint_vision_swin_epoch3_step2072_bleu0.102.pth.


Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Metric Bleu_4 improved by 0.001 >= min_delta = 0.0. New best score: 0.104
INFO:lightning.pytorch.callbacks.early_stopping:Metric Bleu_4 improved by 0.001 >= min_delta = 0.0. New best score: 0.104
INFO: Epoch 4, global step 2590: 'Bleu_4' reached 0.10386 (best 0.10386), saving model to '/content/ablation_study_results/vision_swin/checkpoints/epoch=4-val_Bleu_4=0.00.ckpt' as top 3
INFO:lightning.pytorch.utilities.rank_zero:Epoch 4, global step 2590: 'Bleu_4' reached 0.10386 (best 0.10386), saving model to '/content/ablation_study_results/vision_swin/checkpoints/epoch=4-val_Bleu_4=0.00.ckpt' as top 3


{'Bleu_1': 0.31570935695362506, 'Bleu_2': 0.20880915804063668, 'Bleu_3': 0.14590125932281392, 'Bleu_4': 0.10386350163987325, 'ROUGE_L': np.float64(0.31285272746578185), 'CIDEr': np.float64(0.22938862311496092)}
Saving checkpoint at step 2590 to /content/ablation_study_results/vision_swin/checkpoints/checkpoint_vision_swin_epoch4_step2590_bleu0.104.pth.


INFO: `Trainer.fit` stopped: `max_epochs=5` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

Test result for vision_swin: {'Bleu_1': 4.712717577319185e-32, 'Bleu_2': 3.619907057996158e-26, 'Bleu_3': 3.315172175416523e-24, 'Bleu_4': 3.1725642272616665e-23, 'ROUGE_L': np.float64(0.0), 'CIDEr': np.float64(0.0)}
Experiment vision_swin completed!
All experiments completed!
Ablation Study Results:
       Variant  Bleu_4  ROUGE_L  CIDEr
0  vision_swin       0        0      0


In [None]:
main()

Running experiment: vision_mask
Dataset sizes - Train: 2069, Val: 296, Test: 590
Loading vision encoder: facebook/maskformer-swin-base-ade


config.json:   0%|          | 0.00/12.5k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/412M [00:00<?, ?B/s]

Some weights of SwinModel were not initialized from the model checkpoint at facebook/maskformer-swin-base-ade and are newly initialized: ['embeddings.norm.bias', 'embeddings.norm.weight', 'embeddings.patch_embeddings.projection.bias', 'embeddings.patch_embeddings.projection.weight', 'encoder.layers.0.blocks.0.attention.output.dense.bias', 'encoder.layers.0.blocks.0.attention.output.dense.weight', 'encoder.layers.0.blocks.0.attention.self.key.bias', 'encoder.layers.0.blocks.0.attention.self.key.weight', 'encoder.layers.0.blocks.0.attention.self.query.bias', 'encoder.layers.0.blocks.0.attention.self.query.weight', 'encoder.layers.0.blocks.0.attention.self.relative_position_bias_table', 'encoder.layers.0.blocks.0.attention.self.relative_position_index', 'encoder.layers.0.blocks.0.attention.self.value.bias', 'encoder.layers.0.blocks.0.attention.self.value.weight', 'encoder.layers.0.blocks.0.intermediate.dense.bias', 'encoder.layers.0.blocks.0.intermediate.dense.weight', 'encoder.layers.0.b

Loading vision encoder is Done
Visual encoder output dimension: 1024
Loading LLAMA


The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


model.safetensors:   0%|          | 0.00/411M [00:00<?, ?B/s]

INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


trainable params: 196,608 || all params: 1,315,919,872 || trainable%: 0.0149
Loading LLM Done


INFO: 
  | Name           | Type                 | Params | Mode 
----------------------------------------------------------------
0 | visual_encoder | SwinModel            | 86.9 M | eval 
1 | llama_model    | PeftModelForCausalLM | 1.3 B  | train
2 | embed_tokens   | Embedding            | 102 M  | eval 
3 | llama_proj     | Linear               | 2.1 M  | train
4 | layer_norm     | LayerNorm            | 4.1 K  | train
----------------------------------------------------------------
2.3 M     Trainable params
1.4 B     Non-trainable params
1.4 B     Total params
5,619.607 Total estimated model params size (MB)
24        Modules in train mode
803       Modules in eval mode
INFO:lightning.pytorch.callbacks.model_summary:
  | Name           | Type                 | Params | Mode 
----------------------------------------------------------------
0 | visual_encoder | SwinModel            | 86.9 M | eval 
1 | llama_model    | PeftModelForCausalLM | 1.3 B  | train
2 | embed_tokens   | Embed

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

{'Bleu_1': 0.022002626947828063, 'Bleu_2': 2.6527851903124087e-10, 'Bleu_3': 6.118902647115422e-13, 'Bleu_4': 2.9475390776985495e-14, 'ROUGE_L': np.float64(0.018871156904814394), 'CIDEr': np.float64(4.997475439007447e-27)}
Saving checkpoint at step 0 to /content/ablation_study_results/vision_mask/checkpoints/checkpoint_vision_mask_epoch0_step0_bleu0.000.pth.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Metric Bleu_4 improved. New best score: 0.062
INFO:lightning.pytorch.callbacks.early_stopping:Metric Bleu_4 improved. New best score: 0.062
INFO: Epoch 0, global step 518: 'Bleu_4' reached 0.06170 (best 0.06170), saving model to '/content/ablation_study_results/vision_mask/checkpoints/epoch=0-val_Bleu_4=0.00.ckpt' as top 3
INFO:lightning.pytorch.utilities.rank_zero:Epoch 0, global step 518: 'Bleu_4' reached 0.06170 (best 0.06170), saving model to '/content/ablation_study_results/vision_mask/checkpoints/epoch=0-val_Bleu_4=0.00.ckpt' as top 3


{'Bleu_1': 0.22299493334126114, 'Bleu_2': 0.13667856931454803, 'Bleu_3': 0.09215197368578729, 'Bleu_4': 0.06169697737675803, 'ROUGE_L': np.float64(0.25431115590586845), 'CIDEr': np.float64(0.09066097447074568)}
Saving checkpoint at step 518 to /content/ablation_study_results/vision_mask/checkpoints/checkpoint_vision_mask_epoch0_step518_bleu0.062.pth.


Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Metric Bleu_4 improved by 0.019 >= min_delta = 0.0. New best score: 0.081
INFO:lightning.pytorch.callbacks.early_stopping:Metric Bleu_4 improved by 0.019 >= min_delta = 0.0. New best score: 0.081
INFO: Epoch 1, global step 1036: 'Bleu_4' reached 0.08056 (best 0.08056), saving model to '/content/ablation_study_results/vision_mask/checkpoints/epoch=1-val_Bleu_4=0.00.ckpt' as top 3
INFO:lightning.pytorch.utilities.rank_zero:Epoch 1, global step 1036: 'Bleu_4' reached 0.08056 (best 0.08056), saving model to '/content/ablation_study_results/vision_mask/checkpoints/epoch=1-val_Bleu_4=0.00.ckpt' as top 3


{'Bleu_1': 0.26492040688178475, 'Bleu_2': 0.16885797919832005, 'Bleu_3': 0.11558669980981322, 'Bleu_4': 0.0805597491955816, 'ROUGE_L': np.float64(0.30652186665672054), 'CIDEr': np.float64(0.19031630238886166)}
Saving checkpoint at step 1036 to /content/ablation_study_results/vision_mask/checkpoints/checkpoint_vision_mask_epoch1_step1036_bleu0.081.pth.


Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Metric Bleu_4 improved by 0.015 >= min_delta = 0.0. New best score: 0.095
INFO:lightning.pytorch.callbacks.early_stopping:Metric Bleu_4 improved by 0.015 >= min_delta = 0.0. New best score: 0.095
INFO: Epoch 2, global step 1554: 'Bleu_4' reached 0.09519 (best 0.09519), saving model to '/content/ablation_study_results/vision_mask/checkpoints/epoch=2-val_Bleu_4=0.00.ckpt' as top 3
INFO:lightning.pytorch.utilities.rank_zero:Epoch 2, global step 1554: 'Bleu_4' reached 0.09519 (best 0.09519), saving model to '/content/ablation_study_results/vision_mask/checkpoints/epoch=2-val_Bleu_4=0.00.ckpt' as top 3


{'Bleu_1': 0.296064616353392, 'Bleu_2': 0.19185254342809732, 'Bleu_3': 0.13334329064864986, 'Bleu_4': 0.09519006676520395, 'ROUGE_L': np.float64(0.32462122620845707), 'CIDEr': np.float64(0.22258064538005348)}
Saving checkpoint at step 1554 to /content/ablation_study_results/vision_mask/checkpoints/checkpoint_vision_mask_epoch2_step1554_bleu0.095.pth.


Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Metric Bleu_4 improved by 0.001 >= min_delta = 0.0. New best score: 0.097
INFO:lightning.pytorch.callbacks.early_stopping:Metric Bleu_4 improved by 0.001 >= min_delta = 0.0. New best score: 0.097
INFO: Epoch 3, global step 2072: 'Bleu_4' reached 0.09661 (best 0.09661), saving model to '/content/ablation_study_results/vision_mask/checkpoints/epoch=3-val_Bleu_4=0.00.ckpt' as top 3
INFO:lightning.pytorch.utilities.rank_zero:Epoch 3, global step 2072: 'Bleu_4' reached 0.09661 (best 0.09661), saving model to '/content/ablation_study_results/vision_mask/checkpoints/epoch=3-val_Bleu_4=0.00.ckpt' as top 3


{'Bleu_1': 0.28998191972213444, 'Bleu_2': 0.18911351250315317, 'Bleu_3': 0.13327370744360387, 'Bleu_4': 0.09660969614809206, 'ROUGE_L': np.float64(0.3266172859572732), 'CIDEr': np.float64(0.21542882937098232)}
Saving checkpoint at step 2072 to /content/ablation_study_results/vision_mask/checkpoints/checkpoint_vision_mask_epoch3_step2072_bleu0.097.pth.


Validation: |          | 0/? [00:00<?, ?it/s]

INFO: Epoch 4, global step 2590: 'Bleu_4' reached 0.09497 (best 0.09661), saving model to '/content/ablation_study_results/vision_mask/checkpoints/epoch=4-val_Bleu_4=0.00.ckpt' as top 3
INFO:lightning.pytorch.utilities.rank_zero:Epoch 4, global step 2590: 'Bleu_4' reached 0.09497 (best 0.09661), saving model to '/content/ablation_study_results/vision_mask/checkpoints/epoch=4-val_Bleu_4=0.00.ckpt' as top 3


{'Bleu_1': 0.28829099003876474, 'Bleu_2': 0.18782600180922093, 'Bleu_3': 0.13179491435536198, 'Bleu_4': 0.09496586121661371, 'ROUGE_L': np.float64(0.3268771364083306), 'CIDEr': np.float64(0.19904010413590306)}


INFO: `Trainer.fit` stopped: `max_epochs=5` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

Test result for vision_mask: {'Bleu_1': 4.712717577319185e-32, 'Bleu_2': 3.619907057996158e-26, 'Bleu_3': 3.315172175416523e-24, 'Bleu_4': 3.1725642272616665e-23, 'ROUGE_L': np.float64(0.0), 'CIDEr': np.float64(0.0)}
Experiment vision_mask completed!
All experiments completed!
Ablation Study Results:
       Variant  Bleu_4  ROUGE_L  CIDEr
0  vision_mask       0        0      0
