In [10]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import sys
sys.path.append("/code/LLM-crime/single_model")
from my_models import TransformerRegressionModel, ResNet50Model, ViTClassifier
from LLM_feature_extractor import LLaVaFeatureExtractor
from PIL import Image
import torchvision.transforms as transforms
from safety_perception_dataset import *
import neptune
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
from sklearn.metrics import r2_score
import shutil
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

  from .autonotebook import tqdm as notebook_tqdm


In [12]:
cfg_paras = {
    'debug':False,
    'dataset_path':"/data2/cehou/LLM_safety/img_text_data/dataset_baseline_baseline_baseline_baseline_1401.pkl",
    'save_model_path':"/data2/cehou/LLM_safety/LLM_models/clip_model/test",
    'save_model_name':"model_baseline_test.pt",
    'device':torch.device("cuda:2" if torch.cuda.is_available() else "cpu"),
    'batch_size':192,
    'num_workers':4,
    'head_lr':1e-3,
    'image_encoder_lr':1e-4,
    'text_encoder_lr':1e-5,
    'weight_decay':1e-3,
    'img_type':'PlacePulse',
    'patience':1,
    'factor':0.8,
    'epochs':400,
    'image_embedding':2048,
    'text_embedding':768,
    'max_length':512,
    'size':(224,224),
    
    # models for image and text
    'model_name':'resnet50',
    'text_encoder_model':"distilbert-base-uncased",
    'text_tokenizer': "distilbert-base-uncased",
    'pretrained':True,
    'trainable':True,
    
    # deep learning model parameters
    'temperature':0.07,
    'projection_dim':256,
    'dropout':0.1,
    'early_stopping_threshold':5,
    
    # safety perception
    # 'CLIP_model_path': "/data2/cehou/LLM_safety/LLM_models/clip_model/test/model_baseline_best.pt",
    'variables_save_paths': f"/data2/cehou/LLM_safety/middle_variables/test",
    'safety_model_save_path' : f"/data2/cehou/LLM_safety/LLM_models/safety_perception_model/only_img/",
    'placepulse_datapath': "/data2/cehou/LLM_safety/PlacePulse2.0/image_perception_score.csv",
    'eval_path': "/data2/cehou/LLM_safety/eval/test/only_img/",
    'train_type': 'classification',
    'safety_epochs': 200,
    'class_num': 2,
    'CNN_lr': 1*1e-7,
    'weight_on': False
    }

run = neptune.init_run(
    project="ce-hou/Safety",
    api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJmYzFmZTZkYy1iZmY3LTQ1NzUtYTRlNi1iYTgzNjRmNGQyOGUifQ==",
)  # your credentials

data = pd.read_csv(cfg_paras['placepulse_datapath'])
data_ls = data[data['label'] != 0]
data_ls.loc[data_ls[data_ls['label'] == -1].index, 'label'] = 0
transform = get_transforms(cfg_paras['size'])
split_num = int(len(data_ls) * 0.8)

train_dataset = SafetyPerceptionDataset(data_ls[:split_num], transform=transform, paras=cfg_paras)
valid_dataset = SafetyPerceptionDataset(data_ls[split_num:], transform=transform, paras=cfg_paras)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=cfg_paras['batch_size'], shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=cfg_paras['batch_size'])



[neptune] [info   ] Neptune initialized. Open in the app: https://app.neptune.ai/ce-hou/Safety/e/SAF-302


In [None]:
# 从valid_loader中取出一个batch的数据
data_iter = iter(valid_loader)
images, labels = next(data_iter)

# 取出前八张图像
eight_images = images[:8]
eight_labels = labels[:8]

print(eight_images.shape)
print(eight_labels)

torch.Size([8, 3, 224, 224])
tensor([1, 0, 1, 1, 0, 0, 0, 0])


In [40]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models


import torch.nn as nn


# 创建模型实例

