In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
import os
import json
from tqdm import tqdm

# --- Encoder ---
class EncoderCNN(nn.Module):
    def __init__(self, encoded_image_size=14):
        super(EncoderCNN, self).__init__()
        self.enc_image_size = encoded_image_size
        self.encoder_dim = 2048

        resnet = models.resnet101(weights='DEFAULT')
        modules = list(resnet.children())[:-2]
        self.resnet = nn.Sequential(*modules)
        self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))
        self.fine_tune()

    def forward(self, images):
        out = self.resnet(images)
        out = self.adaptive_pool(out)
        out = out.permute(0, 2, 3, 1)  # (B, 14, 14, 2048)
        out = out.view(out.size(0), -1, out.size(-1))  # (B, num_pixels, 2048)
        return out

    def fine_tune(self, fine_tune=True):
        for p in self.resnet.parameters():
            p.requires_grad = False
        for c in list(self.resnet.children())[5:]:
            for p in c.parameters():
                p.requires_grad = fine_tune

# --- Decoder ---
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, encoder_dim=2048, dropout=0.5,
                 pretrained_embeddings=None, freeze_embeddings=False):
        super(DecoderRNN, self).__init__()
        self.encoder_dim = encoder_dim
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size

        self.embedding = nn.Embedding(vocab_size, embed_size)
        if pretrained_embeddings is not None:
            self.embedding.weight = nn.Parameter(pretrained_embeddings)
            self.embedding.weight.requires_grad = not freeze_embeddings
        else:
            self.embedding.weight.data.uniform_(-0.1, 0.1)

        self.dropout = nn.Dropout(p=dropout)
        self.init_h = nn.Linear(encoder_dim, hidden_size)
        self.init_c = nn.Linear(encoder_dim, hidden_size)
        self.lstm = nn.LSTMCell(embed_size + encoder_dim, hidden_size)
        self.fc = nn.Linear(hidden_size, vocab_size)
        self.init_weights()

    def init_weights(self):
        self.fc.bias.data.fill_(0)
        self.fc.weight.data.uniform_(-0.1, 0.1)

    def init_hidden_state(self, encoder_out):
        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out)
        c = self.init_c(mean_encoder_out)
        return h, c

    def forward(self, encoder_out, captions, caplens):
        batch_size = encoder_out.size(0)
        encoder_out = encoder_out.view(batch_size, -1, self.encoder_dim)
        caplens, sort_ind = caplens.squeeze(1).sort(dim=0, descending=True)
        encoder_out = encoder_out[sort_ind]
        captions = captions[sort_ind]
        embeddings = self.embedding(captions)

        h, c = self.init_hidden_state(encoder_out)
        decode_lengths = (caplens - 1).tolist()
        predictions = torch.zeros(batch_size, max(decode_lengths), self.vocab_size).to(encoder_out.device)

        for t in range(max(decode_lengths)):
            batch_size_t = sum([l > t for l in decode_lengths])
            awe = encoder_out[:batch_size_t].mean(dim=1)
            input_lstm = torch.cat([embeddings[:batch_size_t, t, :], awe], dim=1)
            h, c = self.lstm(input_lstm, (h[:batch_size_t], c[:batch_size_t]))
            preds = self.fc(self.dropout(h))
            predictions[:batch_size_t, t, :] = preds

        return predictions, captions, decode_lengths, sort_ind

In [7]:
import os
import json
import torch
from PIL import Image
from tqdm import tqdm
from torchvision import transforms

# ---- SETUP ----
# --- List aspek dan folder ---
ckpt_dir1 = '/kaggle/input/single-aspects/pytorch/default/3/fine-tuned-models'
ckpt_dir2 = '/kaggle/input/single-aspects-part-2/pytorch/default/2/fine-tuned-models'
aspects = [
    "general_impression", "subject", "use_of_camera",
    "color_light", "composition", "dof_and_focus"
]
WORD_MAP_PATH = '/kaggle/input/food-iac-fine-tune-dataset/preprocessed_dataset/wordmap_all.json'
ALL_JSON_PATH = '/kaggle/input/food-iac-fine-tune-dataset/final/all.json'
IMAGE_DIR = '/kaggle/input/dpchallenge-images-food-gallery/images'

