In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.image as img

import json

from torchvision import transforms as T
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms.functional as F

import torch
import os

import numpy as np


from torch.utils.data import DataLoader
from tqdm import tqdm

import cv2

import clip
from PIL import Image

import torch.nn as nn

from sklearn.metrics import f1_score, accuracy_score, roc_auc_score, roc_curve, confusion_matrix
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier, GradientBoostingClassifier
from sklearn.tree import DecisionTreeClassifier

from prettytable import PrettyTable
from CLIP_utils import get_features, get_lr

import optuna

In [2]:
prettyTable = PrettyTable(['Model name', 'Accuracy', 'ROC AUC'])

In [3]:
from classes.transforms.CropCenter import CropCenter
from classes.transforms.ScaleMaxSideToSize import ScaleMaxSideToSize
from classes.dataset.HatefulMemesDataset import HatefulMemesDataset
from classes.dataset.FeaturesDataset import FeaturesDataset

In [4]:
data_dir = r'E:\datasets\MADE\3_graduation\parthplc\archive\data\\'

In [5]:
train_path = data_dir + 'train.jsonl'
dev_path = data_dir + 'dev.jsonl'

train_data = pd.read_json(train_path, lines=True)
test_data = pd.read_json(dev_path, lines=True)

In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [7]:
clip.available_models()

['RN50',
 'RN101',
 'RN50x4',
 'RN50x16',
 'RN50x64',
 'ViT-B/32',
 'ViT-B/16',
 'ViT-L/14',
 'ViT-L/14@336px']

#### Предобработка изображений

In [8]:
MEAN = torch.tensor([0.485, 0.456, 0.406])
STD = torch.tensor([0.229, 0.224, 0.225])

In [9]:
def train_model(model, train_loader, val_loader, loss, optimizer, num_epochs, scheduler, model_name):    
    best_model_name = None
    loss_history = []
    train_history = []
    val_history = []
    top_val_accuracy = 0.3
    for epoch in range(num_epochs):
        model.train()
        loss_accum = 0
        correct_samples = 0
        total_samples = 0
        for i_step, (x, y) in enumerate(train_loader):
            x = x.to(device)#x.type(torch.float).cpu()
            y = y.to(device)#y.type(torch.float).cpu()
            #model = model.cpu()
            prediction = model(x)    
            loss_value = loss(prediction, y.type(torch.long))
            optimizer.zero_grad()
            loss_value.backward()
            optimizer.step()
            
            _, indices = torch.max(prediction, 1)
            correct_samples += torch.sum(indices == y)
            total_samples += y.shape[0]
            
            loss_accum += loss_value

        ave_loss = loss_accum / (i_step + 1)
        train_accuracy = float(correct_samples) / total_samples
        val_accuracy = compute_accuracy(model, val_loader)
        
        loss_history.append(float(ave_loss))
        train_history.append(train_accuracy)
        val_history.append(val_accuracy)
        if scheduler != None:
            scheduler.step()

        #print("Epoch: %i lr: %f; Average loss: %f, Train accuracy: %f, Val accuracy: %f" % (epoch, get_lr(optimizer), ave_loss, train_accuracy, val_accuracy))

  
        if val_accuracy > top_val_accuracy:
            #
            top_val_accuracy = val_accuracy
            m_name = f'{model_name}_classifier_{epoch}_{round(val_accuracy, 3)}.ckpt'
            best_model_name = m_name
            torch.save(model, open(m_name, 'wb'))
            #print("saved", m_name)

    return loss_history, train_history, val_history, best_model_name
        
    
def compute_accuracy(model, loader):
    """
    Computes accuracy on the dataset wrapped in a loader    
    Returns: accuracy as a float value between 0 and 1
    """
    model.eval()
    correct_samples = 0
    total_samples = 0 
    for i_step, (x, y) in enumerate(loader):
        x = x.to(device)
        y = y.to(device)
        prediction = model(x)
        _, indices = torch.max(prediction, 1)
        correct_samples += torch.sum(indices == y)
        total_samples += y.shape[0]            

    val_accuracy = float(correct_samples) / total_samples
         
    return val_accuracy


In [10]:
results = {}

In [11]:
model_names = clip.available_models()

