In [1]:
from __future__ import print_function

import glob
from itertools import chain
import os
import random
import zipfile

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from linformer import Linformer
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm

from vit_pytorch.efficient import ViT

In [2]:
print(f"Torch: {torch.__version__}")

Torch: 1.10.0+cu113


In [3]:
# Training settings
batch_size = 64
epochs = 30
lr = 3e-5
gamma = 0.7
seed = 42

In [4]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)

In [5]:
device = 'cuda'

Load data

In [6]:
train_dir = 'data/train'
test_dir = 'data/test'

In [None]:
with zipfile.ZipFile('train.zip') as train_zip:
    train_zip.extractall('data')
    
with zipfile.ZipFile('test.zip') as test_zip:
    test_zip.extractall('data')

In [7]:
train_list = glob.glob(os.path.join(train_dir,'*.jpg'))
test_list = glob.glob(os.path.join(test_dir, '*.jpg'))

In [8]:
print(f"Train Data: {len(train_list)}")
print(f"Test Data: {len(test_list)}")

Train Data: 25000
Test Data: 12500


In [9]:
train_list, valid_list = train_test_split(train_list, 
                                          test_size=0.2,
                                          random_state=seed)

In [10]:
print(f"Train Data: {len(train_list)}")
print(f"Validation Data: {len(valid_list)}")
print(f"Test Data: {len(test_list)}")

Train Data: 20000
Validation Data: 5000
Test Data: 12500


Image Augumentation

In [11]:
train_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]
)

val_transforms = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ]
)


test_transforms = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ]
)

Load Datasets

In [12]:
class CatsDogsDataset(Dataset):
    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform

    def __len__(self):
        self.filelength = len(self.file_list)
        return self.filelength

    def __getitem__(self, idx):
        img_path = self.file_list[idx]
        img = Image.open(img_path)
        img_transformed = self.transform(img)

        return img_transformed 

In [13]:
train_data = CatsDogsDataset(train_list, transform=train_transforms)
valid_data = CatsDogsDataset(valid_list, transform=test_transforms)
test_data = CatsDogsDataset(test_list, transform=test_transforms)

In [14]:
train_loader = DataLoader(dataset = train_data, batch_size=batch_size, shuffle=True )
valid_loader = DataLoader(dataset = valid_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset = test_data, batch_size=batch_size, shuffle=True)

In [15]:
print(len(train_data), len(train_loader))

20000 313


In [16]:
print(len(valid_data), len(valid_loader))

5000 79


Visual transformer

In [17]:
import torch
from vit_pytorch import ViT, MAE

model = ViT(
    image_size = 224,
    patch_size = 32,
    num_classes = 2,
    dim = 128,
    depth = 6,
    heads = 8,
    mlp_dim = 256,
    channels=3,
)

mae = MAE(
    encoder = model,
    masking_ratio = 0.75,   # the paper recommended 75% masked patches
    decoder_dim = 512,      # paper showed good results with just 512
    decoder_depth = 6       # anywhere from 1 to 8
).to(device)

In [18]:
for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0

    for data in tqdm(train_loader):
        data = data.to(device)
        
        loss = mae(data)
        loss.backward()
        
        epoch_loss += loss / len(train_loader)

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data in valid_loader:
            data = data.to(device)

            val_loss = mae(data)
            
            epoch_val_loss += val_loss / len(valid_loader)

    print(
        f"Epoch : {epoch+1} - train_loss : {epoch_loss:.4f} - val_loss : {epoch_val_loss:.4f} \n"
    )

  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 1 - train_loss : 1.0774 - val_loss : 1.0772 



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 2 - train_loss : 1.0775 - val_loss : 1.0778 



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 3 - train_loss : 1.0773 - val_loss : 1.0775 



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 4 - train_loss : 1.0776 - val_loss : 1.0785 



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 5 - train_loss : 1.0769 - val_loss : 1.0787 



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 6 - train_loss : 1.0776 - val_loss : 1.0777 



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 7 - train_loss : 1.0775 - val_loss : 1.0777 



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 8 - train_loss : 1.0777 - val_loss : 1.0780 



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 9 - train_loss : 1.0773 - val_loss : 1.0783 



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 10 - train_loss : 1.0776 - val_loss : 1.0776 



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 11 - train_loss : 1.0789 - val_loss : 1.0778 



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 12 - train_loss : 1.0782 - val_loss : 1.0780 



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 13 - train_loss : 1.0785 - val_loss : 1.0772 



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 14 - train_loss : 1.0776 - val_loss : 1.0772 



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 15 - train_loss : 1.0773 - val_loss : 1.0770 



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 16 - train_loss : 1.0769 - val_loss : 1.0775 



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 17 - train_loss : 1.0777 - val_loss : 1.0773 



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 18 - train_loss : 1.0788 - val_loss : 1.0773 



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 19 - train_loss : 1.0771 - val_loss : 1.0769 



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 20 - train_loss : 1.0778 - val_loss : 1.0778 



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 21 - train_loss : 1.0777 - val_loss : 1.0777 



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 22 - train_loss : 1.0783 - val_loss : 1.0773 



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 23 - train_loss : 1.0768 - val_loss : 1.0768 



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 24 - train_loss : 1.0774 - val_loss : 1.0777 



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 25 - train_loss : 1.0775 - val_loss : 1.0777 



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 26 - train_loss : 1.0768 - val_loss : 1.0773 



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 27 - train_loss : 1.0776 - val_loss : 1.0780 



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 28 - train_loss : 1.0781 - val_loss : 1.0773 



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 29 - train_loss : 1.0778 - val_loss : 1.0770 



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 30 - train_loss : 1.0778 - val_loss : 1.0787 

