In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
from torch.optim.lr_scheduler import ReduceLROnPlateau

from loaders import get_train_loader, get_val_test_loader, get_length_vocab, get_pad_index, get_vocab, show_image

from train_forcing import train, validate, val_visualize_captions

from forcing_model import EncoderDecoder


import pandas as pd
from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data_dir = '../data/Images/'
captions_file = '../data/captions.txt'

input_size = (224,224)

transform_train = transforms.Compose(
[
transforms.Resize(input_size),  
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
    mean=[0.485, 0.456, 0.406],  # Normalize with the mean
    std=[0.229, 0.224, 0.225]    # Normalize with the standard deviation
)
])

transform_val = transforms.Compose(
[
transforms.Resize(input_size),  
transforms.ToTensor(),
transforms.Normalize(
    mean=[0.485, 0.456, 0.406],  # Normalize with the mean
    std=[0.229, 0.224, 0.225]    # Normalize with the standard deviation
)
])

# Split data into train and test sets
df_captions = pd.read_csv(captions_file)
unique_images = df_captions['image'].unique()
train_images, test_images = train_test_split(unique_images, test_size=0.2, random_state=42)
train_images, val_images = train_test_split(train_images, test_size=0.2, random_state=42)

train_df = df_captions[df_captions['image'].isin(train_images)]
val_df = df_captions[df_captions['image'].isin(val_images)]
test_df = df_captions[df_captions['image'].isin(test_images)]

pad_index = get_pad_index(data_dir=data_dir, dataframe=train_df, transform=transform_train)

vocab_train_df = get_vocab(data_dir=data_dir, dataframe=train_df, transform=transform_train)
vocab_val_df = get_vocab(data_dir=data_dir, dataframe=val_df, transform=transform_val)
vocab_test_df = get_vocab(data_dir=data_dir, dataframe=test_df, transform=transform_val)

# Create train, validation, and test data loaders
train_dataloader = get_train_loader(data_dir=data_dir, dataframe=train_df, transform=transform_train, batch_size=32, num_workers=1) # Batch size 32
val_dataloader = get_val_test_loader(data_dir=data_dir, dataframe=val_df, transform=transform_val, batch_size=16, num_workers=1) # Batch size 16
test_dataloader = get_val_test_loader(data_dir=data_dir, dataframe=test_df, transform=transform_val, batch_size=16, num_workers=1)

print(len(val_dataloader))

405


In [3]:
# Hyperparameters
teacher_forcing_prob = 0.5
embed_size = 312
hidden_size = 512
vocab_size_train = len(vocab_train_df)
num_layers = 2
learning_rate = 0.003
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [4]:
# initialize model, loss etc
load_model = False
torch.backends.cudnn.benchmark = True

if load_model:
    model = EncoderDecoder(embed_size, hidden_size, vocab_size_train, num_layers).to(device)
    model.load_state_dict(torch.load('model_weights.py'))
    model.eval()
else:
    model = EncoderDecoder(embed_size, hidden_size, vocab_size_train, num_layers).to(device)
    criterion = nn.CrossEntropyLoss(ignore_index=pad_index)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate) # We apply weight decay that is L2 reg to prevent overfitting
    # scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)



In [5]:
losses = {"train": [], "val": []}

for epoch in range(50):
    print(f'Starting epoch {epoch + 1}...')
    
    train_loss = train(epoch, criterion, model, optimizer, train_dataloader, vocab_size_train, device, teacher_forcing_prob)
    t_loss = sum(train_loss) / len(train_loss)
    print(f'Train set: Average loss: {t_loss}')
    val_loss = validate(criterion, model, val_dataloader, vocab_size_train, vocab_train_df, device, teacher_forcing_prob)
    
    losses["train"].append(t_loss)
    losses["val"].append(val_loss.item())
    
    plt.plot(losses["train"], label="training")
    plt.plot(losses["val"], label="validation")
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()


Starting epoch 1...
Train Epoch: 1; Loss: 8.09651


KeyboardInterrupt: 