## 1. Design Model Archetecture

### 1-1. LLM, ViT load

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPVisionModel, CLIPImageProcessor

def load_models_cuda():
    """모델을 CUDA 디바이스로 로드하는 함수"""
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)

    llm = AutoModelForCausalLM.from_pretrained(
        "microsoft/phi-2",
        torch_dtype=torch.bfloat16,
        load_in_4bit=True,     
    ).to(device)

    vision_encoder = CLIPVisionModel.from_pretrained(
        "openai/clip-vit-base-patch32",
        torch_dtype=torch.bfloat16
    ).to(device)

    image_processor = CLIPImageProcessor.from_pretrained(
        "openai/clip-vit-base-patch32"
    )

    return llm, tokenizer, vision_encoder, image_processor
                                    
llm, tokenizer, vision_encoder, image_processor = load_models_cuda()

### 1-2 Cross-Attention arhchitecture 4개 레이어에만 적용

In [None]:
import torch
import torch.nn as nn
from functools import partial

class CrossAttention(nn.Module):
    def __init__(self, model_dims: int, num_heads: int):
        super().__init__()
        self.attention = nn.MultiheadAttention(
            embed_dim=model_dims, 
            num_heads=num_heads, 
            batch_first=True,
            dtype=torch.bfloat16 # float16
        )
        self.layer_norm = nn.LayerNorm(model_dims)

    def forward(self, text_features, image_features):
        # 교차-어텐션 수행
        attn_output, _ = self.attention(text_features, image_features, image_features)
        # Residual Connection (잔차 연결) 및 Layer Normalization
        # 입력(text_features)과 출력(attn_output)을 더해줌으로써 학습 안정성을 높임
        output = self.layer_norm(text_features + attn_output)
        return output

class MultimodalPhi2(nn.Module):
    def __init__(self, peft_llm, vision_encoder):
        super().__init__()
        
        self.vision_encoder = vision_encoder
        self.llm = peft_llm

        target_dtype = self.llm.dtype
        model_dims = self.llm.config.hidden_size
        vit_dims = self.vision_encoder.config.hidden_size
        num_heads = self.llm.config.num_attention_heads
        num_llm_layers = self.llm.config.num_hidden_layers

        self.target_layers = range(num_llm_layers - 4, num_llm_layers)

        self.vision_projection = nn.Linear(vit_dims, model_dims)
        self.vision_projection.to(device=self.llm.device, dtype=target_dtype) 
        
        ## 4개의 레이어에 대해서만 교차-어텐션을 적용
        self.cross_attentions = nn.ModuleDict({
        str(i): CrossAttention(model_dims, num_heads) for i in self.target_layers
        }) 

        ## cross attentioin foward hook 부분
        self.image_features_cache = None # 이미지 특징을 임시 저장할 공간

        for layer_idx in self.target_layers:
        # ModuleDict의 키는 문자열이므로, 인덱싱할 때 str(layer_idx)를 사용합니다.
            layer = self.llm.model.model.layers[layer_idx] # peft_llm의 경우, model을 한번 더 거쳐야 함
            layer.self_attn.register_forward_hook(
                partial(self.cross_attention_hook, layer_idx=layer_idx)
            )
        
        self.cross_attentions.to(device=self.llm.device, dtype=target_dtype)

    def cross_attention_hook(self, module, input, output, layer_idx):
        # output은 self_attn 모듈의 최종 출력값입니다.
        # 원래 출력 형식: (hidden_states, attention_weights, past_key_value)
        hidden_states = output[0]
        
        # ModuleDict의 키는 문자열이므로, 인덱싱할 때 str(layer_idx)를 사용
        cross_attn_output = self.cross_attentions[str(layer_idx)](
            hidden_states, self.image_features_cache
        )
        # 원래 출력의 형태를 유지하면서, 작업이 완료된 hidden_states로 교체하여 반환합니다.
        return (cross_attn_output,) + output[1:]

    def forward(self, input_ids: torch.Tensor, pixel_values: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor = None):
        # 1. 이미지 특징을 계산하고, 훅 함수가 사용할 수 있도록 캐시에 저장합니다.
        image_outputs = self.vision_encoder(pixel_values)
        image_patch_features = image_outputs.last_hidden_state
        self.image_features_cache = self.vision_projection(
            image_patch_features.to(self.llm.dtype)
        )

        outputs = self.llm(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels 
        )
        
        # 3. 사용이 끝난 임시 캐시를 비워줍니다.
        self.image_features_cache = None
        
        return outputs

