In [None]:
import os
os.chdir('..')

In [None]:
!nvidia-smi

In [None]:
from useful_functions import *
device = 'cuda'
# remember to set GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '2,3'
config = yaml.load(open('configs/Pretrain_4m.yaml', 'r'), Loader=yaml.Loader)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
# data augumentations
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))

pretrain_transform = transforms.Compose([
    transforms.RandomResizedCrop(config['image_res'], scale=(0.2, 1.0),
                                 interpolation=InterpolationMode.BICUBIC),
    transforms.RandomHorizontalFlip(),
    RandomAugment(2, 7, isPIL=True, augs=['Identity', 'AutoContrast', 'Equalize', 'Brightness', 'Sharpness',
                                          'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
    transforms.ToTensor(),
    normalize,
])

# first zoom picture then do corp so that argumentated picture can capture more region
train_transform_zoom_corp = transforms.Compose([
    transforms.Resize((int(config['image_res'] * 1.5), int(config['image_res'] * 1.5)), interpolation=InterpolationMode.BICUBIC),
    transforms.CenterCrop(config['image_res'],# scale=(0.5, 1.0),
                                 #interpolation=InterpolationMode.BICUBIC
                         ),
    #transforms.RandomHorizontalFlip(),
    RandomAugment(2, 7, isPIL=True, augs=['AutoContrast','Identity', 'Brightness', 'Sharpness']),
    transforms.ToTensor(),
    normalize,
])

# original train-transform used by CCLM paper
train_transform = transforms.Compose([
    transforms.CenterCrop(config['image_res'],# scale=(0.5, 1.0),
                                 #interpolation=InterpolationMode.BICUBIC
                         ),
    #transforms.RandomHorizontalFlip(),
    RandomAugment(2, 7, isPIL=True, augs=['AutoContrast','Identity', 'Brightness', 'Sharpness']),
    transforms.ToTensor(),
    normalize,
])

val_transform = transforms.Compose([
    transforms.Resize((config['image_res'], config['image_res']), interpolation=InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    normalize,
])

In [None]:
num_gpus = torch.cuda.device_count()
print(f"Number of available GPUs: {num_gpus}")

for i in range(num_gpus):
    gpu_name = torch.cuda.get_device_name(i)
    print(f"GPU {i}: {gpu_name}")

In [None]:
# load and preprocess dataset
data_path = '/mnt/swordfish-pool2/ccu/amith-cache.pkl'
with open(data_path, 'rb') as handle:
    dataset = pickle.load(handle)

# delete file without jpgs
keys_to_remove = []
for key in dataset.keys():
    if dataset[key]['data_type'] !='video':
        keys_to_remove.append(key)
    elif dataset[key]['processed'] == False:
        keys_to_remove.append(key)
for key in keys_to_remove:
    del dataset[key]

# train val test split
train_dataset = {}
val_dataset = {}
test_dataset = {}
final_eval_dataset = {}
for key in dataset.keys():
    if 'INTERNAL_TRAIN' in dataset[key]['splits']:
        train_dataset.update({key:dataset[key]})
    if 'EVALUATION_LDC2023E07' in dataset[key]['splits']:
        final_eval_dataset.update({key:dataset[key]})
    if 'INTERNAL_VAL' in dataset[key]['splits']:
        val_dataset.update({key:dataset[key]})
    if 'INTERNAL_TEST' in dataset[key]['splits']:
        test_dataset.update({key:dataset[key]})

print(f'Inference batch size {16}')

print(len(dataset), len(train_dataset), len(val_dataset), len(test_dataset), len(final_eval_dataset))

In [None]:
def construct_dataset_csv(dataset,transcribe_name,part = None, use_context = True):
    '''
    'part' is used if you want to seperate a whole dataset into two part
    '''
    column_file_id = []
    column_start_time = []
    column_end_time = []
    column_start_frame = []
    column_end_frame = []
    column_text = []
    column_label = []
    file_ids = list(dataset.keys())
    pass_file_num = 0

    if not use_context:
        for file_id in file_ids:
            file_root_path = dataset[file_id]['processed_dir']
            changepoint_list = [changepoint_dict['timestamp'] for changepoint_dict in dataset[file_id]['changepoints']]
            for idx in range(len(dataset[file_id]['utterances'][transcribe_name])):
                sample = dataset[file_id]['utterances'][transcribe_name][idx]
                if len(sample['video_frames'])>=2:
                    column_file_id.append(file_id)
                    column_start_time.append(sample['start'])
                    column_end_time.append(sample['end'])
                    start_frame_path = os.path.join(file_root_path,sample['video_frames'][0][1])
                    end_frame_path = os.path.join(file_root_path,sample['video_frames'][-1][1])
                    column_start_frame.append(start_frame_path)
                    column_end_frame.append(end_frame_path)
                    if len(sample['text'])>512:
                        column_text.append(sample['text'][:500])
                    else:
                        column_text.append(sample['text'])
                    label = is_a_changepoint(changepoint_list, sample['start'],sample['end'])
                    column_label.append(label)
    else:
        for file_id in file_ids:
            file_root_path = dataset[file_id]['processed_dir']
            changepoint_list = [changepoint_dict['timestamp'] for changepoint_dict in dataset[file_id]['changepoints']]
            total_text = []
            for idx in range(len(dataset[file_id]['utterances'][transcribe_name])):
                sample = dataset[file_id]['utterances'][transcribe_name][idx]
                if len(sample['text'])>512:
                        total_text.append(sample['text'][:300])
                else:
                    total_text.append(sample['text'])
            for idx in range(len(dataset[file_id]['utterances'][transcribe_name])):
                sample = dataset[file_id]['utterances'][transcribe_name][idx]
                total_text
                if len(sample['video_frames'])>=2:
                    column_file_id.append(file_id)
                    column_start_time.append(sample['start'])
                    column_end_time.append(sample['end'])
                    start_frame_path = os.path.join(file_root_path,sample['video_frames'][0][1])
                    end_frame_path = os.path.join(file_root_path,sample['video_frames'][-1][1])
                    column_start_frame.append(start_frame_path)
                    column_end_frame.append(end_frame_path)
                    if idx>=10 and idx<=len(dataset[file_id]['utterances'][transcribe_name])-10:
                        pre_context = ' '.join(total_text[idx-10:idx])
                        cur_context = total_text[idx][:256]
                        post_context = ' '.join(total_text[idx+1:idx+11])
                        left_token_len = 400 - len(cur_context)
                        pre_context = pre_context[max(0,(len(pre_context)-int(left_token_len/2))):]
                        post_context = post_context[:int(left_token_len/2)]
                        final_text = pre_context+'<Pre_Context>'+cur_context+'<Post_Context>'+post_context
                        column_text.append(final_text)
                    elif idx<10:
                        pre_context = ' '.join(total_text[:idx])
                        cur_context = total_text[idx][:256]
                        post_context = ' '.join(total_text[idx+1:idx+11])
                        left_token_len = 400 - len(cur_context)
                        pre_context = pre_context[max(0,(len(pre_context)-int(left_token_len/2))):]
                        post_context = post_context[:int(left_token_len/2)]
                        final_text = pre_context+'<Pre_Context>'+cur_context+'<Post_Context>'+post_context
                        column_text.append(final_text)
                    elif idx>len(dataset[file_id]['utterances'][transcribe_name])-10:
                        pre_context = ' '.join(total_text[idx-10:idx])
                        cur_context = total_text[idx][:256]
                        post_context = ' '.join(total_text[idx+1:])
                        left_token_len = 400 - len(cur_context)

                        pre_context = pre_context[max(0,(len(pre_context)-int(left_token_len/2))):]
                        post_context = post_context[:int(left_token_len/2)]
                        final_text = pre_context+'<Pre_Context>'+cur_context+'<Post_Context>'+post_context
                        column_text.append(final_text)

                    label = is_a_changepoint(changepoint_list, sample['start'],sample['end'])
                    column_label.append(label)

    df = pd.DataFrame({
        'file_id':column_file_id,
        'img_start':column_start_frame,
        'img_end':column_end_frame,
        'text':column_text,
        'label':column_label,
        'time_start':column_start_time,
        'time_end':column_end_time
    })
    if part == 1:
        print(int(len(df)/2))
        df = df.iloc[:int(len(df)/2),:]
    elif part == 2:
        df = df.iloc[int(len(df)/2):,:]
    elif part is None:
        pass
    return df

In [None]:
train_csv = construct_dataset_csv(train_dataset,'whisper',use_context = True)
#val_csv = construct_dataset_csv(val_dataset,'whisper',use_context = True)
#test_csv = construct_dataset_csv(test_dataset,'whisper',use_context = True)

In [None]:
# test_parameters 
train_batch_size_for_test = 32
test_seed = 123

# val_dataset
val_dataset = LDCDataset_val(val_csv, val_transform)


# train_dataloader
# random_corp dataloader
test_data_loader = create_down_sample_dataloader(train_csv, 123, train_batch_size_for_test, 
                              train_transform,1)

# zoom_corp dataloader
#test_data_loader = create_down_sample_dataloader(train_csv, 123, train_batch_size_for_test, 
#                              train_transform_zoom_corp)

In [None]:
for batch in test_data_loader:
    show_a_batch(batch)
    break

In [None]:
# run this chunk when training
# parameters
train_batch_size = 15
eval_batch_size = 512
num_epoch = 100
custom_lr = 0.0002
# original 0.0001

# load model
my_nvlr_model = NLVRModel(config=config)
my_nvlr_model.load_pretrained('data/cclm_4m_epoch_29.th', config,  ## need to match 3m or 4m
                              load_nlvr_pretrain= False, is_eval=False) # load_nlvr_pretrain= False because current checkpoint is CCLM model

tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-large')
#special_tokens = ['<Pre_Context>', '<Post_Context>']
#tokenizer.add_special_tokens({'additional_special_tokens': special_tokens})
#my_nvlr_model.text_encoder.resize_token_embeddings(len(tokenizer))

# training parameters 
world_size = utils.get_world_size()

arg_opt = utils.AttrDict(config['optimizer'])
arg_opt['lr'] = custom_lr
optimizer = create_optimizer(arg_opt, my_nvlr_model)
arg_sche = utils.AttrDict(config['schedular'])
arg_sche['step_per_epoch'] = math.ceil(len(train_dataset)/(train_batch_size*world_size))
lr_scheduler = create_scheduler(arg_sche, optimizer)
log = []
my_nvlr_model = nn.DataParallel(my_nvlr_model)
my_nvlr_model.to(device)

In [None]:
metric_logger = utils.MetricLogger(delimiter="  ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}'))
metric_logger.add_meter('loss', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
 
print_freq = 6
train_accuracys = []
train_recalls = []
train_precisions = []
train_aucs = []
train_auc_precisions = []
train_auc_recalls = []
val_accuracys = []
val_recalls = []
val_precisions = []
val_aucs = []
val_auc_precisions = []
val_auc_recalls = []

top_models = []
# dirs
#os.makedirs('/mnt/swordfish-pool2/kh3074/train_cls_head_first/evaluate_results')
#os.makedirs('/mnt/swordfish-pool2/kh3074/train_cls_head_first/save_models')
#os.makedirs('/mnt/swordfish-pool2/kh3074/train_cls_head_first/trained_cls_head_model')
model_save_dir = '/mnt/swordfish-pool2/kh3074/train_cls_head_first/trained_cls_head_model'
log_save_dir = '/mnt/swordfish-pool2/kh3074/train_cls_head_first/evaluate_results'


print('Start training!!')
for epoch in range(num_epoch):
    # start new epoch of training
    my_nvlr_model.train()
    for param in my_nvlr_model.parameters():
        param.requires_grad = True
    for param in my_nvlr_model.module.cls_head.parameters():
    #for param in my_nvlr_model.cls_head.parameters():
        param.requires_grad = True
    
    train_data_loader = create_down_sample_dataloader(train_csv, epoch, train_batch_size, 
                              train_transform,1,evaluation = False) # use epoch as random seed
    header = 'Train Epoch: [{}]'.format(epoch) 
    for i, (image0, image1, text, targets) in enumerate(metric_logger.log_every(train_data_loader, print_freq, header)):
        images = torch.cat([image0, image1], dim=0)
        images, targets = images.to(device), targets.to(device)   

        text_inputs = tokenizer(text, padding='longest', return_tensors="pt").to(device)  

        loss = my_nvlr_model(images, text_inputs.input_ids, text_inputs.attention_mask, targets=targets, train=True)
        loss = loss.mean() # aggregate loss from different GPUs
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
        metric_logger.update(loss=loss.item())
        
    # start evaluation
    my_nvlr_model.eval() 
    # create eval_Dataloader with fix seed for val dataset
    print('start eval on train dataset ---------------------------------------------')
    train_eval_data_loader = create_down_sample_dataloader(train_csv, 2333, eval_batch_size, 
                              val_transform,1,evaluation = True)
    eval_train_df = eval_on_dataset(my_nvlr_model,train_eval_data_loader,device,tokenizer)
    train_recall, train_precision, train_accuracy, train_pr_auc, train_precision_auc, train_recall_auc = calculate_matrix(eval_train_df)
    
    train_accuracys.append(train_accuracy)
    train_recalls.append(train_recall)
    train_precisions.append(train_precision)
    train_aucs.append(train_pr_auc)
    train_auc_precisions.append(train_precision_auc)
    train_auc_recalls.append(train_recall_auc)
    
    
    print('start eval on validation dataset ---------------------------------------------')
    val_data_loader = create_down_sample_dataloader(val_csv, 2333, eval_batch_size, 
                              val_transform,1,evaluation = True) # fix seed 2333
    eval_val_df = eval_on_dataset(my_nvlr_model,val_data_loader,device,tokenizer)
    val_recall, val_precision, val_accuracy, val_pr_auc, val_precision_auc, val_recall_auc = calculate_matrix(eval_val_df)
    
    val_accuracys.append(val_accuracy)
    val_recalls.append(val_recall)
    val_precisions.append(val_precision)
    val_aucs.append(val_pr_auc)
    val_auc_precisions.append(val_precision_auc)
    val_auc_recalls.append(val_recall_auc)
    
    
    # if auc of val dataset is the top 3, then save the model
    if len(top_models) < 3:
        top_models.append((f'model_tuned_epoch_{epoch}', val_pr_auc))
    else:
        top_models.append((f'model_tuned_epoch_{epoch}', val_pr_auc))
        top_models.sort(key=lambda x: x[1], reverse=True)
        
        if os.path.exists(os.path.join(model_save_dir,top_models[-1][0])):
            os.remove(os.path.join(model_save_dir,top_models[-1][0]))
        top_models.pop(3)
        for i in range(3):
            if os.path.exists(os.path.join(model_save_dir,top_models[i][0])):
                pass
            else:
                torch.save(my_nvlr_model.state_dict(), os.path.join(model_save_dir,top_models[i][0]))
    torch.cuda.empty_cache()

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger.global_avg())     
    log.append({k: "{:.5f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()})
    
    #vis
    train_result_df = pd.DataFrame({'train_accuracys':train_accuracys,
                         'train_recalls':train_recalls ,
                         'train_precisions':train_precisions ,
                         'train_aucs':train_aucs ,
                         'train_auc_precisions':train_auc_precisions ,
                         'train_auc_recalls':train_auc_recalls ,
                         'val_accuracys':val_accuracys ,
                         'val_recalls':val_recalls ,
                         'val_precisions':val_precisions ,
                         'val_aucs':val_aucs ,
                         'val_auc_precisions':val_auc_precisions ,
                         'val_auc_recalls':val_auc_recalls })
    
    fig, ax = plt.subplots(1,4, figsize = (15,5))
    epoch_idx = range(len(train_result_df))

    ax[0].plot(epoch_idx,train_result_df.train_accuracys)
    ax[0].set_title('accuracy')

    ax[1].plot(epoch_idx,train_result_df.train_recalls)
    ax[1].set_title('recall')

    ax[2].plot(epoch_idx,train_result_df.train_precisions)
    ax[2].set_title('precision')

    ax[3].plot(epoch_idx,train_result_df.train_aucs)
    ax[3].set_title('auc')


    ax[0].plot(epoch_idx,train_result_df.val_accuracys)
    ax[1].plot(epoch_idx,train_result_df.val_recalls)
    ax[2].plot(epoch_idx,train_result_df.val_precisions)
    ax[3].plot(epoch_idx,train_result_df.val_aucs)
    plt.show()
    
    train_result_df.to_csv(os.path.join(log_save_dir,'train_transform_zoom_corp'))