In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import sys
sys.path.append("/code/LLM-crime/single_model")
from my_models import TransformerRegressionModel, ResNet50Model, ViTClassifier
from LLM_feature_extractor import LLaVaFeatureExtractor
from PIL import Image
import torchvision.transforms as transforms
from safety_perception_dataset import *
import neptune
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
from sklearn.metrics import r2_score
import shutil
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
cfg_paras = {
    'debug':False,
    'dataset_path':"/data2/cehou/LLM_safety/img_text_data/dataset_baseline_baseline_baseline_baseline_1401.pkl",
    'save_model_path':"/data2/cehou/LLM_safety/LLM_models/clip_model/test",
    'save_model_name':"model_baseline_test.pt",
    'device':torch.device("cuda:2" if torch.cuda.is_available() else "cpu"),
    'batch_size':192,
    'num_workers':4,
    'head_lr':1e-3,
    'image_encoder_lr':1e-4,
    'text_encoder_lr':1e-5,
    'weight_decay':1e-3,
    'img_type':'PlacePulse',
    'patience':1,
    'factor':0.8,
    'epochs':400,
    'image_embedding':2048,
    'text_embedding':768,
    'max_length':512,
    'size':(224,224),
    
    # models for image and text
    'model_name':'resnet50',
    'text_encoder_model':"distilbert-base-uncased",
    'text_tokenizer': "distilbert-base-uncased",
    'pretrained':True,
    'trainable':True,
    
    # deep learning model parameters
    'temperature':0.07,
    'projection_dim':256,
    'dropout':0.1,
    'early_stopping_threshold':5,
    
    # safety perception
    # 'CLIP_model_path': "/data2/cehou/LLM_safety/LLM_models/clip_model/test/model_baseline_best.pt",
    'variables_save_paths': f"/data2/cehou/LLM_safety/middle_variables/test",
    'safety_model_save_path' : f"/data2/cehou/LLM_safety/LLM_models/safety_perception_model/only_img/",
    'placepulse_datapath': "/data2/cehou/LLM_safety/PlacePulse2.0/image_perception_score.csv",
    'eval_path': "/data2/cehou/LLM_safety/eval/test/only_img/",
    'train_type': 'classification',
    'safety_epochs': 200,
    'class_num': 2,
    'CNN_lr': 1*1e-7,
    'weight_on': False
    }

run = neptune.init_run(
    project="ce-hou/Safety",
    api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJmYzFmZTZkYy1iZmY3LTQ1NzUtYTRlNi1iYTgzNjRmNGQyOGUifQ==",
)  # your credentials

data = pd.read_csv(cfg_paras['placepulse_datapath'])
data_ls = data[data['label'] != 0]
data_ls.loc[data_ls[data_ls['label'] == -1].index, 'label'] = 0
transform = get_transforms(cfg_paras['size'])
split_num = int(len(data_ls) * 0.8)

train_dataset = SafetyPerceptionDataset(data_ls[:split_num], transform=transform, paras=cfg_paras)
valid_dataset = SafetyPerceptionDataset(data_ls[split_num:], transform=transform, paras=cfg_paras)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=cfg_paras['batch_size'], shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=cfg_paras['batch_size'])

model = LLaVaFeatureExtractor()

In [63]:
class ProjectionHead(nn.Module):
    def __init__(
        self,
        cfg_paras,
        data_type
    ):
        super().__init__()
        if data_type == 'image':
            self.projection = nn.Linear(cfg_paras['image_embedding'], cfg_paras['projection_dim'])
        elif data_type == 'text':
            self.projection = nn.Linear(cfg_paras['text_embedding'], cfg_paras['projection_dim'])
        # self.projection = nn.Linear(cfg_paras['embedding_dim'], cfg_paras['projection_dim'])
        self.gelu = nn.GELU()
        self.fc = nn.Linear(cfg_paras['projection_dim'], cfg_paras['projection_dim'])
        self.dropout = nn.Dropout(cfg_paras['dropout'])
        self.layer_norm = nn.LayerNorm(cfg_paras['projection_dim'])
    
    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x

