In [None]:
import os
from PIL import Image
import numpy as np
from CLIP import clip
import torch
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.nn as nn
device = "cuda"

from loguru import logger
from pytorch_grad_cam import GradCAM, \
                            ScoreCAM, \
                            GradCAMPlusPlus, \
                            AblationCAM, \
                            XGradCAM, \
                            EigenCAM, \
                            EigenGradCAM, \
                            LayerCAM, \
                            FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam import GuidedBackpropReLUModel
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
import cv2
import numpy as np

#preprocess for Grad-cam
def reshape_transform(tensor, height=7, width=7):
    tensor=tensor.reshape(1,tensor.shape[0],tensor.shape[-1])
    result = tensor[:, 1:, :].reshape(tensor.size(0),
    height, width, tensor.size(2))

    result = result.transpose(2, 3).transpose(1, 2)
    return result

use_cuda=1
finalresult=[]
 
#a wrapper for Grad-cam
class load_clip(nn.Module):
    def __init__(self,text):
        super().__init__()
        self.mynet,self.preprocess=clip.load("ViT-B/32", device=device,jit=False)
        self.text=text
        
    def forward(self,image):
        logits_per_image, logits_per_text = self.mynet(image, self.text)
        probs = logits_per_image.softmax(dim=-1)

        return probs



In [None]:
import json
def load_jsonl(file_path):
    data = []
    with open(file_path, 'r') as file:
        for line in file:
            json_obj = json.loads(line)
            data.append(json_obj)
    return data

text=torch.cat([clip.tokenize(f"the photo says \"gas yourself\",this photo is " + "hateful"),\
    clip.tokenize(f"the photo says \"gas yourself\",this photo is " + "not hateful")]).to(device)
model=load_clip(text)  


file_path = '/tmp/DannyWang/CLIP/datasets/hatefulmeme/hateful_memes/dev_seen_m.jsonl'
json_data = load_jsonl(file_path)


from loguru import logger

class YourDataset(Dataset):      
    def __init__(self,dataset,preprocess,sam_num):

        self.img_process = preprocess

        self.samples = []
        self.sam_labels = []
        self.samid=[]
        self.caption=[]
        catcount=[0,0]

        for i in range(len(dataset)):
            cl_id=dataset[i]['label']
            
            if(catcount[cl_id]<sam_num):
                if cl_id==0:
                    label = f"this photo is " + "not hateful"
                else:
                    label = f"{dataset[i]['text']}"
                self.samples.append(self.img_process(Image.open( '/tmp/DannyWang/CLIP/datasets/hatefulmeme/hateful_memes'+ '/' + dataset[i]['img'])))
                self.sam_labels.append(label)
                self.samid.append(dataset[i]['label'])
                self.caption.append(dataset[i]['text'])
                catcount[cl_id]+=1
            if(all(item >sam_num-1 for item in catcount)):
                break
        self.tokens = clip.tokenize(self.sam_labels)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image = self.samples[idx]
        token = self.tokens[idx]
        caption=self.caption[idx]

        
        return image,token,self.samid[idx],caption

sam_numo=[2,3,4,8]
finalresult=[]

def mytrain(sam_num):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    net, preprocess = clip.load("ViT-B/32",device=device,jit=False)
    # 创建损失函数
    loss_img = nn.CrossEntropyLoss()
    loss_txt = nn.CrossEntropyLoss()
    net.visual.requires_grad=False
    params = filter(lambda p: p.requires_grad, net.parameters())
    optimizer = optim.Adam(params, lr=1e-6,betas=(0.9,0.98),eps=1e-6,weight_decay=0.001)
    datasets=YourDataset(json_data,preprocess,sam_num)
    your_dataloader=DataLoader(dataset=datasets,batch_size=2,shuffle=False,num_workers=4,pin_memory=False)
    total_length=sam_num*2
    
    phase = "train"
    epoches = 15
    net.train()
    
    for epoch in range(epoches):
        total_loss = 0
        batch_num = 0
        with torch.cuda.amp.autocast(enabled=True):
            for images,labels,cl_id,caption in your_dataloader:

                images = images.to(device)
                labels = labels.to(device)
                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == "train"):
                    logits_per_image, logits_per_text = net(images, labels)
                    ground_truth = torch.arange(len(images),dtype=torch.long,device=device)
                    cur_loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2
                    total_loss += cur_loss
                    if phase == "train":
                        cur_loss.backward()
                        if device == "cpu":
                            optimizer.step()
                        else:
                            optimizer.step()
                            clip.model.convert_weights(net) 
                batch_num+=1
            epoch_loss = total_loss / total_length
            logger.info('{} Loss: {:.4f}'.format(
                phase, epoch_loss))

            
    net.eval()    
    image= preprocess(Image.open("/tmp/DannyWang/gas.jpg")).unsqueeze(0).to(device)
    image_input = (image)
    text_inputs = torch.cat([clip.tokenize("this photo is hateful"),clip.tokenize("this photo is not hateful")]).to(device)

    with torch.no_grad():
        image_features = net.encode_image(image_input)
        text_features = net.encode_text(text_inputs)

    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
    values, indices = similarity[0].topk(1)
    if indices==0:
        print("预测对")
    else:
        print("预测错")
        


    model.mynet=net
    target_layers=[model.mynet.visual.transformer.resblocks[11].ln_1]       
    #the layers where Grad-cam focus on
    cam = GradCAM(model=model,
            target_layers=target_layers,
            use_cuda=use_cuda,
            reshape_transform=reshape_transform)

    # read the input image
    image_path = "/tmp/DannyWang/gas.jpg"
    rgb_img = cv2.imread(image_path, 1)[:, :, ::-1]
    rgb_img = (cv2.resize(rgb_img, (224, 224))).astype(float)

    rgb_img/=255

    input_tensor = preprocess_image(rgb_img,
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]).to(torch.float32)

    if use_cuda:
        input_tensor = input_tensor.cuda()

    target_category = [ClassifierOutputTarget(1)] #the prediction target that Grad-cam focus on
    
    #target_category = None
    grayscale_cam = cam(input_tensor=(input_tensor), targets=target_category)
    grayscale_cam = grayscale_cam[0, :]

    # lay the output of grad-cam over the raw picture
    visualization = show_cam_on_image(rgb_img, grayscale_cam)
    cv2.cvtColor(visualization, cv2.COLOR_RGB2BGR, visualization)
    cv2.imwrite(f'newftgas15次/opp{sam_num}.jpg', visualization) 


for i in sam_numo:
    mytrain(i)