## 2. Data Processing

### 2-1. 데이터, 전처리 도구 불러오기

In [None]:
from datasets import load_dataset

# 데이터셋 로드
dataset = load_dataset("clip-benchmark/wds_mscoco_captions2017")

# load vit processor and tokenizer
from transformers import CLIPImageProcessor, AutoTokenizer
image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)

### 2-2. 데이터 전처리

In [None]:
from datasets import Dataset
from tqdm.auto import tqdm 

# --- 1. 데이터셋 '스트리밍'을 위한 제너레이터 함수 정의 (tqdm 추가) ---
def flatten_generator(dataset):
    """
    데이터셋을 순회하며, [이미지, 단일 캡션] 쌍을 하나씩 생성(yield)하고 진행 상황을 보여줍니다.
    """
    print("Generator started. Processing data one by one...")
    # tqdm으로 dataset을 감싸서 진행 상황을 표시합니다.
    # desc는 진행률 표시줄 앞에 표시될 설명입니다.
    for item in tqdm(dataset, desc="Flattening dataset"):
        captions = [line for line in item['txt'].split('\n')]
        for caption in captions:
            yield {'jpg': item['jpg'], 'caption': caption}

# --- 2. 제너레이터로부터 새로운 데이터셋 생성 (이전과 동일) ---
print("Creating a new dataset from the generator (memory-efficient)...")
expanded_dataset = Dataset.from_generator(flatten_generator, gen_kwargs={"dataset": dataset['train']})
print(f"Dataset created with {len(expanded_dataset)} samples.")

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    print("Tokenizer's pad_token has been set to eos_token.")

# --- 3. .map()을 위한 전처리 함수 정의 (이전과 동일) ---
def preprocess_function(examples):
    # ... (이전 답변의 preprocess_function 코드와 동일)
    images = [img.convert("RGB") for img in examples['jpg']]
    captions = examples['caption']
    model_inputs = image_processor(images, return_tensors="pt")
    text_inputs = tokenizer(captions, padding="max_length", truncation=True, max_length=128, return_tensors="pt")
    model_inputs['input_ids'] = text_inputs['input_ids']
    model_inputs['attention_mask'] = text_inputs['attention_mask']
    model_inputs['labels'] = text_inputs['input_ids'].clone()
    return model_inputs


# --- 4. .map() 함수로 '전처리 공장' 가동 (자동으로 진행률 표시) ---
print("\nStarting memory-efficient preprocessing with .map()...")

final_dataset = expanded_dataset.map(
    function=preprocess_function,
    batched=True,
    batch_size=100,
    remove_columns=expanded_dataset.column_names
)

print("Preprocessing and caching complete!")
print("\n--- Final Processed Dataset Info ---")
print(final_dataset)

### 2-3. 데이터 저장하기

In [None]:
final_save_path = "./my_final_dataset"
print(f"Saving the final processed dataset to '{final_save_path}'...")
final_dataset.save_to_disk(final_save_path)
print("Final dataset saved.")

### 2-4. 데이터 불러오기

In [None]:
from datasets import load_from_disk

final_save_path = "./my_final_dataset"
print(f"Loading the final preprocessed dataset from '{final_save_path}'...")

# 최종 저장된 데이터셋을 바로 불러옵니다.
final_dataset = load_from_disk(final_save_path)

print("Dataset loaded instantly from disk!")
print(final_dataset)

## 3. Model training

### 3-1. Model loading

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPVisionModel, CLIPImageProcessor


"""모델을 CUDA 디바이스로 로드하는 함수"""
def load_models_cuda():
    """모델을 CUDA 디바이스로 로드하는 함수"""
    device = "cuda" if torch.cuda.is_available() else "cpu"

    tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)

    llm = AutoModelForCausalLM.from_pretrained(
        "microsoft/phi-2",
        torch_dtype=torch.bfloat16,
        load_in_4bit=True,     
    ).to(device)

    vision_encoder = CLIPVisionModel.from_pretrained(
        "openai/clip-vit-base-patch32",
        torch_dtype=torch.bfloat16
    ).to(device)

    image_processor = CLIPImageProcessor.from_pretrained(
        "openai/clip-vit-base-patch32"
    )

    return tokenizer, llm, vision_encoder, image_processor

