In [1]:
import torch
from transformers import AlignProcessor, AlignModel
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import torch.nn.functional as F
from datasets import DatasetDict, Dataset
from collections import defaultdict 
from transformers import TrainingArguments
from transformers import Trainer
from transformers import DataCollatorWithPadding
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
import wandb
import math
from datetime import datetime

from src.datasets.meme_text_dataloader import get_meme_text_dataloader
from src.utilities import *
from src.models.align_base import align_base


# The dataset and the model

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else  'cpu')
image_height = 289
image_width = 289
seed = 42
# Load images into a Dataset, but the pixels will be transformed into list elements, which is not efficient.
meme_loader = get_meme_text_dataloader('memecap', (image_height, image_width))

# load memes and texts. In total 873 steps
# meme_loader.load_datasets(splits=['trainval'], from_idx=0, to_idx=100) # test or trainval
len(meme_loader.trainval_text_data), type(meme_loader.trainval_text_data), meme_loader.trainval_text_data[0]

(5823,
 list,
 {'category': 'memes',
  'img_captions': ['Person in Spider Man outfit gives a lecture on stage.',
   'Person dressed as spider man stands in front of crowd with notes'],
  'meme_captions': ['Meme poster is frustrated about the format of the website and is making a suggestion for improvement.'],
  'title': 'For real though',
  'url': 'https://i.redd.it/m16dhaqyply21.jpg',
  'img_fname': 'memes_bpet7l.png',
  'metaphors': [{'metaphor': 'Spider Man outfit', 'meaning': 'Meme poster'},
   {'metaphor': 'a lecture', 'meaning': 'complaint'},
   {'metaphor': 'spider man', 'meaning': 'Meme poster'},
   {'metaphor': 'crowd', 'meaning': 'meme readers'}],
  'post_id': 'bpet7l'})

In [3]:
import random
random.seed(seed)
random.shuffle(meme_loader.trainval_text_data) # in-place shuffle
len(meme_loader.trainval_text_data), type(meme_loader.trainval_text_data), meme_loader.trainval_text_data[0]

(5823,
 list,
 {'category': 'memes',
  'img_captions': ['The little kid looks confused and is ready to ask a snarky question'],
  'meme_captions': ["Meme poster trying to figure out how much sleep he'll get if they continue to watch tv."],
  'title': 'All my homies are nocturnal',
  'url': 'https://i.redd.it/hrlf4s40cst51.jpg',
  'img_fname': 'memes_jdabfs.png',
  'metaphors': [{'metaphor': 'The little kid', 'meaning': 'Meme poster'},
   {'metaphor': 'looks', 'meaning': 'calculating time'}],
  'post_id': 'jdabfs'})

In [4]:
split_point = int(len(meme_loader.trainval_text_data) * 0.8)
train_text = meme_loader.trainval_text_data[:split_point]
val_text   = meme_loader.trainval_text_data[split_point:]
len(train_text), len(val_text)

(4658, 1165)

In [5]:
# align_model = align_base()

In [6]:
# After 'processor' all tokens are transformed into ids and with paddings.
# Preprocessing: resize and crop
import os
def pre_porcess(text_batch):
    '''
    text_batch [list]
    '''
    processor = AlignProcessor.from_pretrained("kakaobrain/align-base")
    directory = 'data/memes'
    images = []
    captions = []
    for item in text_batch:
        img_path = os.path.join(directory, item['img_fname'])
        
        try:
            with Image.open(img_path) as img:
                img = resize_and_crop_image(img, image_width, image_height)
                img_array = np.array(img) # dtype=numpy.uint8
                if len(img_array.shape) == 2: # handle gray images with shape of (802, 640)
                    images.append(np.stack([img_array, img_array, img_array], axis=-1))
                elif img_array.shape[2] == 4: # handle images with shape of (802, 640, 4)
                    images.append(img_array[:, :, :-1])
                else:
                    images.append(img_array)
                    # Only load the first meme caption
                captions.append(item.get('meme_captions', [""])[0])
        except IOError:
            print(f"Error opening image {img_path}")
    return  processor(text=captions, 
                                    images=images, 
                                    return_tensors="pt",
                                    truncation=True )


In [7]:
pre_porcess(train_text[:3])

Unused or unrecognized kwargs: truncation.


