In [45]:
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import os
from collections import Counter

In [52]:
# Create a simple vocabulary from captions
word_to_idx = {}
for caption in captions:
    for word in caption.split():
        if word not in word_to_idx:
            word_to_idx[word] = len(word_to_idx)

# Map words in the captions to indices
def caption_to_indices(caption):
    return [word_to_idx.get(word, 0) for word in caption.split()]  # 0 for unknown words

# Update __getitem__ to use the word_to_idx mapping
class ImageCaptionDataset(Dataset):
    def __init__(self, image_paths, captions, transform=None):
        self.image_paths = image_paths
        self.captions = [caption_to_indices(caption) for caption in captions]
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        caption = self.captions[idx]

        try:
            image = Image.open(image_path).convert("RGB")
            if self.transform:
                image = self.transform(image)
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            return None, None

        caption_tensor = torch.tensor(caption, dtype=torch.long)
        return image, caption_tensor


# Custom collate function to handle None values (for skipped images)
def collate_fn(batch):
    batch = [b for b in batch if b[0] is not None]  # Remove None values from batch
    if not batch:  # If all items are None, return empty tensors
        return torch.empty(0), torch.empty(0)
    images, captions = zip(*batch)
    return torch.stack(images), torch.stack(captions)

TypeError: Tensor.split() missing 1 required positional argument: 'split_size'

In [47]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

image_paths = ["dog playing in park.jpg", "cat sitting chair.jpg"]
valid_image_paths = [path for path in image_paths if os.path.exists(path)]
captions = ["a dog playing in the park", "a cat sitting on a chair"]

# Initialize the dataset and DataLoader with num_workers=0 to avoid worker-related issues
dataset = ImageCaptionDataset(image_paths=valid_image_paths, captions=captions, transform=transform)

dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=0, collate_fn=collate_fn)

In [48]:
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        self.conv = nn.Conv2d(3, 16, kernel_size=3, stride=2)

    def forward(self, x):
        return self.conv(x)

class CaptionGenerator(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(CaptionGenerator, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, features, captions):
        embeddings = self.embedding(captions)
        inputs = torch.cat((features.unsqueeze(1), embeddings), dim=1)
        hiddens, _ = self.lstm(inputs)
        outputs = self.fc(hiddens)
        return outputs

In [49]:
embed_size = 256
hidden_size = 512
num_layers = 1
learning_rate = 1e-3
num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

feature_extractor = FeatureExtractor().to(device)
caption_generator = CaptionGenerator(embed_size, hidden_size, vocab_size, num_layers).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(caption_generator.parameters(), lr=learning_rate)

In [51]:
for epoch in range(num_epochs):
    for images, captions in dataloader:
        images, captions = images.to(device), captions.to(device)

        # Extract features
        features = feature_extractor(images)

        # Forward pass
        outputs = caption_generator(features, captions[:, :-1])
        loss = criterion(outputs.reshape(-1, vocab_size), captions[:, 1:].reshape(-1))

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")


IndexError: index out of range in self