tokenizer, llm, vision_encoder, image_processor = load_models_cuda()

### 3-2. Connet LoRA example

In [None]:
from peft import get_peft_model, LoraConfig, TaskType

# LoRA 설정
lora_config = LoraConfig(
    r=16, # LoRA를 이용해 몇차원 으로 줄일지 설정
    lora_alpha=32, # LoRA의 영향력 조절
    target_modules=["q_proj", "v_proj", "k_proj", "dense"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)
peft_llm = get_peft_model(llm, lora_config)
peft_llm.print_trainable_parameters()

device = "cuda" if torch.cuda.is_available() else "cpu"

multi_modal_model = MultimodalPhi2(peft_llm, vision_encoder).to(device)

### 3-3. Find the best parameters

In [None]:
import wandb
wandb.login()

In [None]:
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW 
from transformers import get_linear_schedule_with_warmup
from transformers import default_data_collator
from tqdm.auto import tqdm
import wandb 
from peft import get_peft_model, LoraConfig, TaskType
import gc


def model_train_eval(llm, vision_encoder, dataset, parameters, device):

    lora_config = LoraConfig(
        r=parameters['LoRA_R'], # LoRA를 이용해 몇차원 으로 줄일지 설정
        lora_alpha=parameters['LoRA_R']*2, # LoRA의 영향력 조절
        target_modules=["q_proj", "v_proj", "k_proj", "dense"],
        lora_dropout=0.05,
        bias="none",
        task_type=TaskType.CAUSAL_LM
    )
    peft_llm = get_peft_model(llm, lora_config)
    model = MultimodalPhi2(peft_llm, vision_encoder).to(device)

    wandb.init(
        project="Making-Multimodal-Models",
        name=f"{config['architecture']} with LoRA R={config['LoRA_R']} LoRA alpha={config['LoRA_R']*2} lr={config['learning_rate']}",
        config=config,
    )

    test_final_dataset = dataset.select(range(5000)) #훈련 상황에 따라 바꾸면 됨
    test_eval_final_dataset = dataset.select(range(5000, 5500)) # 테스트 데이터셋

    # collate_fn = defalut_data_collator를 함으로써 데이터의 형태로 올바르게 맞춰 줌
    train_dataloader = DataLoader(test_final_dataset, batch_size=config["batch_size"], collate_fn=default_data_collator, shuffle=True)
    eval_dataloader = DataLoader(test_eval_final_dataset, batch_size=config["batch_size"], collate_fn=default_data_collator, shuffle=True)

    trainable_params = [p for p in model.parameters() if p.requires_grad] # require_grad가 허용된(미분가능) 부분에만 optimizer적용
    optimizer = AdamW(trainable_params, lr=config["learning_rate"])

    num_training_steps = config["num_epochs"] * len(train_dataloader)
    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=0, # 몇 steo동안 천천히 증가(웜업) 할지 설정
        num_training_steps=num_training_steps, # 몇 step에 걸쳐 천천히 감소할지 설정
    )
    print("Optimizer and Scheduler have been set up.")

    # 학습 시작
    print(f"\n--- Starting Training for {config['num_epochs']} epoch(s) ---")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # 훈련 루프
    for epoch in range(config["num_epochs"]):
        model.train()
        progress_bar = tqdm(train_dataloader, desc="Training")
        print(train_dataloader)
        for step, batch in enumerate(progress_bar):
            outputs = model(batch['input_ids'].to(device),
                                        batch['pixel_values'].to(device),
                                        batch['attention_mask'].to(device),
                                        batch['labels'].to(device))
            loss = outputs.loss
            
            loss.backward()
            
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

            if step % 10 == 0:
                    wandb.log({"train/loss": loss.item()})
            
            progress_bar.set_postfix({"loss": loss.item()})

        # 검증 루프 
        model.eval()
        eval_loss_total = 0
        print(f"\n--- Validating Epoch {epoch + 1} ---")
        
        with torch.no_grad():
            for batch in tqdm(eval_dataloader, desc="Validation"):
                batch = {k: v.to(device) for k, v in batch.items()}
                outputs = model(**batch)
                loss = outputs.loss
                eval_loss_total += loss.item()
        
        avg_eval_loss = eval_loss_total / len(eval_dataloader)
        
        wandb.log({
            "eval/loss": avg_eval_loss,
            "epoch": epoch + 1
        })
        

    print("Training complete!")
    wandb.finish()

