In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install -q git+https://github.com/openai/CLIP.git transformers bitsandbytes

[K     |████████████████████████████████| 5.8 MB 33.7 MB/s 
[K     |████████████████████████████████| 62.5 MB 1.2 MB/s 
[K     |████████████████████████████████| 53 kB 1.9 MB/s 
[K     |████████████████████████████████| 7.6 MB 50.0 MB/s 
[K     |████████████████████████████████| 182 kB 77.8 MB/s 
[?25h  Building wheel for clip (setup.py) ... [?25l[?25hdone


In [None]:
!cp -r /content/drive/MyDrive/Olimpiads/nto_hack_2022/V2/data /content/
!unzip -q /content/drive/MyDrive/Olimpiads/nto_hack_2022/V2/videos_train.zip 
!mv videos_train /content/data

!cp -i /content/drive/MyDrive/Olimpiads/nto_hack_2022/V2/config.py /content/
!cp -i /content/drive/MyDrive/Olimpiads/nto_hack_2022/V2/utils.py /content/
!cp -i /content/drive/MyDrive/Olimpiads/nto_hack_2022/V2/model.py /content/

In [None]:
import bitsandbytes as bnb
import gc
import io
import os
import random
import numpy as np
from tqdm import tqdm
import pandas as pd
import cv2
import sys

import torchvision
import transformers
import torch
from torch.nn import functional as nnf
from torch.utils.data import DataLoader
from transformers import AdamW, get_linear_schedule_with_warmup
from torch.cuda.amp import autocast

from utils import *
from model import *
from config import CFG

from transformers.optimization import Adafactor, AdafactorSchedule
from torch.utils.checkpoint import checkpoint_sequential
import warnings
warnings.simplefilter('ignore')

import time

In [None]:
CFG.batch_size

24

In [None]:
def get_caption(prefix, model, device, tokenizer, prompt=''):
    prefix = prefix.to(device)
    with torch.no_grad():

        prefix_embed = model.clip_project(prefix).reshape(len(prefix), CFG.prefix_length, -1)

        answers = []

        for x in range(len(prefix_embed)):

            start = time.time()

            cur_prefix_embed = prefix_embed[x].unsqueeze(0).to('cpu')
            
            print('first', time.time()-start)
            
            if prompt:
                generated_text_prefix = generate2(model, tokenizer, prompt=prompt, embed=cur_prefix_embed)
            else:
                generated_text_prefix = generate2(model, tokenizer, embed=cur_prefix_embed)
            
            print('second', time.time()-start)
        
            answers.append(generated_text_prefix.replace('\n',' ').replace('<|endoftext|',''))

    return [x[len(prompt):].strip() for x in answers]

# def get_ans(model, clip_emb, prompt, device, tokenizer):
#     output = get_caption(clip_emb, model, device, tokenizer, prompt=prompt)
#     return output

In [None]:
def train(train_loader, model, optimizer, scheduler, device, epoch):
    loss_avg = AverageMeter()
    
    model = model.to(device)
    model.train()

    progress = tqdm(total=len(train_loader))
    for idx, (tokens, mask, prefix) in enumerate(train_loader):
        model.zero_grad()
        tokens, mask, prefix = tokens.to(device), mask.to(device), prefix.to(device, dtype=torch.float32)
        
        outputs = model(tokens, prefix, mask)
        logits = outputs.logits[:, CFG.prefix_length-1: -1]

        loss = nnf.cross_entropy(logits.reshape(-1, logits.shape[-1]), tokens.flatten(), ignore_index=0)

        segments = 2

        # out = checkpoint_sequential(modules, segments, input_var)

        loss.backward()    
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        
        clipping_value = 0.5 # arbitrary value of your choosing
        #torch.nn.utils.clip_grad_norm(model.parameters(), clipping_value)

        loss_avg.update(loss.item(), len(mask))
        progress.set_description(f"loss: {loss_avg.avg:.5f}")
        progress.update()
        

        # del tokens
        # del mask
        # del prefix
        torch.clear_autocast_cache()
        torch.cuda.empty_cache()
    progress.close()

    return model


def valid(model, valid_loader, device, gt):
    loss_avg = AverageMeter()
    model.eval()
    progress = tqdm(enumerate(valid_loader), total=len(valid_loader))
    tokenizer = GPT2Tokenizer.from_pretrained(CFG.backbone)
    

    all_answers = []
    for idx, (tokens, mask, prefix) in progress:
        tokens, mask, prefix = tokens.to(device), mask.to(device), prefix.to(device, dtype=torch.float32)
        answer = get_caption(prefix, model, device, tokenizer, prompt='Caption: ')
        all_answers.append(answer)
    score = bleu_metric(gt, np.concatenate(all_answers))
    return score
    

import nltk
def bleu_metric(ground_truth, prediction):
    scores = []
    for gt, pred in zip(ground_truth, prediction):
        if type(pred)==str and type(gt)==str:
            score = nltk.translate.bleu_score.sentence_bleu([gt.lower().split()], pred.lower().replace('<|endoftext|>','').split(), weights = (0.5, 0.5))
        scores+=[score]
    return np.array(scores).mean()*100

In [None]:
def main():
    valid_df = pd.read_csv(CFG.valid_df_path)
    train_ds = ClipCocoDataset(CFG.train_features_path, CFG.prefix_length)
    valid_ds = ClipCocoDataset(CFG.valid_features_path, CFG.prefix_length)

    train_loader = DataLoader(train_ds, batch_size=CFG.batch_size, shuffle=True, drop_last=True)
    valid_loader = DataLoader(train_ds, batch_size=CFG.batch_size, shuffle=False)

    model = ClipCaptionModel(prefix_length = CFG.prefix_length, backbone = CFG.backbone)
    device = torch.device('cuda') # xm.xla_device()

    if not os.path.exists(CFG.out_dir):
        os.makedirs(CFG.out_dir)

    # model.load_state_dict(torch.load('/content/drive/MyDrive/Olimpiads/nto_hack_2022/V2/coco_flickr-pretrained.pt', map_location='cpu')) 

    model.load_state_dict(torch.load('/content/drive/MyDrive/Olimpiads/nto_hack_2022/V2/v2_1.pt', map_location='cpu'))
    model = model.to(device)
   
    #model = freeze(model)

    model.train()
    optimizer = AdamW(model.parameters(),lr=CFG.learning_rate, betas=(0.9, 0.995))
    #optimizer = bnb.optim.Adam8bit(model.parameters(), lr=0.001, betas=(0.9, 0.995))
    #optimizer = SM3(model.parameters(),lr=args.lr)
    #Adafactor(model.parameters(),scale_parameter=True, relative_step=True, warmup_init=True, lr=None)

    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=CFG.warmup_steps, num_training_steps=CFG.epochs * len(train_loader))
    #AdafactorSchedule(optimizer)#num_training_steps=epochs * len(train_loader)

    for epoch in range(1, 1+CFG.epochs):
        train(train_loader, model, optimizer, scheduler, device, epoch)
        valid(model, valid_loader, device, valid_df.caption.tolist()) 

        if epoch % CFG.save_every==0:
            torch.save(model.state_dict(),os.path.join(CFG.out_dir, f"{CFG.model_name}.pt"))
    
main()

### Validation

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as nnf
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from tqdm import tqdm, trange
import os
import pickle
import sys
import argparse
import json
from typing import Tuple, Optional, Union
#from torch.cuda.amp import autocast
import io
import os
import PIL
import random
import numpy as np
import torch
import torchvision
import transformers
import more_itertools
import numpy as np
import matplotlib.pyplot as plt
#from tqdm import tqdm
import pandas as pd
from torch.utils.data import Dataset
#from tqdm import tqdm
from dataclasses import dataclass, field
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import cv2
from PIL import Image
import clip

import transformers

from utils import *
from model import *
import re

import warnings
warnings.simplefilter('ignore')

import time

In [None]:
def get_caption(prefix, prompt=''):
        prefix = prefix.to(device)
        display(prefix.shape)
        with torch.no_grad():

            prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)

            start = time.time()

            if prompt:
                generated_text_prefix = generate2(model, tokenizer, prompt=prompt, embed=prefix_embed)
            else:
                generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed)

            print(time.time() - start)

        return generated_text_prefix.replace('\n',' ').replace('<|endoftext|','')

def get_ans(clip_emb, prompt):
        output = get_caption(clip_emb, prompt=prompt)
        ans = output[len(prompt):].strip()
        return {'answer': ans}


#from tqdm import tqdm, trange



# parser = argparse.ArgumentParser()
# parser.add_argument('--input_path', default='./input_test.csv', type=str, help='input path')
# parser.add_argument('--video_path', default='./videos_val/', type=str, help='input path')
# parser.add_argument('--output_path', default='./output/', type=str, help='config path')
# args = parser.parse_args()

config = dict(
        model_path = '/content/drive/MyDrive/Olimpiads/nto_hack_2022/V2/v2_1.pt',
        video_path = '/content/data/videos_train/',
        val_path = '/content/data/new_valid.csv',
        gpt = 'gpt2',
        prefix_len = 35
    )


prefix_length = config['prefix_len']#40

device = 'cuda'
clip_model, preprocess = clip.load("ViT-L/14@336px", device=device, jit=False)
clip_model.to(device)

tokenizer = GPT2Tokenizer.from_pretrained(config['gpt'])


model_path = config['gpt']
model = ClipCaptionModel(prefix_length = prefix_length, backbone = config['gpt'])

model.load_state_dict(torch.load(config['model_path'], map_location='cpu'))
model.to(device)

out_path = 'Features_val.pkl'


val_embeddings = []
val_captions = []

input_test = pd.read_csv(config['val_path'])

c = 0
for p in tqdm(input_test.paths):
         #print(p)
         #n= df_eval.iloc[i, 0]#, df_eval.iloc[i, 1]
    text = f'Caption:'
    path = f'{config["video_path"]}{p}'
    try:
        video = read_video(path, transform=None,frames_num=1)

        i = image_grid(video,1,1)
        image = preprocess(i).unsqueeze(0).to(device)

        # image = make_images(video)
        # new_images = []
        # for img in image:
        #     new_images.append(preprocess(img).unsqueeze(0))

        with torch.no_grad():
            prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
            # prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
        val_embeddings.append(prefix)
        val_captions.append(text)
    except Exception as e:
        print(e)

    c+=1
    if c > 50:
        break

answers = []
for i in tqdm(range(len(val_embeddings))):
        emb = val_embeddings[i]

        #qid = df_eval.iloc[i, 2]
        ans = get_ans(emb, 'Caption: ')
        answers.append(ans['answer'])


df = pd.DataFrame({'captions':answers})
df.to_csv('/content/data/answer.csv', index=False)

  4%|▍         | 50/1131 [00:33<12:12,  1.48it/s]
  0%|          | 0/51 [00:00<?, ?it/s]

torch.Size([1, 768])

  2%|▏         | 1/51 [00:00<00:15,  3.23it/s]

0.3034696578979492


torch.Size([1, 768])

  4%|▍         | 2/51 [00:00<00:09,  5.35it/s]

0.09210205078125


torch.Size([1, 768])

  6%|▌         | 3/51 [00:00<00:11,  4.03it/s]

0.31274938583374023


torch.Size([1, 768])

0.08498620986938477


torch.Size([1, 768])

 10%|▉         | 5/51 [00:01<00:09,  4.92it/s]

0.23816752433776855


torch.Size([1, 768])

 12%|█▏        | 6/51 [00:01<00:10,  4.39it/s]

0.2779276371002197


torch.Size([1, 768])

 14%|█▎        | 7/51 [00:01<00:08,  4.96it/s]

0.1304011344909668


torch.Size([1, 768])

 16%|█▌        | 8/51 [00:01<00:11,  3.85it/s]

0.39212965965270996


torch.Size([1, 768])

 18%|█▊        | 9/51 [00:02<00:09,  4.57it/s]

0.11242198944091797


torch.Size([1, 768])

 20%|█▉        | 10/51 [00:02<00:07,  5.43it/s]

0.09522724151611328


torch.Size([1, 768])

 22%|██▏       | 11/51 [00:02<00:10,  3.85it/s]

0.4285433292388916


torch.Size([1, 768])

 24%|██▎       | 12/51 [00:03<00:15,  2.51it/s]

0.7149248123168945


torch.Size([1, 768])

 25%|██▌       | 13/51 [00:03<00:13,  2.79it/s]

0.2564268112182617


torch.Size([1, 768])

 27%|██▋       | 14/51 [00:03<00:12,  2.99it/s]

0.2703378200531006


torch.Size([1, 768])

 29%|██▉       | 15/51 [00:04<00:15,  2.32it/s]

0.6493005752563477


torch.Size([1, 768])

0.07866024971008301


torch.Size([1, 768])

 33%|███▎      | 17/51 [00:04<00:09,  3.45it/s]

0.15224575996398926


torch.Size([1, 768])

 35%|███▌      | 18/51 [00:05<00:13,  2.43it/s]

0.7770967483520508


torch.Size([1, 768])

 37%|███▋      | 19/51 [00:05<00:10,  2.93it/s]

0.12859082221984863


torch.Size([1, 768])

 39%|███▉      | 20/51 [00:06<00:11,  2.72it/s]

0.43067240715026855


torch.Size([1, 768])

 41%|████      | 21/51 [00:06<00:11,  2.72it/s]

0.3566253185272217


torch.Size([1, 768])

 43%|████▎     | 22/51 [00:06<00:08,  3.22it/s]

0.1589794158935547


torch.Size([1, 768])

 45%|████▌     | 23/51 [00:06<00:07,  3.54it/s]

0.20769381523132324


torch.Size([1, 768])

 47%|████▋     | 24/51 [00:07<00:06,  3.97it/s]

0.17076396942138672


torch.Size([1, 768])

 49%|████▉     | 25/51 [00:07<00:07,  3.70it/s]

0.306041955947876


torch.Size([1, 768])

0.07049560546875


torch.Size([1, 768])

 53%|█████▎    | 27/51 [00:07<00:04,  5.24it/s]

0.10656166076660156


torch.Size([1, 768])

 55%|█████▍    | 28/51 [00:07<00:05,  4.10it/s]

0.40012168884277344


torch.Size([1, 768])

 57%|█████▋    | 29/51 [00:08<00:04,  4.68it/s]

0.12095379829406738


torch.Size([1, 768])

 59%|█████▉    | 30/51 [00:08<00:05,  4.01it/s]

0.33652377128601074


torch.Size([1, 768])

 61%|██████    | 31/51 [00:08<00:04,  4.02it/s]

0.23685932159423828


torch.Size([1, 768])

 63%|██████▎   | 32/51 [00:08<00:05,  3.70it/s]

0.3125741481781006


torch.Size([1, 768])

 65%|██████▍   | 33/51 [00:09<00:05,  3.38it/s]

0.3499138355255127


torch.Size([1, 768])

 67%|██████▋   | 34/51 [00:09<00:04,  3.84it/s]

0.1642293930053711


torch.Size([1, 768])

0.0750422477722168


torch.Size([1, 768])

 71%|███████   | 36/51 [00:09<00:02,  5.35it/s]

0.10628890991210938


torch.Size([1, 768])

 73%|███████▎  | 37/51 [00:10<00:03,  4.26it/s]

0.37145566940307617


torch.Size([1, 768])

 75%|███████▍  | 38/51 [00:10<00:03,  3.68it/s]

0.36944079399108887


torch.Size([1, 768])

 76%|███████▋  | 39/51 [00:10<00:02,  4.08it/s]

0.16354870796203613


torch.Size([1, 768])

 78%|███████▊  | 40/51 [00:10<00:02,  4.43it/s]

0.1637716293334961


torch.Size([1, 768])

 80%|████████  | 41/51 [00:11<00:02,  4.63it/s]

0.1833963394165039


torch.Size([1, 768])

 82%|████████▏ | 42/51 [00:11<00:02,  3.89it/s]

0.35005927085876465


torch.Size([1, 768])

 84%|████████▍ | 43/51 [00:11<00:02,  3.97it/s]

0.2335970401763916


torch.Size([1, 768])

 86%|████████▋ | 44/51 [00:11<00:01,  4.50it/s]

0.14165067672729492


torch.Size([1, 768])

 88%|████████▊ | 45/51 [00:11<00:01,  5.17it/s]

0.11670351028442383


torch.Size([1, 768])

 90%|█████████ | 46/51 [00:12<00:00,  5.73it/s]

0.12346720695495605


torch.Size([1, 768])

 92%|█████████▏| 47/51 [00:12<00:00,  6.34it/s]

0.1095726490020752


torch.Size([1, 768])

0.07436227798461914


torch.Size([1, 768])

 96%|█████████▌| 49/51 [00:12<00:00,  5.16it/s]

0.3842051029205322


torch.Size([1, 768])

 98%|█████████▊| 50/51 [00:12<00:00,  5.68it/s]

0.11292529106140137


torch.Size([1, 768])

100%|██████████| 51/51 [00:12<00:00,  3.98it/s]

0.08947086334228516





In [None]:
path_pred = '/content/data/answer.csv'
df_pred = pd.read_csv(path_pred)
new_preds = []
for x in df_pred['captions']:
    a = ''
    if x is not np.nan:
        for y in x:
            if re.findall(r'[A-Za-z0-9 \.\,\-]', y):
                a += y
    new_preds.append(a)
df_pred['captions'] = new_preds
df_pred.to_csv('/content/data/new_answer.csv', index=False)

In [None]:
import pandas as pd
import nltk
import numpy as np

path_gt = '/content/data/new_valid.csv'
path_pred = '/content/data/answer.csv'
df_eval = pd.read_csv(path_gt)
df_pred = pd.read_csv(path_pred)
scores = []
for pred, gt in zip(df_eval.caption, df_pred.captions):
    if type(pred)==str and type(gt)==str:
        score = nltk.translate.bleu_score.sentence_bleu([gt.lower().split()], pred.lower().replace('<|endoftext|>','').split(), weights = (0.5, 0.5))
    scores+=[score]
ans = np.array(scores).mean()*100
print(round(ans, 5))

7.11994


In [None]:
for i in range(50):
    print(df_eval.caption.tolist()[i])
    print(df_pred.captions.tolist()[i])
    print()

Beautiful blue and yellow flowers in an arrangement close-up, pan shot.
Bouquet of little pink flowers of different colors and sizes

Dirty car bumper
Cars driving in a highway

Girl running on the street with masks, front view and close to the shot, with the panorama of the street out of focus in the background, with cars, trees, the sidewalk and the sun lighting up.
A woman in a black mask runs through the park with her hands on her feet, surrounded by trees and buildings.

Woman with a hat watching the ocean
A calm sea and mountains

Married couple receiving at the door of their house another neighbor couple who arrived with welcome gifts, while smiling happily.
Happy businessman gives a thumbs up as he hugs his girlfriend during a break on a date.

View under a bridge that crosses a river in a big city, where some people go sailing by boat underneath.
A group of people rowing on a river surrounded by trees and buildings, during a sunny day.

Guangzhou illuminated cityscape with clo