In [1]:
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn

from transformers import ViTFeatureExtractor, ViTForImageClassification

import albumentations as A
from albumentations.pytorch import ToTensorV2

import random
import time
from datetime import datetime

import matplotlib
import matplotlib.pyplot as plt
import cv2
import os
import copy

import pandas as pd

import timm
from timm.optim import create_optimizer_v2
from timm.utils import CheckpointSaver
import torch.nn.functional as F

matplotlib.use('TkAgg')

In [2]:
random.seed(42)
num_classes = 36
using_dataset = f"food_ai_iccv"
# with open(f"./{using_dataset}/meta/train.txt", 'r') as f_train: 
#     correct_images_filepaths = [f"./{using_dataset}/images/{line[:-1]}.jpg" for line in f_train.readlines()]
#     #correct_images_filepaths = [i for i in images_filepaths if cv2.imread(i) is not None]
#     random.shuffle(correct_images_filepaths)
#     train_images_paths = correct_images_filepaths

# print(len(train_images_paths))

# with open(f"./{using_dataset}/meta/test.txt", 'r') as f_test: 
#     test_images_paths = [f"./{using_dataset}/images/{line[:-1]}.jpg" for line in f_test.readlines()]
#     #test_images_paths = [i for i in test_images_paths if cv2.imread(i) is not None]
#     #random.shuffle(test_images_paths)
#     #test_images_paths = test_images_paths[:500]

# print(len(test_images_paths))
# train_images_paths[:20]

In [3]:
data_transform_train = transforms.Compose(
            [transforms.Resize([256,256]),
             transforms.RandomCrop([224,224]),
             transforms.RandomHorizontalFlip(),
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225]),
             ])
data_transform_test = transforms.Compose(
            [transforms.Resize([224,224]),
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225]),
             ])

ds_train = datasets.ImageFolder(f"{using_dataset}/Train", transform=data_transform_train)
ds_test = datasets.ImageFolder(f"{using_dataset}/Test", transform=data_transform_test)


# label2id = ds_test.class_to_idx
# id2label = {identifier: label for label, identifier in label2id.items()}

# label2id
ds_test.class_to_idx

{'0': 0,
 '1': 1,
 '10': 2,
 '11': 3,
 '12': 4,
 '13': 5,
 '14': 6,
 '15': 7,
 '16': 8,
 '17': 9,
 '18': 10,
 '19': 11,
 '2': 12,
 '20': 13,
 '21': 14,
 '22': 15,
 '23': 16,
 '24': 17,
 '25': 18,
 '26': 19,
 '27': 20,
 '28': 21,
 '29': 22,
 '3': 23,
 '30': 24,
 '31': 25,
 '32': 26,
 '33': 27,
 '34': 28,
 '35': 29,
 '4': 30,
 '5': 31,
 '6': 32,
 '7': 33,
 '8': 34,
 '9': 35}

In [4]:
BATCH_SIZE = 16
dl_train = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True)
dl_test = DataLoader(ds_test, batch_size=BATCH_SIZE, shuffle=False)
dl_train

<torch.utils.data.dataloader.DataLoader at 0x7f309037c4f0>

In [5]:
def get_accuracy(predictions, labels):
    return (predictions.argmax(dim=1) == labels).float().mean()

def get_loss(predictions, labels):
    predictions = predictions.reshape(-1, predictions.shape[-1])
    #labels = labels.unsqueeze(1).expand(-1, 1).reshape(-1)
    return F.cross_entropy(predictions, labels)

In [6]:
model = timm.create_model('vit_base_patch16_224_in21k', pretrained=True, num_classes=num_classes)
opt = create_optimizer_v2(model, learning_rate=1e-5)

# PATH = f"./checkpoints/{using_dataset}/ViT/20211028-093107/model.pt"
# checkpoint = torch.load(PATH)
# model.load_state_dict(checkpoint['model_state_dict'])
# opt.load_state_dict(checkpoint['optimizer_state_dict'])
# last_epoch = checkpoint['epoch']
# loss = checkpoint['loss']
# acc = checkpoint['acc']


model.eval()

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (1): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

print(device)


cuda


