In [1]:
import os
import copy
from PIL import Image
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torch.utils import data
from torch.utils.data import Dataset, ConcatDataset
from torch.utils.data import DataLoader
import torchvision.models as models
import torch.optim as optim
from PIL import Image
import numpy as np
from torchvision.datasets import ImageFolder

from sklearn.metrics import accuracy_score



In [2]:
augment_transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Grayscale(num_output_channels=1),
    transforms.Resize((224, 224), antialias=True),
    transforms.RandomHorizontalFlip(),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.2), antialias=True),
    transforms.RandomApply([transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5)], p=0.5),
    transforms.RandomApply([transforms.RandomAffine(0, translate=(0.2, 0.2))], p=0.5),
    transforms.RandomApply([transforms.RandomRotation(10)], p=0.5),
    transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))
])

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224), antialias=True),
    # transforms.Grayscale(num_output_channels=1)
    transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))
])

In [3]:
path = "/kaggle/input/fer2013"
train_dataset = ImageFolder(path+"/train", transform=transform)
augment_dataset = ImageFolder(path+"/train", transform=augment_transform)
train_dataset = ConcatDataset([train_dataset, augment_dataset])
test_dataset = ImageFolder(path+"/test", transform=transform)

In [4]:
print(len(train_dataset))
print(len(test_dataset))
print(test_dataset[2][0].size())

57418
7178
torch.Size([3, 224, 224])


In [5]:
VALID_RATIO = 0.9
n_train_examples = int(len(train_dataset) * VALID_RATIO)
n_valid_examples = len(train_dataset) - n_train_examples
train_data, valid_data = data.random_split(train_dataset, [n_train_examples, n_valid_examples])
valid_data = copy.deepcopy(valid_data)
valid_data.dataset.transform = transform

In [6]:
def collate_fn(dataset):
  images = []
  labels = []
  for image, label in dataset:
      images.append(image)
      labels.append(label)

  pixel_values = torch.stack(images)
  labels = torch.tensor(labels)
  return {"pixel_values": pixel_values, "labels": labels}

Batch_Size = 16
train_loader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=Batch_Size, shuffle=True, num_workers=2)
valid_loader = DataLoader(valid_data, collate_fn=collate_fn, batch_size=Batch_Size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, collate_fn=collate_fn, batch_size=Batch_Size, shuffle=False, num_workers=2)

In [7]:
device = device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [8]:
import gc
torch.cuda.empty_cache()
gc.collect()

0

In [9]:
from transformers import ViTConfig, ViTForImageClassification

In [10]:
configuration = ViTConfig()
configuration.num_labels = 7
configuration.return_dict = False

In [11]:
model = ViTForImageClassification(configuration).from_pretrained('google/vit-base-patch16-224')
model.num_labels = 7
model.name = 'ViT'
model.classifier = nn.Linear(configuration.hidden_size, configuration.num_labels)
vit = model.to(device)

for param in vit.vit.parameters():
    param.requires_grad = False
for param in vit.classifier.parameters():
    param.requires_grad = True
for param in vit.vit.encoder.layer[-1].parameters():
    param.requires_grad = True
for param in vit.vit.layernorm.parameters():
    param.requires_grad = True

config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

In [12]:
for name, param in vit.named_parameters():
     print(name, param.requires_grad)

vit.embeddings.cls_token False
vit.embeddings.position_embeddings False
vit.embeddings.patch_embeddings.projection.weight False
vit.embeddings.patch_embeddings.projection.bias False
vit.encoder.layer.0.attention.attention.query.weight False
vit.encoder.layer.0.attention.attention.query.bias False
vit.encoder.layer.0.attention.attention.key.weight False
vit.encoder.layer.0.attention.attention.key.bias False
vit.encoder.layer.0.attention.attention.value.weight False
vit.encoder.layer.0.attention.attention.value.bias False
vit.encoder.layer.0.attention.output.dense.weight False
vit.encoder.layer.0.attention.output.dense.bias False
vit.encoder.layer.0.intermediate.dense.weight False
vit.encoder.layer.0.intermediate.dense.bias False
vit.encoder.layer.0.output.dense.weight False
vit.encoder.layer.0.output.dense.bias False
vit.encoder.layer.0.layernorm_before.weight False
vit.encoder.layer.0.layernorm_before.bias False
vit.encoder.layer.0.layernorm_after.weight False
vit.encoder.layer.0.layer

