In [1]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
class ImageTextDataset(Dataset):
    def __init__(self, image_dir, text_file, transform=None):
        """
        Args:
            image_dir (str): Directory with all the images.
            text_file (str): Path to the text file with image names and corresponding captions.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.image_dir = image_dir
        self.transform = transform
        self.image_caption_pairs = self._load_image_caption_pairs(text_file)
        self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

    def _load_image_caption_pairs(self, text_file):
        pairs = []
        with open(text_file, 'r') as file:
            for line in file:
                try:
                    image_name, caption = line.strip().split('\t')
                    pairs.append((image_name, caption))
                except:
                    image_name = line.strip().split('\t')[0]
                    pairs.append((image_name, ''))
        return pairs

    def __len__(self):
        return len(self.image_caption_pairs)

    def __getitem__(self, idx):
        image_name, caption = self.image_caption_pairs[idx]
        image_path = os.path.join(self.image_dir, image_name)
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        caption_ids = self.tokenizer.encode(caption, add_special_tokens=True)
        return image, torch.tensor(caption_ids)


In [9]:
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

image_dir = './data/img'
text_file = './data/annotations.txt'
dataset = ImageTextDataset(image_dir=image_dir, text_file=text_file, transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)


In [10]:
class ImageCaptioningModel(nn.Module):
    def __init__(self, swin_model, gpt2_model):
        super(ImageCaptioningModel, self).__init__()
        self.swin = swin_model
        self.gpt2 = gpt2_model

    def forward(self, images, captions):
        # Extract features from images
        image_features = self.swin(images).last_hidden_state.mean(dim=1)
        # Generate text from image features
        outputs = self.gpt2(input_ids=captions, labels=captions)
        return outputs.loss, outputs.logits


In [11]:
swin_model = models.swin_tiny_patch4_window7_224(pretrained=True)
gpt2_model = GPT2LMHeadModel.from_pretrained('gpt2')
model = ImageCaptioningModel(swin_model, gpt2_model)


AttributeError: module 'torchvision.models' has no attribute 'swin_tiny_patch4_window7_224'

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
criterion = nn.CrossEntropyLoss()


In [None]:
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for images, captions in dataloader:
        images, captions = images.to(device), captions.to(device)
        optimizer.zero_grad()
        loss, logits = model(images, captions)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(dataloader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')