with open(WORD_MAP_PATH, 'r') as f:
    word_map = json.load(f)

with open(ALL_JSON_PATH, 'r') as f:
    data_json = json.load(f)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [8]:
# --- Load Model ---
def load_model_for_aspect(aspect):
    if aspect in ["color_light", "composition", "dof_and_focus"]:
        ckpt_dir = ckpt_dir1
    else:
        ckpt_dir = ckpt_dir2

    checkpoint_path = os.path.join(ckpt_dir, f"{aspect}_best.pth")
    checkpoint = torch.load(checkpoint_path, map_location=device)

    encoder = EncoderCNN().to(device)
    decoder = DecoderRNN(
        embed_size=300,
        hidden_size=512,
        vocab_size=len(word_map),
        pretrained_embeddings=None
    ).to(device)

    encoder.load_state_dict(checkpoint["encoder"])
    decoder.load_state_dict(checkpoint["decoder"])
    encoder.eval()
    decoder.eval()
    return encoder, decoder

# --- Load Gambar ---
from torchvision import transforms
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [9]:
def generate_caption(encoder, decoder, image_tensor, word_map, max_len=30, device='cuda'):
    """
    Generate caption (greedy decoding) for 1 image_tensor (unsqueezed).
    """
    # Encode image
    encoder.eval()
    decoder.eval()
    image_tensor = image_tensor.unsqueeze(0).to(device)  # [1, 3, H, W]
    with torch.no_grad():
        encoder_out = encoder(image_tensor)   # (1, num_pixels, encoder_dim)
    num_pixels = encoder_out.size(1)
    encoder_dim = encoder_out.size(-1)
    encoder_out = encoder_out.view(1, -1, encoder_dim)  # (1, num_pixels, encoder_dim)

    # Initialize LSTM state
    h, c = decoder.init_hidden_state(encoder_out)

    # Store sampled word indices
    word_idxs = []
    word = torch.tensor([word_map['<start>']]).to(device)
    embeddings = decoder.embedding(word)  # (1, embed_dim)

    for t in range(max_len):
        awe = encoder_out.mean(dim=1)  # global average pooling
        lstm_input = torch.cat([embeddings, awe], dim=1)
        h, c = decoder.lstm(lstm_input, (h, c))
        preds = decoder.fc(h)           # (1, vocab_size)
        predicted = preds.argmax(1)     # (1,)
        word_idxs.append(predicted.item())
        if predicted.item() == word_map['<end>']:
            break
        embeddings = decoder.embedding(predicted)

    # Decode to words
    idx2word = {v: k for k, v in word_map.items()}
    words = []
    for idx in word_idxs:
        word = idx2word.get(idx, '<unk>')
        if word == '<end>':
            break
        if word not in ['<start>', '<pad>']:
            words.append(word)
    return ' '.join(words)

In [10]:
# --- Load semua model aspek ---
aspect_models = {}
for aspect in aspects:
    encoder, decoder = load_model_for_aspect(aspect)
    aspect_models[aspect] = (encoder, decoder)

In [11]:
# --------- MAIN LOOP ----------
all_samples = []
for img_info in tqdm(data_json['images']):
    # Bisa filter split di sini kalau hanya test/val
    # if img_info['split'] != 'test':
    #     continue
    filename = img_info['filename']
    img_path = os.path.join(IMAGE_DIR, filename)
    if not os.path.exists(img_path):
        print(f'Not found: {img_path}')
        continue

    # Load and preprocess image
    image = Image.open(img_path).convert('RGB')
    image_tensor = transform(image)

    # Inference semua aspek
    aspect_captions = {}
    for aspect, (encoder, decoder) in aspect_models.items():
        aspect_captions[aspect] = generate_caption(encoder, decoder, image_tensor.squeeze(0), word_map, device=device)

    # Reference ground truth (semua sentences digabung 1 string, atau pilih satu)
    reference = [sent['raw'] for sent in img_info['sentences']]
    all_samples.append({
        'image_id': filename,
        'captions': aspect_captions,
        'reference': reference
    })

# ---- SIMPAN ----
with open('dae_dataset.json', 'w', encoding='utf8') as f:
    json.dump(all_samples, f, ensure_ascii=False, indent=2)

100%|██████████| 7221/7221 [15:27<00:00,  7.79it/s]
