# Teacher Model Demo

In [None]:
import av
import numpy as np
import os
from PIL import Image
from huggingface_hub import hf_hub_download
from transformers import AutoProcessor, AutoModelForCausalLM
import torch
from math import floor
from transformers import GitForCausalLM, GitConfig, BertTokenizer

from generativeimage2text.torch_common import torch_load, load_state_dict
from generativeimage2text.model import get_git_model
from generativeimage2text.tsv_io import TSVFile, tsv_writer, tsv_reader
from generativeimage2text.inference import get_image_transform

param = {"num_image_with_embedding":15}
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
# image_tsv = TSVFile(image_tsv)
transforms = get_image_transform(param)

model = get_git_model(tokenizer, param)
model

In [None]:
import av
import numpy as np
import os
from PIL import Image
from huggingface_hub import hf_hub_download
from transformers import AutoProcessor, AutoModelForCausalLM
import torch
from math import floor
from transformers import GitForCausalLM, GitConfig, BertTokenizer

from generativeimage2text.torch_common import torch_load, load_state_dict
from generativeimage2text.model import get_git_model
from generativeimage2text.tsv_io import TSVFile, tsv_writer, tsv_reader
from generativeimage2text.inference import get_image_transform

param = {"num_image_with_embedding":15}
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
# image_tsv = TSVFile(image_tsv)
transforms = get_image_transform(param)

model = get_git_model(tokenizer, param)
pretrained = f'model/GIT_BASE_MSRVTT.pt'
checkpoint = torch_load(pretrained)['model']
load_state_dict(model, checkpoint)
model.cuda()
model.eval()
# model.config.num_image_with_embedding = 20

# set seed for reproducability
np.random.seed(45)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def read_video_pyav(container, indices):
    '''
    Decode the video with PyAV decoder.
    Args:
        container (`av.container.input.InputContainer`): PyAV container.
        indices (`List[int]`): List of frame indices to decode.
    Returns:
        result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
    '''
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]
    for i, frame in enumerate(container.decode(video=0)):
        # transfer to PIL image
        frame = frame.to_image()
        if i > end_index:
            break
        if i >= start_index and i in indices:
            frames.append(frame)
    frames = [transforms(i) for i in frames]
    # return np.stack([x.to_ndarray(format="rgb24") for x in frames])
    return frames


def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
    '''
    Sample a given number of frame indices from the video.
    Args:
        clip_len (`int`): Total number of frames to sample.
        frame_sample_rate (`int`): Sample every n-th frame.
        seg_len (`int`): Maximum allowed index of sample's last frame.
    Returns:
        indices (`List[int]`): List of sampled frame indices
    '''
    converted_len = int(clip_len * frame_sample_rate)
    end_idx = np.random.randint(converted_len, seg_len)
    start_idx = end_idx - converted_len
    indices = np.linspace(start_idx, end_idx, num=clip_len)
    indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
    return indices


# load video
# file_path = hf_hub_download(
#     repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
# )
file_path="../../../USC/CSCI567/project/datasets/MSRVTT/videos/all/video0.mp4"

container = av.open(file_path)

# sample frames
# model.config.num_image_with_embedding = 20
num_frames = model.num_image_with_embedding
indices = sample_frame_indices(
    clip_len=num_frames, frame_sample_rate=4, seg_len=container.streams.video[0].frames
    # clip_len=num_frames, frame_sample_rate=floor(container.streams.video[0].frames/num_frames), seg_len=container.streams.video[0].frames
)
frames = read_video_pyav(container, indices)
frames = [i.unsqueeze(0).cuda() for i in frames]

print("len(frames) = ", len(frames))
print("frames[0].shape = ", frames[0].shape)
# prefix
max_text_len = 50
prefix_encoding = tokenizer("",
                            padding='do_not_pad',
                            truncation=True,
                            add_special_tokens=False,
                            max_length=max_text_len)
payload = prefix_encoding['input_ids']
if len(payload) > max_text_len - 2:
    payload = payload[-(max_text_len - 2):]
input_ids = [tokenizer.cls_token_id] + payload

# with torch.no_grad():
#     features = model.get_image_feature({
#         'image': frames,
#         'prefix': torch.tensor(input_ids).unsqueeze(0).cuda(),
#     })
# print("features.shape = ",features.shape)

with torch.no_grad():
    result = model({
        'image': frames,
        'prefix': torch.tensor(input_ids).unsqueeze(0).cuda(),
    })
cap = tokenizer.decode(result['predictions'][0].tolist(), skip_special_tokens=True)
cap

## number of parameters

In [None]:
pytorch_total_params = sum(p.numel() for p in model.parameters())
pytorch_total_params

# logits extraction

In [None]:
import av
import numpy as np
import os
import random
from tqdm import tqdm,trange
from PIL import Image
from huggingface_hub import hf_hub_download
from transformers import AutoProcessor, AutoModelForCausalLM
import torch

import pickle
from math import floor
from transformers import GitForCausalLM, GitConfig, BertTokenizer