{'input_ids': tensor([[  101,  2033,  4168, 13082,  2667,  2000,  3275,  2041,  2129,  2172,
          3637,  2002,  1005,  2222,  2131,  2065,  2027,  3613,  2000,  3422,
          2694,  1012,   102,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0],
        [  101,  2033,  4168, 13082,  2003,  2667,  2000, 16636,  1996,  9164,
          1997,  2111,  2040,  2113,  2037, 12010,  3570,  1998,  2216,  2040,
          2024, 12010,  1998,  2024,  2025,  5204,  1012,   102,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,


In [8]:
from src.utilities import recall_at_k
from src.evaluation import similarity_align

def evaluation(model, dataset=val_text):
    model.eval()
    with torch.no_grad():
        print('Number of samples:', len(dataset))
        inputs = pre_porcess(dataset)
        # Calculate the similarity matrix of memes and texts.
        text2image_si = similarity_align(inputs, model, device)
        # t2i R@k
        t2i = recall_at_k(text2image_si, prefix='t2i_')
        # i2t R@k
        i2t = recall_at_k(text2image_si.T, prefix='i2t_')
        # Merge two dictionaries
        i2t.update(t2i)
    # print(i2t)
    return i2t

In [9]:
# evaluation(align_model)

In [10]:
def save_checkpoint(model, optimizer, epoch, save_dir, filename="checkpoint.pth"):
    """
    Saves a checkpoint of the model, optimizer, and training parameters.

    Args:
        model (torch.nn.Module): The model to save.
        optimizer (torch.optim.Optimizer): The optimizer used for training.
        epoch (int): The current epoch.
        save_dir (str): The directory to save the checkpoint in.
        filename (str, optional): The filename of the checkpoint. Default: 'checkpoint.pth'.
    """

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    checkpoint_path = os.path.join(save_dir, filename)
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, checkpoint_path)
    print(f"Checkpoint saved to {checkpoint_path}")

def load_checkpoint(model, optimizer, checkpoint_path):
    """
    Loads a checkpoint and resumes training from the saved state.

    Args:
        model (torch.nn.Module): The model to load the checkpoint into.
        optimizer (torch.optim.Optimizer): The optimizer to load the checkpoint into.
        checkpoint_path (str): The path to the checkpoint file.
    """

    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")

    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch'] + 1  # Add 1 to epoch for next iteration

    print(f"Checkpoint loaded from {checkpoint_path}. Resuming training from epoch {epoch}.")
    return model, optimizer, epoch

In [11]:

def train(model, train_data, validation_data, epochs, 
          optimizer, batch_size, log_step=10, evaluation_step =50, 
          saving_model_step=50, out_dir='./output'):
  """
  Trains a CNN model with cosine scheduler and Wandb recording.

  Args:
      model (torch.nn.Module): The CNN model to train.
      train_data (torch.utils.data.DataLoader): Training data loader.
      val_data (torch.utils.data.DataLoader): Validation data loader (optional).
      epochs (int): Number of training epochs.
      learning_rate (float): Initial learning rate.
      batch_size (int): Batch size for training.
      device (str): Device to use for training ('cpu' or 'cuda' if available).
  """
  train_num_batches = math.ceil(len(train_data) / batch_size)
  val_num_batches = math.ceil(len(validation_data) / batch_size)
  total_steps = epochs * train_num_batches
  print(f'train_num_batches: {train_num_batches}, val_num_batches: {val_num_batches}, total_steps: {total_steps}')
  # Initialize optimizer and scheduler

  scheduler = CosineAnnealingLR(optimizer, T_max=epochs)  # Cosine scheduler ???
  step = 0
  # Initialize Wandb (optional)
  if wandb.run is None:
    wandb.init(project="meme-text")  # Replace with your project name

  # Training loop
  train_loss = 0
  for epoch in range(epochs):
    for idx in range(0, len(train_data), batch_size):
      model.train()
      input = pre_porcess(train_data[idx: min(idx+batch_size, len(train_data))])
      input.to(device)
      optimizer.zero_grad()
      output = model(**input)
      loss = output['loss']
      loss.backward()
      optimizer.step()
      train_loss += loss.item()

      # Log training metrics to Wandb (optional)
      if step%log_step == (log_step - 1) and wandb.run is not None:
        print(f'step: {step}/{total_steps}; training loss: {train_loss/step}')
        wandb.log({"train_loss": train_loss/step, 'learning_rate': optimizer.param_groups[0]["lr"]})
        train_loss = 0
      # Validation step (optional)
      if step%evaluation_step == (evaluation_step - 1) and wandb.run is not None:
        metrics = evaluation(model)
        print(f'step: {step}/{total_steps}; metrics: {metrics}')
        wandb.log(metrics)
      if step%saving_model_step == (saving_model_step - 1):
        # save the model
        now = datetime.now()
        dt_string = now.strftime("%d_%m_%H-%M-%S") # 27_12_10-09-20
        save_checkpoint(model, optimizer, epoch, out_dir, f'{dt_string}_align_{step}.pth')
        pass
      step += 1
    # Update scheduler after each epoch
    scheduler.step()

  # Finish Wandb run (optional)
  if wandb.run is not None:
    wandb.finish()


# Single training loop

In [12]:
# # Login to your Wandb account (optional)
# wandb.login()
# wandb.init(project="meme-text") 
# learning_rate = 1e-5
# optimizer = Adam(align_model.model.parameters(), lr=learning_rate)
# train(align_model.model, train_text, val_text, epochs=3, 
#       optimizer=optimizer, batch_size=8,
#       log_step=10, evaluation_step=50, saving_model_step=218)


# Sweep


In [13]:
sweep_config = {
                'method': 'bayes',
                'metric': {'goal':'minimize', 'name':'loss'},
                'parameters': {
                    'epochs': {'values': [1, 2]},
                    'learning_rate': {'distribution': 'uniform',
                                        'max': 4e-5,
                                        'min': 1e-6},
                    'optimizer': {'values': ['adam']}
                }
 }

sweep_id = wandb.sweep(sweep_config, project="meme-text")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


Create sweep with ID: 57xm3plb
Sweep URL: https://wandb.ai/adl_shilingdeng/meme-text/sweeps/57xm3plb


In [14]:
def sweep(config=None):
    with wandb.init(project='meme-text', entity='serkar', config=config):
        config = wandb.config

        # Init the model
        align_model = align_base()
        align_model.model.to(device)

        # # Define optimizer
        optimizer = torch.optim.Adam(align_model.model.parameters(), lr=config.learning_rate)

        wandb.watch(align_model.model, log="all")

        # Run training
        train(align_model.model, train_text, val_text, epochs=config.epochs, 
                optimizer=optimizer, batch_size=6,
                log_step=10, evaluation_step=50, saving_model_step=218)
        # After training or computations
        torch.cuda.empty_cache()
        
    print('Finished Training')


In [15]:
wandb.agent(sweep_id, function=sweep, count=10)

[34m[1mwandb[0m: Agent Starting Run: a5ep8inu with config:
[34m[1mwandb[0m: 	epochs: 1
[34m[1mwandb[0m: 	learning_rate: 3.645353107667553e-05
[34m[1mwandb[0m: 	optimizer: adam
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mshilingdeng7187[0m. Use [1m`wandb login --relogin`[0m to force relogin


train_num_batches: 777, val_num_batches: 195, total_steps: 777


Unused or unrecognized kwargs: truncation.
Unused or unrecognized kwargs: truncation.
Unused or unrecognized kwargs: truncation.


# Load a check point 

In [None]:
# new_model = align_base()
# optimizer = Adam(new_model.model.parameters(), lr=0.1)
# load_checkpoint(new_model.model, optimizer, './output/align_1525_27_05_00-21-48.pth')
# # val dataset
# evaluation(new_model.model)
# # test dataset
# evaluation(new_model.model, meme_loader.test_text_data)



Checkpoint loaded from ./output/align_1525_27_05_00-21-48.pth. Resuming training from epoch 2.
Number of samples: 559


Unused or unrecognized kwargs: truncation.


torch.Size([559, 640])
torch.Size([559, 640])


{'i2t_r1': 0.49016100178890876,
 'i2t_r5': 0.6869409660107334,
 'i2t_r10': 0.738819320214669,
 'i2t_r_mean': 0.6386404293381037,
 't2i_r1': 0.5116279069767442,
 't2i_r5': 0.6797853309481217,
 't2i_r10': 0.7584973166368515,
 't2i_r_mean': 0.6499701848539058}

# Tensorboard
tensorboard --logdir=check_points/runs

In [None]:
import os
os.system("shutdown -t  +30 ") # in minutes