<a href="https://colab.research.google.com/github/Codedestructor56/Multimodal-Token-Fusion/blob/main/Training_in_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!nvidia-smi

Sat May 25 08:43:03 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   45C    P8              11W /  70W |      0MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [3]:
!pip install huggingface_hub
!huggingface-cli login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): 
Add token as git credential? (Y/n) n
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [4]:
from google.colab import drive
drive.mount("/content/drive/")

Mounted at /content/drive/


In [5]:
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from torch.nn.utils.rnn import pad_sequence
from dataclasses import dataclass
from dataclasses import dataclass
from typing import Optional,List
from transformers import BertTokenizer, AutoTokenizer, AutoProcessor
import shutil
import cv2
import os
import json
import numpy as np
import pandas as pd
import torch.nn as nn
import math
import torch.nn.functional as F
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp
from transformers import AutoTokenizer
from torch.optim import AdamW
import re

In [6]:
class MedicalDatasetCreator():

    def __init__(self, dataset_name: str):
        self.dataset_name = dataset_name

        self.loaded_df = None
        self.image = None
        self.case_text = None
        self.gender = None
        self.age = None
        self.image_labels = None
        self.dataset_length = None

    def load_dataset(self):
        self.loaded_df = pd.read_csv(os.path.join(self.dataset_name, "info_dataframe.csv"))
        self.loaded_df.dropna(inplace = True)
        self.dataset_length = self.loaded_df.shape[0]

    def load_row(self, index: int = 0):
        self.image = cv2.imread(os.path.join(self.dataset_name,self.loaded_df["file"].iloc[index]))
        self.case_text = self.loaded_df["case_text"].iloc[index]
        self.gender = self.loaded_df["gender"].iloc[index]
        self.age = self.loaded_df["age"].iloc[index]

In [7]:
@dataclass
class Parameters:
    device: str
    num_heads: int
    emb_dim: int
    max_seq_len: int
    tokenizer: str
    max_im_height: int
    max_im_width: int
    batch_size: int
    dataset_name: str
    use_cache: str
    ffn_hidden_dim: int
    thresh: Optional[int]
    num_layers: int
    vocab_size: int
    patch_size: int
    token_thresh: int
    imp_layer_hidden: int
    div_batch: int

class Medical_Data(Dataset):
    def __init__(self, params:Parameters):
        super().__init__()
        self.device = params.device
        self.max_seq_len = params.max_seq_len
        self.emb_dim = params.emb_dim
        self.max_height = params.max_im_height
        self.max_width = params.max_im_width
        self.batch_size = params.batch_size
        self.tokenizer = params.tokenizer
        self.dataset = MedicalDatasetCreator(params.dataset_name)
        self.dataset.load_dataset()

    def __len__(self):
        return self.dataset.dataset_length

    def __getitem__(self, idx):
        self.dataset.load_row(idx)
        case_text = self.dataset.case_text
        encoded_case_text = torch.tensor(self.tokenizer.encode(case_text),dtype = torch.int32).to(self.device)
        image = torch.tensor(self.dataset.image, dtype = torch.int8)

        return torch.cat((encoded_case_text[:self.max_seq_len-1], encoded_case_text[-1].unsqueeze(0)), dim=0), image

    def pad_im(self, imgs):
        img_batch = []
        for img in imgs:
            img_gray = cv2.cvtColor(np.array(img, dtype=np.uint8), cv2.COLOR_BGR2GRAY)
            height, width = img_gray.shape[:2]

            pad_height = max(0, self.max_height - height)
            pad_width = max(0, self.max_width - width)
            top_pad = pad_height // 2
            bottom_pad = pad_height - top_pad
            left_pad = pad_width // 2
            right_pad = pad_width - left_pad
            img_padded = cv2.copyMakeBorder(img_gray, top_pad, bottom_pad, left_pad, right_pad, cv2.BORDER_CONSTANT, value=0)
            img_cropped = img_padded[:self.max_height, :self.max_width]
            img_batch.append(img_cropped)

        imgs_tensor = torch.tensor(np.array(img_batch), dtype=torch.uint8).to(self.device)
        del img_batch
        return imgs_tensor

    def collate_fn(self, batch):
        batch = list(zip(*batch))
        padded_text = pad_sequence(batch[0], batch_first = True, padding_value = 0)
        padded_img = self.pad_im(batch[1])

        return padded_text.to(self.device), padded_img


