In [1]:
# Import utility functions for data preprocessing and feature extraction and all other necessary libraries

import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import pandas as pd
import numpy as np
import re
import os
from Decoder.Decoder_Utils import *

In [2]:
# Preprocess captions from the dataset and get vocabulary information

df, word_to_idx, idx_to_word, vocab_size = preprocess_captions("./data/flickr30k_images/results.csv")

In [3]:
import torch.nn as nn
import torch.nn.functional as F

from torchvision.models import resnet18

# Load a pre-trained ResNet18 model for feature extraction
net = resnet18(num_classes=4)
net.load_state_dict(torch.load("save_model/self_supervised_rotation_model.pth", weights_only=True))

<All keys matched successfully>

In [4]:
# Extract image features using the ResNet18 model

image_features = extract_image_features(net, df)

In [5]:
# %store image_features

Stored 'image_features' (Tensor)


In [None]:
%store -r image_features

In [6]:
from Decoder.Decoder_DataSetup import ImageCaptionDataset
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

In [7]:
# Split the DataFrame into training and validation sets
df_train, df_val = train_test_split(df, test_size=0.2, random_state=42)

In [8]:
# Get indices for training and validation splits
train_indices = df_train.index
val_indices = df_val.index

# Split img_features_tensor based on these indices
img_features_train = image_features[train_indices]
img_features_val = image_features[val_indices]

# Create training and validation dataset instances
train_dataset = ImageCaptionDataset(img_features_train, df_train)
val_dataset = ImageCaptionDataset(img_features_val, df_val)

#batch size
batch_size = 128

# data loaders
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_data_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)


In [9]:
# Initialize the model
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
device

device(type='mps')

In [10]:
from Decoder.transformer import Transformer
import torch.optim as optim
model = Transformer(
    word_to_idx=word_to_idx, 
    input_dim=512,  
    wordvec_dim=128,  
    num_heads=4, 
    num_layers=6,
    max_length=17
)
model = model.to(device)

In [12]:
# Import training function and learning rate scheduler

from Decoder.Decoder_Train import train_model
from torch.optim.lr_scheduler import ReduceLROnPlateau

criterion = nn.CrossEntropyLoss(ignore_index=word_to_idx["<TAB>"])  # Cross-entropy loss, ignoring padding index
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)  # Adam optimizer with weight decay
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True # Learning rate scheduler



In [13]:
train_model(model, train_data_loader, val_data_loader, optimizer, scheduler, num_epochs=5)

Epoch [1/5], Loss: 5.1270
Validation Loss: 4.4356
Epoch [2/5], Loss: 4.2864
Validation Loss: 4.0439
Epoch [3/5], Loss: 4.0414
Validation Loss: 3.8444
Epoch [4/5], Loss: 3.9126
Validation Loss: 3.7625
Epoch [5/5], Loss: 3.8209
Validation Loss: 3.6798
