In [1]:
import os
# Getting all the models off my home directory for space concerns
os.environ['TRANSFORMERS_CACHE'] = '/projectnb/sparkgrp/colejh'

# Imports
import pandas as pd
import multiprocessing as mp
from tqdm import tqdm_notebook as tqdm
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
import random
import matplotlib.pyplot as plt
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from torch.utils.data import DataLoader
import torch.nn as nn
import glob
from transformers import AdamW
import matplotlib.pyplot as plt
import wandb
import torch.optim as optim

#suppressing all the huggingface warnings
SUPPRESS = True
if SUPPRESS:
    from transformers.utils import logging
    logging.set_verbosity(40)

#others
#ignoring UserWarning and FutureWarning
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

In [2]:
def catch_invalid(img):
    try:
        a = Image.open(img)
        a.close()
    except Exception as e:
        print(e)
        return False
    return True

# Check for valid images

def catch_invalid2(img):
    try:
        a = Image.open(img)
        a.close()
    except Exception as e:
        print(e)
        return False,img
    return True
try:
    #reading in saved cvit file
    save_dir = '/projectnb/sparkgrp/colejh/saved_results/'
    df = pd.read_pickle(save_dir+'cvit.pkl')
except:
    directory = '/projectnb/sparkgrp/ml-herbarium-grp/ml-herbarium-data/tesseract-training/training/CVIT/Images_90K_Normalized/'
    count = 0
    all_img = []
    all_labels = []
    no_tag = []
    # list_dir = os.listdir(directory)

    for (dirpath, dirnames, filenames) in os.walk(directory):
        for file in filenames:
            if file.endswith('.png'):
                image = os.path.join(dirpath, file)
    #             display(Image.open(image))
    #             if catch_invalid(image):
                try:
                    with open(image.split('.')[0]+'.gt.txt') as f:
                        contents = f.readlines()
                    all_img.append(image)
                    all_labels.append(contents[0])
                except FileNotFoundError as f:
    #                             print(f)
                        no_tag.append(image)
                        break # some folders have no ground truth

            count+=1
            if count%10000 == 0:
                print(count)


    # create dataframe from all_img and all_labels
    df = pd.DataFrame({'image':all_img, 'label':all_labels})

    save_dir = '/projectnb/sparkgrp/colejh/saved_results/'
    # save dataframe to pickle at specified directory
    df.to_pickle(save_dir+'cvit.pkl')


    pool = mp.Pool(mp.cpu_count())

    bad = []
    for output in tqdm(pool.imap(catch_invalid2, df['image']), total=len(df['image']),desc = 'Checking for Valid Images'):
        if output == False:
    #         bad.append(output[1])
            print(output)

In [3]:
# Create a CVIT dataset class  
class CVITDataset(Dataset):
    def __init__(self, df, processor, max_target_length=128):
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # get file name + text 
        image = self.df['image'][idx]
        text = self.df['label'][idx]
        # prepare image (i.e. resize + normalize)
        image = Image.open(image).convert("RGB")
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        # add labels (input_ids) by encoding the text
        labels = self.processor.tokenizer(text, 
                                          padding="max_length", 
                                          max_length=self.max_target_length).input_ids
        # important: make sure that PAD tokens are ignored by the loss function
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]

        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding
        

In [4]:
from accelerate import Accelerator
model_directory = '/projectnb/sparkgrp/ml-herbarium-grp/ml-herbarium-data/TROCR_Training/'
def create_CVIT(df,processor):
    # Training and validaiton splits
    train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)
    train_df.reset_index(drop=True, inplace=True)
    val_df.reset_index(drop=True, inplace=True)
    
    train_dataset = CVITDataset(df=train_df,processor=processor)
    val_dataset = CVITDataset(df=val_df,processor=processor)

    return train_dataset,val_dataset



In [18]:
from datasets import load_metric

cer_metric = load_metric("cer")
def compute_cer(pred_ids, label_ids,processor):
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)
#     print(pred_str,label_str)
    return cer


