<a href="https://colab.research.google.com/github/avikumart/LLM-GenAI-Transformers-Notebooks/blob/main/DeepLearningFiles/Pytorch_Hybrid_CLIP_T5_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# import the dependencies
import torch
import torch.nn as nn
from transformers import AutoTokenizer, T5ForConditionalGeneration, AutoModelForCausalLM
from transformers.modeling_outputs import BaseModelOutput
import open_clip

# from the preprocessing import labels
MAX_TEXT_LEN = 256
BATCH_SIZE = 4
EPOCHS = 5

# set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("avikalsingh/Clinical-T5-Base", legacy=False)

# add the vision wrapper to generate the embeddings
class CLIPVisionWrapper(nn.Module):
  def __init__(self, vision_model):
    super().__init__()
    self.vision_model = vision_model
    self.embed_dim = getattr(vision_model, "output_dim", getattr(vision_model, "embed_dim", getattr(vision_model, "width", 512)))

    token_dim = None
    trunk = getattr(vision_model, "trunk",None)
    if trunk is not None:
      token_dim = getattr(trunk, "embed_dim", None)
      if token_dim is None and hasattr(trunk, "blocks") and len(trunk.blocks) > 0:
        token_dim = getattr(trunk.blocks[0], "dim", None)
    self.token_dim = int(token_dmi) if token_dim is not None else int(self.embed_dim)

  def forward(self, x):
    feats = self.vision(x)
    if feats.ndim == 2: return feats
    if feats.ndim == 4:
      return 0.5 * (feats.mean(dim=(2,3)) + feats.amax(dim=(2,3)))
    raise RuntimeError("Invalid input shape")

  @torch.no_grad()
  def forward_text(self, x):
    v = self.vision
    trunk = getattr(v, "trunk", None)
    if trunk is not None and hasattr(trunk, "forward_features"):
      out = trunk.forward_features(x)
      if isinstance(out, torch.Tensor) and out.ndim == 3: return out
    if hasattr(v, "forward_features"):
      out = v.forward_features(x)
      if isinstance(out, torch.Tensor) and out.ndim == 3: return out
    pooled = self.forward(x)
    return pooled.unsequeeze(1)


# stage 1 for the classification
def build_classifier(num_labels: int):
  clip_model, _, _ = open_clip.create_model_and_transforms("hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224")
  vision_encoder = CLIPVisionWrapper(clip_model.visual)
  embed_dim = vision_encoder.embed_dim
  classifier_head = nn.Linear(embed_dim, num_labels)
  return nn.Sequential(vision_encoder, classifier_head)


