In [None]:
import pandas as pd
import numpy as np
import clip
import torch
import tqdm
import json
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
import matplotlib.pyplot as plt

data = pd.read_csv('../data/datasets/train-sarcasm-dataset.csv', encoding='utf-8')
data.head()

Unnamed: 0,Image,Caption,Label
0,8ae451edcd8ebf697f8763ece249115813149c55733bf8...,Cô ấy trên mạng vs cô ấy ngoài đời =))),sarcasm
1,35370ffd6c791d6f8c4ab3dd4363ed468fab41e4824ee9...,Người tâm linh giao tiếp với người thực tế :))),not-sarcasm
2,316fdd1477725b9fb1a55015ac06b68b92b50bd4303e08...,Hình như Trăng hôm nay đẹp quá mọi người ạ! 😃 ...,sarcasm
3,8a0f34e0e30e4e5cfb306933c1d25fa801a5da78646b59...,MỌI NGƯỜI NGHĨ SAO VỀ PHÁT BIỂU CỦA SHARK VIỆT...,not-sarcasm
4,e517a5e95d1065886a7c815e82fe254381d4f9f4b244d4...,2 tay hai nàng chứ việc gì phải lệ hai hàng,sarcasm


In [2]:
train_size = int(0.7 * len(data))
val_size = len(data) - train_size
train_dataset, val_dataset = random_split(data, [train_size, val_size])

print("Lenght of train data: {}".format(len(train_dataset)))
print("Lenght of val data: {}".format(len(val_dataset)))

Lenght of train data: 7563
Lenght of val data: 3242


In [None]:
class ImageDataset(Dataset):
    def __init__(self, data, subcategories):
        self.data = data
        self.subcategories = subcategories
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image_path = "../data/images/train-images/{}".format(item['Image'])
        image = Image.open(image_path).convert("RGB")  
        subcategory = item['Label']
        label = self.subcategories.index(subcategory) if subcategory in self.subcategories else -1
        return self.transform(image), label

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

labels = ["sarcasm", "not-sarcasm"]
text_inputs = processor(text=[f"{label}." for label in labels], return_tensors="pt", padding=True).to(device)
subcategories = list(data['Label'].unique())

train_dataset, val_dataset = train_test_split(data.to_dict(orient='records'), test_size=0.2, random_state=42)
train_loader = DataLoader(ImageDataset(train_dataset, subcategories), batch_size=32, shuffle=True)
val_loader = DataLoader(ImageDataset(val_dataset, subcategories), batch_size=32, shuffle=False)



In [5]:
for images, labels in train_loader:
    print("Image Shape:", images.shape)
    print("Labels Shape:", labels.shape)
    break

Image Shape: torch.Size([32, 3, 224, 224])
Labels Shape: torch.Size([32])


In [6]:
import torch.nn as nn

class CLIPFineTuner(nn.Module):
    def __init__(self, clip_model, num_classes):
        super(CLIPFineTuner, self).__init__()
        self.clip_model = clip_model
        self.fc = nn.Linear(clip_model.config.projection_dim, num_classes)

    def forward(self, images, input_ids):
        outputs = self.clip_model(pixel_values=images, input_ids=input_ids)
        logits = self.fc(outputs.image_embeds)
        return logits

num_classes = len(subcategories)
model_ft = CLIPFineTuner(model, num_classes).to(device)

In [7]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_ft.fc.parameters(), lr=1e-4)

In [8]:
from tqdm import tqdm

num_epochs = 4