if __name__ == "__main__":
    # 데이터 텐서로 바꿔주기
    columns_to_tensorize = ['pixel_values', 'input_ids', 'attention_mask', 'labels']
    final_dataset.set_format(type='torch', columns=columns_to_tensorize)

    parameters = {
        "learning_rate": [5e-5, 3e-5, 1e-5],
        "LoRA_R": [8, 16, 32],
    }
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # 파라미터별 데이터 성능 테스트
    for i in range(3):
        for j in range(3):
            # 모델 로드
            llm, tokenizer, vision_encoder, image_processor = load_models_cuda()
            
            config = {"num_epochs": 1,
                            "batch_size": 4,
                            "learning_rate": parameters['learning_rate'][i],
                            "LoRA_R": parameters['LoRA_R'][j],
                            "architecture": "Corss-Attention Multimodal Phi-2",
                            "dataset": "lip-benchmark/wds_mscoco_captions2017"}
            model_train_eval(llm, vision_encoder, final_dataset, config, device)

            del llm, vision_encoder, image_processor, tokenizer
            gc.collect()
            torch.cuda.empty_cache()
    print("Model test complete!")

# train.py

In [None]:
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW 
import torch.nn as nn
from functools import partial
from transformers import get_linear_schedule_with_warmup
from transformers import default_data_collator
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPVisionModel, CLIPImageProcessor
from tqdm.auto import tqdm
from datasets import load_from_disk
import wandb 
from peft import get_peft_model, LoraConfig, TaskType
import os

#------------------------------------------- 모델 불러오기
def load_models_cuda():
    """모델을 CUDA 디바이스로 로드하는 함수"""
    device = "cuda" if torch.cuda.is_available() else "cpu"

    tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)

    llm = AutoModelForCausalLM.from_pretrained(
        "microsoft/phi-2",
        torch_dtype=torch.bfloat16,
        load_in_4bit=True, 
    ).to(device)

    vision_encoder = CLIPVisionModel.from_pretrained(
        "openai/clip-vit-base-patch32",
        torch_dtype=torch.bfloat16
    ).to(device)

    image_processor = CLIPImageProcessor.from_pretrained(
        "openai/clip-vit-base-patch32"
    )

    return llm, tokenizer, vision_encoder, image_processor

llm, tokenizer, vision_encoder, image_processor = load_models_cuda()

#------------------------------------------- 모델 아키텍쳐 작성
class CrossAttention(nn.Module):
    def __init__(self, model_dims: int, num_heads: int):
        super().__init__()
        self.attention = nn.MultiheadAttention(
            embed_dim=model_dims, 
            num_heads=num_heads, 
            batch_first=True,
            dtype=torch.bfloat16 # float16
        )
        self.layer_norm = nn.LayerNorm(model_dims)

    def forward(self, text_features, image_features):
        attn_output, _ = self.attention(text_features, image_features, image_features)
        output = self.layer_norm(text_features + attn_output)
        return output

