In [None]:
!nvcc --version
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
!pip install ftfy regex tqdm
!pip install yacs
!pip install torch transformers pytorch-lightning

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Tue_Aug_15_22:02:13_PDT_2023
Cuda compilation tools, release 12.2, V12.2.140
Build cuda_12.2.r12.2/compiler.33191640_0
Looking in indexes: https://download.pytorch.org/whl/cu121


In [None]:
import os
from PIL import Image
import torch
import torch.nn as nn
import torchmetrics
from transformers import BlipProcessor, BlipForConditionalGeneration
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
import pandas as pd
from yacs.config import CfgNode

In [None]:
class Custom_Dataset(Dataset):
    def __init__(self, cfg, root_folder, dataset, label, split='train', image_size=224, fast=True):
        super(Custom_Dataset, self).__init__()
        self.cfg = cfg
        self.root_folder = root_folder
        self.dataset = dataset
        self.split = split
        self.label = label

        self.image_size = image_size
        self.fast = fast

        self.info_file = cfg.info_file
        self.df = pd.read_csv(self.info_file)
        self.df = self.df[self.df['split'] == self.split].reset_index(drop=True)

        if self.label == 'target':
            self.df = self.df[self.df['hate'] == 1].reset_index(drop=True)

        float_cols = self.df.select_dtypes(float).columns
        self.df[float_cols] = self.df[float_cols].fillna(-1).astype('Int64')

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        if row['text'] == 'None':
            text = 'null'
        else:
            text = row['text']

        image_fn = row['name']
        image = Image.open(f"{self.cfg.img_folder}/{image_fn}").convert('RGB')
        image = image.resize((self.image_size, self.image_size))

        return {
            'image': image,
            'text': text,
            'label': row[self.label],
        }

In [None]:
class Custom_Collator(object):
    def __init__(self, cfg):
        self.cfg = cfg
        self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")

    def __call__(self, batch):
        images = [item['image'] for item in batch]
        texts = [item['text'] for item in batch]
        labels = torch.LongTensor([item['label'] for item in batch])

        # 使用 BLIP Processor 处理图像和文本
        inputs = self.processor(
            images=images,
            text=texts,
            return_tensors="pt",
            padding=True
        )

        return {
            'pixel_values': inputs['pixel_values'],  # 预处理后的图像
            'input_ids': inputs['input_ids'],        # 预处理后的文本
            'attention_mask': inputs['attention_mask'],  # 文本注意力掩码
            'labels': labels,                         # 标签
        }


In [None]:
def load_dataset(cfg, split):
    dataset = Custom_Dataset(
        cfg=cfg,
        root_folder=cfg.root_dir,
        dataset=cfg.dataset_name,
        split=split,
        image_size=cfg.image_size,
        label=cfg.label,
        fast=cfg.fast_process
    )
    return dataset


In [None]:
def create_dataloader(cfg, split="train"):
    dataset = load_dataset(cfg, split)
    collator = Custom_Collator(cfg)
    dataloader = DataLoader(
        dataset,
        batch_size=cfg.batch_size,
        shuffle=(split == "train"),
        collate_fn=collator
    )
    return dataloader


In [None]:
# cfg = {
#     "info_file": "/content/drive/MyDrive/Colab_Notebooks/MemeCLIP-main/dataset/PrideMM.csv",
#     "img_folder": "/content/drive/MyDrive/Colab_Notebooks/MemeCLIP-main/dataset/Images",
#     "root_dir": "./",
#     "dataset_name": "Pride",
#     "split": "train",
#     "label": "label_column_name",
#     "image_size": 224,
#     "batch_size": 16,
#     "fast_process": True,
# }

cfg = CfgNode()

# 路径设置
cfg.root_dir = './'
cfg.img_folder = '/content/drive/MyDrive/Colab_Notebooks/MemeCLIP-main/dataset/Images'
cfg.info_file = '/content/drive/MyDrive/Colab_Notebooks/MemeCLIP-main/dataset/PrideMM.csv'
cfg.checkpoint_path = os.path.join(cfg.root_dir, 'checkpoints')
cfg.checkpoint_file = os.path.join(cfg.checkpoint_path, 'model.ckpt')

# 模型与数据集设置
cfg.clip_variant = "ViT-L/14"
cfg.dataset_name = 'Pride'
cfg.name = 'MemeBLIP'
cfg.label = 'hate'
cfg.seed = 42
cfg.test_only = False
cfg.device = 'cuda'
cfg.gpus = [0]

# 根据任务类型动态设置类别
if cfg.label == 'hate':
    cfg.class_names = ['Benign Meme', 'Harmful Meme']
elif cfg.label == 'humour':
    cfg.class_names = ['No Humour', 'Humour']
elif cfg.label == 'target':
    cfg.class_names = ['No particular target', 'Individual', 'Community', 'Organization']
elif cfg.label == 'stance':
    cfg.class_names = ['Neutral', 'Support', 'Oppose']