class LLMImageFeaturePrextractor(nn.Module):
    def __init__(self, process='mean'):
        super(LLMImageFeaturePrextractor, self).__init__()
        self.llava_extractor = LLaVaFeatureExtractor()
        self.conv_dim1 = nn.Conv2d(3, 1, kernel_size=1)  # 输入3通道，输出1通道
        self.conv_dim2 = nn.Conv2d(3, 3, kernel_size=1)  # 输入3通道，输出3通道
        self.process = process
    
    def forward(self, x):
        img_feature = self.llava_extractor.image_extractor(x)
        
        if self.process == 'mean_dim1':
            img_feature = img_feature.mean(dim=(1))
        if self.process == 'mean_dim2':
            img_feature = img_feature.mean(dim=(2))
        if self.process == 'max_dim1':
            img_feature = img_feature.max(dim=(1))[0]
        if self.process == 'max_dim2':
            img_feature = img_feature.max(dim=(2))[0]
        if self.process == 'reshape':
            img_feature = img_feature.reshape(-1, img_feature.shape[3], img_feature.shape[4])
        if self.process == 'conv_dim1':
            img_feature = self.conv_dim1(img_feature)
        if self.process == 'conv_dim2':
            img_feature = self.conv_dim2(img_feature)
        return img_feature

class Extractor(nn.Module):
    def __init__(self, pretrained_model='resnet18'):
        super(Extractor, self).__init__()
        if pretrained_model == 'ViT':
            pass
        if pretrained_model == 'resnet50':
            self.model = models.resnet50(pretrained=True)
            # 去掉最后的全连接层
            self.model = nn.Sequential(*list(self.model.children())[:-1])
        if pretrained_model == 'resnet18':
            self.model = models.resnet18(pretrained=True)
            # 去掉最后的全连接层
            self.model = nn.Sequential(*list(self.model.children())[:-1])
            

    def forward(self, x):
        # 输入图像 x，返回提取的特征
        with torch.no_grad():  # 禁用梯度计算
            features = self.model(x)
        # 返回特征的展平（flatten）形式
        return features.view(features.size(0), -1)
    
class Adaptor(nn.Module):
    def __init__(
        self,
        input_dim,
        projection_dim,
        data_type
    ):
        super(Adaptor, self).__init__()
        if data_type == 'image':
            self.projection = nn.Linear(input_dim, projection_dim)
        elif data_type == 'text':
            self.projection = nn.Linear(input_dim, projection_dim)
        # self.projection = nn.Linear(cfg_paras['embedding_dim'], cfg_paras['projection_dim'])
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(0.1)
        self.layer_norm = nn.LayerNorm(projection_dim)
    
    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x


class Classifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(Classifier, self).__init__()
        # 一个简单的全连接层作为分类器
        self.fc = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        # 输入适配后的特征向量，输出分类结果
        return self.fc(x)

class FullModel(nn.Module):
    def __init__(self, extractor, adaptor, classifier):
        super(FullModel, self).__init__()
        self.extractor = extractor
        self.adaptor = adaptor
        self.classifier = classifier

    def forward(self, x):
        # 先通过extractor提取特征，再通过adaptor处理，最后分类
        features = self.extractor(x)
        print("extracted feature: ", features.shape)
        adapted_features = self.adaptor(features)
        print("adapted feature: ", adapted_features.shape)
        output = self.classifier(adapted_features)
        print("final feature", output.shape)
        return output

## 训练代码


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
import numpy as np
from torch.utils.data import DataLoader, TensorDataset

batch_size = 8
input_dim = 512  # ResNet18 输出的特征维度
adaptor_output_dim = 256  # 适配器输出的维度
num_classes = 10  # 假设有 10 类
LLM_loaded = True
if LLM_loaded == False:
    LLM_pre_extractor = LLMImageFeaturePrextractor(process='mean_dim1') # LLM将图像提取为一个浅的图像特征，维度为[3,336,336]
    LLM_loaded = True

data = pd.read_csv(cfg_paras['placepulse_datapath'])
data_ls = data[data['label'] != 0]
data_ls.loc[data_ls[data_ls['label'] == -1].index, 'label'] = 0
transform = get_transforms(cfg_paras['size'])
split_num = int(len(data_ls) * 0.8)

train_dataset = SafetyPerceptionDataset(data_ls[:split_num], transform=transform, paras=cfg_paras)
valid_dataset = SafetyPerceptionDataset(data_ls[split_num:], transform=transform, paras=cfg_paras)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=cfg_paras['batch_size'], shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=cfg_paras['batch_size'])


# 将数据集划分为训练集和验证集
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

# 转换为PyTorch的Tensor
X_train_tensor = torch.tensor(X_train)
y_train_tensor = torch.tensor(y_train)
X_val_tensor = torch.tensor(X_val)
y_val_tensor = torch.tensor(y_val)