class MultimodalPhi2(nn.Module):
    def __init__(self, peft_llm, vision_encoder):
        super().__init__()
        
        self.vision_encoder = vision_encoder
        self.llm = peft_llm

        target_dtype = self.llm.dtype
        model_dims = self.llm.config.hidden_size
        vit_dims = self.vision_encoder.config.hidden_size
        num_heads = self.llm.config.num_attention_heads
        num_llm_layers = self.llm.config.num_hidden_layers

        self.target_layers = range(num_llm_layers - 4, num_llm_layers)

        self.vision_projection = nn.Linear(vit_dims, model_dims)
        self.vision_projection.to(device=self.llm.device, dtype=target_dtype) 
        
        self.cross_attentions = nn.ModuleDict({
        str(i): CrossAttention(model_dims, num_heads) for i in self.target_layers
        }) 

        self.image_features_cache = None # 이미지 특징을 임시 저장할 공간

        for layer_idx in self.target_layers:
        # ModuleDict의 키는 문자열이므로, 인덱싱할 때 str(layer_idx)를 사용합니다.
            layer = self.llm.model.model.layers[layer_idx] 
            layer.self_attn.register_forward_hook(
                partial(self.cross_attention_hook, layer_idx=layer_idx)
            )
        
        self.cross_attentions.to(device=self.llm.device, dtype=target_dtype)

    def cross_attention_hook(self, module, input, output, layer_idx):

        hidden_states = output[0]
        
        # ModuleDict의 키는 문자열이므로, 인덱싱할 때 str(layer_idx)를 사용
        cross_attn_output = self.cross_attentions[str(layer_idx)](
            hidden_states, self.image_features_cache
        )
        # 원래 출력의 형태를 유지하면서, 작업이 완료된 hidden_states로 교체하여 반환합니다.
        return (cross_attn_output,) + output[1:]

    def forward(self, input_ids: torch.Tensor, pixel_values: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor = None):
        image_outputs = self.vision_encoder(pixel_values)
        image_patch_features = image_outputs.last_hidden_state
        self.image_features_cache = self.vision_projection(
            image_patch_features.to(self.llm.dtype)
        )

        outputs = self.llm(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels 
        )
        
        # 3. 사용이 끝난 임시 캐시를 비워줍니다.
        self.image_features_cache = None
        
        return outputs
    
#------------------------------------------- 모델 로드    
final_save_path = "./my_final_dataset"
eval_save_path = "./test_dataset"
print(f"Loading the final preprocessed dataset from '{final_save_path}'...")

# 최종 저장된 데이터셋을 바로 불러옵니다.
dataset = load_from_disk(final_save_path)
eval_dataset = load_from_disk(eval_save_path)

#------------------------------------------- config 작성
config = {"num_epochs": 5,
            "batch_size": 4,
            "learning_rate": 5e-05,
            "LoRA_R": 32,
            "LoRA_alpha": 64,
            "architecture": "Corss-Attention Multimodal Phi-2",
            "dataset": "lip-benchmark/wds_mscoco_captions2017"}

#------------------------------------------- 모델 학습 

device = "cuda" if torch.cuda.is_available() else "cpu"

columns_to_tensorize = ['pixel_values', 'input_ids', 'attention_mask', 'labels']
dataset.set_format(type='torch', columns=columns_to_tensorize)
eval_dataset.set_format(type='torch', columns=columns_to_tensorize)

config = {"num_epochs": 5,
            "batch_size": 8,
            "learning_rate": 5e-5,
            "LoRA_R": 32,
            "LoRA_alpha": 64,
            "architecture": "Corss-Attention Multimodal Phi-2",
            "dataset": "lip-benchmark/wds_mscoco_captions2017"}

lora_config = LoraConfig(
    r=config['LoRA_R'], # LoRA를 이용해 몇차원 으로 줄일지 설정
    lora_alpha=config['LoRA_alpha'], # LoRA의 영향력 조절
    target_modules=["q_proj", "v_proj", "k_proj", "dense"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM
)
peft_llm = get_peft_model(llm, lora_config)
model = MultimodalPhi2(peft_llm, vision_encoder).to(device)

wandb.init(
        project="final-multimodal-training",
        name="multimodal_phi2_training",
        config=config,
    )

# collate_fn = defalut_data_collator를 함으로써 데이터의 형태로 올바르게 맞춰 줌
train_dataloader = DataLoader(dataset, batch_size=config["batch_size"], collate_fn=default_data_collator, shuffle=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=config["batch_size"], collate_fn=default_data_collator, shuffle=True)

trainable_params = [p for p in model.parameters() if p.requires_grad] # require_grad가 허용된(미분가능) 부분에만 optimizer적용
optimizer = AdamW(trainable_params, lr=config["learning_rate"])

num_training_steps = config["num_epochs"] * len(train_dataloader)
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0, # 몇 step동안 천천히 증가(웜업) 할지 설정
    num_training_steps=num_training_steps, # 몇 step에 걸쳐 천천히 감소할지 설정
)
print("Optimizer and Scheduler have been set up.")

# 학습 시작
print(f"\n--- Starting Training for {config['num_epochs']} epoch(s) ---")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

save_directory = "./save_model"