In [None]:
# 从valid_loader中取出一个batch的数据
data_iter = iter(valid_loader)
images, labels = next(data_iter)

# 取出前五张图像
five_images = images[:5]
five_labels = labels[:5]

img_feature = model.image_extractor([images[i].permute(1, 2, 0) for i in range(5)]) # [batch_size, 3, 3, 336, 336]

In [101]:
# 1 mean
img_feature_mean = img_feature[0].mean(dim=(0, 1))
img_feature_max = img_feature[0].max(dim=(0))[0].max(dim=(0))[0]
img_feature_reshaped = img_feature[0].reshape(-1, img_feature.shape[3], img_feature.shape[4]) # [3, 3, 336, 336] -> [9, 336, 336]


In [None]:
class ProjectionHead(nn.Module):
    def __init__(
        self,
        cfg_paras,
        data_type
    ):
        super().__init__()
        if data_type == 'image':
            self.projection = nn.Linear(cfg_paras['image_embedding'], cfg_paras['projection_dim'])
        elif data_type == 'text':
            self.projection = nn.Linear(cfg_paras['text_embedding'], cfg_paras['projection_dim'])
        # self.projection = nn.Linear(cfg_paras['embedding_dim'], cfg_paras['projection_dim'])
        self.gelu = nn.GELU()
        self.fc = nn.Linear(cfg_paras['projection_dim'], cfg_paras['projection_dim'])
        self.dropout = nn.Dropout(cfg_paras['dropout'])
        self.layer_norm = nn.LayerNorm(cfg_paras['projection_dim'])
    
    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x
    
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(in_features=32 * 56 * 56, out_features=512)
        self.fc2 = nn.Linear(in_features=512, out_features=num_classes)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(-1, 32 * 56 * 56)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [3]:
import torch.nn as nn

class FeatureExtractor(nn.Module):
    def __init__(self, cfg_paras, process='mean'):
        super(FeatureExtractor, self).__init__()
        self.llava_extractor = LLaVaFeatureExtractor()
        self.cnn_block = SimpleCNN(num_classes=cfg_paras['class_num'])
        self.projection_head = ProjectionHead(cfg_paras, data_type='image')
        self.fc = nn.Linear(cfg_paras['projection_dim'], 2048)
        self.process = process
    
    def forward(self, x):
        img_feature = self.llava_extractor.image_extractor(x)
        if self.process == 'mean':
            img_feature = img_feature.mean(dim=(0, 1))
        elif self.process == 'max':
            img_feature = img_feature.max(dim=(0))[0].max(dim=(0))[0]
        elif self.process == 'reshape':
            img_feature = img_feature.reshape(-1, img_feature.shape[3], img_feature.shape[4])
        img_feature = self.cnn_block(img_feature)
        projected_feature = self.projection_head(img_feature)
        output = self.fc(projected_feature)
        return output

# 创建模型实例
feature_extractor_model = FeatureExtractor(cfg_paras, process='mean').to(cfg_paras['device'])

# 测试模型
img_feature = model.image_extractor([images[i].permute(1, 2, 0) for i in range(5)])
output = feature_extractor_model(img_feature_mean)

print(output.shape)  # 应该输出 torch.Size([5, 2048])

NameError: name 'cfg_paras' is not defined