# 创建DataLoader
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
val_dataset = TensorDataset(X_val_tensor, y_val_tensor)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

# 定义简单的神经网络模型
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(20, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.sigmoid(self.fc3(x))
        return x

# 初始化模型
model = SimpleNN()

# 损失函数和优化器
criterion = nn.BCELoss()  # 二分类交叉熵损失
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Early Stopping类
class EarlyStopping:
    def __init__(self, patience=20, verbose=False):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss < self.best_loss:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
                if self.verbose:
                    print("Early stopping triggered")

## 用来调试的代码

In [None]:
#  测试代码
if __name__ == '__main__':
    # 假设输入的图像大小是 (batch_size, 3, 224, 224)，即 RGB 图像
    batch_size = 8
    input_dim = 512  # ResNet18 输出的特征维度
    adaptor_output_dim = 256  # 适配器输出的维度
    num_classes = 10  # 假设有 10 类
    LLM_loaded = True
    if LLM_loaded == False:
        LLM_pre_extractor = LLMImageFeaturePrextractor(process='mean_dim1') # LLM将图像提取为一个浅的图像特征，维度为[3,336,336]
        LLM_loaded = True
        
    # 从valid_loader中取出一个batch的数据
    data_iter = iter(valid_loader)
    images, labels = next(data_iter)
        
    x = LLM_pre_extractor([images[i] for i in range(8)])
    print(x.shape)
    
    # 初始化模块
    extractor = Extractor(pretrained_model='resnet18')
    adaptor = Adaptor(input_dim=input_dim, projection_dim=adaptor_output_dim, data_type='image')
    classifier = Classifier(input_dim=adaptor_output_dim, num_classes=num_classes)

    # 组合成一个完整的神经网络
    model = FullModel(extractor, adaptor, classifier)

    # 将模型移动到与输入相同的设备
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    x = x.to(device)

    # 前向传播
    output = model(x)
    print(output.shape)  # 输出分类结果的形状


图像检查


In [28]:
x = LLM_pre_extractor([images[i] for i in range(8)])
sns.heatmap(x[0,0].detach().cpu().numpy(), cmap='viridis', cbar=False)
plt.axis('off')

In [None]:
def train_model(train_loader, valid_loader, paras):
    print(f'device: {paras["device"]}')
    if paras['train_type'] == 'regression':
        input_dim = 3 * paras['size'][0] * paras['size'][1]
        model_dim = 512
        num_heads = 8  
        num_layers = 6
        dropout = paras['dropout']
        output_dim = 1
        model = ResNet50Model(output_dim).to(paras['device'])
        # print(model)
        # model = TransformerRegressionModel(input_dim, model_dim, num_heads, num_layers, output_dim, dropout).to(paras['device'])
        criterion = nn.MSELoss()
        optimizer = optim.Adam(model.parameters(), lr=paras["CNN_lr"])
        
    elif paras['train_type'] == 'classification':
        input_dim = 3 * paras['size'][0] * paras['size'][1]
        output_dim = paras['class_num']
        # model = ResNet50Model(output_dim).to(paras['device'])
        # model = ViTClassifier(output_dim).to(paras['device'])
        # print(model)
        # model = ViTClassifier(num_classes=paras['class_num'],input_dim=input_dim).to(paras['device'])
        if paras['weight_on']:
            class_weights = torch.FloatTensor(paras['class_weights']).to(paras['device'])
            # print("class_weights: ", class_weights)
            criterion = nn.CrossEntropyLoss(weight=class_weights)
        else:
            criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=paras["CNN_lr"])

    # Training loop
    num_epochs = paras['safety_epochs']
    best_loss = float('inf')
    count_after_best = 0
    for epoch in range(num_epochs):
        model.train()
        train_running_loss = 0.0
        tqdm_loader = tqdm(train_loader, total=len(train_loader))
        for inputs,labels in tqdm_loader:
            inputs = inputs.to(paras['device']) #16, 3, 300, 400
            # print(labels)
            if paras['train_type'] == 'classification':
                labels = labels.to(paras['device']).long()
            elif paras['train_type'] == 'regression':
                labels = labels.to(paras['device']).float()
                
            optimizer.zero_grad()
            outputs = model(inputs) # 16, 6
            if outputs.shape[1] == 1:
                outputs = outputs.squeeze(1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_running_loss += loss.item()
            # print("train_running_loss: ", loss.item())
            
            # Update tqdm description with current loss
            tqdm_loader.set_description(f"Epoch [{epoch+1}/{num_epochs}]")
            tqdm_loader.set_postfix(loss=train_running_loss)

        model.eval()
        val_running_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for i, (inputs,labels) in enumerate(valid_loader):
                if paras['train_type'] == 'classification':
                    inputs = inputs.to(paras['device'])
                    labels = labels.to(paras['device']).long()
                    # print(labels)
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    val_running_loss += loss.item()
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
                    # Record predictions and true labels
                    if i == 0:
                        all_preds = predicted.cpu().numpy()
                        all_labels = labels.cpu().numpy()
                    else:
                        all_preds = np.concatenate((all_preds, predicted.cpu().numpy()))
                        all_labels = np.concatenate((all_labels, labels.cpu().numpy()))
                        
                elif paras['train_type'] == 'regression':
                    inputs = inputs.to(paras['device'])
                    labels = labels.to(paras['device']).float()
                    outputs = model(inputs)
                    outputs = outputs.squeeze(1)
                    loss = criterion(outputs, labels)
                    val_running_loss += loss.item()
                    if i == 0:
                        all_preds = outputs.cpu().numpy()
                        all_labels = labels.cpu().numpy()
                    else:
                        all_preds = np.concatenate((all_preds, outputs.cpu().numpy()))
                        all_labels = np.concatenate((all_labels, labels.cpu().numpy()))

        count_after_best += 1
        if val_running_loss < best_loss:
            best_loss = val_running_loss
            count_after_best = 0
            if not os.path.exists(paras['safety_model_save_path']):
                os.makedirs(paras['safety_model_save_path'])
            torch.save(model.state_dict(), os.path.join(paras['safety_model_save_path'], f"best_{paras['train_type']}_model.pth"))
            print(f"save the best model to {os.path.join(paras['safety_model_save_path'])}.")
        run["train/total_loss"].append(train_running_loss/len(train_loader))
        run["valid/total_loss"].append(val_running_loss/len(valid_loader))
        
        print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_running_loss/train_loader.batch_size:.4f}, Validation Loss: {val_running_loss/valid_loader.batch_size:.4f}")
        if paras['train_type'] == 'classification':
            run["valid/accuracy"].append(correct / total)
            print(f"Accuracy: {100 * correct / total:.2f}%")
            # Calculate confusion matrix
            cm = confusion_matrix(all_labels, all_preds)
            # Plot confusion matrix
            plt.figure(figsize=(10, 8))
            # plt.xlabel("Predicted")
            # plt.ylabel("True")
            sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", cbar=False, 
                        annot_kws={"size": 12, "weight": "bold", "color": "red"} 
                        )
            plt.title(f"Confusion Matrix epoch {epoch+1} acc: {correct/total:0.2%}")
            cm_savepath = os.path.join(paras['eval_path'], 'valid_cm')
            if not os.path.exists(cm_savepath):
                os.makedirs(cm_savepath)
            plt.savefig(os.path.join(cm_savepath, f"confusion_matrix_epoch_{epoch+1}.png"))
            plt.close()
        elif paras['train_type'] == 'regression':
            r2 = r2_score(all_preds, all_labels)
            run["valid/r2_score"].append(r2)
            print(f"R2 score: {r2:.2f}")       
            # Plot R2 score curve
            plt.figure(figsize=(10, 8))
            sns.regplot(x=all_labels, y=all_preds, scatter_kws={'s':10}, line_kws={"color":"red"})
            plt.xlim(-0.5,1.5)
            plt.ylim(-0.5,1.5)
            plt.xlabel("True Labels")
            plt.ylabel("Predicted Labels")
            plt.title(f"Regression Results epoch {epoch+1} R2: {r2:.2f}")
            regplot_savepath = os.path.join(paras['eval_path'], 'regression_plots')
            if not os.path.exists(regplot_savepath):
                os.makedirs(regplot_savepath)
            plt.savefig(os.path.join(regplot_savepath, f"regression_plot_epoch_{epoch+1}.png"))
            plt.close()
                
        if count_after_best > paras['early_stopping_threshold']:
            print("Early Stopping!")
            break