In [8]:
class RotaryEmbeddings(nn.Module):
    def __init__(self, device:str, theta: int =10000):
        super().__init__()
        self.theta = theta
        self.device = device

    def forward(self, x: torch.Tensor, seq_len:Optional[int]=None, emb_dim:Optional[int]=None)->torch.Tensor:
        batch_size, seq_len, emb_dim = x.shape
        assert emb_dim%2==0, "Embeddings dimension must be even"
        #Q_i=10000^(-2(i-1)/emb_dim)
        thetas = (1.0/self.theta**((2*torch.arange(0,emb_dim,2))//emb_dim)).to(self.device)
        thetas_repeated = thetas.unsqueeze(0).repeat(seq_len, 1)
        thetas_true = thetas_repeated * (torch.arange(seq_len, device = self.device)+1).unsqueeze(1)
        #calculate the rotation matrices using these thetas, apply them on the embeddings in  2D or complex space
        matrix_rot = torch.stack((torch.sin(thetas_true),torch.cos(thetas_true)),dim=-1).to(self.device)
        comp_matrix = torch.view_as_complex(matrix_rot).unsqueeze(0)
        x_reshaped = torch.view_as_complex(x.reshape(batch_size, seq_len, emb_dim//2, 2))
        rotated_x = torch.view_as_real(x_reshaped * comp_matrix).squeeze(-1).reshape(batch_size, seq_len, emb_dim).to(self.device)
        del x_reshaped, comp_matrix, matrix_rot, thetas_true, thetas_repeated, thetas
        torch.cuda.empty_cache()
        return rotated_x


class Attention(nn.Module):
    def __init__(self, params: Parameters):
        super().__init__()
        self.use_cache = params.use_cache
        self.device = params.device
        self.pos_rotor = RotaryEmbeddings(self.device)

        self.num_heads = params.num_heads
        assert params.emb_dim % self.num_heads==0, "Make the embedding dim divisible by num_heads"
        self.head_dim = params.emb_dim//self.num_heads
        self.wq = nn.Linear(params.emb_dim, self.num_heads*self.head_dim).to(self.device)
        self.wk = nn.Linear(params.emb_dim, self.num_heads*self.head_dim).to(self.device)
        self.wv = nn.Linear(params.emb_dim, self.num_heads*self.head_dim).to(self.device)
        self.wo = nn.Linear(params.emb_dim, self.num_heads*self.head_dim).to(self.device)
        if self.use_cache:
            self.c_v = torch.zeros((params.max_batch_size, params.max_seq_len, self.num_heads, self.head_dim))
            self.c_k = torch.zeros((params.max_batch_size, params.max_seq_len, self.num_heads, self.head_dim))

    def forward(self, x:torch.Tensor, cur_pos: Optional[int]=None)->torch.Tensor:
        batch_size, seq_len, emb_dim = x.shape
        query = self.wq(x)
        key = self.wk(x)
        value = self.wv(x)
        output = self.wo(x)

        xq = self.pos_rotor(query).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
        xv = self.pos_rotor(value).reshape(batch_size, seq_len, self.num_heads, self.head_dim)
        xk = key.reshape(batch_size, seq_len, self.num_heads, self.head_dim)

        if self.use_cache:
            self.c_v[:batch_size, cur_pos:cur_pos+seq_len]=xv
            self.c_k[:batch_size, cur_pos:cur_pos+seq_len]=xk

            keys = self.c_k[:batch_size, :cur_pos+seq_len]
            values = self.c_v[:batch_size, :cur_pos+seq_len]

            keys = keys[:,:,:,None,:].expand(keys.shape[0], keys.shape[1],
                                           self.num_heads, 1, self.head_dim).reshape(keys.shape[0],
                                            keys.shape[1], self.num_heads, self.head_dim)

            values = values[:,:,:,None,:].expand(values.shape[0], values.shape[1],
                                                 self.num_heads, 1, self.head_dim).reshape(values.shape[0],
                                                 values.shape[1], self.num_heads, self.head_dim)

        else:
            keys = xq
            values = xv

        xq = xq.permute(0, 2, 1, 3).contiguous().to(self.device)
        keys = keys.permute(0, 2, 3, 1).contiguous().to(self.device)
        values = values.permute(0, 2, 1, 3).contiguous().to(self.device)

        query_key_score = torch.matmul(xq, keys)/math.sqrt(self.head_dim)
        attention_score = torch.matmul(query_key_score, values).transpose(1,2).contiguous().reshape(batch_size, seq_len, -1)
        output = self.wo(attention_score)

        del query_key_score, attention_score, xq, keys, values
        torch.cuda.empty_cache()
        #make sure that the dimensions are correct and that the training and inferencing parts are compatible
        return output

class RMSnorm(nn.Module):
    def __init__(self, dim:int, device:str, thresh: float = 1e-4):
        super().__init__()
        self.params = nn.Parameter(torch.ones(dim))
        self.thresh = thresh
        self.device = device

    def forward(self, x:torch.Tensor)->torch.Tensor:
        denom = torch.sqrt(x.pow(2).mean(-1,keepdims=True)).to(self.device)
        res = ((x.to(self.device))*self.params.to(self.device))/denom
        del denom
        torch.cuda.empty_cache()
        return res

class SwiGLu_Forward(nn.Module):
    def __init__(self, params:Parameters):
        super().__init__()
        self.hidden_dim = params.ffn_hidden_dim
        self.device = params.device
        self.w1 = nn.Linear(params.emb_dim, self.hidden_dim).to(self.device)
        self.w2 = nn.Linear(params.emb_dim, self.hidden_dim).to(self.device)
        self.w3 = nn.Linear(self.hidden_dim, params.emb_dim).to(self.device)

    def forward(self, x:torch.Tensor)->torch.Tensor:
        return self.w3(self.w2(x)*nn.functional.silu(self.w1(x)))




class PatchEmbeddings(nn.Module):
    def __init__(self, params: Parameters):
        super().__init__()
        self.device = params.device
        self.emb_dim = params.emb_dim
        self.patch_size = params.patch_size
        self.max_height = params.max_im_height
        self.max_width = params.max_im_width
        assert self.max_height == self.max_width, "Width and height should be equal"
        assert self.max_height % self.patch_size == 0, "Patch size and image dims should be compatible"
        self.linear = nn.Linear(self.patch_size**2, self.emb_dim).to(self.device)

    def patchify(self, image: torch.Tensor):
        patches = []
        batch_size, height, width = image.size()
        for b in range(batch_size):
            for h in range(0, height, self.patch_size):
                for w in range(0, width, self.patch_size):
                    patch = image[b,h:h+self.patch_size, w:w+self.patch_size].float()
                    patches.append(patch)

        return torch.stack(patches, dim = 0).reshape(batch_size, -1, self.patch_size, self.patch_size)

    def forward(self, x:torch.Tensor)->torch.tensor:
        batch_size, seq_len, patch_height, patch_width = x.size()
        #print(patch_height, patch_width)
        assert patch_width == patch_height, "Uniform patch size should be provided"
        patches = self.linear(x.view(batch_size, seq_len, -1, patch_width * patch_height).to(torch.float32).squeeze(2))
        positions = torch.arange(patches.shape[1], dtype=torch.float).unsqueeze(1)
        pe = torch.zeros(patches.shape[1], self.emb_dim)
        div_term = torch.exp(torch.arange(0, self.emb_dim, 2).float() * (-math.log(10000.0) / self.emb_dim))
        pe[:, 0::2] = torch.sin(positions * div_term)
        pe[:, 1::2] = torch.cos(positions * div_term)

        pe = pe.repeat(batch_size, 1, 1, 1).squeeze(1).to(self.device)
        patches += pe
        del pe, positions, div_term
        torch.cuda.empty_cache()
        return patches



In [9]:
class Encoder(nn.Module):
    def __init__(self, params: Parameters):
        super().__init__()
        self.device = params.device
        self.emb_dim = params.emb_dim
        self.thresh = params.thresh
        self.norm = RMSnorm(self.emb_dim, self.device, self.thresh)
        self.attention = Attention(params)
        self.ffn = SwiGLu_Forward(params)

    def forward(self, x:torch.Tensor, cur_pos: Optional[int])->torch.Tensor:
        first_layer = x + self.attention(self.norm(x), cur_pos)
        second_layer = first_layer + self.ffn(self.norm(first_layer))

        del first_layer
        torch.cuda.empty_cache()
        return second_layer

class Transformer(nn.Module):
    def __init__(self, params: Parameters):
        super().__init__()
        self.layers_enc_text = nn.ModuleList()
        self.layers_enc_im = nn.ModuleList()
        for _ in range(params.num_layers):
            self.layers_enc_text.append(Encoder(params))
            self.layers_enc_im.append(Encoder(params))
        self.device = params.device
        self.emb_dim = params.emb_dim
        self.seq_len = params.max_seq_len
        self.vocab_size = params.vocab_size
        self.text_embeddings = nn.Embedding(self.vocab_size, self.emb_dim).to(self.device)
        self.thresh = params.thresh
        self.norm = RMSnorm(self.emb_dim, self.device, self.thresh)
        self.div_batch = params.div_batch
        self.patch_embeddings = PatchEmbeddings(params)
        self.linear = nn.Linear(self.emb_dim, self.vocab_size).to(self.device)

        self.max_height = params.max_im_height
        self.max_width = params.max_im_width
        self.patch_size = params.patch_size

    def forward(self, x: torch.Tensor, cur_pos: Optional[int], im_inc: bool)->torch.Tensor:
        assert self.div_batch<=x.shape[0], "Batch serializer should not exceed tensor dimensions"
        if im_inc:
            im_seq_len = (self.max_height//self.patch_size)*(self.max_width//self.patch_size)
            res = self.patch_embeddings(x)
            del im_seq_len
            torch.cuda.empty_cache()

            if cur_pos is None:
                res = self.norm(res)
                for layer in self.layers_enc_im:
                    res = layer(res, cur_pos)

                res = torch.chunk(res, self.div_batch, dim = 0)

                accumulated_output = None
                for chunk_idx in range(len(res)):
                    out = self.linear(res[chunk_idx])
                    if accumulated_output is None:
                        accumulated_output = out
                    else:
                        # Concatenate the current output with the accumulated output along the specified dimension
                        accumulated_output = torch.cat((accumulated_output, out), dim=0)

                    del out
                    torch.cuda.empty_cache()

                del accumulated_output
                torch.cuda.empty_cache()
            else:
                assert res.shape[1]==1, "Please pass one token at a time"
                res = self.norm(res)
                for layer in self.layers_enc_im:
                    res = layer(res, cur_pos)
                res = self.linear(res)
        else:
            res = self.text_embeddings(x)
            if cur_pos is None:
                res = self.norm(res)
                for layer in self.layers_enc_text:
                    res = layer(res, cur_pos)

                res = torch.chunk(res, self.div_batch, dim = 0)

                accumulated_output = None
                for chunk_idx in range(len(res)):
                    out = self.linear(res[chunk_idx])
                    if accumulated_output is None:
                        accumulated_output = out
                    else:
                        accumulated_output = torch.cat((accumulated_output, out), dim=0)

                    del out
                    torch.cuda.empty_cache()
            else:
                assert res.shape[1]==1, "Pass one token at a time"
                res = self.norm(res)
                for layer in self.layers_enc_text:
                    res = layer(res, cur_pos)
                res = self.linear(res)

        return res

class TokenFusion(nn.Module):
    def __init__(self, params: Parameters):
        super().__init__()
        self.device = params.device
        self.emb_dim = params.emb_dim
        self.token_thresh = params.token_thresh
        self.hidden_dim = params.imp_layer_hidden
        self.thresh = params.thresh
        self.norm = RMSnorm(self.emb_dim, self.device, self.thresh)
        self.vocab_size = params.vocab_size
        self.seq_len = params.max_seq_len
        self.imp_layer1 = nn.Linear(self.seq_len, self.hidden_dim).to(self.device)
        self.imp_layer2 = nn.Linear(self.hidden_dim, 1).to(self.device)
        self.sigmoid = nn.Sigmoid()
        self.transformer = Transformer(params)


    def forward(self, x:Optional[torch.Tensor], y:Optional[torch.Tensor], cur_pos: Optional[int], im_inc: bool):
        if cur_pos is None:
            x, y = self.transformer(x, cur_pos, False), self.transformer(y, cur_pos, im_inc)
            x = torch.stack(x, dim=0).squeeze(1)
            y = torch.stack(y, dim=0).squeeze(1)
            seq_len1, seq_len2 = x.shape[1], y.shape[1]
            min_seq_len = min(seq_len1, seq_len2)
            x_fuse, y_fuse, x_rem, y_rem = x[:,:min_seq_len,:], y[:,:min_seq_len,:], x[:,min_seq_len:,:], y[:,min_seq_len:,:]

            token_scores_x = self.sigmoid(self.imp_layer2(self.imp_layer1(x_fuse)))
            token_scores_y = self.sigmoid(self.imp_layer2(self.imp_layer1(y_fuse)))
            mask_x = (token_scores_x > self.token_thresh).int()
            inv_mask_x = 1-mask_x
            mask_y = (token_scores_y > self.token_thresh).int()
            inv_mask_y = 1-mask_y
            x_fin = x_fuse * mask_x + y_fuse * inv_mask_x.expand(x_fuse.shape[0],min_seq_len,x_fuse.shape[2])
            y_fin = y_fuse * mask_y + x_fuse * inv_mask_y.expand(y_fuse.shape[0],min_seq_len,y_fuse.shape[2])

            x_fin = torch.cat((x_fin, x_rem), dim=1).squeeze()
            y_fin = torch.cat((y_fin, y_rem), dim=1).squeeze()


            del x_fuse, y_fuse, x_rem, y_rem, token_scores_x, token_scores_y, mask_x, mask_y, inv_mask_x, inv_mask_y
            torch.cuda.empty_cache()
            return x_fin, y_fin
        else:
            if im_inc:
                return self.transformer(y, cur_pos, im_inc)
            else:
                return self.transformer(x, cur_pos, im_inc)




Method 1:

In [10]:
def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def train(rank, world_size, params):
    setup(rank, world_size)
    torch.manual_seed(42)

    dataset = Medical_Data(params)
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    dataloader = DataLoader(dataset, batch_size=params.batch_size, sampler=sampler, collate_fn=dataset.collate_fn)
    model = TokenFusion(params).to(rank)
    model = DDP(model, device_ids=[rank])

    optimizer = AdamW(model.parameters(), lr=1e-5)
    model.train()
    for epoch in range(3):
        sampler.set_epoch(epoch)
        for batch in dataloader:
            texts, images = batch
            pt = PatchEmbeddings(params)
            patched_images = pt.patchify(images)
            optimizer.zero_grad()
            outputs = model(texts, patched_images, cur_pos=None, im_inc=True)
            probs = [F.softmax(outputs[0], dim=-1), F.softmax(outputs[1], dim=-1)]
            embeds = nn.Embedding(texts.shape[-1], params.emb_dim).to(params.device)
            texts = embeds(texts)

            if probs[0].shape[1] < texts.shape[1]:
                probs[0] = torch.nn.functional.pad(probs[0], (0, 0, 0, texts.shape[1]-probs[0].shape[1]))

            if probs[1].shape[1] < texts.shape[1]:
                probs[1] = torch.nn.functional.pad(probs[1], (0, 0, 0, texts.shape[1]-probs[1].shape[1]))

            print(texts.shape, probs[0].shape, probs[1].shape)
            loss_text1 = F.cross_entropy(probs[0], texts)
            loss_text2 = F.cross_entropy(probs[1], texts)
            loss = (loss_text1 + loss_text2) / 2.0
            loss.backward()
            optimizer.step()
            if rank == 0:
                print(f"Epoch: {epoch}, Loss: {loss.item()}")

    cleanup()

def get_num_gpus():
    if torch.cuda.is_available():
        num_gpus = torch.cuda.device_count()
    else:
        num_gpus = 0
    return num_gpus

def main():
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
    print(tokenizer.vocab_size)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    params = Parameters(device = device, use_cache = False, num_heads = 16, thresh = None, emb_dim = 256, max_seq_len = 256
                    ,ffn_hidden_dim = 512, batch_size = 8, div_batch = 8,
                    tokenizer = tokenizer, vocab_size = tokenizer.vocab_size+1,
                    max_im_width = 240, max_im_height = 240, num_layers = 1, patch_size = 16, dataset_name = "/content/drive/MyDrive/ct_scan_data",
                    token_thresh = 0.3, imp_layer_hidden = 512)

    world_size = get_num_gpus()
    print(f"Num of GPUs: {world_size}")
    mp.spawn(train, args=(world_size, params), nprocs=world_size, join=True)

main()

In [11]:
!tensorboard --logdir=runs

2024-05-25 08:44:10.278523: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-25 08:44:10.278594: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-25 08:44:10.280162: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.15.2 at http://localhost:6006/ (Press CTRL+C to quit)
^C


Method 2:

In [12]:
from torch.utils.tensorboard import SummaryWriter
import math

def train(params):
    torch.manual_seed(42)
    dataset = Medical_Data(params)
    dataloader = DataLoader(dataset, batch_size=params.batch_size, shuffle=True, collate_fn=dataset.collate_fn)
    model = TokenFusion(params).to(params.device)
    optimizer = AdamW(model.parameters(), lr=1e-5)
    model.train()
    writer = SummaryWriter()

    for epoch in range(3):
        epoch_loss = 0
        for batch_idx, batch in enumerate(dataloader):
            texts, images = batch
            pt = PatchEmbeddings(params)
            patched_images = pt.patchify(images)
            optimizer.zero_grad()
            outputs = model(texts, patched_images, cur_pos=None, im_inc=True)
            probs = [F.softmax(outputs[0], dim=-1), F.softmax(outputs[1], dim=-1)]
            embeds = nn.Embedding(params.vocab_size, params.emb_dim).to(params.device)
            texts = embeds(texts)
            if probs[0].shape[1] < texts.shape[1]:
                probs[0] = torch.nn.functional.pad(probs[0], (0, 0, 0, texts.shape[1] - probs[0].shape[1]))
            if probs[1].shape[1] < texts.shape[1]:
                probs[1] = torch.nn.functional.pad(probs[1], (0, 0, 0, texts.shape[1] - probs[1].shape[1]))
            loss_text1 = F.cross_entropy(probs[0], texts)
            loss_text2 = F.cross_entropy(probs[1], texts)
            loss = (loss_text1 + loss_text2) / 2.0
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

            # Log training loss
            writer.add_scalar('Loss/train', loss.item(), epoch * len(dataloader) + batch_idx)

            # Calculate and log perplexity
            perplexity = math.exp(loss.item())
            writer.add_scalar('Perplexity/train', perplexity, epoch * len(dataloader) + batch_idx)

            print(f"Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item()}, Perplexity: {perplexity}")

        avg_epoch_loss = epoch_loss / len(dataloader)
        writer.add_scalar('Loss/epoch', avg_epoch_loss, epoch)

        avg_epoch_perplexity = math.exp(avg_epoch_loss)
        writer.add_scalar('Perplexity/epoch', avg_epoch_perplexity, epoch)

        print(f"Epoch: {epoch}, Average Loss: {avg_epoch_loss}, Average Perplexity: {avg_epoch_perplexity}")

    torch.save(model.state_dict(), 'token_fusion_model.pth')
    writer.add_text("Model Save", "Model saved as token_fusion_model.pth")
    writer.close()
    print("Model saved as token_fusion_model.pth")

def main():
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
    print(tokenizer.vocab_size)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    params = Parameters(
        device=device,
        use_cache=False,
        num_heads=16,
        thresh=None,
        emb_dim=256,
        max_seq_len=256,
        ffn_hidden_dim=512,
        batch_size=8,
        div_batch=8,
        tokenizer=tokenizer,
        vocab_size=tokenizer.vocab_size + 1,
        max_im_width=240,
        max_im_height=240,
        num_layers=1,
        patch_size=16,
        dataset_name="/content/drive/MyDrive/ct_scan_data",
        token_thresh=0.3,
        imp_layer_hidden=512
    )
    train(params)

main()


In [None]:
from google.colab import files

files.download('token_fusion_model.pth')


Inference:


In [None]:
from tqdm import tqdm


def _sample_top_p(probs, p):
    sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
    top_p_mask = cumulative_probs <= p
    top_p_mask[..., 0] = True
    filtered_probs = sorted_probs * top_p_mask
    filtered_probs = filtered_probs / filtered_probs.sum(dim=-1, keepdim=True)
    next_token = torch.multinomial(filtered_probs, num_samples=1)
    next_token = torch.gather(sorted_indices, -1, next_token)
    return next_token


def pad_im(img, max_height, max_width, device):
    img_gray = cv2.cvtColor(np.array(img, dtype=np.uint8), cv2.COLOR_BGR2GRAY)
    height, width = img_gray.shape[:2]
    pad_height = max(0, max_height - height)
    pad_width = max(0, max_width - width)
    top_pad = pad_height // 2
    bottom_pad = pad_height - top_pad
    left_pad = pad_width // 2
    right_pad = pad_width - left_pad

    img_padded = cv2.copyMakeBorder(img_gray, top_pad, bottom_pad, left_pad, right_pad, cv2.BORDER_CONSTANT, value=0)
    img_cropped = img_padded[:max_height, :max_width]

    img_tensor = torch.tensor(np.array(img_cropped), dtype=torch.uint8).to(device)
    return img_tensor


def infer(tokenizer, prompts: list[str], params: Parameters, model_path: Optional[str], model: TokenFusion , temp = 0.3, top_p = 0.8):
    max_len = params.max_seq_len
    batch_size = params.batch_size
    device = params.device
    image_paths = []
    text_descriptions = []


    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    for prompt in prompts:
        if "image: " in prompt:
            image_part, text_part = prompt.split(" text: ")
            image_path = image_part.replace("image: ", "")
        else:
            image_path = None
            text_part = prompt.replace("text: ", "")

        image_paths.append(image_path)
        text_descriptions.append(text_part)

    prompts = [tokenizer.encode(prompt) for prompt in text_descriptions]
    assert len(prompts)<=batch_size, f"Too many prompts, they should be less than or equal to{batch_size}"
    max_prompt_len = max(len(prompt) for prompt in prompts)
    assert max_prompt_len<=max_len, f"Keep your prompt size below {max_len}"

    total_len = min(params.max_seq_len, max_len + max_prompt_len)
    pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
    tokens = torch.full((batch_size, total_len), pad_id, dtype=torch.long, device=device)
    for k, t in enumerate(prompts):
        tokens[k, :len(t)] = torch.tensor(t, dtype=torch.long, device=device)

    only_text_indices = []
    for i in range(len(image_paths)):
        if image_paths[i] is not None:
            only_text_indices.append(i)

    only_text_indices = torch.tensor(only_text_indices).to(device)

    images = []
    for path in image_paths:
        if path is None:
            images.append(torch.zeros(params.max_im_height, params.max_im_width))
        else:
            images.append(pad_im(cv2.imread(path), params.max_im_height, params.max_im_width, device))

    images = torch.stack(images, dim=0).to(device)
    images = torch.concat((torch.zeros(batch_size-images.shape[0],images.shape[1],images.shape[2]),images),dim=0)

    pt = PatchEmbeddings(params)
    patched_images = pt.patchify(images)
    eos_reached = torch.tensor([False] * batch_size, device=device)
    prompt_tokens_mask = tokens != pad_id
    token_storage_multimodal = torch.zeros(patched_images.size(0), patched_images.size(1)).to(device)

    for cur_pos in tqdm(range(1, total_len), desc='Generating tokens'):
        with torch.no_grad():
            logits = model.forward(tokens[:, cur_pos-1:cur_pos], None, cur_pos, False)
        if temp > 0:
            probs = torch.softmax(logits[:, -1] / temp, dim=-1)
            next_token = _sample_top_p(probs, top_p)
        else:
            next_token = torch.argmax(logits[:, -1], dim=-1)
        next_token = next_token.reshape(-1)
        next_token = torch.where(prompt_tokens_mask[:, cur_pos], tokens[:, cur_pos], next_token)
        tokens[:, cur_pos] = next_token
        eos_reached |= (~prompt_tokens_mask[:, cur_pos]) & (next_token == tokenizer.eos_token_id)
        if all(eos_reached):
            break

    for cur_pos in tqdm(range(1, patched_images.size(1)), desc='Generating tokens'):
        with torch.no_grad():
            logits = model.forward(None, patched_images[:, cur_pos-1:cur_pos], cur_pos, True)
        if temp > 0:
            probs = torch.softmax(logits[:, -1] / temp, dim=-1)
            next_token = _sample_top_p(probs, top_p)
        else:
            next_token = torch.argmax(logits[:, -1], dim=-1)

        next_token = next_token.reshape(-1)
        token_storage_multimodal[:, cur_pos] = next_token
        eos_reached |=  (next_token == tokenizer.eos_token_id)
        if all(eos_reached):
            break

    out_tokens = []
    out_text = []
    for prompt_index, current_prompt_tokens in enumerate(token_storage_multimodal.tolist()):
        if tokenizer.eos_token_id in current_prompt_tokens:
            eos_idx = current_prompt_tokens.index(tokenizer.eos_token_id)
            current_prompt_tokens = current_prompt_tokens[:eos_idx]
        out_tokens.append(current_prompt_tokens)
        out_text.append(tokenizer.decode(torch.tensor(current_prompt_tokens, dtype=torch.int32)))

    for prompt_index, current_prompt_tokens in enumerate(tokens.tolist()):
        if tokenizer.eos_token_id in current_prompt_tokens:
            eos_idx = current_prompt_tokens.index(tokenizer.eos_token_id)
            current_prompt_tokens = current_prompt_tokens[:eos_idx]
        out_tokens.append(current_prompt_tokens)
        out_text.append(tokenizer.decode(current_prompt_tokens))


    return (out_tokens, out_text)


device = "cuda" if torch.cuda.is_available() else "cpu"
#tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")

#hf_checkpoint = "Intel/llava-llama-4-8b"
#tokenizer = AutoTokenizer.from_pretrained(hf_checkpoint)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
print(tokenizer.vocab_size)
#processor = AutoProcessor.from_pretrained(hf_checkpoint)
params = Parameters(device = device, use_cache = False, num_heads = 16, thresh = None, emb_dim = 256, max_seq_len = 256
                    ,ffn_hidden_dim = 512, batch_size = 8, div_batch = 8,
                    tokenizer = tokenizer, vocab_size = tokenizer.vocab_size+1,
                    max_im_width = 240, max_im_height = 240, num_layers = 1, patch_size = 16, dataset_name = "ct_scan_data",
                    token_thresh = 0.3, imp_layer_hidden = 512)

dataloader = DataLoader(
    Medical_Data(params),
    batch_size=params.batch_size,
    collate_fn=Medical_Data(params).collate_fn
)

tk = TokenFusion(params)

#change these queries to get your own desired results
#print(infer(tokenizer,["text: Hello? How are you?","text: Medical Imaging right here",
                        #"image: LPMC/PMC908/PMC9088011_fimmu-13-881352-g002_undivided_1_1.jpg text: What does the image describe?"], params, tk))