## Обучение MetaCLIP
В данной тетрадке расположен код обучения и инференса MetaCLIP_vit_b_16_qg модели. Перед использованием кода нужно удостовериться, что все нужные библиотеки установлены и указан ваш wandb_api_key в СFG.wandb_key. Для инференса рекомендуется использовать модель с наименьшим RMSE.

Вывод: Данный эксперимент показывает нам, что FastVit модели, не смотря на очень быстрый инференс показывают SOTA результаты среди трансформеров.

In [1]:
from IPython.display import clear_output

#!pip install lightning timm opendatasets albumentations catboost gdown wandb ftfy
#clear_output()

In [2]:
!git clone https://github.com/facebookresearch/MetaCLIP.git
%cd ./MetaCLIP/src

In [5]:
import os
import pandas as pd
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import pytorch_lightning as pl
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
from math import sin,cos,pi,floor
from sklearn.metrics import accuracy_score,f1_score,balanced_accuracy_score
from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error, r2_score
import albumentations as A
import open_clip
from albumentations.pytorch.transforms import ToTensorV2
from catboost import CatBoostClassifier,Pool,cv
from transformers import get_cosine_schedule_with_warmup
from copy import deepcopy
import wandb
import timm
pl.seed_everything(56)

Seed set to 56


56

# Конфиги:
- использована модель `ViT-B-16-quickgelu` с весами `metaclip/b32_400m.pt`
- разный learning rate для MLP и embedding extractor'а
- функция активации [PReLU в MLP](../experiments/CLIP_train/experiments.ipynb)
- [Cosine Scheduler with warmup](../experiments/CLIP_train/experiments.ipynb)
- было выбрано оптимальное разделение на обучающую и валидационную выборку: 20%

In [None]:
class CFG:
    class data:
        train_data= 'aiijc23-4/train_scores.csv'
        test_data = './simple_sub.csv'
        train_path='aiijc23-4/train/train/'
        test_path = 'aiijc23-4/test/test/'
        num_workers = 4
        val_split_size = 0.2
        batch_size = 32
        seed = 56
    class model:
        model ='ViT-B-16-quickgelu'
        pretrained = 'metaclip/b32_400m.pt'
        scheduler= True
        max_epoches= 4
        mlp_lr = 1.6e-4
        encoder_lr = 4e-6
        mlp_weights_decay = 0.02
        encoder_weights_decay = 0.02
        warmup_step = 0.1
        warmup_epoch = 4
        eps=1e-6
        betas=(0.9, 0.999)
        num_cycles=0.55
    wandb_key="your_key"
    seed=56

# Предобработка датасета

In [20]:
def make_df(path,root_path=CFG.data.train_path):
    data = pd.read_csv(path)
    df = pd.DataFrame()
    df['image'] = data['IMAGE'].apply(lambda x:root_path + x)
    df['label'] = data['SCORE']
    return df

In [21]:
class PLDataset(Dataset):
    def __init__(self, df,preprocess):
        super().__init__()
        self.cfg = CFG.data
        self.data = df[['image','label']].values
        self.preprocess = preprocess
    def __getitem__(self, index):
        image = Image.open(self.data[index][0])
        image = self.preprocess(image)
        label = self.data[index][1]
        return image,label
    def __len__(self):
        return len(self.data)

In [22]:
class PLDataModule(pl.LightningDataModule):
    def __init__(self,preprocess):
        super().__init__()
        self.cfg = CFG.data
        self.train_dataset_path = self.cfg.train_data
        self.test_dataset_path = self.cfg.test_path
        self.val_split_size = self.cfg.val_split_size
        self.batch_size = self.cfg.batch_size
        self.num_workers = self.cfg.num_workers
        self.is_setup = False
        self.preprocess = preprocess
        
    def prepare_data(self):
        self.train_df = make_df(self.train_dataset_path)
        self.test_df = make_df(CFG.data.test_data,
                               root_path=CFG.data.test_path)
        
    def setup(self, stage: str):
        if self.is_setup:
            return None
        self.train_df, self.val_df = train_test_split(self.train_df, test_size=self.val_split_size,random_state=self.cfg.seed)
        self.train_dataset = PLDataset(self.train_df,self.preprocess)
        self.val_dataset = PLDataset(self.val_df,self.preprocess)
        self.test_dataset = PLDataset(self.test_df,self.preprocess)
        self.is_setup = True
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset,
                         batch_size=self.batch_size,
                         num_workers=self.num_workers,
                         shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset,
                          batch_size=self.batch_size,
                          num_workers=self.num_workers)
    
    def predict_dataloader(self):
        return DataLoader(self.test_dataset,
                          batch_size=self.batch_size,
                          num_workers=self.num_workers)