# multitask model for the hybrid training strategies
class ClipMultiTaskModel(nn.Module):
  def __init__(self, clip_encoder, text_decoder, num_labels: int, use_label_token: bool = True, label_token_scale: float = 0.5, max_visual_tokens: int = 64, use_decoder_aux: bool = False):
    super().__init__()
    self.clip_encoder = clip_encoder
    self.text_decoder = text_decoder
    self.num_labels = num_labels

    # -- llama tokenizer for refinement ---
    self.llama_tokeniner = AutoTokenizer.from_pretrained("meta-llam/Llama-3.3-1B")
    if self.llama_tokenizer.pad_token is None:
      self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token

    embed_dim = getattr(clip_encoder, "embed_dim", 512)
    token_dim = getattr(clip_encoder, "token_dim", embed_dim)
    d_model = text_decoder.config.d_model # T5 hidden size

    self.enc_ln = nn.LayerNorm(d_model)
    self.enc_gate = nn.Parameter(torch.tensor(1.0))
    self.dec_ln = nn.LayerNorm(d_model)
    self.class_head = nn.Linear(d_model, num_labels)

    # ---- NEW NLP PROJECTOR ----
    self.visual_projection = nn.Sequential(
        nn.Linear(int(token_dim), int(d_model)),
        nn.GELU(),
        nn.Linear(int(d_model), int(d_model)),
        nn.Dropout(p=0.1)
    )

    self.max_visual_tokens = int(max_visual_tokens)
    self.log_sigma_clf = nn.Parameter(torch.zeros(1))
    self.log_sigma_gen = nn.Parameter(torch.zeros(1))

    # llama 3 refiner
    self.llm = AutoModelForCausalLM.from_pretrained("meta-llam/Llama-3.3-1B")
    for param in self.llm.parameters():
      param.requires_grad = False


    # auxiliary head (decoder side)
    llama_dim = self.llm.config.hidden_size
    self.use_decoder_aux = bool(use_decoder_aux)
    self.decoder_cls_head = nn.Linear(llama_dim, num_labels) if self.use_decoder_aux else None

  def forward(
      self,
      images,
      clf_labels = None,
      t5_labels =None,
      decoder_attention_mask= None,
      generate: bool = False,
      generate_kwargs: dict | None = None,
      **kwargs):

     base_features = self.clip_encoder(images)
     vis_tokens = self.clip_encoder.forward_tokens(images)

     if vis_tokens.size(1) > self.max_visual_tokens:
      idx = torch.linspace(0, vis_tokens.size(1) -1, steps=self.max_visual_tokens).long().to(vis_tokens.device)
      vis_tokens = vis_tokens[:, idx, :]


     encoder_hidden = self.visual_projection(vis_tokens)
     encoder_hidden = self.enc_ln(encoder_hidden)
     gate = torch.sigmoid(self.enc_gate)
     encoder_hidden = gate * encoder_hidden


     # create the attention mask for te T5
     enc_attn_mask = torch.ones((images.size(0), encoder_hidden.size(1)), dtype=torch.long, device=images.device)
     encoder_outputs = BaseModelOutput(last_hidden_state=encoder_hidden)

     # stage 1 clfs
     class_logits = self.class_head(base_features)

     # generation mode:
     if generate:
      if generate_kwargs is None: generate_kwargs = {}
            # Defaults for clinical accuracy
      generate_kwargs.setdefault("max_new_tokens", 128)
      generate_kwargs.setdefault("num_beams", 4)
      generate_kwargs.setdefault("do_sample", False)
      generate_kwargs.setdefault("decoder_start_token_id", self.decoder.config.pad_token_id or 0)

            # Step A: Generate draft with T5 (using projected visual tokens)
      t5_gen_ids = self.decoder.generate(
                encoder_outputs=encoder_outputs,
                attention_mask=enc_attn_mask,
                **generate_kwargs,
            )

            # Step B: Decode T5 output to text
      draft_texts = tokenizer.batch_decode(t5_gen_ids, skip_special_tokens=True)

            # Step C: Refine with Llama
            # Prompt Engineering for Refinement
      refinement_prompts = [
                f"Fix grammar and medical terminology in this report:\nDraft: {txt}\nRefined Report:"
                for txt in draft_texts
            ]

      llama_inputs = self.llama_tokenizer(
                refinement_prompts,
                return_tensors="pt",
                padding=True,
                truncation=True
            ).to(images.device)

      final_ids = self.llm.generate(
                **llama_inputs,
                max_new_tokens=128,
                do_sample=True, # Slight sampling for natural flow in refinement
                temperature=0.7
            )

      return class_logits, final_ids

        # ---------------------------------------------------------
        # TRAINING MODE
        # ---------------------------------------------------------
     if t5_labels is None:
      return class_logits, None

        # T5 Forward
     outputs = self.decoder(
            labels=t5_labels,
            attention_mask=enc_attn_mask,
            encoder_outputs=encoder_outputs,
            decoder_attention_mask = decoder_attention_mask,
            output_hidden_states=self.use_decoder_aux,
            return_dict=True,
        )

      # Llama Aux Head Logic
     aux_logits = None
     if self.use_decoder_aux:
            with torch.no_grad():
                # Re-tokenize T5 labels for Llama
                t5_labels_replaced = t5_labels.clone()
                t5_labels_replaced[t5_labels_replaced == -100] = tokenizer.pad_token_id
                decoded_labels = tokenizer.batch_decode(t5_labels_replaced, skip_special_tokens=True)

                llama_inputs = self.llama_tokenizer(
                    decoded_labels,
                    return_tensors="pt",
                    padding=True,
                    truncation=True
                ).to(images.device)

                llm_outputs = self.llm(
                    input_ids=llama_inputs.input_ids,
                    attention_mask=llama_inputs.attention_mask,
                    output_hidden_states=True
                )

            # Feature Extraction (Last Hidden State Mean Pool)
            hidden_states = llm_outputs.hidden_states[-1]
            decoder_hidden_pooled = hidden_states.mean(dim=1)
            aux_logits = self.decoder_cls_head(decoder_hidden_pooled)

     return class_logits, outputs, aux_logits