In [None]:
class Trainer:
    def __init__(self, model, train_loader, valid_loader, criterion, optimizer, device, save_path, num_epochs=25, early_stopping_threshold=5):
        self.model = model
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        self.save_path = save_path
        self.num_epochs = num_epochs
        self.early_stopping_threshold = early_stopping_threshold

    def train(self):
        best_loss = float('inf')
        count_after_best = 0

        for epoch in range(self.num_epochs):
            self.model.train()
            train_running_loss = 0.0
            for inputs, labels in self.train_loader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                self.optimizer.zero_grad()
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()
                train_running_loss += loss.item()

            valid_running_loss = self.validate()

            if valid_running_loss < best_loss:
                best_loss = valid_running_loss
                count_after_best = 0
                torch.save(self.model.state_dict(), self.save_path)
                print(f"Saved best model to {self.save_path}")
            else:
                count_after_best += 1

            print(f"Epoch [{epoch+1}/{self.num_epochs}], Train Loss: {train_running_loss/len(self.train_loader):.4f}, Validation Loss: {valid_running_loss/len(self.valid_loader):.4f}")

            if count_after_best > self.early_stopping_threshold:
                print("Early Stopping!")
                break

    def validate(self):
        self.model.eval()
        valid_running_loss = 0.0
        with torch.no_grad():
            for inputs, labels in self.valid_loader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)
                valid_running_loss += loss.item()
        return valid_running_loss

In [None]:
trainer = Trainer(
    model = LLaVaFeatureExtractor(),
    train_loader=train_loader,
    valid_loader=valid_loader,
    criterion=nn.CrossEntropyLoss(),
    optimizer=optim.Adam(lr=cfg_paras["CNN_lr"]),
    device='cuda:1',
    save_path=os.path.join(cfg_paras['safety_model_save_path'], "best_llm_model.pth"),
    num_epochs=cfg_paras['safety_epochs'],
    early_stopping_threshold=cfg_paras['early_stopping_threshold']
)

trainer.train()

