In [36]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np

In [43]:
# Encoder
class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet50(pretrained=True)
        for param in resnet.parameters():
            param.requires_grad_(False)
        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)
        self.embed = nn.Linear(resnet.fc.in_features, embed_size)

    def forward(self, images):
        features = self.resnet(images)
        features = features.view(features.size(0), -1)
        features = self.embed(features)
        return features

# Decoder
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
        super(DecoderRNN, self).__init__()
        self.hidden_dim = hidden_size
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.hidden = (torch.zeros(num_layers, 1, hidden_size), torch.zeros(num_layers, 1, hidden_size))

    def forward(self, features, captions):
        cap_embedding = self.embed(captions[:, :-1])
        embeddings = torch.cat((features.unsqueeze(dim=1), cap_embedding), dim=1)
        lstm_out, self.hidden = self.lstm(embeddings)
        outputs = self.linear(lstm_out)
        return outputs

    def sample(self, inputs, states=None, max_len=20):
        res = []
        for i in range(max_len):
            lstm_out, states = self.lstm(inputs, states)
            outputs = self.linear(lstm_out.squeeze(dim=1))
            _, predicted_idx = outputs.max(dim=1)
            res.append(predicted_idx.item())
            if predicted_idx == 1:  # Assuming 1 is the end token
                break
            inputs = self.embed(predicted_idx)
            inputs = inputs.unsqueeze(1)
        return res


In [40]:
from torch import device


def train_model(encoder, decoder, data_loader, criterion, optimizer, num_epochs=5, log_file='./training_log.txt'):
    # Open the training log file
    f = open(log_file, "w")
    
    for epoch in range(1, num_epochs + 1):
        for i_step, (images, captions) in enumerate(data_loader):
            images = images.to(device)
            captions = captions.to(device)

            # Zero the gradients
            decoder.zero_grad()
            encoder.zero_grad()

            # Forward pass
            features = encoder(images)
            outputs = decoder(features, captions)

            # Calculate loss
            loss = criterion(outputs.view(-1, vocab_size), captions.view(-1))

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            # Print and log statistics
            stats = (
                f"Epoch [{epoch}/{num_epochs}], Step [{i_step + 1}/{len(data_loader)}], "
                f"Loss: {loss.item():.4f}, Perplexity: {np.exp(loss.item()):.4f}"
            )
            f.write(stats + "\n")
            f.flush()
            if i_step % 10 == 0:
                print("\r" + stats, end="")
        
        # Save the models
        if epoch % 1 == 0:
            torch.save(decoder.state_dict(), os.path.join("./models", f"decoder-{epoch}.pkl"))
            torch.save(encoder.state_dict(), os.path.join("./models", f"encoder-{epoch}.pkl"))
    
    # Close the log file
    f.close()

def generate_caption(encoder, decoder, image_path, vocab):
    image = load_image(image_path, transform)
    features = encoder(image).unsqueeze(1)
    output = decoder.sample(features)
    caption = [vocab.get(idx, "<unk>") for idx in output]
    return ' '.join(caption)

def load_image(image_path, transform):
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0)
    return image.to(device)


In [41]:
# Set Parameters and Initialize Models

In [42]:
# Hyperparameters
embed_size = 256
hidden_size = 512
vocab_size = 5000  # Update with actual vocab size

# File paths
encoder_file = './models/encoder-3.pkl'
decoder_file = './models/decoder-3.pkl'

# Initialize models
encoder = EncoderCNN(embed_size)
decoder = DecoderRNN(embed_size, hidden_size, vocab_size)

# Load or train models
if os.path.exists(encoder_file) and os.path.exists(decoder_file):
    encoder.load_state_dict(torch.load(encoder_file))
    decoder.load_state_dict(torch.load(decoder_file))
else:
    # Assume `data_loader` and other necessary components are defined
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()))
    train_model(encoder, decoder, data_loader, criterion, optimizer, num_epochs=5)
    torch.save(encoder.state_dict(), encoder_file)
    torch.save(decoder.state_dict(), decoder_file)

# Move models to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder.to(device)
decoder.to(device)
encoder.eval()
decoder.eval()

# Define image transform
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

# Example vocabulary dictionary
vocab = {0: '<start>', 1: '<end>', 2: 'a', 3: 'man', 4: 'with', 5: 'hat'}

# Generate caption
image_path = '/home/hariom/python/Image detector/cat.jpeg'
caption = generate_caption(encoder, decoder, image_path, vocab)
print("Generated Caption:", caption)


Generated Caption: <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk>