# 超参数设置
cfg.batch_size = 16
cfg.image_size = 224
cfg.num_mapping_layers = 1
cfg.unmapped_dim = 768
cfg.map_dim = 1024
cfg.num_pre_output_layers = 1
cfg.drop_probs = [0.1, 0.4, 0.2]
cfg.lr = 1e-4
cfg.max_epochs = 10
cfg.weight_decay = 1e-4
cfg.num_classes = len(cfg.class_names)
cfg.scale = 30
cfg.print_model = True
cfg.fast_process = True
cfg.reproduce = False
cfg.ratio = 0.7

print(cfg)


batch_size: 16
checkpoint_file: ./checkpoints/model.ckpt
checkpoint_path: ./checkpoints
class_names: ['Benign Meme', 'Harmful Meme']
clip_variant: ViT-L/14
dataset_name: Pride
device: cuda
drop_probs: [0.1, 0.4, 0.2]
fast_process: True
gpus: [0]
image_size: 224
img_folder: /content/drive/MyDrive/Colab_Notebooks/MemeCLIP-main/dataset/Images
info_file: /content/drive/MyDrive/Colab_Notebooks/MemeCLIP-main/dataset/PrideMM.csv
label: hate
lr: 0.0001
map_dim: 1024
max_epochs: 10
name: MemeBLIP
num_classes: 2
num_mapping_layers: 1
num_pre_output_layers: 1
print_model: True
ratio: 0.7
reproduce: False
root_dir: ./
scale: 30
seed: 42
test_only: False
unmapped_dim: 768
weight_decay: 0.0001


In [None]:
data_file = "/content/drive/MyDrive/Colab_Notebooks/MemeCLIP-main/dataset/PrideMM.csv"
df = pd.read_csv(data_file)
print(df.columns)

# 加载训练和验证数据
train_loader = create_dataloader(cfg, split="train")
val_loader = create_dataloader(cfg, split="val")

for batch in train_loader:
    print(batch.keys())
    print(batch['pixel_values'].shape)  # 图像
    print(batch['input_ids'].shape)     # 文本
    print(batch['labels'].shape)        # 标签
    break


Index(['name', 'hate', 'target', 'stance', 'humour', 'split', 'text'], dtype='object')
dict_keys(['pixel_values', 'input_ids', 'attention_mask', 'labels'])
torch.Size([16, 3, 384, 384])
torch.Size([16, 491])
torch.Size([16])


In [None]:
class MemeBLIP(pl.LightningModule):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        # BLIP 模型和处理器
        self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
        self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(cfg.device)

        # 分类器
        self.map_dim = cfg.map_dim
        self.classifier = nn.Linear(self.map_dim, cfg.num_classes)
        self.cross_entropy_loss = nn.CrossEntropyLoss()

        # 评估指标
        self.acc = torchmetrics.Accuracy(task="multiclass", num_classes=cfg.num_classes)
        self.f1 = torchmetrics.F1Score(task="multiclass", num_classes=cfg.num_classes)

    def preprocess_batch(self, batch):
        inputs = self.processor(
            images=batch["images"],
            text=batch["text"],
            return_tensors="pt",
            padding=True
        ).to(self.device)
        return inputs

    def forward(self, batch):
        inputs = self.preprocess_batch(batch)
        outputs = self.model.generate(**inputs)
        return outputs

    def common_step(self, batch):
        inputs = self.preprocess_batch(batch)
        logits = self.model(**inputs).logits
        loss = self.cross_entropy_loss(logits, batch["labels"])
        preds = torch.argmax(logits, dim=-1)
        acc = self.acc(preds, batch["labels"])
        f1 = self.f1(preds, batch["labels"])
        return {"loss": loss, "acc": acc, "f1": f1}

    def training_step(self, batch, batch_idx):
        outputs = self.common_step(batch)
        self.log("train_loss", outputs["loss"])
        self.log("train_acc", outputs["acc"])
        return outputs["loss"]

    def validation_step(self, batch, batch_idx):
        outputs = self.common_step(batch)
        self.log("val_loss", outputs["loss"])
        self.log("val_acc", outputs["acc"])
        self.log("val_f1", outputs["f1"])

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.cfg.lr)


In [None]:
model = MemeBLIP(cfg)

trainer = pl.Trainer(
    max_epochs=cfg.max_epochs,
    accelerator="gpu" if cfg.device == "cuda" else "cpu",
    devices=len(cfg.gpus) if cfg.device == "cuda" else 1,
)

trainer.fit(model, train_loader, val_loader)


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name               | Type                         | Params | Mode 
----------------------------------------------------------------------------
0 | model              | BlipForConditionalGeneration | 247 M  | eval 
1 | classifier         | Linear                       | 2.0 K  | train
2 | cross_entropy_loss | CrossEntropyLoss             | 0      | train
3 | acc                | MulticlassAccuracy           | 0      | train
4 | f1                 | MulticlassF1Score            | 0      | train
----------------------------------------------------------------------------
247 M     Trainable params
0         Non-tra

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


KeyError: 'images'