In [None]:
def train_model(train_loader, valid_loader, paras):
    print(f'device: {paras["device"]}')
    if paras['train_type'] == 'regression':
        input_dim = 3 * paras['size'][0] * paras['size'][1]
        model_dim = 512
        num_heads = 8  
        num_layers = 6
        dropout = paras['dropout']
        output_dim = 1
        model = ResNet50Model(output_dim).to(paras['device'])
        # print(model)
        # model = TransformerRegressionModel(input_dim, model_dim, num_heads, num_layers, output_dim, dropout).to(paras['device'])
        criterion = nn.MSELoss()
        optimizer = optim.Adam(model.parameters(), lr=paras["CNN_lr"])
        
    elif paras['train_type'] == 'classification':
        input_dim = 3 * paras['size'][0] * paras['size'][1]
        output_dim = paras['class_num']
        # model = ResNet50Model(output_dim).to(paras['device'])
        # model = ViTClassifier(output_dim).to(paras['device'])
        # print(model)
        # model = ViTClassifier(num_classes=paras['class_num'],input_dim=input_dim).to(paras['device'])
        if paras['weight_on']:
            class_weights = torch.FloatTensor(paras['class_weights']).to(paras['device'])
            # print("class_weights: ", class_weights)
            criterion = nn.CrossEntropyLoss(weight=class_weights)
        else:
            criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=paras["CNN_lr"])

    # Training loop
    num_epochs = paras['safety_epochs']
    best_loss = float('inf')
    count_after_best = 0
    for epoch in range(num_epochs):
        model.train()
        train_running_loss = 0.0
        tqdm_loader = tqdm(train_loader, total=len(train_loader))
        for inputs,labels in tqdm_loader:
            inputs = inputs.to(paras['device']) #16, 3, 300, 400
            # print(labels)
            if paras['train_type'] == 'classification':
                labels = labels.to(paras['device']).long()
            elif paras['train_type'] == 'regression':
                labels = labels.to(paras['device']).float()
                
            optimizer.zero_grad()
            outputs = model(inputs) # 16, 6
            if outputs.shape[1] == 1:
                outputs = outputs.squeeze(1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_running_loss += loss.item()
            # print("train_running_loss: ", loss.item())
            
            # Update tqdm description with current loss
            tqdm_loader.set_description(f"Epoch [{epoch+1}/{num_epochs}]")
            tqdm_loader.set_postfix(loss=train_running_loss)

        model.eval()
        val_running_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for i, (inputs,labels) in enumerate(valid_loader):
                if paras['train_type'] == 'classification':
                    inputs = inputs.to(paras['device'])
                    labels = labels.to(paras['device']).long()
                    # print(labels)
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    val_running_loss += loss.item()
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
                    # Record predictions and true labels
                    if i == 0:
                        all_preds = predicted.cpu().numpy()
                        all_labels = labels.cpu().numpy()
                    else:
                        all_preds = np.concatenate((all_preds, predicted.cpu().numpy()))
                        all_labels = np.concatenate((all_labels, labels.cpu().numpy()))
                        
                elif paras['train_type'] == 'regression':
                    inputs = inputs.to(paras['device'])
                    labels = labels.to(paras['device']).float()
                    outputs = model(inputs)
                    outputs = outputs.squeeze(1)
                    loss = criterion(outputs, labels)
                    val_running_loss += loss.item()
                    if i == 0:
                        all_preds = outputs.cpu().numpy()
                        all_labels = labels.cpu().numpy()
                    else:
                        all_preds = np.concatenate((all_preds, outputs.cpu().numpy()))
                        all_labels = np.concatenate((all_labels, labels.cpu().numpy()))

        count_after_best += 1
        if val_running_loss < best_loss:
            best_loss = val_running_loss
            count_after_best = 0
            if not os.path.exists(paras['safety_model_save_path']):
                os.makedirs(paras['safety_model_save_path'])
            torch.save(model.state_dict(), os.path.join(paras['safety_model_save_path'], f"best_{paras['train_type']}_model.pth"))
            print(f"save the best model to {os.path.join(paras['safety_model_save_path'])}.")
        run["train/total_loss"].append(train_running_loss/len(train_loader))
        run["valid/total_loss"].append(val_running_loss/len(valid_loader))
        
        print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_running_loss/train_loader.batch_size:.4f}, Validation Loss: {val_running_loss/valid_loader.batch_size:.4f}")
        if paras['train_type'] == 'classification':
            run["valid/accuracy"].append(correct / total)
            print(f"Accuracy: {100 * correct / total:.2f}%")
            # Calculate confusion matrix
            cm = confusion_matrix(all_labels, all_preds)
            # Plot confusion matrix
            plt.figure(figsize=(10, 8))
            # plt.xlabel("Predicted")
            # plt.ylabel("True")
            sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", cbar=False, 
                        annot_kws={"size": 12, "weight": "bold", "color": "red"} 
                        )
            plt.title(f"Confusion Matrix epoch {epoch+1} acc: {correct/total:0.2%}")
            cm_savepath = os.path.join(paras['eval_path'], 'valid_cm')
            if not os.path.exists(cm_savepath):
                os.makedirs(cm_savepath)
            plt.savefig(os.path.join(cm_savepath, f"confusion_matrix_epoch_{epoch+1}.png"))
            plt.close()
        elif paras['train_type'] == 'regression':
            r2 = r2_score(all_preds, all_labels)
            run["valid/r2_score"].append(r2)
            print(f"R2 score: {r2:.2f}")       
            # Plot R2 score curve
            plt.figure(figsize=(10, 8))
            sns.regplot(x=all_labels, y=all_preds, scatter_kws={'s':10}, line_kws={"color":"red"})
            plt.xlim(-0.5,1.5)
            plt.ylim(-0.5,1.5)
            plt.xlabel("True Labels")
            plt.ylabel("Predicted Labels")
            plt.title(f"Regression Results epoch {epoch+1} R2: {r2:.2f}")
            regplot_savepath = os.path.join(paras['eval_path'], 'regression_plots')
            if not os.path.exists(regplot_savepath):
                os.makedirs(regplot_savepath)
            plt.savefig(os.path.join(regplot_savepath, f"regression_plot_epoch_{epoch+1}.png"))
            plt.close()
                
        if count_after_best > paras['early_stopping_threshold']:
            print("Early Stopping!")
            break