## Train the pretraining+classifier model

Train a classifier using the model that was trained using only the contrastive loss as base model

In [13]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
import torch.nn.functional as F

import matplotlib.pyplot as plt

import seaborn as sns
sns.set_theme(color_codes=True)
from PIL import Image
import os
import sys

import transformers
from transformers import AutoModelForImageClassification, AutoConfig, AutoFeatureExtractor
from transformers.utils import logging
from datasets import Features, Value
from transformers import DefaultDataCollator

logging.set_verbosity(transformers.logging.ERROR) 
logging.disable_progress_bar() 

p = os.path.abspath('../')
sys.path.insert(1, p)

from datasets import Dataset

from torchvision.io import read_image
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
import torchvision.transforms as transforms

import evaluate

from transformers import TrainingArguments, Trainer
from src.contrastive_transformers.losses import SupConLoss
from src.utils.utils import *
from src.transforms.transforms import Noise
from torchvision.io import ImageReadMode
import math
import copy

from collections import defaultdict

import random
import torchvision
from datasets import Image
from src.utils.utils import *
from src.wordnet_ontology.wordnet_ontology import WordnetOntology

import os
from datasets import load_dataset 

seed=2401
n_excluded_classes=int(556 * 0.05)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

%load_ext autoreload
%autoreload 2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

mapping_filename = './data/external/imagenet/LOC_synset_mapping.txt'
wn = WordnetOntology(mapping_filename)

## Preparing datasets

In [16]:
sketch = load_dataset("imagenet_sketch", split='train', cache_dir='./cache/')
vocab = torch.load('./models/vocab.pt')
NUM_CLASSES = len(vocab)

sketch = sketch.map(lambda x: {
    'label': vocab[wn.hypernym(wn.class_for_index[x['label']])],
})

Using custom data configuration default
Found cached dataset imagenet_sketch (/mnt/HDD/kevinds/sketch/./cache/imagenet_sketch/default/0.0.0/9bbda26372327ae1daa792112c8bbd2545a91b9f397ea6f285576add0a70ab6e)
Loading cached processed dataset at /mnt/HDD/kevinds/sketch/./cache/imagenet_sketch/default/0.0.0/9bbda26372327ae1daa792112c8bbd2545a91b9f397ea6f285576add0a70ab6e/cache-c5da55713e638949.arrow


In [17]:
imagenet_classes_folder = './data/external/imagenet/ILSVRC/Data/CLS-LOC/train'

image_labels = [] 
image_paths = []

N_IMAGENET_EXAMPLES = 32
imagenet_classes = sorted(os.listdir(imagenet_classes_folder))
for img_class in imagenet_classes:
    all_imgs = os.listdir(f"{imagenet_classes_folder}/{img_class}/")
    img_names = [random.choice(all_imgs) for _ in range(0, N_IMAGENET_EXAMPLES)]
                              
    image_paths.extend([f"{imagenet_classes_folder}/{img_class}/{name}" for name in img_names])
    image_labels.extend([img_class] * len(img_names))

In [18]:
_classes = list(set(sketch['label']))
excluded_classes = [random.choice(_classes) for i in range(n_excluded_classes)]
sketch = sketch.cast_column('image', Image(decode=False))

dt = train_test_split(sketch, excluded_labels=excluded_classes)
train_sketch, test_sketch = dt['train'], dt['test']
def get_image_path(row):
    return {'path': row['image']['path']}

train = train_sketch.map(get_image_path, remove_columns=['image'])
test = test_sketch.map(get_image_path, remove_columns=['image'])

  0%|          | 0/37961 [00:00<?, ?ex/s]

  0%|          | 0/12928 [00:00<?, ?ex/s]

In [19]:
train_data = pd.concat([
    pd.DataFrame({'image': train['path'], 'label': train['label']}), 
    pd.DataFrame({'image': image_paths, 'label': (vocab[wn.hypernym(lb)] for lb in image_labels)})
], axis=0).reset_index(drop=True)
train_dataset = Dataset.from_pandas(train_data)
test_dataset = test_sketch.map(get_image_path, remove_columns=['image']).rename_column('path', 'image')

Loading cached processed dataset at /mnt/HDD/kevinds/sketch/./cache/imagenet_sketch/default/0.0.0/9bbda26372327ae1daa792112c8bbd2545a91b9f397ea6f285576add0a70ab6e/cache-1935775775b71dbc.arrow


In [20]:
feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")

_transforms = torch.nn.Sequential(
    transforms.Resize((feature_extractor.size, feature_extractor.size)),
    transforms.ConvertImageDtype(torch.float),
    Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
)

def get_pixel_values(examples):
    paths = examples['image']
    images = [ _transforms(torchvision.io.read_image(path, ImageReadMode.RGB)) for path in paths]
    examples["pixel_values"] = images
    del examples["image"]
    return examples

