In [None]:
import argparse
import torch
import os
import json
from tqdm import tqdm
import shortuuid
import sys
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

from moellava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from moellava.conversation import conv_templates, SeparatorStyle
from moellava.model.builder import load_pretrained_model
from moellava.utils import disable_torch_init
from moellava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
from moellava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria

from PIL import Image
import math

In [None]:
model_path = ("./checkpoints-clip336/llavastablelm-1.6b-finetune")
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name,device_map='auto')

In [None]:
import json
import torch
from PIL import Image
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import Dataset, DataLoader
from torch import nn, optim
import torch.nn.functional as F

# Load your json with modality

with open('Your json files with modality for training.', 'r', encoding='utf-8') as file:
    data = json.load(file)


unique_modalities = set(item['modality'] for item in data)#make sure your json file with modality
label_map = {modality: idx for idx, modality in enumerate(unique_modalities)}
#label_map = {'MRI': 0, 'CT': 1, 'pathology': 2, 'X-Ray': 3}

class PVQADataset(Dataset):
    def __init__(self, data, tokenizer, image_processor, max_length, label_map):
        self.data = data
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.max_length = max_length
        self.label_map = label_map

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        qs = item['conversations'][0]['value']  # question
        qs = qs.replace('<image>\n', '')
        image_file = "image_path" + item['image']  # image_path
        label = self.label_map[item['modality']]  
        
        image = Image.open(image_file).convert('RGB')
        images = self.image_processor['image'].preprocess(image, return_tensors='pt')['pixel_values']
        
        if getattr(model.config, 'mm_use_im_start_end', False):
            qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
        else:
            qs = DEFAULT_IMAGE_TOKEN + '\n' + qs

        conv = conv_templates['stablelm'].copy()
        conv.append_message(conv.roles[0], qs)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()
        input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')

        return input_ids.squeeze(0), images.squeeze(0), label

max_length = 2048
# Construct Dataset和DataLoader
dataset = PVQADataset(data, tokenizer, image_processor, max_length, label_map)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

# MLP model
class MLPClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(MLPClassifier, self).__init__()
        self.fc1 = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        x = self.fc1(x)
        return x


embedding_dim = 2048  # Stablelm embedding dim==2048,phi2 as 4096
num_classes = len(label_map)

mlp = MLPClassifier(embedding_dim, num_classes).to("cuda")
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(mlp.parameters(), lr=0.001)

accumulation_steps = 32
accumulated_embeddings = []
accumulated_labels = []

for epoch in range(10):  # 10epoch
    for i, (input_ids, images, label) in enumerate(dataloader):
        input_ids, images, label = input_ids.cuda(), images.to("cuda", dtype=torch.float16), label.cuda()
        
        with torch.inference_mode():
            outputs = model(input_ids, images=images, output_hidden_states=True, return_dict=True)
            hidden_states = outputs.hidden_states 

        all_layer_mean = torch.mean(torch.stack(hidden_states), dim=0)  
        embedding = all_layer_mean.mean(dim=1)  # Avg
        accumulated_embeddings.append(embedding.float())
        accumulated_labels.append(label)
        
        if (i + 1) % accumulation_steps == 0:
            accumulated_embeddings = torch.cat(accumulated_embeddings, dim=0)
            accumulated_labels = torch.cat(accumulated_labels, dim=0)

            optimizer.zero_grad()
            outputs = mlp(accumulated_embeddings)
            loss = criterion(outputs, accumulated_labels)
            #print(f"Loss: {loss.item()}")
            loss.backward()
            optimizer.step()

            accumulated_embeddings = []
            accumulated_labels = []

    print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

print("Done")


In [None]:
model_save_path = 'save path'
torch.save(mlp.state_dict(), model_save_path)
print(f"Model has been saved to {model_save_path}")