In [1]:
import kagglehub

path = kagglehub.dataset_download("hammadjavaid/6992-labeled-meme-images-dataset")

print("Path to dataset files:", path)

Path to dataset files: /root/.cache/kagglehub/datasets/hammadjavaid/6992-labeled-meme-images-dataset/versions/1


In [2]:
import os
import pandas as pd
from PIL import Image
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from transformers import BertTokenizer, BertModel


In [3]:

images_path = "/root/.cache/kagglehub/datasets/hammadjavaid/6992-labeled-meme-images-dataset/versions/1/images/images"
labels_file = "/root/.cache/kagglehub/datasets/hammadjavaid/6992-labeled-meme-images-dataset/versions/1/labels.csv"

data = pd.read_csv(labels_file)

data

Unnamed: 0.1,Unnamed: 0,image_name,text_ocr,text_corrected,overall_sentiment
0,0,image_1.jpg,LOOK THERE MY FRIEND LIGHTYEAR NOW ALL SOHALIK...,LOOK THERE MY FRIEND LIGHTYEAR NOW ALL SOHALIK...,very_positive
1,1,image_2.jpeg,The best of #10 YearChallenge! Completed in le...,The best of #10 YearChallenge! Completed in le...,very_positive
2,2,image_3.JPG,Sam Thorne @Strippin ( Follow Follow Saw every...,Sam Thorne @Strippin ( Follow Follow Saw every...,positive
3,3,image_4.png,10 Year Challenge - Sweet Dee Edition,10 Year Challenge - Sweet Dee Edition,positive
4,4,image_5.png,10 YEAR CHALLENGE WITH NO FILTER 47 Hilarious ...,10 YEAR CHALLENGE WITH NO FILTER 47 Hilarious ...,neutral
...,...,...,...,...,...
6987,6987,image_6988.jpg,Tuesday is Mardi Gras Wednesday is Valentine's...,Tuesday is Mardi Gras Wednesday is Valentine's...,neutral
6988,6988,image_6989.jpg,MUST WATCH MOVIES OF 2017 ITI Chennai memes MA...,MUST WATCH MOVIES OF 2017 ITI Chennai memes MA...,neutral
6989,6989,image_6990.png,LESS MORE TALKING PLANNING SODA JUNK FOOD COMP...,LESS MORE TALKING PLANNING SODA JUNK FOOD COMP...,positive
6990,6990,image_6991.jpg,When I VERY have time is a fantasy No one has ...,When I have time is a fantasy. no one has time...,very_positive


In [4]:
class MemeDataset(Dataset):
    def __init__(self, data, images_path, transform=None, tokenizer=None, max_len=50):
        self.data = data
        self.images_path = images_path
        self.transform = transform
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image_id = self.data.iloc[idx]['image_name']
        image_path = os.path.join(self.images_path, image_id)
        label = str(self.data.iloc[idx]['text_corrected'])

        try:
            image = Image.open(image_path).convert("RGB")
            if self.transform:
                image = self.transform(image)
        except (OSError, IOError):
            return None

        tokens = self.tokenizer(label, padding='max_length', max_length=self.max_len, truncation=True, return_tensors="pt")

        return image, tokens.input_ids.flatten(), tokens.attention_mask.flatten()


In [5]:
class MemeCaptioningModel(nn.Module):
    def __init__(self, feature_dim, hidden_dim, vocab_size, max_len=50):
        super(MemeCaptioningModel, self).__init__()

        self.encoder = models.resnet50(pretrained=True)
        self.encoder.fc = nn.Linear(self.encoder.fc.in_features, feature_dim)

        # Decoder: LSTM для генерации текста
        self.lstm = nn.LSTM(input_size=feature_dim, hidden_size=hidden_dim, num_layers=1, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, images, captions):
        features = self.encoder(images)
        features = features.unsqueeze(1).repeat(1, captions.size(1), 1)
        lstm_out, _ = self.lstm(features)
        outputs = self.fc(lstm_out)
        return outputs


In [6]:
def collate_fn(batch):
    batch = [item for item in batch if item is not None]
    return torch.utils.data.default_collate(batch)


In [7]:
import torch
from PIL import Image
from transformers import BertTokenizer

def generate_caption(image_path, model, tokenizer, max_len=50, device='cuda'):
    image = Image.open(image_path).convert("RGB")
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])
    image = transform(image).unsqueeze(0).to(device)

    model.eval()

    caption = [tokenizer.cls_token_id]
    attention_mask = torch.ones(1, max_len).to(device)

    with torch.no_grad():
        for _ in range(max_len - 1):
            input_ids = torch.tensor(caption).unsqueeze(0).to(device)
            outputs = model(image, input_ids)

            next_token_id = outputs[0, -1, :].argmax(dim=-1).item()

            caption.append(next_token_id)

            if next_token_id == tokenizer.sep_token_id:
                break

    caption_text = tokenizer.decode(caption, skip_special_tokens=True)
    return caption_text

