# Main Training Notebook

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
datadir = "/content/drive/MyDrive/CS_444/DL_project/clip-gpt-captioning/src"

os.chdir(datadir)
!pwd

In [None]:
!pip install -r requirements.txt

In [None]:
!pip install wandb

In [None]:
!conda install pytorch

# Preparing the data for training

In [None]:
import os
import pickle
import random

import numpy as np
import pandas as pd
from PIL import Image

import torch
from transformers import CLIPModel, CLIPProcessor
from tqdm import tqdm

if __name__ == '__main__':
    # Set constants
    SEED = 100
    DATA_PATH = os.path.join('data')

    # Set random seed
    random.seed(SEED)
    torch.manual_seed(SEED)
    np.random.seed(SEED)

    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Load CLIP model and processor
    preprocessor = CLIPProcessor.from_pretrained('openai/clip-vit-large-patch14')
    model = CLIPModel.from_pretrained('openai/clip-vit-large-patch14').vision_model.to(device)

    # Load dataset
    df = pd.read_csv(os.path.join(DATA_PATH, 'raw', 'results.csv'), sep='|')
    df.columns = [col.strip() for col in df.columns]

    df = df.drop(['comment_number'], axis=1)

    # get every 5 elemeFnt of the df (5 captions per image) and save image name with corresponding captions
    ds = [(img_name, df[df['image_name'] == img_name]['comment'].values) for img_name, _ in df[0::5].to_numpy()]

    # Based on loaded dataset, create a list of (image name, image embedding, caption) tuples
    results = []
    loop = tqdm(ds, total=len(ds), position=0, leave=True)
    for img_name, cap in loop:
        try:
            img = Image.open(os.path.join(DATA_PATH, 'raw', 'flickr10k_images', img_name))

            with torch.no_grad():
                img_prep = preprocessor(images=img, return_tensors='pt').to(device)
                
                img_features = model(**img_prep)
                img_features = img_features.pooler_output
                img_features = img_features.squeeze()
                img_features = img_features.numpy()

            for c in cap:
                results.append((img_name, img_features, c[1:])) # because of the separator there is a space at the beginning of the caption
                
        except:
            print(f'Lack of image {img_name}')

    # save data into pickle file
    # img_name, img_features, caption
    with open(os.path.join(DATA_PATH, 'processed', 'dataset.pkl'), 'wb') as f:
        pickle.dump(results, f)

# Training the Model

In [None]:
import argparse
import os
import random

import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import random_split

import wandb
from data import MiniFlickrDataset, get_loader
from model import Net, Trainer
from utils import ConfigS, ConfigL, LRWarmup



In [None]:
checkpoint_name = 'model_train.pt'
size = 'S'

config = ConfigL() if size.upper() else ConfigS()

# set seed
random.seed(config.seed)
np.random.seed(config.seed)
torch.manual_seed(config.seed)
torch.cuda.manual_seed(config.seed)
torch.backends.cudnn.deterministic = True

if __name__ == '__main__':
    is_cuda = torch.cuda.is_available()
    device = torch.device('cuda' if is_cuda else 'cpu')

    dataset = MiniFlickrDataset(os.path.join('data', 'processed', 'data.pkl'))

    config.train_size = int(config.train_size * len(dataset))
    config.val_size = int(config.val_size * len(dataset))
    config.test_size = len(dataset) - config.train_size - config.val_size

    train_dataset, val_dataset, test_dataset = random_split(dataset, [config.train_size, config.val_size, config.test_size])

    train_loader = get_loader(
        train_dataset, 
        bs_exp=config.batch_size_exp if is_cuda else 2, 
        shuffle=True, 
        num_workers=config.num_workers if is_cuda else 0,
        pin_memory=is_cuda
    )

    valid_loader = get_loader(
        val_dataset, 
        bs_exp=config.batch_size_exp if is_cuda else 2, 
        shuffle=False, 
        num_workers=config.num_workers if is_cuda else 0,
        pin_memory=is_cuda
    )

    model = Net(
        clip_model=config.clip_model,
        text_model=config.text_model,
        ep_len=config.ep_len,
        num_layers=config.num_layers, 
        n_heads=config.n_heads, 
        forward_expansion=config.forward_expansion, 
        dropout=config.dropout, 
        max_len=config.max_len,
        device=device
    )

    optimizer = optim.Adam(model.parameters(), lr=config.lr)

    warmup = LRWarmup(epochs=config.epochs, max_lr=config.lr, k=config.k)

    scheduler = optim.lr_scheduler.LambdaLR(optimizer, warmup.lr_warmup)
    scaler = torch.cuda.amp.GradScaler()    

    ckp_path = os.path.join(config.weights_dir,checkpoint_name)

    trainer = Trainer(
        model=model,
        optimizer=optimizer,
        scaler=scaler,
        scheduler=scheduler,
        train_loader=train_loader,
        valid_loader=valid_loader,
        test_dataset=test_dataset,
        test_path=os.path.join('data', 'raw', 'flickr8k_images'),
        ckp_path=ckp_path,
        device=device
    )

    # build train model process with experiment tracking from wandb
    wandb.init(project='clipXgpt2 captioner', config=config.__dict__)
    wandb.watch(trainer.model, log='all')
    for epoch in range(trainer.epoch, config.epochs):
        trainer.train_epoch()
        trainer.valid_epoch()
        trainer.test_step()

        metadata = trainer.get_training_data()

        # log loss to wandb
        wandb.log({
            'train_loss/loss': metadata['train_loss'][-1],
            'valid_loss/loss': metadata['valid_loss'][-1],
            'lr': metadata['lr'],
            'examples': wandb.Image(metadata['examples']),
        })

        if not os.path.exists(config.weights_dir):
            os.makedirs(config.weights_dir)

        if (epoch + 1) % 6 == 0:
            trainer.save_ckp(os.path.join(config.weights_dir, f'epoch_{epoch + 1}.pt'))