In [13]:
model_parameters = filter(lambda p: p.requires_grad, vit.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
params

7094791

In [14]:
optimizer = optim.Adam(vit.parameters(), lr = 1e-3)
num_epochs = 30
num_training_steps = num_epochs * len(train_loader)

In [15]:
from tqdm import tqdm
def vit_test_runner(model, dataloader):
  model.eval()

  correct = 0
  total_loss = 0
  total = 0
  for batch in tqdm(dataloader):
      batch = {k: v.to(device) for k, v in batch.items()}
      with torch.no_grad():
          outputs = model(**batch)

      total_loss += outputs[0]
      logits = outputs[1]
      predictions = torch.argmax(logits, dim=-1)
      total += batch['labels'].size(0)
      correct += (predictions == batch['labels']).sum().item()
  acc = 100 * correct / total
  total_loss = total_loss / total

  return acc, total_loss

In [16]:
def vit_train_runner(model):
  model.train()

  train_loss = 0
  train_correct = 0
  total = 0
  for epoch in range(num_epochs):
    for batch in tqdm(train_loader):
      batch = {k: v.to(device) for k, v in batch.items()}
      outputs = model(**batch)
      loss = outputs[0]
      loss.backward()

      train_loss += loss
      logits = outputs[1]
      predictions = torch.argmax(logits, dim=-1)
      total += batch['labels'].size(0)
      train_correct += (predictions == batch['labels']).sum().item()

      optimizer.step()
      optimizer.zero_grad()
      # torch.cuda.empty_cache()

    train_loss = train_loss / total
    train_acc = 100 * train_correct / total
    print("Epoch {}".format(epoch+1))
    print("Train loss: {}.\t Train accuracy: {:.2f}%".format(train_loss, train_acc))

    valid_acc, valid_loss = vit_test_runner(model, valid_loader)
    print("Valid loss: {}.\t Valid accuracy: {:.2f}%".format(valid_loss, valid_acc))

In [17]:
vit_train_runner(vit)

100%|██████████| 3589/3589 [07:17<00:00,  8.20it/s]


Epoch 1
Train loss: 0.06989163905382156.	 Train accuracy: 57.65%


100%|██████████| 359/359 [00:37<00:00,  9.54it/s]


Valid loss: 0.062485065311193466.	 Valid accuracy: 62.17%


100%|██████████| 3589/3589 [07:12<00:00,  8.30it/s]


Epoch 2
Train loss: 0.03125345706939697.	 Train accuracy: 60.15%


100%|██████████| 359/359 [00:37<00:00,  9.52it/s]


Valid loss: 0.05860493704676628.	 Valid accuracy: 65.19%


100%|██████████| 3589/3589 [07:13<00:00,  8.28it/s]


Epoch 3
Train loss: 0.019925028085708618.	 Train accuracy: 61.47%


100%|██████████| 359/359 [00:37<00:00,  9.51it/s]


Valid loss: 0.056054625660181046.	 Valid accuracy: 67.03%


100%|██████████| 3589/3589 [07:13<00:00,  8.28it/s]


Epoch 4
Train loss: 0.014259294606745243.	 Train accuracy: 62.54%


100%|██████████| 359/359 [00:37<00:00,  9.52it/s]


Valid loss: 0.05208014324307442.	 Valid accuracy: 68.13%


100%|██████████| 3589/3589 [07:13<00:00,  8.29it/s]


Epoch 5
Train loss: 0.010978199541568756.	 Train accuracy: 63.44%


100%|██████████| 359/359 [00:37<00:00,  9.50it/s]


Valid loss: 0.052785128355026245.	 Valid accuracy: 66.96%


100%|██████████| 3589/3589 [07:12<00:00,  8.30it/s]


Epoch 6
Train loss: 0.008921094238758087.	 Train accuracy: 64.17%


100%|██████████| 359/359 [00:37<00:00,  9.54it/s]


Valid loss: 0.049988146871328354.	 Valid accuracy: 70.20%


100%|██████████| 3589/3589 [07:12<00:00,  8.30it/s]


Epoch 7
Train loss: 0.007426263298839331.	 Train accuracy: 64.82%


100%|██████████| 359/359 [00:37<00:00,  9.52it/s]


Valid loss: 0.04804345220327377.	 Valid accuracy: 71.65%


100%|██████████| 3589/3589 [07:12<00:00,  8.29it/s]


Epoch 8
Train loss: 0.00628794077783823.	 Train accuracy: 65.44%


100%|██████████| 359/359 [00:37<00:00,  9.52it/s]


Valid loss: 0.04845641925930977.	 Valid accuracy: 70.83%


100%|██████████| 3589/3589 [07:12<00:00,  8.29it/s]


Epoch 9
Train loss: 0.0054837968200445175.	 Train accuracy: 66.00%


100%|██████████| 359/359 [00:37<00:00,  9.53it/s]


Valid loss: 0.04468591883778572.	 Valid accuracy: 73.81%


100%|██████████| 3589/3589 [07:12<00:00,  8.29it/s]


Epoch 10
Train loss: 0.0047775451093912125.	 Train accuracy: 66.53%


100%|██████████| 359/359 [00:37<00:00,  9.52it/s]


Valid loss: 0.04341894015669823.	 Valid accuracy: 73.82%


100%|██████████| 3589/3589 [07:12<00:00,  8.30it/s]


Epoch 11
Train loss: 0.004191854037344456.	 Train accuracy: 67.05%


100%|██████████| 359/359 [00:37<00:00,  9.52it/s]


Valid loss: 0.043107807636260986.	 Valid accuracy: 74.66%


100%|██████████| 3589/3589 [07:12<00:00,  8.29it/s]


Epoch 12
Train loss: 0.003760009305551648.	 Train accuracy: 67.52%


100%|██████████| 359/359 [00:37<00:00,  9.49it/s]


Valid loss: 0.04273446276783943.	 Valid accuracy: 74.07%


100%|██████████| 3589/3589 [07:12<00:00,  8.30it/s]


Epoch 13
Train loss: 0.0033946381881833076.	 Train accuracy: 67.99%


100%|██████████| 359/359 [00:37<00:00,  9.51it/s]


Valid loss: 0.04079597070813179.	 Valid accuracy: 75.64%


100%|██████████| 3589/3589 [07:12<00:00,  8.29it/s]


Epoch 14
Train loss: 0.003081199247390032.	 Train accuracy: 68.43%


100%|██████████| 359/359 [00:37<00:00,  9.51it/s]


Valid loss: 0.038773417472839355.	 Valid accuracy: 77.31%


100%|██████████| 3589/3589 [07:13<00:00,  8.28it/s]


Epoch 15
Train loss: 0.0027985998895019293.	 Train accuracy: 68.87%


100%|██████████| 359/359 [00:37<00:00,  9.47it/s]


Valid loss: 0.03897710517048836.	 Valid accuracy: 76.44%


100%|██████████| 3589/3589 [07:13<00:00,  8.28it/s]


Epoch 16
Train loss: 0.002562866313382983.	 Train accuracy: 69.28%


100%|██████████| 359/359 [00:37<00:00,  9.49it/s]


Valid loss: 0.03767926245927811.	 Valid accuracy: 77.85%


100%|██████████| 3589/3589 [07:13<00:00,  8.27it/s]


Epoch 17
Train loss: 0.0023588459007441998.	 Train accuracy: 69.66%


100%|██████████| 359/359 [00:37<00:00,  9.46it/s]


Valid loss: 0.03550563380122185.	 Valid accuracy: 78.47%


100%|██████████| 3589/3589 [07:13<00:00,  8.28it/s]


Epoch 18
Train loss: 0.002169517334550619.	 Train accuracy: 70.05%


100%|██████████| 359/359 [00:37<00:00,  9.50it/s]


Valid loss: 0.03544686734676361.	 Valid accuracy: 78.98%


100%|██████████| 3589/3589 [07:13<00:00,  8.29it/s]


Epoch 19
Train loss: 0.002012192504480481.	 Train accuracy: 70.43%


100%|██████████| 359/359 [00:37<00:00,  9.48it/s]


Valid loss: 0.0349857360124588.	 Valid accuracy: 79.26%


100%|██████████| 3589/3589 [07:12<00:00,  8.29it/s]


Epoch 20
Train loss: 0.0018751014722511172.	 Train accuracy: 70.78%


100%|██████████| 359/359 [00:37<00:00,  9.50it/s]


Valid loss: 0.03471268340945244.	 Valid accuracy: 79.15%


100%|██████████| 3589/3589 [07:12<00:00,  8.29it/s]


Epoch 21
Train loss: 0.001760327722877264.	 Train accuracy: 71.12%


100%|██████████| 359/359 [00:37<00:00,  9.48it/s]


Valid loss: 0.03480345383286476.	 Valid accuracy: 79.36%


100%|██████████| 3589/3589 [07:13<00:00,  8.29it/s]


Epoch 22
Train loss: 0.0016469225520268083.	 Train accuracy: 71.44%


100%|██████████| 359/359 [00:37<00:00,  9.48it/s]


Valid loss: 0.03360901400446892.	 Valid accuracy: 79.29%


100%|██████████| 3589/3589 [07:13<00:00,  8.28it/s]


Epoch 23
Train loss: 0.001549309235997498.	 Train accuracy: 71.76%


100%|██████████| 359/359 [00:37<00:00,  9.48it/s]


Valid loss: 0.03314951807260513.	 Valid accuracy: 80.37%


100%|██████████| 3589/3589 [07:13<00:00,  8.28it/s]


Epoch 24
Train loss: 0.0014536939561367035.	 Train accuracy: 72.06%


100%|██████████| 359/359 [00:37<00:00,  9.49it/s]


Valid loss: 0.03366401419043541.	 Valid accuracy: 80.08%


100%|██████████| 3589/3589 [07:13<00:00,  8.28it/s]


Epoch 25
Train loss: 0.0013570917071774602.	 Train accuracy: 72.36%


100%|██████████| 359/359 [00:37<00:00,  9.47it/s]


Valid loss: 0.030806545168161392.	 Valid accuracy: 81.77%


100%|██████████| 3589/3589 [07:13<00:00,  8.29it/s]


Epoch 26
Train loss: 0.0012955020647495985.	 Train accuracy: 72.65%


100%|██████████| 359/359 [00:37<00:00,  9.47it/s]


Valid loss: 0.03144384175539017.	 Valid accuracy: 81.82%


100%|██████████| 3589/3589 [07:13<00:00,  8.28it/s]


Epoch 27
Train loss: 0.001230795169249177.	 Train accuracy: 72.93%


100%|██████████| 359/359 [00:37<00:00,  9.48it/s]


Valid loss: 0.029653439298272133.	 Valid accuracy: 82.60%


100%|██████████| 3589/3589 [07:13<00:00,  8.28it/s]


Epoch 28
Train loss: 0.0011564666638150811.	 Train accuracy: 73.21%


100%|██████████| 359/359 [00:37<00:00,  9.47it/s]


Valid loss: 0.03092867136001587.	 Valid accuracy: 81.63%


100%|██████████| 3589/3589 [07:13<00:00,  8.29it/s]


Epoch 29
Train loss: 0.00109088362660259.	 Train accuracy: 73.48%


100%|██████████| 359/359 [00:37<00:00,  9.46it/s]


Valid loss: 0.02922353334724903.	 Valid accuracy: 83.04%


100%|██████████| 3589/3589 [07:13<00:00,  8.28it/s]


Epoch 30
Train loss: 0.0010456532472744584.	 Train accuracy: 73.75%


100%|██████████| 359/359 [00:37<00:00,  9.47it/s]


Valid loss: 0.02834797278046608.	 Valid accuracy: 82.97%


In [18]:
vit_test_runner(vit, test_loader)

100%|██████████| 449/449 [00:47<00:00,  9.46it/s]


(67.13569239342435, tensor(0.0892, device='cuda:0'))

In [19]:
torch.save(vit, '/kaggle/working/vit-ckpt.pt')