train_dataset = train_dataset.with_transform(get_pixel_values)
test_dataset = test_dataset.with_transform(get_pixel_values)
collator = DefaultDataCollator()

## Train the model

In [21]:
accuracyk = evaluate.load("KevinSpaghetti/accuracyk")
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    
    k=5
    top1_pred = np.argmax(logits, axis=-1, keepdims=True)
    top5_pred = np.argpartition(logits, -k, axis=-1)[:, -k:]
    
    top1 = accuracyk.compute(predictions=top1_pred, references=labels)
    top5 = accuracyk.compute(predictions=top5_pred, references=labels)
    return {
        'top1': top1['accuracy'],
        'top5': top5['accuracy']
    }

cb = StoreLosses()

In [22]:
torch.hub.set_dir('./cache')
model = AutoModelForImageClassification.from_pretrained(
    f"./models/contrastive-pretraining-{seed}/last-checkpoint", 
    cache_dir='./cache/',
    num_labels=len(vocab),
    label2id=vocab.get_stoi(),
    id2label=dict(enumerate(vocab.get_itos()))
)
model.train()

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0): ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_

In [None]:
training_args = TrainingArguments(
    output_dir=f"./models/classifier-training-{seed}/",
    resume_from_checkpoint=True,
    load_best_model_at_end=True,
    save_total_limit=1,
    save_strategy='epoch',
    num_train_epochs=16,
    learning_rate=2e-4,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    gradient_accumulation_steps=4,
    warmup_steps=250,
    weight_decay=0.01,
    disable_tqdm=False,
    remove_unused_columns=False,
    evaluation_strategy='epoch',
    eval_steps=250,
    logging_steps=50,
    dataloader_num_workers=4,
    dataloader_pin_memory=True,
    fp16=True,
    fp16_opt_level='03',
    report_to="wandb",
    optim="adamw_torch"
)
           
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    data_collator=collator,
    compute_metrics=compute_metrics,
    tokenizer=feature_extractor,
    callbacks=[cb]
)

trainer.train()

Using cuda_amp half precision backend
***** Running training *****
  Num examples = 69961
  Num Epochs = 16
  Instantaneous batch size per device = 32
  Total train batch size (w. parallel, distributed & accumulation) = 128
  Gradient Accumulation steps = 4
  Total optimization steps = 8736
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


Epoch,Training Loss,Validation Loss,Top1,Top5
0,2.3794,2.184873,0.577661,0.86827
1,0.9585,1.226911,0.699876,0.916306
2,0.5692,1.108506,0.729425,0.921643
3,0.3035,1.078849,0.747215,0.92466
4,0.2135,1.123052,0.750928,0.918007
5,0.159,1.195078,0.751315,0.914449
6,0.0997,1.192004,0.754796,0.921566
7,0.0726,1.254992,0.753868,0.916692
8,0.0583,1.199379,0.765393,0.920328
9,0.0446,1.262574,0.766244,0.916615


***** Running Evaluation *****
  Num examples = 12928
  Batch size = 64
Saving model checkpoint to ./models/classifier-training-2401/checkpoint-546
Configuration saved in ./models/classifier-training-2401/checkpoint-546/config.json
Model weights saved in ./models/classifier-training-2401/checkpoint-546/pytorch_model.bin
Feature extractor saved in ./models/classifier-training-2401/checkpoint-546/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 12928
  Batch size = 64
Saving model checkpoint to ./models/classifier-training-2401/checkpoint-1092
Configuration saved in ./models/classifier-training-2401/checkpoint-1092/config.json
Model weights saved in ./models/classifier-training-2401/checkpoint-1092/pytorch_model.bin
Feature extractor saved in ./models/classifier-training-2401/checkpoint-1092/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 12928
  Batch size = 64
Saving model checkpoint to ./models/classifier-training-2401/checkpoint-1638
C

In [None]:
fig = plt.figure(figsize=(16, 4))
ax = plt.GridSpec(1, 3, figure=fig)

class_ax = plt.subplot(ax[0, 0])
contr_ax = plt.subplot(ax[0, 1])
test_ax = plt.subplot(ax[0, 2])

class_ax.set_title("train loss")
contr_ax.set_title("eval loss")
test_ax.set_title("eval accuracy") #classification loss on the test dataset

sns.lineplot(x=range(1, len(cb.train_loss) + 1), y=cb.train_loss, ax=contr_ax)
sns.lineplot(x=range(1, len(cb.eval_loss) + 1), y=cb.eval_loss, ax=class_ax)
sns.lineplot(x=range(1, len(cb.top1) + 1), y=cb.top1, ax=test_ax)

fig.show()