In [22]:
import argparse
from email.mime import image
import torch
import os
import json
from tqdm import tqdm
import shortuuid
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from PIL import Image
import math


def split_list(lst, n):
    chunk_size = math.ceil(len(lst) / n)  # integer division
    return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]


def get_chunk(lst, n, k):
    chunks = split_list(lst, n)
    return chunks[k]


# Custom dataset class
class CustomDataset(Dataset):
    def __init__(self,args, tokenizer, model_config):
        self.list_data_dict = json.load(open(args.data_path, "r"))
        self.list_data_dict = [e for e in self.list_data_dict if (e['task']=="vqa")]
        self.image_files=[]
        self.weather_labels=[]
        for data_dict in self.list_data_dict:
            conversations=data_dict['conversations']
            for conversation in conversations:
                if conversation['from']=='gpt':
                    if conversation['value']=='sunny':
                        self.image_files.append(data_dict['image'])
                        self.weather_labels.append(0)
                    elif conversation['value']=='raining':
                        self.image_files.append(data_dict['image'])
                        self.weather_labels.append(1)
                    elif conversation['value']=='cloudy':
                        self.image_files.append(data_dict['image'])
                        self.weather_labels.append(2)
        self.tokenizer = tokenizer
        self.model_config = model_config
        self.image_folder = args.image_folder
        self.gen_processor=args.gen_processor
        self.un_processor=args.un_processor
        self.conv_mode=args.conv_mode

    def __getitem__(self, index):
        
        image_file = self.image_files[index]
        image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB')

        image_un = self.un_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]

        return  image_un,self.weather_labels[index]

    def __len__(self):
        return len(self.image_files)


def collate_fn(batch):
    input_ids, image_un,image_gen = zip(*batch)
    image_un= [img for img in image_un if img is not None]
    image_gen= [img for img in image_gen if img is not None]
    input_ids = torch.stack(input_ids, dim=0)
    image_un = torch.stack(image_un, dim=0) if len(image_un) > 0 else None
    image_gen = torch.stack(image_gen, dim=0) if len(image_gen) > 0 else None
    images={'images_un':image_un,'images_gen':image_gen}
    return input_ids, images


# DataLoader
def create_data_loader(args, tokenizer, model_config, batch_size=50, num_workers=4):
    
    dataset = CustomDataset(args, tokenizer, model_config)
    data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
    return data_loader,dataset

# def generate_image(input_ids,model,num_image_tokens):
#     output_img=[]
#     inputs_embeds=model.get_model().embed_tokens(input_ids) #1, seq_le, 4096
#     with torch.inference_mode():
#         for i in range(num_image_tokens):
#             outputs = model.model(
#                 input_ids=None,
#                 attention_mask=None,
#                 position_ids=None,
#                 past_key_values=None,
#                 inputs_embeds=inputs_embeds,
#             )
#             hidden_states = outputs[0]
#             img = model.get_model().mm_projector_head(hidden_states[:,-1,:])
#             output_img.append(img)
#             if model.get_model().mm_projector_gen is not None:
#                 new_embed=model.get_model().mm_projector_gen(img)
#             else:
#                 new_embed=model.get_model().mm_projector_un(img)
#             new_embed=new_embed.unsqueeze(1).to(inputs_embeds.device)
#             inputs_embeds=torch.cat([inputs_embeds,new_embed],dim=1)
            
#     return output_img


# def generate_image_vq(input_ids,model,num_image_tokens):
#     output_img_id=[]
#     inputs_embeds=model.get_model().embed_tokens(input_ids) #1, seq_le, 4096
#     with torch.inference_mode():
#         for i in range(num_image_tokens):
#             outputs = model.model(
#                 input_ids=None,
#                 attention_mask=None,
#                 position_ids=None,
#                 past_key_values=None,
#                 inputs_embeds=inputs_embeds,
#             )
#             hidden_states = outputs[0]
#             img_logits = model.get_model().mm_projector_head(hidden_states[:,-1,:])
#             img_id=img_logits.argmax(dim=-1) # shape (1,)
#             output_img_id.append(img_id)
#             img_latent=model.get_model().vision_tower_gen.vision_tower.quantize.get_codebook_entry(img_id, shape=None, channel_first=True) # (1,8)
#             if model.get_model().mm_projector_gen is not None:
#                 new_embed=model.get_model().mm_projector_gen(img_latent)
#             else:
#                 new_embed=model.get_model().mm_projector_un(img_latent)
#             new_embed=new_embed.unsqueeze(1).to(inputs_embeds.device)
#             inputs_embeds=torch.cat([inputs_embeds,new_embed],dim=1)
            