from accelerate import Accelerator
import wandb
from accelerate import notebook_launcher
from accelerate.utils import set_seed
from torch.optim.lr_scheduler import OneCycleLR
from accelerate import DistributedDataParallelKwargs
def training_loop(df,mixed_precision="fp16", seed: int = 42, batch_size: int = 32,num_processes=4):
    set_seed(seed)

    num_epochs = 3
    # Initialize accelerator
    ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
    accelerator = Accelerator(mixed_precision=mixed_precision,kwargs_handlers=[ddp_kwargs],log_with="wandb")
    accelerator.init_trackers("CVIT_TRAIN")
    # Build dataloaders
    # Instantiate the model (you build the model here so that the seed also controls new weight initaliziations)
    device = accelerator.device
    model_name = 'microsoft/trocr-base-printed'
    processor = TrOCRProcessor.from_pretrained(model_name) 
    train_dataset,val_dataset = create_CVIT(df,processor)
    model = VisionEncoderDecoderModel.from_pretrained(model_name)
    model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
    model.config.pad_token_id = processor.tokenizer.pad_token_id
    model.config.vocab_size = model.config.decoder.vocab_size
    model.config.eos_token_id = processor.tokenizer.sep_token_id
    model.config.max_length = 64
    model.config.early_stopping = True
    model.config.no_repeat_ngram_size = 3
    model.config.length_penalty = 2.0
    model.config.num_beams = 4

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    eval_dataloader = DataLoader(val_dataset, batch_size=batch_size)

    # Intantiate the optimizer
    optimizer = optim.AdamW(model.parameters(), lr=1e-5)
    # Instantiate the learning rate scheduler
    lr_scheduler = OneCycleLR(optimizer=optimizer, max_lr=1e-5, epochs=num_epochs, steps_per_epoch=len(train_dataloader))

    # Preparing everything for multi-gpu support
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
    )
    
    # Getting the progress bar set up 
    num_training_steps = num_epochs * len(train_dataloader)
    progress_bar = tqdm(range(num_training_steps))
    patience = 0
    best_val_loss = float('inf')
    last_val_loss = float('inf')
    running_tloss = []
    running_vloss = []
    running_cer = []
    # Now you train the model
    for epoch in range(num_epochs):
        # Track metrics
        train_loss = 0.0
        running_loss = 0
        model.train()
        for i,batch in enumerate(train_dataloader):
            outputs = model(**batch)
            loss = outputs.loss
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            progress_bar.update(1)
            train_loss += loss.item()
            
            # Print out the loss 100 times per epoch just to confirm training is going well 
            if accelerator.is_main_process:
                try:
                    running_loss += loss.sum().item()
                    if i % round(len(train_dataloader)/100) == round(len(train_dataloader)/100)-1:   
                        accelerator.print('[Epoch: %d, Batches Processed: %d] loss: %.3f' %
                              (epoch + 1, i + 1, running_loss /round(len(train_dataloader)/100)))
                        running_loss = 0.0
                except ZeroDivisionError: # if the trainloader  is small enough this isnt neccesary 
                    pass
                    
        train_loss = train_loss/len(train_dataloader)      
        running_tloss.append(train_loss)

        # Evaluation 
        model.eval()
        val_loss = 0
        valid_cer = 0
        
        for batch in eval_dataloader:
            with torch.no_grad():

                outputs = model(**batch)
                loss = outputs.loss
                val_loss += loss.item()
                a =accelerator.unwrap_model(model).generate(batch["pixel_values"])
                cer = compute_cer(pred_ids=a, label_ids=batch["labels"],processor = processor)
                valid_cer += cer 
        val_loss = val_loss/len(eval_dataloader)
        valid_cer = valid_cer/len(eval_dataloader)
        running_vloss.append(val_loss)
        running_cer.append(valid_cer/len(eval_dataloader))
        accelerator.log({"train_loss": train_loss, "valid_loss": val_loss,"CER": valid_cer}, step=epoch)
        # Print metrics
