In [None]:
!pip install transformers datasets

In [None]:
import torch
from torch import nn

from torch.utils.data import Dataset, DataLoader

import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

from transformers import ViTImageProcessor, ViTForImageClassification
from datasets import load_dataset

from tqdm.notebook import tqdm

torch.manual_seed(0);

In [None]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device

'cuda:0'

In [None]:
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
model.classifier = nn.Linear(768, 2)

Downloading (…)rocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

In [None]:
model

In [None]:
ds = load_dataset('cats_vs_dogs')

Downloading builder script:   0%|          | 0.00/3.33k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.94k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/8.06k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/825M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/23410 [00:00<?, ? examples/s]

In [None]:
indexes = list(range((len(ds['train']))))
train, test = train_test_split(indexes, test_size=0.2, random_state=0)

In [None]:
class CustomDataset(Dataset):
    def __init__(self, ids, dataset):
        self.ids = ids
        self.ds = dataset

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, index):
      image = self.ds['train'][index]['image']
      label = self.ds['train'][index]['labels']

      image = processor(
          image.convert("RGB"),
          return_tensors='pt'
          )

      image['pixel_values'] = image['pixel_values'].squeeze(0)

      return image, label

In [None]:
train_dataset = CustomDataset(
    ids=train,
    dataset=ds
)

val_dataset = CustomDataset(
    ids=test,
    dataset=ds
)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=60, shuffle=True, num_workers=2, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=60, shuffle=False, num_workers=2)

In [None]:
epochs = 5
model_lr = 1e-5

model_optimizer = torch.optim.AdamW(model.parameters(), model_lr)

criterion = nn.CrossEntropyLoss()

model.to(device)

In [None]:
for epoch in range(5):

    model.train()

    train_loss = []
    for i, (batch, targets) in enumerate(tqdm(train_loader, desc=f"Epoch: {epoch}")):
        model_optimizer.zero_grad()

        batch = batch.to(device)
        targets = targets.to(device)

        outputs = model(**batch)
        logits = outputs.logits

        loss = criterion(logits, targets)
        loss.backward()
        model_optimizer.step()

        train_loss.append(loss.item())

    print('Training loss:', np.mean(train_loss))

    model.eval()

    val_loss = []
    val_targets = []
    val_preds = []
    for i, (batch, targets) in enumerate(tqdm(val_loader, desc=f"Epoch: {epoch}")):
        with torch.no_grad():

            batch = batch.to(device)
            targets = targets.to(device)

            outputs = model(**batch)
            logits = outputs.logits

            loss = criterion(logits, targets)

            val_loss.append(loss.item())
            val_targets.extend(targets.cpu().numpy())
            val_preds.extend(logits.argmax(axis=1).cpu().numpy())

    print('Val loss:', np.mean(val_loss))
    print('F1:', f1_score(val_targets, val_preds, average='macro'))

    torch.save(model.state_dict(), f'ViT_{epoch+1}.pt')

In [None]:
val_loss = []
val_targets = []
val_preds = []

with torch.no_grad():
  for i, (batch, targets) in enumerate(tqdm(val_loader, desc=f"Epoch: {epoch}")):

      batch = batch.to(device)
      targets = targets.to(device)

      outputs = model(**batch)
      logits = outputs.logits
      val_targets.extend(targets.cpu().numpy())
      val_preds.extend(logits.argmax(axis=1).cpu().numpy())

print('F1:', f1_score(val_targets, val_preds, average='macro'))

Epoch: 2:   0%|          | 0/79 [00:00<?, ?it/s]

F1: 1.0