In [8]:
df_val = pd.DataFrame()
val_labels = []
val_pred = []
losses, accs = [], []
with torch.no_grad():
    for images, labels in dl_test:
        images = images.to(device)
        labels = labels.to(device)
        pred = model(images)
        loss = get_loss(pred, labels)
        acc = get_accuracy(pred, labels)
        accs.append(acc * images.shape[0])            
        losses.append(loss * images.shape[0])
        val_labels.append(labels)
        val_pred.append(pred)

df_val[f"epoch_prova_pred"] = val_pred
df_val[f"epoch_prova_labels"] = val_labels
df_val.to_csv(f"epoch_prova_val.csv")
loss = torch.stack(losses).sum() / len(dl_test.dataset)
acc = torch.stack(accs).sum() / len(dl_test.dataset)

print(f'Epoch: {0+1:>2}    Loss: {loss.item():.3f}    Accuracy: {acc:.3f}')

Epoch:  1    Loss: 5.615    Accuracy: 0.015


In [None]:
EPOCHS = 20
df_val = pd.DataFrame()
df_train = pd.DataFrame()
output_dir = f"./checkpoints/{using_dataset}/ViT/{datetime.now().strftime('%Y%m%d-%H%M%S')}"
os.system(f"mkdir {output_dir}")

for epoch in range(EPOCHS):
    losses, accs = [], []
    model.train()
    train_labels = []
    train_pred = []
    for images, labels in dl_train:
        opt.zero_grad()
        images = images.to(device)
        labels = labels.to(device)
        pred = model(images)
        loss = get_loss(pred, labels)
        train_pred.append(pred)
        train_labels.append(labels)
        loss.backward()
        opt.step()
    df_train[f"epoch_{epoch}_pred"] = train_pred
    df_train[f"epoch_{epoch}_labels"] = train_labels
    df_train.to_csv(f"epoch_{epoch}_train.csv")
        
    model.eval()
    val_labels = []
    val_pred = []
    with torch.no_grad():
        for images, labels in dl_test:
            images = images.to(device)
            labels = labels.to(device)
            pred = model(images)
            loss = get_loss(pred, labels)
            acc = get_accuracy(pred, labels)
            accs.append(acc * images.shape[0])            
            losses.append(loss * images.shape[0])
            val_labels.append(labels)
            val_pred.append(pred)
            
    df_val[f"epoch_{epoch}_pred"] = val_pred
    df_val[f"epoch_{epoch}_labels"] = val_labels
    df_val.to_csv(f"epoch_{epoch}_val.csv")
    loss = torch.stack(losses).sum() / len(dl_test.dataset)
    acc = torch.stack(accs).sum() / len(dl_test.dataset)

    print(f'Epoch: {epoch+1:>2}    Loss: {loss.item():.3f}    Accuracy: {acc:.3f}')
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': opt.state_dict(),
        'loss': loss.item(),
        'acc': acc
    }, f"{output_dir}/model.pt")

Epoch:  1    Loss: 1.100    Accuracy: 0.734
Epoch:  2    Loss: 0.543    Accuracy: 0.863
Epoch:  3    Loss: 0.391    Accuracy: 0.896
Epoch:  4    Loss: 0.304    Accuracy: 0.919
Epoch:  5    Loss: 0.258    Accuracy: 0.930
Epoch:  6    Loss: 0.234    Accuracy: 0.933
Epoch:  7    Loss: 0.211    Accuracy: 0.940
Epoch:  8    Loss: 0.197    Accuracy: 0.941
Epoch:  9    Loss: 0.189    Accuracy: 0.945
Epoch: 10    Loss: 0.178    Accuracy: 0.945
Epoch: 11    Loss: 0.168    Accuracy: 0.950
Epoch: 12    Loss: 0.166    Accuracy: 0.951
Epoch: 13    Loss: 0.156    Accuracy: 0.950
Epoch: 14    Loss: 0.154    Accuracy: 0.956
Epoch: 15    Loss: 0.149    Accuracy: 0.955
Epoch: 16    Loss: 0.142    Accuracy: 0.957
Epoch: 17    Loss: 0.144    Accuracy: 0.959
Epoch: 18    Loss: 0.142    Accuracy: 0.958
Epoch: 19    Loss: 0.139    Accuracy: 0.958
