# Imports

In [1]:
%env CUDA_VISIBLE_DEVICES=0

In [2]:
import os
import numpy as np
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
import cv2
from random import shuffle
from PIL import Image

%matplotlib inline

import glob
import gc
from itertools import chain
import os
import random
import zipfile

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm

print(f"Torch: {torch.__version__}")

Torch: 2.0.1+cu117


In [3]:
from comet_ml import Experiment, ExistingExperiment
from comet_ml.integration.pytorch import log_model

# Get model

In [4]:
train_paths = None
with open('images/MAE_train.txt', 'r') as file:
    train_paths = file.readlines()
    
len(train_paths)

43406717

In [5]:
val_paths = None
with open('images/MAE_val.txt', 'r') as file:
    val_paths = file.readlines()
    
len(val_paths)

80117

In [6]:
IMG_SIZE=224
transform = transforms.Compose(
    [
        transforms.Resize((IMG_SIZE,IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.533, 0.425, 0.374],
                             std=[0.244, 0.214, 0.202])
    ]
)

In [7]:
def get_images(transform, imgs):
    img_tensors = []
    for img_name in imgs:
        try:
            img_name = img_name.replace('\n', '')
            img = Image.open(img_name)
            img_tensor = transform(img)
            if img.size:
                img_tensors.append(img_tensor)

        except:
            print(f'{img_name} does not open')
            
    try:
        img_tensors = torch.stack(img_tensors)
            
    except: 
        img_tensors = []
        
    return img_tensors

In [8]:
from mae.encoder import ViTBaseEncoder
from mae.mae import MAE

In [9]:
def build_model(args):
    '''
    build MAE model.
    :param args: model args
    :return: model
    '''
    # build model
    v = ViTBaseEncoder(image_size=args['image_size'],
                       patch_size=args['patch_size'],
                       dim=args['vit_dim'],
                       depth=args['vit_depth'],
                       heads=args['vit_heads'],
                       mlp_dim=args['vit_mlp_dim'],
                       masking_ratio=args['masking_ratio'],
                       device=args['device']).to(args['device'])

    mae = MAE(encoder=v,
              decoder_dim=args['decoder_dim'],
              decoder_depth=args['decoder_depth'],
              device=args['device']).to(args['device'])

    return mae

In [10]:
args = {
    'image_size': IMG_SIZE,
    'patch_size': 16,
    'vit_dim': 768,
    'vit_depth': 5,
    'vit_heads': 6,
    'vit_mlp_dim': 2048,
    'masking_ratio': 0.75,
    'decoder_dim': 256,
    'decoder_depth': 5,
    'device': 'cuda'
}

In [11]:
cpath = '/home/hse_student/apsidorova/embedding_models/mae/mae/ckpt/EMERGY_Vit_Base_ep1_step53820.pt'
model = build_model(args)
model.train();
model.load_state_dict(torch.load(cpath))

<All keys matched successfully>

# Train

In [12]:
# experiment = Experiment(
#   api_key="XhQqrLR91F7zW3AZ7LgVT3zp2",
#   project_name="abaw6",
#   workspace="annanet"
# )

# experiment.set_name('MAE train')
# experiment.add_tags(['AffectNet', 'CASIA-WebFace', 'CelebA', 'IMDB-WIKI', 'WebFace260M'])

In [13]:
experiment = ExistingExperiment(
        api_key="XhQqrLR91F7zW3AZ7LgVT3zp2",
        experiment_key="b9a8283da29c4831a4def0cb5fea0a8a",
    )