from generativeimage2text.torch_common import torch_load, load_state_dict
from generativeimage2text.model import get_git_model
from generativeimage2text.tsv_io import TSVFile, tsv_writer, tsv_reader
from generativeimage2text.inference import get_image_transform

param = {"num_image_with_embedding":15}
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
# image_tsv = TSVFile(image_tsv)
transforms = get_image_transform(param)

model = get_git_model(tokenizer, param)
pretrained = f'model/GIT_BASE_MSRVTT.pt'
checkpoint = torch_load(pretrained)['model']
load_state_dict(model, checkpoint)
model.cuda()
model.eval()
# model.config.num_image_with_embedding = 20


num_frames = model.num_image_with_embedding

# set seed for reproducability
np.random.seed(45)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def read_video_pyav(container, indices):
    '''
    Decode the video with PyAV decoder.
    Args:
        container (`av.container.input.InputContainer`): PyAV container.
        indices (`List[int]`): List of frame indices to decode.
    Returns:
        result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
    '''
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]
    for i, frame in enumerate(container.decode(video=0)):
        # transfer to PIL image
        frame = frame.to_image()
        if i > end_index:
            break
        if i >= start_index and i in indices:
            frames.append(frame)
    frames = [transforms(i) for i in frames]
    # return np.stack([x.to_ndarray(format="rgb24") for x in frames])
    return frames
def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
    '''
    Sample a given number of frame indices from the video.
    Args:
        clip_len (`int`): Total number of frames to sample.
        frame_sample_rate (`int`): Sample every n-th frame.
        seg_len (`int`): Maximum allowed index of sample's last frame.
    Returns:
        indices (`List[int]`): List of sampled frame indices
    '''
    converted_len = int(clip_len * frame_sample_rate)
    end_idx = np.random.randint(converted_len, seg_len)
    start_idx = end_idx - converted_len
    indices = np.linspace(start_idx, end_idx, num=clip_len)
    indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
    return indices


# load video
file_path="../../../USC/CSCI567/project/datasets/MSRVTT/videos/all/"
#file_path = "dataset/MSRVTT/data/MSRVTT/videos/all"
max_text_len = 50
prefix_encoding = tokenizer("",
                            padding='do_not_pad',
                            truncation=True,
                            add_special_tokens=False,
                            max_length=max_text_len)
payload = prefix_encoding['input_ids']
if len(payload) > max_text_len - 2:
    payload = payload[-(max_text_len - 2):]
input_ids = [tokenizer.cls_token_id] + payload

batch_size=1

# video_logits_batch = torch.zeros((batch_size*4, 30522)).cuda()
batch_count = 0
#cur_idx = 0
frames_batch_list = []
for i in range(param['num_image_with_embedding']):
    frames_batch_list.append(torch.empty((batch_size, 3, 224, 224)).cuda())

filename_list=os.listdir(file_path)
filename_list.sort()
#random.shuffle(filename_list)
video_name_list=[]
cur_idx = 0

save_path = "/../mnt/e/datasets/MSRVTT/video_logits_batch"

for filename in tqdm(filename_list[3000:4000]):
    video_path = os.path.join(file_path, filename)
    if not os.path.isfile(video_path) or not filename.endswith(('.mp4', '.avi', '.mkv')): 
        continue
    video_path = os.path.join(file_path, filename)
    container = av.open(video_path)
    
    if num_frames*4>=container.streams.video[0].frames:
        # print('num-frame: ', container.streams.video[0].frames)
        continue
    video_name_list.append(filename)

    indices = sample_frame_indices(
        clip_len=num_frames, frame_sample_rate=4, seg_len=container.streams.video[0].frames
    )
    
    frames = read_video_pyav(container, indices)
    frames = [i.unsqueeze(0).cuda() for i in frames]
    for i in range(param['num_image_with_embedding']):
        frames_batch_list[i][batch_count, :, :, :] = frames[i]
    container.close()
    
    with torch.no_grad():
        result = model.get_video_logits({
            'image': frames_batch_list,
            'prefix': torch.tensor(input_ids).unsqueeze(0).cuda(),
            # 'prefix': torch.tensor([input_ids]*batch_size).cuda(),
        })
    # video_logits_batch = result["logits"]

    with open(f"{save_path}/{filename}.pkl", "wb") as f:
        # print(video_feature_batch)
        pickle.dump(result, f)

# feacher extraction

In [None]:
import av
import numpy as np
import os
import random
from tqdm import tqdm
from PIL import Image
from huggingface_hub import hf_hub_download
from transformers import AutoProcessor, AutoModelForCausalLM
import torch

import pickle
from math import floor
from transformers import GitForCausalLM, GitConfig, BertTokenizer

from generativeimage2text.torch_common import torch_load, load_state_dict
from generativeimage2text.model import get_git_model
from generativeimage2text.tsv_io import TSVFile, tsv_writer, tsv_reader
from generativeimage2text.inference import get_image_transform

param = {"num_image_with_embedding":15}
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
# image_tsv = TSVFile(image_tsv)
transforms = get_image_transform(param)