for epoch in range(num_epochs):
    model_ft.train() 
    running_loss = 0.0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}, Loss: 0.0000")

    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)  
        optimizer.zero_grad()  
        input_ids = text_inputs.input_ids.repeat(images.size(0), 1)  
        outputs = model_ft(images, input_ids) 
        loss = criterion(outputs, labels)
        loss.backward()  
        optimizer.step()  

        running_loss += loss.item() 
        pbar.set_description(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader):.4f}") 

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}') 

    model_ft.eval()  
    correct = 0 
    total = 0 

    with torch.no_grad():  
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device) 
            input_ids = text_inputs.input_ids.repeat(images.size(0), 1) 
            outputs = model_ft(images, input_ids)  
            _, predicted = torch.max(outputs.data, 1)  
            total += labels.size(0) 
            correct += (predicted == labels).sum().item()  
    print(f'Validation Accuracy: {100 * correct / total}%') 

  attn_output = torch.nn.functional.scaled_dot_product_attention(
Epoch 1/16, Loss: 0.6898: 100%|██████████| 271/271 [01:37<00:00,  2.78it/s]


Epoch [1/16], Loss: 0.6898
Validation Accuracy: 58.861638130495145%


Epoch 2/16, Loss: 0.6763: 100%|██████████| 271/271 [01:37<00:00,  2.77it/s]


Epoch [2/16], Loss: 0.6763
Validation Accuracy: 59.18556223970384%


Epoch 3/16, Loss: 0.6669: 100%|██████████| 271/271 [01:37<00:00,  2.77it/s]


Epoch [3/16], Loss: 0.6669
Validation Accuracy: 61.26793151318834%


Epoch 4/16, Loss: 0.6596: 100%|██████████| 271/271 [01:37<00:00,  2.79it/s]


Epoch [4/16], Loss: 0.6596
Validation Accuracy: 63.39657565941694%


Epoch 5/16, Loss: 0.6532: 100%|██████████| 271/271 [01:37<00:00,  2.79it/s]


Epoch [5/16], Loss: 0.6532
Validation Accuracy: 63.48912540490514%


Epoch 6/16, Loss: 0.6476: 100%|██████████| 271/271 [01:37<00:00,  2.79it/s]


Epoch [6/16], Loss: 0.6476
Validation Accuracy: 64.09069875057844%


Epoch 7/16, Loss: 0.6426: 100%|██████████| 271/271 [01:36<00:00,  2.80it/s]


Epoch [7/16], Loss: 0.6426
Validation Accuracy: 64.64599722350763%


Epoch 8/16, Loss: 0.6388: 100%|██████████| 271/271 [01:37<00:00,  2.78it/s]


Epoch [8/16], Loss: 0.6388
Validation Accuracy: 65.06247107820454%


Epoch 9/16, Loss: 0.6355: 100%|██████████| 271/271 [01:36<00:00,  2.80it/s]


Epoch [9/16], Loss: 0.6355
Validation Accuracy: 65.20129569643683%


Epoch 10/16, Loss: 0.6319: 100%|██████████| 271/271 [01:37<00:00,  2.79it/s]


Epoch [10/16], Loss: 0.6319
Validation Accuracy: 65.20129569643683%


Epoch 11/16, Loss: 0.6287: 100%|██████████| 271/271 [01:37<00:00,  2.77it/s]


Epoch [11/16], Loss: 0.6287
Validation Accuracy: 65.29384544192503%


Epoch 12/16, Loss: 0.6265: 100%|██████████| 271/271 [01:37<00:00,  2.79it/s]


Epoch [12/16], Loss: 0.6265
Validation Accuracy: 65.61776955113373%


Epoch 13/16, Loss: 0.6235: 100%|██████████| 271/271 [01:37<00:00,  2.79it/s]


Epoch [13/16], Loss: 0.6235
Validation Accuracy: 65.47894493290143%


Epoch 14/16, Loss: 0.6226: 100%|██████████| 271/271 [01:37<00:00,  2.79it/s]


Epoch [14/16], Loss: 0.6226
Validation Accuracy: 65.38639518741323%


Epoch 15/16, Loss: 0.6199: 100%|██████████| 271/271 [01:37<00:00,  2.79it/s]


Epoch [15/16], Loss: 0.6199
Validation Accuracy: 65.47894493290143%


Epoch 16/16, Loss: 0.6180: 100%|██████████| 271/271 [01:37<00:00,  2.79it/s]


Epoch [16/16], Loss: 0.6180
Validation Accuracy: 65.75659416936604%


In [None]:
torch.save(model_ft.state_dict(), '../data/models/clip-vit-base-patch32-finetuned.pth')  