# Метрики и лосс
- были использованы различные метрики для создания новых гипотез
- была выбрана метрика RMSE, потому что она показывала себя лучше всего на регрессионных метриках
$$ RMSE = \sqrt{\frac{\sum_{i=1}^n(pred_i-act_i)^2}{n}}$$

In [24]:
class AverageMeter():
    
    def __init__(self):
        self.labels = []
        self.preds = []
        
    def reset(self):
        self.labels = []
        self.preds = []
        
    def update(self,labels,preds):
        self.labels += labels
        self.preds += preds
    
    def calc_metrics(self):
        labels = pd.Series(self.labels)
        preds = pd.Series(self.preds)
        preds_bin = preds.map(round)
        metrics = dict()
        
        metrics['val_rmse'] = (mean_squared_error(labels,preds)) ** 0.5
        metrics['val_mae'] = mean_absolute_error(labels,preds)
        metrics['mape']= mean_absolute_percentage_error(labels,preds)
        metrics['val_r2'] = r2_score(labels,preds)
        
        metrics['val_f1'] = f1_score(labels,preds_bin,average='macro')
        metrics['val_acc'] = accuracy_score(labels,preds_bin)
        metrics['val_w_acc'] = balanced_accuracy_score(labels,preds_bin)
        return metrics

In [None]:
class RMSELoss(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.mse = nn.MSELoss()
        self.eps = eps
        
    def forward(self,yhat,y):
        loss = torch.sqrt(self.mse(yhat,y) + self.eps)
        return loss

# Инициализация модели и алгоритм обучения

In [32]:
class PLModule(pl.LightningModule):
    def __init__(self,clip):
        super().__init__()
        self.cfg = CFG.model
        self.clip = clip
        self.mlp = nn.Sequential(nn.Linear(512,512 * 2),
                                 nn.PReLU(),
                                 nn.LayerNorm(512 * 2),
                                 nn.Linear(512 * 2 ,1))
        self.criterion = RMSELoss()
        self.avg_meter = AverageMeter()
        self.last_loss = 0
        self.losses = []
        
    def forward(self,x):
        features = self.clip.encode_image(x)
        features = self.mlp(features)
        return torch.squeeze(features)

    def training_step(self, batch, i):
        x,targets = batch
        x,targets = x.float(),targets.float()
        logits = self(x)
        loss = self.criterion(targets, logits)
        self.log_dict({'train_loss':loss.item()})
        self.last_loss = loss.item()
        return loss
    
    def predict_step(self, batch, i):
        x,targets = batch
        x,targets = x.float(),targets.float()
        logits = self(x)
        return logits.tolist()
        
    def validation_step(self, batch, _):
        x,targets = batch
        x,targets = x.float(),targets.float()
        logits = self(x)
        loss = self.criterion(targets,logits)
        self.log_dict({'val_loss':loss.item()})
        self.avg_meter.update(targets.cpu().detach().tolist(),
                              logits.cpu().detach().tolist())
    
                
    def on_validation_epoch_end(self):
        self.log_dict(self.avg_meter.calc_metrics())
        self.avg_meter.reset()
            
    def configure_optimizers(self):
        grouped_parameters = [
                            {'params':self.clip.parameters(),
                               'lr':self.cfg.encoder_lr,
                               'weight_decay':self.cfg.encoder_weights_decay},
                             {'params':self.mlp.parameters(),
                               'lr':self.cfg.mlp_lr,
                               'weight_decay':self.cfg.mlp_weights_decay},
                             ]
        optim = torch.optim.AdamW(grouped_parameters,
                                  betas=self.cfg.betas,
                                  eps=self.cfg.eps)
        
        scheduler = get_cosine_schedule_with_warmup(optim,
                                                    num_warmup_steps = TRAIN_STEPS * self.cfg.warmup_epoch * self.cfg.warmup_step, 
                                                    num_training_steps = TRAIN_STEPS * self.cfg.warmup_epoch,
                                                    num_cycles = self.cfg.num_cycles)
        scheduler = {'scheduler':scheduler,
                     'interval':'step',
                     'frequency':1}
        
        return [optim],[scheduler]

In [11]:
!wget https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_400m.pt

--2023-10-13 14:52:10--  https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_400m.pt
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 18.238.4.71, 18.238.4.66, 18.238.4.36, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|18.238.4.71|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1795791003 (1.7G) [binary/octet-stream]
Saving to: ‘b16_400m.pt.1’


2023-10-13 14:53:33 (20.8 MB/s) - ‘b16_400m.pt.1’ saved [1795791003/1795791003]



In [14]:
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16-quickgelu', pretrained='./b16_400m.pt')

In [17]:
%cd /notebooks

/notebooks


In [26]:
dm = PLDataModule(preprocess)
dm.prepare_data()
dm.setup(0)
TRAIN_STEPS = len(dm.train_dataloader())

In [33]:
pl_model = PLModule(model)

In [29]:
wandb.login(key=CFG.wandb_key)
os.environ['WANDB_API_KEY'] = CFG.wandb_key
wandb.init(project='AIIJC',name='meta_clip_vit_b_32_400m')

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


[34m[1mwandb[0m: Currently logged in as: [33mandrey20007[0m ([33mandrey2007[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [34]:
lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch')
checkpoint_cb = pl.callbacks.ModelCheckpoint(
    dirpath='./outputs_meta_clip_vit_b_400m/',
    filename='model_{epoch:02d}-{val_loss:.4f}',
    monitor='val_loss',
    mode='min',
    save_last=True
)

trainer = pl.Trainer(
    accelerator="gpu",
    precision=32,
    callbacks = [lr_monitor,checkpoint_cb],
    logger = pl.loggers.WandbLogger(),
    min_epochs=1,
    devices=[0],
    check_val_every_n_epoch=1,
    max_epochs=CFG.model.max_epoches
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [35]:
trainer.fit(pl_model,datamodule=dm)

/usr/local/lib/python3.9/dist-packages/pytorch_lightning/loggers/wandb.py:389: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/usr/local/lib/python3.9/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:630: Checkpoint directory /notebooks/outputs_vit_l_400m exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type       | Params
-----------------------------------------
0 | clip      | CLIP       | 149 M 
1 | mlp       | Sequential | 528 K 
2 | criterion | RMSELoss   | 0     
-----------------------------------------
150 M     Trainable params
0         Non-trainable params
150 M     Total params
600.596   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]



Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]



Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [36]:
preds = trainer.predict(pl_model,datamodule=dm)

Restoring states from the checkpoint path at /notebooks/outputs_vit_l_400m/model_epoch=03-val_loss=0.5053.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /notebooks/outputs_vit_l_400m/model_epoch=03-val_loss=0.5053.ckpt


Predicting: |          | 0/? [00:00<?, ?it/s]

In [38]:
test_df = pd.read_csv('simple_sub.csv')

In [40]:
test_df['SCORE'] = np.concatenate(preds)

In [41]:
test_df

Unnamed: 0,IMAGE,SCORE
0,86cc6e863c9b6bb2a0e0db114c9775aa.jpg,8.093831
1,da71671681d9cef5b60727801bf95ef8.jpg,5.993963
2,821a9ff5df6e581c68c0371dc6b1eb90.jpg,7.852841
3,ed842bb42c39fe257ac459b544bb7ba8.jpg,8.108117
4,6eed69b8f6d62f28b809d9cbafcaab0b.jpg,6.248960
...,...,...
9983,46aa642bbcfdbed5eed380ba10fdbbbc.jpg,4.812338
9984,dc737dc9a19844943540816e4488b425.jpg,4.600523
9985,8038a60dac57ba5d445b99978a36a1d5.jpg,5.473619
9986,5381dd99c37acb63db903eac293fbe4c.jpg,4.280603


In [42]:
test_df.to_csv('meta_clip_vit_b_32_400m.csv',index=False)

In [32]:
test_df['SCORE'] = test_df['SCORE'].map(round)