#     return output_img_id




    


In [68]:
class Args_main:
    def __init__(self):
        self.device = 'cuda:7'
        self.ckpt_start = 10
        self.ckpt_step = 30
        self.ckpt_num = 1
        self.model_name = 'llava-v1.5-7b-siglip-vq_u_weather_biased-2-u-sw-lora'
        self.understanding_only = False
        self.generation_only = False
args = Args_main()

#load trained model
device=args.device
ckp_list=[i*args.ckpt_step for i in range(args.ckpt_start,args.ckpt_num+args.ckpt_start)]
model_name=args.model_name
understanding_only=args.understanding_only
generation_only=args.generation_only
model_list=[f'/public_data/jihai/understanding/scripts/v1_5/checkpoints/{model_name}/checkpoint-{i}' for i in ckp_list]
k=0
infer_args = type('Args', (), {
    "model_path": model_list[k],
    "model_base": '/public_data/jihai/tmp/vicuna-7b-v1.5',
    "data_path": '/public_data/jihai/data/multimodalout/smart_watch_image_train_weather_tl.json',
    "image_folder": '/public_data/jihai/data/multimodalout/smart_watch_image_train_weather_tl',
    "conv_mode": "llava_v1",
    "num_chunks": 1,
    "chunk_idx": 0,
    "temperature": 0,
    "top_p": None,
    "num_beams": 1,
    "max_new_tokens": 128,
    "image_un_size": [3,224,224],
    "image_gen_size": [3,256,256]
})()



In [69]:

disable_torch_init()
model_path = os.path.expanduser(infer_args.model_path)

model_type = get_model_name_from_path(model_path)
tokenizer, model, image_processor,image_processor_gen, context_len = load_pretrained_model(model_path, infer_args.model_base, model_name,device=device)




Loading LLaVA from base model...