In [None]:
from tqdm import tqdm
from torch.utils.data import DataLoader

feature_dim = 512
hidden_dim = 1024
max_len = 50
batch_size = 32
epochs = 40
learning_rate = 1e-4

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
vocab_size = tokenizer.vocab_size

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

dataset = MemeDataset(data, images_path, transform=transform, tokenizer=tokenizer, max_len=max_len)

def collate_fn(batch):
    batch = [item for item in batch if item is not None]
    return torch.utils.data.default_collate(batch)

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

model = MemeCaptioningModel(feature_dim, hidden_dim, vocab_size).to('cuda')
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(epochs):
    model.train()
    epoch_loss = 0
    with tqdm(total=len(dataloader), desc=f"Epoch {epoch + 1}/{epochs}") as pbar:
        for images, captions, attention_masks in dataloader:
            optimizer.zero_grad()
            images = images.to('cuda')
            captions = captions.to('cuda')
            attention_masks = attention_masks.to('cuda')

            outputs = model(images, captions)
            loss = criterion(outputs.view(-1, vocab_size), captions.view(-1))
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            pbar.set_postfix({"Loss": f"{loss.item():.4f}"})
            pbar.update(1)

    avg_loss = epoch_loss / len(dataloader)
    test_image_path = "/content/1.jpg"

    caption = generate_caption(test_image_path, model, tokenizer)
    print("Generated Caption:", caption)
    print(f"Epoch [{epoch+1}/{epochs}], Average Loss: {avg_loss:.4f}")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Epoch 1/40: 100%|██████████| 219/219 [02:45<00:00,  1.33it/s, Loss=3.3842]


Generated Caption: 
Epoch [1/40], Average Loss: 3.8650


Epoch 2/40: 100%|██████████| 219/219 [02:42<00:00,  1.35it/s, Loss=3.1024]


Generated Caption: i
Epoch [2/40], Average Loss: 3.1043


Epoch 3/40: 100%|██████████| 219/219 [02:42<00:00,  1.35it/s, Loss=3.2561]


Generated Caption: i i
Epoch [3/40], Average Loss: 3.0287


Epoch 4/40: 100%|██████████| 219/219 [02:42<00:00,  1.35it/s, Loss=3.1591]


Generated Caption: i i
Epoch [4/40], Average Loss: 2.9871


Epoch 5/40: 100%|██████████| 219/219 [02:41<00:00,  1.36it/s, Loss=2.8461]


Generated Caption: i you
Epoch [5/40], Average Loss: 2.9544


Epoch 6/40: 100%|██████████| 219/219 [02:42<00:00,  1.35it/s, Loss=3.2632]


Generated Caption: i you.
Epoch [6/40], Average Loss: 2.9133


Epoch 7/40: 100%|██████████| 219/219 [02:41<00:00,  1.35it/s, Loss=2.4394]


Generated Caption: i you
Epoch [7/40], Average Loss: 2.8812


Epoch 8/40: 100%|██████████| 219/219 [02:42<00:00,  1.35it/s, Loss=2.6331]


Generated Caption: i you...
Epoch [8/40], Average Loss: 2.8585


Epoch 9/40: 100%|██████████| 219/219 [02:42<00:00,  1.35it/s, Loss=2.7388]


Generated Caption: i
Epoch [9/40], Average Loss: 2.8379


Epoch 10/40: 100%|██████████| 219/219 [02:45<00:00,  1.32it/s, Loss=3.1134]


Generated Caption: i
Epoch [10/40], Average Loss: 2.8169


Epoch 11/40:  10%|▉         | 21/219 [00:15<02:20,  1.41it/s, Loss=2.6313]

In [None]:

test_image_path = "/content/1.jpg"

caption = generate_caption(test_image_path, model, tokenizer)
print("Generated Caption:", caption)