[1;38;5;39mCOMET INFO:[0m Couldn't find a Git repository in '/home/hse_student/apsidorova/embedding_models/mae' nor in any parent directory. Set `COMET_GIT_DIRECTORY` if your Git Repository is elsewhere.
[1;38;5;39mCOMET INFO:[0m Experiment is live on comet.com https://www.comet.com/annanet/abaw6/b9a8283da29c4831a4def0cb5fea0a8a



In [18]:
hyperparams = {
    'epochs': 10,
    'optimizer': 'AdamW',
    'loss': 'pixel-wise L2 loss',
    'lr': 1.5e-5,
    'steplr': 1,
    'batch': 256,
    'weight_decay': 5e-2,
    'momentum': (0.9, 0.95),
    'epochs_warmup': 40,
    'warmup_from': 1e-3, 
    'lr_decay_rate': 1e-2,
    'warmup_to': 0.0002981531029360196,
    'ckpt_folder_best': '/home/hse_student/apsidorova/embedding_models/mae/mae/ckpt_best',
    'ckpt_folder': '/home/hse_student/apsidorova/embedding_models/mae/mae/ckpt',
    'freq_val': 1_000,
}

In [19]:
experiment.log_parameters(hyperparams)
experiment.log_parameters(args)

In [16]:
trainp_loader = DataLoader(train_paths,
                           batch_size=hyperparams['batch'], 
                           shuffle=True)
valp_loader = DataLoader(val_paths,
                         batch_size=hyperparams['batch'], 
                         shuffle=False)

In [20]:
optimizer = torch.optim.AdamW(model.parameters(),
                              lr=hyperparams['lr'],
                              weight_decay=hyperparams['weight_decay'],
                              betas=hyperparams['momentum'])

In [21]:
import math

def lr_lambda(epoch):
    if epoch < hyperparams['steplr']:
        lr =  1.0
    elif epoch < hyperparams['epochs_warmup']:
        p = epoch / hyperparams['epochs_warmup']
        lr = hyperparams['warmup_from'] + p * (hyperparams['warmup_to'] - hyperparams['warmup_from'])
    else:
        eta_min = hyperparams['lr'] * (hyperparams['lr_decay_rate'] ** 3)
        lr = eta_min + (hyperparams['lr'] - eta_min) * (1 + math.cos(math.pi * epoch / hyperparams['epochs'])) / 2
    return lr

In [22]:
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

In [23]:
class AverageMeter(object):
    '''
    compute and store the average and current value
    '''
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
global_best_loss = 1e+10
checkpoint_freq = 100
PATH = ''

for epoch in range(1, hyperparams['epochs'] + 1):
    # records
    losses = AverageMeter()
    print('\nEPOCH {}:'.format(epoch))

    print('Start train')
    # train by epoch
    try:
        for idx, path in tqdm(enumerate(trainp_loader), total=len(trainp_loader)):
            # put images into device
            tensor = get_images(transform, path)

            if tensor==[]:
                print('Do not find images')
                continue

            tensor = tensor.to(args['device'])
            # forward
            loss = model(tensor)
            experiment.log_metric('current loss train', loss.to('cpu').item(), 
                                  step=(epoch-1)*len(trainp_loader) + idx)
            # back propagation
            optimizer.zero_grad()
            loss.backward()
            # record
            losses.update(loss.to('cpu').item(), hyperparams['batch'])
            del tensor
            del loss
            del path
            gc.collect()
            torch.cuda.empty_cache()

            optimizer.step()
            torch.cuda.empty_cache()

            if (idx%checkpoint_freq)==0:
                print('Saving checkpoint')
                cpath = f'{hyperparams["ckpt_folder"]}/Vit_Base_ep{epoch}_step{idx}.pt'
                torch.save(model.state_dict(), cpath)

            if (idx%hyperparams['freq_val'])==0:
                model.eval()

                losses_val = AverageMeter()
                print('Val part')

                for path in tqdm(valp_loader, total=len(valp_loader)):
                    tensor = get_images(transform, path).to(args['device'])
                    if tensor==[]:
                        print('Do not find images')
                        continue
                    loss = model(tensor)
                    # record
                    losses_val.update(loss.to('cpu').item(), hyperparams['batch'])
                    del tensor
                    del loss
                    del path
                    gc.collect()
                    torch.cuda.empty_cache()

                experiment.log_metric('avg loss val', losses_val.avg, 
                                      step=(epoch-1)*len(trainp_loader) + idx)
                print(f'Current Validation loss is {losses_val.avg}')
                if global_best_loss > losses_val.avg:
                    global_best_loss = losses_val.avg
                    print('New Best Validation loss')

                    # save model
                    PATH = f'{hyperparams["ckpt_folder_best"]}/Vit_Base_ep{epoch}_step{idx}.pt'
                    torch.save(model, PATH)
                model.train()

            experiment.log_metric('avg loss train', losses.avg, 
                                  epoch=epoch)
            print(f'Current Train average loss is {losses.avg}')

        scheduler.step()
        
    except Exception as e:
        print(type(e).__name__)  
        print('EMERGY saving checkpoint')
        cpath = f'{hyperparams["ckpt_folder"]}/EMERGY_Vit_Base_ep{epoch}_step{idx}.pt'
        torch.save(model.state_dict(), cpath)
        
        torch.save(trainp_loader, 'EMERGY_train_dataloader.pth')
        torch.save(valp_loader, 'EMERGY_val_dataloader.pth')
        break

In [25]:
experiment.log_model("mae_model.pt", file_or_folder=PATH)
experiment.end()

[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m Comet.ml ExistingExperiment Summary
[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m   Data:
[1;38;5;39mCOMET INFO:[0m     display_summary_level : 1
[1;38;5;39mCOMET INFO:[0m     url                   : https://www.comet.com/annanet/abaw6/b9a8283da29c4831a4def0cb5fea0a8a
[1;38;5;39mCOMET INFO:[0m   Metrics [count] (min, max):
[1;38;5;39mCOMET INFO:[0m     avg loss train [763]     : (0.12267985567450523, 0.12730881253655038)
[1;38;5;39mCOMET INFO:[0m     avg loss val             : 0.14919541102533523
[1;38;5;39mCOMET INFO:[0m     current loss train [764] : (0.1112850233912468, 0.14719419181346893)
[1;38;5;39mCOMET INFO:[0m   Parameters:
[1;38;5;39mCOMET INFO:[0m     batch            : 256
[1;38;5;39mCOMET INFO:[0m     ckpt_folder     