#         accelerator.print(f"Epoch {epoch} validation loss: {val_loss}")
        # Save model if the validation loss is the best we've seen so far.
        accelerator.print(f"Epoch {epoch} train loss: {train_loss} CER: {valid_cer:.2f} Val Loss: {val_loss}")
        
        if val_loss < best_val_loss:
            patience = 0
            best_val_loss = val_loss
#             accelerator.wait_for_everyone()
            unwrapped_model = accelerator.unwrap_model(model)
#             if accelerator.is_main_process:
#                 unwrapped_model.save_pretrained(os.path.join(model_directory,"best_model"))
            accelerator.print("Saving model checkpoint to ",os.path.join(model_directory,model_name.split('/')[1]+'_best_model_CVITpoint1.pt'))
            accelerator.wait_for_everyone()
            torch.save(unwrapped_model.state_dict(), model_directory + model_name.split('/')[1]+'_best_model_CVITpoint1.pt')
            
#         elif val_loss >= last_val_loss:
#             patience += 1
#             if patience == 3:
                
#                 accelerator.print("Stopping training")
# #                 
#                 return 2
        last_val_loss = val_loss
    accelerator.end_training()
        # Use accelerator.print to print only on the main process.
#     global_metrics['train_loss'] = running_tloss
#     global_metrics['val_loss'] = running_vloss
#     global_metrics['cer'] = running_cer
# #     accelerator.wait_for_everyone()
#     global_metrics['model'] = accelerator.unwrap_model(model)
#     global_metrics['processor'] = processor

In [None]:
# global_metrics = {}
df2 = df.sample(frac=.1)
args = (df2,"fp16", 42, 8)
notebook_launcher(training_loop, args, num_processes=4)

Launching training on 4 GPUs.


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mlilloukas[0m. Use [1m`wandb login --relogin`[0m to force relogin


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  progress_bar = tqdm(range(num_training_steps))
Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  progress_bar = tqdm(range(num_training_steps))
Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  progress_bar = tqdm(range(num_training_steps))
Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  progress_bar = tqdm(range(num_training_steps))


  0%|          | 0/38430 [00:00<?, ?it/s]

  0%|          | 0/38430 [00:00<?, ?it/s]

  0%|          | 0/38430 [00:00<?, ?it/s]

  0%|          | 0/38430 [00:00<?, ?it/s]

[Epoch: 1, Batches Processed: 128] loss: 8.818
[Epoch: 1, Batches Processed: 256] loss: 2.823
[Epoch: 1, Batches Processed: 384] loss: 1.648
[Epoch: 1, Batches Processed: 512] loss: 1.255
[Epoch: 1, Batches Processed: 640] loss: 1.012
[Epoch: 1, Batches Processed: 768] loss: 0.928
[Epoch: 1, Batches Processed: 896] loss: 0.910
[Epoch: 1, Batches Processed: 1024] loss: 0.804
[Epoch: 1, Batches Processed: 1152] loss: 0.740
[Epoch: 1, Batches Processed: 1280] loss: 0.690
[Epoch: 1, Batches Processed: 1408] loss: 0.612
[Epoch: 1, Batches Processed: 1536] loss: 0.613
[Epoch: 1, Batches Processed: 1664] loss: 0.591
[Epoch: 1, Batches Processed: 1792] loss: 0.547
[Epoch: 1, Batches Processed: 1920] loss: 0.525
[Epoch: 1, Batches Processed: 2048] loss: 0.511
[Epoch: 1, Batches Processed: 2176] loss: 0.497
[Epoch: 1, Batches Processed: 2304] loss: 0.469
[Epoch: 1, Batches Processed: 2432] loss: 0.436
[Epoch: 1, Batches Processed: 2560] loss: 0.466
[Epoch: 1, Batches Processed: 2688] loss: 0.463