# 훈련 루프
for epoch in range(config["num_epochs"]):
    model.train()
    progress_bar = tqdm(train_dataloader, desc="Training")
    print(train_dataloader)
    for step, batch in enumerate(progress_bar):
        outputs = model(batch['input_ids'].to(device),
                        batch['pixel_values'].to(device),
                        batch['attention_mask'].to(device),
                        batch['labels'].to(device))
        loss = outputs.loss
        
        loss.backward()
        
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        
        if step % 10 == 0:
            wandb.log({"train/loss": loss.item()})
        progress_bar.set_postfix({"loss": loss.item()})

    # 검증 루프 
    model.eval()
    eval_loss_total = 0
    print(f"\n--- Validating Epoch {epoch + 1} ---")
    
    with torch.no_grad():
            for batch in tqdm(eval_dataloader, desc="Validation"):
                batch = {k: v.to(device) for k, v in batch.items()}
                outputs = model(**batch)
                loss = outputs.loss
                eval_loss_total += loss.item()
        
    avg_eval_loss = eval_loss_total / len(eval_dataloader)
        
    wandb.log({
            "eval/loss": avg_eval_loss,
            "epoch": epoch + 1
    })

    # 모델 저장
    os.makedirs(save_directory+f"/{epoch}epoch", exist_ok=True)
    
    lora_save_path = os.path.join(save_directory+f"/{epoch}epoch", f"llm_adapters{epoch}")
    model.llm.save_pretrained(lora_save_path)
    print(f"LoRA adapters saved to {lora_save_path}")

    vision_projection_path = os.path.join(save_directory+f"/{epoch}epoch", f"vision_projection{epoch}.pt")
    torch.save(model.vision_projection.state_dict(), vision_projection_path)
    print(f"Vision projection saved to {vision_projection_path}")

    cross_attentions_path = os.path.join(save_directory+f"/{epoch}epoch", f"cross_attentions{epoch}.pt")
    torch.save(model.cross_attentions.state_dict(), cross_attentions_path)
    print(f"Cross attentions saved to {cross_attentions_path}")


## 4. Test

In [None]:
import pandas as pd

test_dataset = pd.read_csv("./open/test.csv")

In [None]:
## 추론
import matplotlib.pyplot as plt
from PIL import Image
from torch.functional import F

def inference(model, tokenizer, image_processor, prompt, image_path):
    model.eval()
    device = "cuda" if torch.cuda.is_available() else "cpu"

    if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token
                print("Tokenizer pad_token is set to eos_token.")

    prompt_tokens = tokenizer(prompt, return_tensors="pt")
    input_ids = prompt_tokens.input_ids.to(device)
    attention_mask = prompt_tokens.attention_mask.to(device)
    PIL_image = Image.open(image_path).convert("RGB")
    image = image_processor(images=PIL_image, return_tensors="pt").pixel_values.to(device) 

    with torch.no_grad():
        image_outputs = model.vision_encoder(image)
        image_patch_features = image_outputs.last_hidden_state
        model.image_features_cache = model.vision_projection(image_patch_features)
    
    generated_ids = model.llm.generate(
        input_ids=input_ids, 
        attention_mask=attention_mask, 
        max_new_tokens=128,  # 새로 생성할 최대 토큰 수
        do_sample=False,     # 샘플링을 활성화
        temperature=1,
        top_p=0.9,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id
    )

    model.image_features_cache = None

    input_token_len = input_ids.shape[1]
    generated_text_ids = generated_ids[:, input_token_len:]
    
    generated_text = tokenizer.batch_decode(generated_text_ids, skip_special_tokens=True)[0]
    return generated_text.strip()

# 기존 질문 내용

index = 1

image_path = "./open/"+test_dataset['img_path'][index]

question = test_dataset['Question'][index]+" A."+ test_dataset['A'][index] + " B." + test_dataset['B'][index] + " C." + test_dataset['C'][index] + " D." + test_dataset['D'][index]
prompt = f"""ROLE: You're an ASSISTANT, I give you an image and a question, and you answer it for me.\n 
            RULE: Please select only one correct answer from A,B,C,D\n
            USER: {question}\n
            ASSISTANT:"""

result = inference(model, tokenizer, image_processor, prompt, image_path)

plt.imshow(plt.imread(image_path))
plt.axis('off')
plt.show()

print(f"Question: {question}")
print(f"Generated text: {result}")