In [3]:
import torch
import torch.nn as nn
import torchvision.models as models

from transformers import BertTokenizer, BertModel, BertConfig
from torch.nn.utils.rnn import pack_padded_sequence

In [5]:
class ResNetEncoder(nn.Module):
    def __init__(self, embed_size):
        super(ResNetEncoder, self).__init__()
        resnet = models.resnet101(pretrained= True)
        modelus = list(resnet.children())[:-2] # 마지막 layer 제외
        self.resnet = nn.Sequential(*modules)
        self.avgpool = nn.AdaptiveAvgPool2d(1, 1)
        self.fc = nn.Linear(resnet.fc.in_features, embed_size)
        self.bn = nn.BatchNorm1d(embed_size, momentum= 0.01)

    def forward(self, images):
        with torch.no_grad():
            features = self.resnet(images)
        features = self.avgpool(features)
        features = features.view(features.size(0), -1)
        features = self.bn(self.fc(features))

        return features

In [7]:
# Bert Decoder
class BertDecoder(nn.Module):
    def __init__(self, embed_size, hidden_size, num_layers, vocab_size):
        super(BertDecoder, self).__init__()
        self.bert_config = BertConfig(hidden_size= hidden_size,
                                      num_hidden_layers= num_layers,
                                      num_attention_heads= 8,
                                      intermediate_size= hidden_size)
        self.bert = BertModel(self.bert_config)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, encoded_images, captions, lengths):
        embeddings = self.bert(captions)[0]
        embeddings = pack_padded_sequence(embeddings, lengths, batch_first=True).data
        
        features = torch.cat([encoded_images.unsqueeze(1), embeddings], dim=1)
        outputs = self.fc(features)
        return outputs

In [6]:
# 전체 모델 정의 (ResNet + BERT)
class ImageCaptioningModel(nn.Module):
    def __init__(self, embed_size, hidden_size, num_layers, vocab_size):
        super(ImageCaptioningModel, self).__init__()
        self.encoder = ResNetEncoder(embed_size)
        self.decoder = BertDecoder(embed_size, hidden_size, num_layers, vocab_size)

    def forward(self, images, captions, lengths):
        features = self.encoder(images)
        outputs = self.decoder(features, captions, lengths)
        return outputs

In [None]:
# 학습 데이터 로딩 예시 (MSCOCO 데이터셋을 사용한다고 가정)
# 이미지는 사전에 전처리되어야 하며, 캡션은 토크나이징이 되어야 합니다.
# 데이터 로더와 학습 루프 등은 데이터셋에 따라 구성해야 합니다.

# 모델 초기화
embed_size = 512  # 이미지와 문장 임베딩 크기
hidden_size = 512  # BERT hidden size
num_layers = 6  # BERT layer 수
vocab_size = 10000  # 어휘 크기 (예시용)

model = ImageCaptioningModel(embed_size, hidden_size, num_layers, vocab_size)

# 손실 함수와 옵티마이저 설정
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 학습 루프 예시 (실제 데이터셋에 맞게 수정 필요)
for epoch in range(num_epochs):
    for images, captions, lengths in data_loader:
        optimizer.zero_grad()
        outputs = model(images, captions, lengths)
        loss = criterion(outputs.view(-1, vocab_size), captions.view(-1))
        loss.backward()
        optimizer.step()

        # 학습 과정 출력 (필요에 따라 추가)
        if (i+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_steps}], Loss: {loss.item():.4f}')

# 학습 완료 후 모델 저장 (필요에 따라 추가)
torch.save(model.state_dict(), 'image_captioning_model.pth')