model = get_git_model(tokenizer, param)
pretrained = f'model/GIT_BASE_MSRVTT.pt'
checkpoint = torch_load(pretrained)['model']
load_state_dict(model, checkpoint)
model.cuda()
model.eval()
# model.config.num_image_with_embedding = 20


num_frames = model.num_image_with_embedding

# set seed for reproducability
np.random.seed(45)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def read_video_pyav(container, indices):
    '''
    Decode the video with PyAV decoder.
    Args:
        container (`av.container.input.InputContainer`): PyAV container.
        indices (`List[int]`): List of frame indices to decode.
    Returns:
        result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
    '''
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]
    for i, frame in enumerate(container.decode(video=0)):
        # transfer to PIL image
        frame = frame.to_image()
        if i > end_index:
            break
        if i >= start_index and i in indices:
            frames.append(frame)
    frames = [transforms(i) for i in frames]
    # return np.stack([x.to_ndarray(format="rgb24") for x in frames])
    return frames
def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
    '''
    Sample a given number of frame indices from the video.
    Args:
        clip_len (`int`): Total number of frames to sample.
        frame_sample_rate (`int`): Sample every n-th frame.
        seg_len (`int`): Maximum allowed index of sample's last frame.
    Returns:
        indices (`List[int]`): List of sampled frame indices
    '''
    converted_len = int(clip_len * frame_sample_rate)
    end_idx = np.random.randint(converted_len, seg_len)
    start_idx = end_idx - converted_len
    indices = np.linspace(start_idx, end_idx, num=clip_len)
    indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
    return indices


# load video
file_path="../../../USC/CSCI567/project/datasets/MSRVTT/videos/all/"
#file_path = "dataset/MSRVTT/data/MSRVTT/videos/all"
max_text_len = 50
prefix_encoding = tokenizer("",
                            padding='do_not_pad',
                            truncation=True,
                            add_special_tokens=False,
                            max_length=max_text_len)
payload = prefix_encoding['input_ids']
if len(payload) > max_text_len - 2:
    payload = payload[-(max_text_len - 2):]
input_ids = [tokenizer.cls_token_id] + payload

batch_size=20
pca_video_size=500

video_feature_batch = torch.zeros((pca_video_size, 197*param['num_image_with_embedding'], 768)).cuda()
batch_count = 0
#cur_idx = 0
frames_batch_list = []
for i in range(param['num_image_with_embedding']):
    frames_batch_list.append(torch.empty((batch_size, 3, 224, 224)).cuda())

filename_list=os.listdir(file_path)
#random.shuffle(filename_list)
video_name_list=[]
cur_idx = 0
print(filename_list[cur_idx])

save_path = "/../mnt/e/datasets/MSRVTT/video_feature_batch"
cur_idx = 0
for cur_idx in range(int(10000/pca_video_size)):
    filename_list_100 = filename_list[cur_idx*pca_video_size:cur_idx*pca_video_size+pca_video_size]
    vid_batch_idx = 0
    batch_count = 0
    for filename in tqdm(filename_list_100):
        if vid_batch_idx>=pca_video_size:
            break
        video_path = os.path.join(file_path, filename)
        if not os.path.isfile(video_path) or not filename.endswith(('.mp4', '.avi', '.mkv')): 
            continue
        if batch_count<batch_size:
            video_path = os.path.join(file_path, filename)
            container = av.open(video_path)
            # print(container.name)
            # sample frames
            if num_frames*4>=container.streams.video[0].frames:
                print('num-frame: ', container.streams.video[0].frames)
                video_feature_batch[vid_batch_idx-batch_size+1:vid_batch_idx+1,:,:] = torch.zeros((batch_size, 197*param['num_image_with_embedding'], 768)).cuda()
                continue
            video_name_list.append(filename)

            indices = sample_frame_indices(
                clip_len=num_frames, frame_sample_rate=4, seg_len=container.streams.video[0].frames
            )
            
            frames = read_video_pyav(container, indices)
            frames = [i.unsqueeze(0).cuda() for i in frames]
            for i in range(param['num_image_with_embedding']):
                frames_batch_list[i][batch_count, :, :, :] = frames[i]
            container.close()

            if batch_count==batch_size-1:
                #print('bc:',batch_count, vid_batch_idx, cur_idx)
                with torch.no_grad():
                    features_batch = model.get_image_feature({
                        'image': frames_batch_list,
                        'prefix': torch.tensor(input_ids).unsqueeze(0).cuda(),
                    })
                #print(features_batch.shape)
                #print(cur_idx)
                video_feature_batch[vid_batch_idx-batch_size+1:vid_batch_idx+1,:,:] = features_batch
                #print(features_batch[-1,0,0:10])
                #print(video_feature_batch[cur_idx,0,0:10])
                batch_count=-1

        batch_count+=1
        vid_batch_idx+=1
        
    print("haha")
    with open(f"{save_path}/video_feature_batch_{cur_idx}.pkl", "wb") as f:
        #print(video_feature_batch)
        pickle.dump(video_feature_batch.to("cpu"), f)
    with open(f"{save_path}/video_name_batch_{cur_idx}.pkl", "wb") as f:
        pickle.dump(video_name_list, f)