Loading checkpoint shards: 100%|██████████| 2/2 [00:36<00:00, 18.40s/it]
Some weights of LlavaLlamaForCausalLM_ImgGen were not initialized from the model checkpoint at /public_data/jihai/tmp/vicuna-7b-v1.5 and are newly initialized: ['model.mm_projector_gen.bias', 'model.mm_projector_gen.weight', 'model.mm_projector_head.bias', 'model.mm_projector_head.weight', 'model.mm_projector_un.0.bias', 'model.mm_projector_un.0.weight', 'model.mm_projector_un.2.bias', 'model.mm_projector_un.2.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loading additional LLaVA weights...
Loading LoRA weights...
Merging LoRA weights...
Model is loaded...


In [71]:
infer_args.gen_processor=image_processor_gen
infer_args.un_processor=image_processor

if 'plain' in model_type and 'finetune' not in model_type.lower() and 'mmtag' not in infer_args.conv_mode:
    infer_args.conv_mode = infer_args.conv_mode + '_mmtag'
    print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {infer_args.conv_mode}.')
data_loader,data_set = create_data_loader(infer_args, tokenizer, model.config)
print (len(data_set))
image_features=[]
labels=[]
for image, w_label in tqdm(data_loader):
    image_feature=model.get_model().get_vision_tower()(image.to(device).half())
   
    if isinstance(image_feature, tuple): #VQ
        image_feature = image_feature[0]
        image_feature=image_feature.permute(0,2,3,1) #b,h,w,c
        image_feature=image_feature.view(image_feature.shape[0],-1,image_feature.shape[-1]) #b,seq_len,c
    
    #image_feature=image_feature[0]
   
    image_feature=image_feature[:,0,:]
    image_feature=model.get_model().mm_projector_un(image_feature)
    image_feature=image_feature.cpu()
    image_features.append(image_feature)
    labels.append(w_label)
image_features=torch.cat(image_features,dim=0)
labels=torch.cat(labels, dim=0)

    


5585


  0%|          | 0/112 [00:00<?, ?it/s]

100%|██████████| 112/112 [00:05<00:00, 18.95it/s]


In [72]:
torch.save(image_features,'./eval/weather_features_siglip-u.pt')
torch.save(labels,'./eval/weather_labels_siglip-u.pt')

In [4]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
image_features=torch.load('./eval/weather_features_siglip-siglip.pt')
labels=torch.load('./eval/weather_labels_siglip-siglip.pt')
# 假设你的数据为PyTorch张量，需要先转换为numpy数组
# 转换特征数据 [5585, 4096]
features = image_features.detach().cpu().numpy()  # 如果已经在CPU可省略.detach().cpu()
# 转换标签数据 [5585,]
labels = labels.detach().cpu().numpy()

# 数据预处理
# 1. 标准化处理（重要！因为t-SNE对尺度敏感）
scaler = StandardScaler()
features_scaled = scaler.fit_transform(features)
#features_scaled=features

# 2. 可选：先用PCA降维到50维加速计算（推荐处理高维数据）
pca = PCA(n_components=50)
features_pca = pca.fit_transform(features_scaled)

# 执行t-SNE降维（推荐先PCA再t-SNE）
tsne = TSNE(n_components=2, perplexity=30, learning_rate=200, 
            n_iter=1000, random_state=42)
tsne_results = tsne.fit_transform(features_pca)  # 直接使用原始特征可替换为features_scaled

# 可视化
plt.figure(figsize=(4, 4))
colors = ['red', 'green', 'blue']
labels_dict = {0: 'Sunny', 1: 'Rainy', 2: 'Cloudy'}

for i in [0, 1, 2]:
    plt.scatter(tsne_results[labels == i, 0], 
                tsne_results[labels == i, 1], 
                c=colors[i], 
                label=labels_dict[i],
                alpha=0.6,
                s=10)  # 点的大小可根据需要调整

#plt.title('t-SNE Visualization of Image Features', fontsize=14)
#plt.xlabel('t-SNE Component 1', fontsize=12)
#plt.ylabel('t-SNE Component 2', fontsize=12)
ax = plt.gca()
ax.axis('off')
#plt.legend(fontsize=16,markerscale=2)
plt.grid(True, linestyle='--', alpha=0.5)
plt.tight_layout()
plt.savefig('./eval/weather_features_siglip-siglip.png', dpi=300)
plt.close()

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split
from sklearn.preprocessing import StandardScaler

# 设置随机种子
torch.manual_seed(42)
image_features=torch.load('./eval/weather_features_vq-u.pt')
labels=torch.load('./eval/weather_labels_vq-u.pt')
# 假设你的数据为PyTorch张量，需要先转换为numpy数组
# 转换特征数据 [5585, 4096]
features = image_features.detach().cpu().numpy()  # 如果已经在CPU可省略.detach().cpu()
# 转换标签数据 [5585,]
labels = labels.detach().cpu().numpy()
# 假设原始数据已经转换为numpy数组
# features: [5585, 4096]
# labels: [5585,]

# 1. 数据预处理
scaler = StandardScaler()
features_scaled = scaler.fit_transform(features)  # 使用之前t-SNE相同的标准化参数

# 转换为PyTorch张量
X = torch.tensor(features_scaled, dtype=torch.float32)
y = torch.tensor(labels, dtype=torch.long)

# 创建数据集
dataset = TensorDataset(X, y)

# 划分数据集（4500训练，1085测试）
train_size = 4500
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# 创建DataLoader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128)

# 2. 定义线性分类器
class LinearProbe(nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.linear = nn.Linear(input_dim, num_classes)
        
    def forward(self, x):
        return self.linear(x)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = LinearProbe(4096, 3).to(device)

# 3. 训练配置
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)

# 4. 训练循环
num_epochs = 50
best_acc = 0.0

for epoch in range(num_epochs):
    # 训练阶段
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0
    
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    
    scheduler.step()
    
    # 验证阶段
    model.eval()
    test_loss = 0.0
    test_correct = 0
    test_total = 0
    
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            test_total += targets.size(0)
            test_correct += predicted.eq(targets).sum().item()
    
    # 打印结果
    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"Train Loss: {train_loss/len(train_loader):.4f} | Train Acc: {100*correct/total:.2f}%")
    print(f"Test Loss: {test_loss/len(test_loader):.4f} | Test Acc: {100*test_correct/test_total:.2f}%")
    
    # 保存最佳模型
    current_acc = test_correct / test_total
    if current_acc > best_acc:
        best_acc = current_acc
        torch.save(model.state_dict(), 'best_linear_probe.pth')

print(f"Best Test Accuracy: {100*best_acc:.2f}%")