In [12]:
def work(model_name):
    CROP_SIZE=224
    
    if '336' in model_name:
        CROP_SIZE=336
        
    transforms = T.Compose([
        ScaleMaxSideToSize(CROP_SIZE),
        CropCenter(CROP_SIZE),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD),
    ])    
    
    train_dataset = HatefulMemesDataset(train_path, transforms)
    val_dataset = HatefulMemesDataset(dev_path, transforms)

    model, preprocess = clip.load(model_name, device=device)
    
    features_train, labels_train = get_features(model, train_dataset)
    features_val, labels_val = get_features(model, val_dataset)

    features_train_dataset = FeaturesDataset(features_train, labels_train)
    features_val_dataset = FeaturesDataset(features_val, labels_val)

    input_shape = features_train[0].shape[0]
    num_classes = 2


    torch.manual_seed(1024)

    shape = 256
    nn_model = nn.Sequential(
                nn.Linear(input_shape, shape),
                nn.Dropout(0.66),
                nn.BatchNorm1d(shape),
                nn.ReLU(inplace=True),    

                nn.Linear(shape, shape),
                nn.Dropout(0.66),
                nn.BatchNorm1d(shape),
                nn.ReLU(inplace=True),    

                nn.Linear(shape, num_classes),
                )

    nn_model = nn_model.to(device)
    #print(nn_model)
    loss = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(nn_model.parameters(), lr=1e-2)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8)

    loss_history, train_history, val_history, best_model_name = train_model(
        nn_model, 
        DataLoader(features_train_dataset, batch_size=500),
        DataLoader(features_val_dataset, batch_size=500),
        loss, optimizer, 100, scheduler, 'clip')

    print("best model:", best_model_name)
    best_model = torch.load(open(best_model_name, 'rb'))
    #print(best_model)

    best_model = best_model.to(device)
    
    best_model.eval()
    for i_step, (x, y) in enumerate(DataLoader(features_val_dataset, batch_size=5000)):
        x = x.to(device)
        prediction = best_model(x)

    acc_score = accuracy_score(np.array([x.item() for x in labels_val]), torch.max(prediction.cpu(), 1)[1])
    auc_score = roc_auc_score(np.array([x.item() for x in labels_val]), prediction.cpu()[:,1].detach().numpy())
    prettyTable.add_row([model_name, acc_score, auc_score])
    print(model_name, CROP_SIZE, input_shape, acc_score, auc_score)

    return (acc_score, auc_score)

In [13]:
for x in model_names:
    try:
        results[x] = work(x)
    except Exception as ex:
        print(ex)

100%|██████████████████████████████████████████████████████████████████████████████████| 85/85 [01:17<00:00,  1.09it/s]
100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:04<00:00,  1.07it/s]


best model: clip_classifier_87_0.636.ckpt
RN50 224 2048 0.636 0.676096


100%|██████████████████████████████████████████████████████████████████████████████████| 85/85 [01:21<00:00,  1.04it/s]
100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:04<00:00,  1.00it/s]


best model: clip_classifier_74_0.68.ckpt
RN101 224 1024 0.68 0.7092319999999999


  0%|                                                                                           | 0/85 [00:01<?, ?it/s]


The size of tensor a (50) must match the size of tensor b (82) at non-singleton dimension 0


  0%|                                                                                           | 0/85 [00:00<?, ?it/s]


The size of tensor a (50) must match the size of tensor b (145) at non-singleton dimension 0


  0%|                                                                                           | 0/85 [00:00<?, ?it/s]


The size of tensor a (50) must match the size of tensor b (197) at non-singleton dimension 0


100%|██████████████████████████████████████████████████████████████████████████████████| 85/85 [01:14<00:00,  1.15it/s]
100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:04<00:00,  1.09it/s]


best model: clip_classifier_44_0.654.ckpt
ViT-B/32 224 1024 0.654 0.698512


100%|██████████████████████████████████████████████████████████████████████████████████| 85/85 [01:29<00:00,  1.06s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:05<00:00,  1.14s/it]


best model: clip_classifier_93_0.672.ckpt
ViT-B/16 224 1024 0.672 0.734128


100%|██████████████████████████████████████████████████████████████████████████████████| 85/85 [02:33<00:00,  1.80s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:09<00:00,  1.88s/it]


best model: clip_classifier_49_0.724.ckpt
ViT-L/14 224 1536 0.724 0.786672


100%|██████████████████████████████████████████████████████████████████████████████████| 85/85 [04:46<00:00,  3.36s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:17<00:00,  3.42s/it]


best model: clip_classifier_14_0.742.ckpt
ViT-L/14@336px 336 1536 0.742 0.7866240000000001


In [14]:
print(prettyTable)

+----------------+----------+--------------------+
|   Model name   | Accuracy |      ROC AUC       |
+----------------+----------+--------------------+
|      RN50      |  0.636   |      0.676096      |
|     RN101      |   0.68   | 0.7092319999999999 |
|    ViT-B/32    |  0.654   |      0.698512      |
|    ViT-B/16    |  0.672   |      0.734128      |
|    ViT-L/14    |  0.724   |      0.786672      |
| ViT-L/14@336px |  0.742   | 0.7866240000000001 |
+----------------+----------+--------------------+


In [15]:
results

{'RN50': (0.636, 0.676096),
 'RN101': (0.68, 0.7092319999999999),
 'ViT-B/32': (0.654, 0.698512),
 'ViT-B/16': (0.672, 0.734128),
 'ViT-L/14': (0.724, 0.786672),
 'ViT-L/14@336px': (0.742, 0.7866240000000001)}