Epoch 1/50
Train Loss: 0.0107 | Train Acc: 99.27%
Test Loss: 0.0000 | Test Acc: 100.00%
Epoch 2/50
Train Loss: 0.0000 | Train Acc: 100.00%
Test Loss: 0.0000 | Test Acc: 100.00%
Epoch 3/50
Train Loss: 0.0000 | Train Acc: 100.00%
Test Loss: 0.0000 | Test Acc: 100.00%
Epoch 4/50
Train Loss: 0.0000 | Train Acc: 100.00%
Test Loss: 0.0000 | Test Acc: 100.00%
Epoch 5/50
Train Loss: 0.0000 | Train Acc: 100.00%
Test Loss: 0.0000 | Test Acc: 100.00%
Epoch 6/50
Train Loss: 0.0000 | Train Acc: 100.00%
Test Loss: 0.0000 | Test Acc: 100.00%
Epoch 7/50
Train Loss: 0.0000 | Train Acc: 100.00%
Test Loss: 0.0000 | Test Acc: 100.00%
Epoch 8/50
Train Loss: 0.0000 | Train Acc: 100.00%
Test Loss: 0.0000 | Test Acc: 100.00%
Epoch 9/50
Train Loss: 0.0000 | Train Acc: 100.00%
Test Loss: 0.0000 | Test Acc: 100.00%
Epoch 10/50
Train Loss: 0.0000 | Train Acc: 100.00%
Test Loss: 0.0000 | Test Acc: 100.00%
Epoch 11/50
Train Loss: 0.0000 | Train Acc: 100.00%
Test Loss: 0.0000 | Test Acc: 100.00%
Epoch 12/50
Train Lo

In [6]:


images_gen_pad=torch.zeros([0]+infer_args.image_gen_size).to(device=device, dtype=torch.float16)
images_un_pad=torch.zeros([0]+infer_args.image_un_size).to(device=device, dtype=torch.float16)
count=0
for (input_ids, images), line in tqdm(zip(data_loader, list_data_dict), total=len(list_data_dict)):
    count+=1
    if count==500: break

    cur_prompt = line["conversations"][0]["value"]
    groun_truth=line["conversations"][1]["value"]
    groun_truth_img_tensor=line["image"]
    input_ids = input_ids.to(device=device, non_blocking=True)
    images['images_gen']=images['images_gen'].to(dtype=torch.float16, device=device, non_blocking=True) if images['images_gen'] is not None else images_gen_pad
    images['images_un']=images['images_un'].to(dtype=torch.float16, device=device, non_blocking=True) if images['images_un'] is not None else images_un_pad
    with torch.inference_mode():
        outputs = model.generate(
            input_ids,
            images=images,
            do_sample=True if infer_args.temperature > 0 else False,
            temperature=infer_args.temperature,
            top_p=infer_args.top_p,
            num_beams=infer_args.num_beams,
            max_new_tokens=infer_args.max_new_tokens,
            use_cache=True)
    output_ids=outputs['generated_tokens']
    outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=False)[0].strip()
    #print(outputs)

    img_indicator = torch.tensor([529,  3027, 29958])
    id_seq = output_ids[0].cpu()

    # 子序列长度
    sub_seq_len = len(img_indicator)

    # 滑动窗口查找子序列
    start_idx = -1
    for i in range(id_seq.size(0) - sub_seq_len + 1):
        if torch.equal(id_seq[i:i + sub_seq_len], img_indicator):
            start_idx = i
            break
    img_file=None
    if start_idx != -1:
        output_ids=output_ids[:,1:start_idx+3]
        input_ids=torch.cat((input_ids, output_ids), dim=1)
        img_id=generate_image_vq(input_ids,model,model.get_model().vision_tower_gen.num_patches)
        with torch.no_grad():
            img=model.get_model().vision_tower_gen.vision_tower.decode_code(img_id,[1,8,16,16])
        img = F.interpolate(img, size=[infer_args.image_gen_size[1], infer_args.image_gen_size[2]], mode='bicubic').permute(0, 2, 3, 1)[0]
        img = torch.clamp(127.5 * img + 128.0, 0, 255).to("cpu", dtype=torch.uint8)
        img_file=os.path.join(infer_args.answer_image_file, f'{count}.pt')
        torch.save(img, img_file)

    ans_file.write(json.dumps({"prompt": cur_prompt,
                                "groun_truth": groun_truth,
                                "answer": outputs,
                                "groun_truth_img_tensor": groun_truth_img_tensor,
                                "output_img_file": img_file,
                                "model_id": model_name,
                                "metadata": {}}) + "\n")
    #outputs = tokenizer.batch_decode(input_ids, skip_special_tokens=False)[0].strip()

print(ans_file)
ans_file.close() 

  0%|          | 0/600 [00:00<?, ?it/s]

 10%|▉         | 58/600 [02:06<19:43,  2.18s/it]


KeyboardInterrupt: 

In [47]:
1568/8

196.0